Source code for mmaction.models.common.tam
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]class TAM(nn.Module):
"""Temporal Adaptive Module(TAM) for TANet.
This module is proposed in `TAM: TEMPORAL ADAPTIVE MODULE FOR VIDEO
RECOGNITION <https://arxiv.org/pdf/2005.06803>`_
Args:
in_channels (int): Channel num of input features.
num_segments (int): Number of frame segments.
alpha (int): ``alpha`` in the paper and is the ratio of the
intermediate channel number to the initial channel number in the
global branch. Defaults to 2.
adaptive_kernel_size (int): ``K`` in the paper and is the size of the
adaptive kernel size in the global branch. Defaults to 3.
beta (int): ``beta`` in the paper and is set to control the model
complexity in the local branch. Defaults to 4.
conv1d_kernel_size (int): Size of the convolution kernel of Conv1d in
the local branch. Defaults to 3.
adaptive_convolution_stride (int): The first dimension of strides in
the adaptive convolution of ``Temporal Adaptive Aggregation``.
Defaults to 1.
adaptive_convolution_padding (int): The first dimension of paddings in
the adaptive convolution of ``Temporal Adaptive Aggregation``.
Defaults to 1.
init_std (float): Std value for initiation of `nn.Linear`. Defaults to
0.001.
"""
def __init__(self,
in_channels: int,
num_segments: int,
alpha: int = 2,
adaptive_kernel_size: int = 3,
beta: int = 4,
conv1d_kernel_size: int = 3,
adaptive_convolution_stride: int = 1,
adaptive_convolution_padding: int = 1,
init_std: float = 0.001) -> None:
super().__init__()
assert beta > 0 and alpha > 0
self.in_channels = in_channels
self.num_segments = num_segments
self.alpha = alpha
self.adaptive_kernel_size = adaptive_kernel_size
self.beta = beta
self.conv1d_kernel_size = conv1d_kernel_size
self.adaptive_convolution_stride = adaptive_convolution_stride
self.adaptive_convolution_padding = adaptive_convolution_padding
self.init_std = init_std
self.G = nn.Sequential(
nn.Linear(num_segments, num_segments * alpha, bias=False),
nn.BatchNorm1d(num_segments * alpha), nn.ReLU(inplace=True),
nn.Linear(num_segments * alpha, adaptive_kernel_size, bias=False),
nn.Softmax(-1))
self.L = nn.Sequential(
nn.Conv1d(
in_channels,
in_channels // beta,
conv1d_kernel_size,
stride=1,
padding=conv1d_kernel_size // 2,
bias=False), nn.BatchNorm1d(in_channels // beta),
nn.ReLU(inplace=True),
nn.Conv1d(in_channels // beta, in_channels, 1, bias=False),
nn.Sigmoid())
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Defines the computation performed at every call.
Args:
x (torch.Tensor): The input data.
Returns:
torch.Tensor: The output of the module.
"""
# [n, c, h, w]
n, c, h, w = x.size()
num_segments = self.num_segments
num_batches = n // num_segments
assert c == self.in_channels
# [num_batches, c, num_segments, h, w]
x = x.view(num_batches, num_segments, c, h, w)
x = x.permute(0, 2, 1, 3, 4).contiguous()
# [num_batches * c, num_segments, 1, 1]
theta_out = F.adaptive_avg_pool2d(
x.view(-1, num_segments, h, w), (1, 1))
# [num_batches * c, 1, adaptive_kernel_size, 1]
conv_kernel = self.G(theta_out.view(-1, num_segments)).view(
num_batches * c, 1, -1, 1)
# [num_batches, c, num_segments, 1, 1]
local_activation = self.L(theta_out.view(-1, c, num_segments)).view(
num_batches, c, num_segments, 1, 1)
# [num_batches, c, num_segments, h, w]
new_x = x * local_activation
# [1, num_batches * c, num_segments, h * w]
y = F.conv2d(
new_x.view(1, num_batches * c, num_segments, h * w),
conv_kernel,
bias=None,
stride=(self.adaptive_convolution_stride, 1),
padding=(self.adaptive_convolution_padding, 0),
groups=num_batches * c)
# [n, c, h, w]
y = y.view(num_batches, c, num_segments, h, w)
y = y.permute(0, 2, 1, 3, 4).contiguous().view(n, c, h, w)
return y