Source code for mmaction.models.common.conv_audio
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model.weight_init import constant_init, kaiming_init
from torch.nn.modules.utils import _pair
from mmaction.registry import MODELS
[docs]@MODELS.register_module()
class ConvAudio(nn.Module):
"""Conv2d module for AudioResNet backbone.
<https://arxiv.org/abs/2001.08740>`_.
Args:
in_channels (int): Same as ``nn.Conv2d``.
out_channels (int): Same as ``nn.Conv2d``.
kernel_size (Union[int, Tuple[int]]): Same as ``nn.Conv2d``.
op (str): Operation to merge the output of freq
and time feature map. Choices are ``sum`` and ``concat``.
Defaults to ``concat``.
stride (Union[int, Tuple[int]]): Same as ``nn.Conv2d``. Defaults to 1.
padding (Union[int, Tuple[int]]): Same as ``nn.Conv2d``. Defaults to 0.
dilation (Union[int, Tuple[int]]): Same as ``nn.Conv2d``.
Defaults to 1.
groups (int): Same as ``nn.Conv2d``. Defaults to 1.
bias (Union[bool, str]): If specified as ``auto``, it will be decided
by the ``norm_cfg``. Bias will be set as True if ``norm_cfg``
is None, otherwise False. Defaults to False.
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]],
op: str = 'concat',
stride: Union[int, Tuple[int]] = 1,
padding: Union[int, Tuple[int]] = 0,
dilation: Union[int, Tuple[int]] = 1,
groups: int = 1,
bias: Union[bool, str] = False) -> None:
super().__init__()
kernel_size = _pair(kernel_size)
stride = _pair(stride)
padding = _pair(padding)
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
assert op in ['concat', 'sum']
self.op = op
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.bias = bias
self.output_padding = (0, 0)
self.transposed = False
self.conv_1 = ConvModule(
in_channels,
out_channels,
kernel_size=(kernel_size[0], 1),
stride=stride,
padding=(kernel_size[0] // 2, 0),
bias=bias,
conv_cfg=dict(type='Conv'),
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'))
self.conv_2 = ConvModule(
in_channels,
out_channels,
kernel_size=(1, kernel_size[1]),
stride=stride,
padding=(0, kernel_size[1] // 2),
bias=bias,
conv_cfg=dict(type='Conv'),
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'))
self.init_weights()
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Defines the computation performed at every call.
Args:
x (torch.Tensor): The input data.
Returns:
torch.Tensor: The output of the module.
"""
x_1 = self.conv_1(x)
x_2 = self.conv_2(x)
if self.op == 'concat':
out = torch.cat([x_1, x_2], 1)
else:
out = x_1 + x_2
return out
[docs] def init_weights(self) -> None:
"""Initiate the parameters from scratch."""
kaiming_init(self.conv_1.conv)
kaiming_init(self.conv_2.conv)
constant_init(self.conv_1.bn, 1, bias=0)
constant_init(self.conv_2.bn, 1, bias=0)