Source code for mmaction.models.necks.tpn
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model.weight_init import constant_init, normal_init, xavier_init
from mmaction.registry import MODELS
from mmaction.utils import ConfigType, OptConfigType, SampleList
class DownSample(nn.Module):
"""DownSample modules.
It uses convolution and maxpooling to downsample the input feature,
and specifies downsample position to determine `pool-conv` or `conv-pool`.
Args:
in_channels (int): Channel number of input features.
out_channels (int): Channel number of output feature.
kernel_size (int or Tuple[int]): Same as :class:`ConvModule`.
Defaults to ``(3, 1, 1)``.
stride (int or Tuple[int]): Same as :class:`ConvModule`.
Defaults to ``(1, 1, 1)``.
padding (int or Tuple[int]): Same as :class:`ConvModule`.
Defaults to ``(1, 0, 0)``.
groups (int): Same as :class:`ConvModule`. Defaults to 1.
bias (bool or str): Same as :class:`ConvModule`. Defaults to False.
conv_cfg (dict or ConfigDict): Same as :class:`ConvModule`.
Defaults to ``dict(type='Conv3d')``.
norm_cfg (dict or ConfigDict, optional): Same as :class:`ConvModule`.
Defaults to None.
act_cfg (dict or ConfigDict, optional): Same as :class:`ConvModule`.
Defaults to None.
downsample_position (str): Type of downsample position. Options are
``before`` and ``after``. Defaults to ``after``.
downsample_scale (int or Tuple[int]): downsample scale for maxpooling.
It will be used for kernel size and stride of maxpooling.
Defaults to ``(1, 2, 2)``.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int]] = (3, 1, 1),
stride: Union[int, Tuple[int]] = (1, 1, 1),
padding: Union[int, Tuple[int]] = (1, 0, 0),
groups: int = 1,
bias: Union[bool, str] = False,
conv_cfg: ConfigType = dict(type='Conv3d'),
norm_cfg: OptConfigType = None,
act_cfg: OptConfigType = None,
downsample_position: str = 'after',
downsample_scale: Union[int, Tuple[int]] = (1, 2, 2)
) -> None:
super().__init__()
self.conv = ConvModule(
in_channels,
out_channels,
kernel_size,
stride,
padding,
groups=groups,
bias=bias,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
assert downsample_position in ['before', 'after']
self.downsample_position = downsample_position
self.pool = nn.MaxPool3d(
downsample_scale, downsample_scale, (0, 0, 0), ceil_mode=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Defines the computation performed at every call."""
if self.downsample_position == 'before':
x = self.pool(x)
x = self.conv(x)
else:
x = self.conv(x)
x = self.pool(x)
return x
class LevelFusion(nn.Module):
"""Level Fusion module.
This module is used to aggregate the hierarchical features dynamic in
visual tempos and consistent in spatial semantics. The top/bottom features
for top-down/bottom-up flow would be combined to achieve two additional
options, namely 'Cascade Flow' or 'Parallel Flow'. While applying a
bottom-up flow after a top-down flow will lead to the cascade flow,
applying them simultaneously will result in the parallel flow.
Args:
in_channels (Tuple[int]): Channel numbers of input features tuple.
mid_channels (Tuple[int]): Channel numbers of middle features tuple.
out_channels (int): Channel numbers of output features.
downsample_scales (Tuple[int | Tuple[int]]): downsample scales for
each :class:`DownSample` module.
Defaults to ``((1, 1, 1), (1, 1, 1))``.
"""
def __init__(
self,
in_channels: Tuple[int],
mid_channels: Tuple[int],
out_channels: int,
downsample_scales: Tuple[int, Tuple[int]] = ((1, 1, 1), (1, 1, 1))
) -> None:
super().__init__()
num_stages = len(in_channels)
self.downsamples = nn.ModuleList()
for i in range(num_stages):
downsample = DownSample(
in_channels[i],
mid_channels[i],
kernel_size=(1, 1, 1),
stride=(1, 1, 1),
bias=False,
padding=(0, 0, 0),
groups=32,
norm_cfg=dict(type='BN3d', requires_grad=True),
act_cfg=dict(type='ReLU', inplace=True),
downsample_position='before',
downsample_scale=downsample_scales[i])
self.downsamples.append(downsample)
self.fusion_conv = ConvModule(
sum(mid_channels),
out_channels,
1,
stride=1,
padding=0,
bias=False,
conv_cfg=dict(type='Conv3d'),
norm_cfg=dict(type='BN3d', requires_grad=True),
act_cfg=dict(type='ReLU', inplace=True))
def forward(self, x: Tuple[torch.Tensor]) -> torch.Tensor:
"""Defines the computation performed at every call."""
out = [self.downsamples[i](feature) for i, feature in enumerate(x)]
out = torch.cat(out, 1)
out = self.fusion_conv(out)
return out
class SpatialModulation(nn.Module):
"""Spatial Semantic Modulation.
This module is used to align spatial semantics of features in the
multi-depth pyramid. For each but the top-level feature, a stack
of convolutions with level-specific stride are applied to it, matching
its spatial shape and receptive field with the top one.
Args:
in_channels (Tuple[int]): Channel numbers of input features tuple.
out_channels (int): Channel numbers of output features tuple.
"""
def __init__(self, in_channels: Tuple[int], out_channels: int) -> None:
super().__init__()
self.spatial_modulation = nn.ModuleList()
for channel in in_channels:
downsample_scale = out_channels // channel
downsample_factor = int(np.log2(downsample_scale))
op = nn.ModuleList()
if downsample_factor < 1:
op = nn.Identity()
else:
for factor in range(downsample_factor):
in_factor = 2**factor
out_factor = 2**(factor + 1)
op.append(
ConvModule(
channel * in_factor,
channel * out_factor, (1, 3, 3),
stride=(1, 2, 2),
padding=(0, 1, 1),
bias=False,
conv_cfg=dict(type='Conv3d'),
norm_cfg=dict(type='BN3d', requires_grad=True),
act_cfg=dict(type='ReLU', inplace=True)))
self.spatial_modulation.append(op)
def forward(self, x: Tuple[torch.Tensor]) -> list:
"""Defines the computation performed at every call."""
out = []
for i, _ in enumerate(x):
if isinstance(self.spatial_modulation[i], nn.ModuleList):
out_ = x[i]
for op in self.spatial_modulation[i]:
out_ = op(out_)
out.append(out_)
else:
out.append(self.spatial_modulation[i](x[i]))
return out
class AuxHead(nn.Module):
"""Auxiliary Head.
This auxiliary head is appended to receive stronger supervision,
leading to enhanced semantics.
Args:
in_channels (int): Channel number of input features.
out_channels (int): Channel number of output features.
loss_weight (float): weight of loss for the auxiliary head.
Defaults to 0.5.
loss_cls (dict or ConfigDict): Config for building loss.
Defaults to ``dict(type='CrossEntropyLoss')``.
"""
def __init__(
self,
in_channels: int,
out_channels: int,
loss_weight: float = 0.5,
loss_cls: ConfigType = dict(type='CrossEntropyLoss')
) -> None:
super().__init__()
self.conv = ConvModule(
in_channels,
in_channels * 2, (1, 3, 3),
stride=(1, 2, 2),
padding=(0, 1, 1),
bias=False,
conv_cfg=dict(type='Conv3d'),
norm_cfg=dict(type='BN3d', requires_grad=True))
self.avg_pool = nn.AdaptiveAvgPool3d((1, 1, 1))
self.loss_weight = loss_weight
self.dropout = nn.Dropout(p=0.5)
self.fc = nn.Linear(in_channels * 2, out_channels)
self.loss_cls = MODELS.build(loss_cls)
def init_weights(self) -> None:
"""Initiate the parameters from scratch."""
for m in self.modules():
if isinstance(m, nn.Linear):
normal_init(m, std=0.01)
if isinstance(m, nn.Conv3d):
xavier_init(m, distribution='uniform')
if isinstance(m, nn.BatchNorm3d):
constant_init(m, 1)
def loss(self, x: torch.Tensor,
data_samples: Optional[SampleList]) -> dict:
"""Calculate auxiliary loss."""
x = self(x)
labels = [x.gt_label for x in data_samples]
labels = torch.stack(labels).to(x.device)
labels = labels.squeeze()
if labels.shape == torch.Size([]):
labels = labels.unsqueeze(0)
losses = dict()
losses['loss_aux'] = self.loss_weight * self.loss_cls(x, labels)
return losses
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Auxiliary head forward function."""
x = self.conv(x)
x = self.avg_pool(x).squeeze(-1).squeeze(-1).squeeze(-1)
x = self.dropout(x)
x = self.fc(x)
return x
class TemporalModulation(nn.Module):
"""Temporal Rate Modulation.
The module is used to equip TPN with a similar flexibility for temporal
tempo modulation as in the input-level frame pyramid.
Args:
in_channels (int): Channel number of input features.
out_channels (int): Channel number of output features.
downsample_scale (int): Downsample scale for maxpooling. Defaults to 8.
"""
def __init__(self,
in_channels: int,
out_channels: int,
downsample_scale: int = 8) -> None:
super().__init__()
self.conv = ConvModule(
in_channels,
out_channels, (3, 1, 1),
stride=(1, 1, 1),
padding=(1, 0, 0),
bias=False,
groups=32,
conv_cfg=dict(type='Conv3d'),
act_cfg=None)
self.pool = nn.MaxPool3d((downsample_scale, 1, 1),
(downsample_scale, 1, 1), (0, 0, 0),
ceil_mode=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Defines the computation performed at every call."""
x = self.conv(x)
x = self.pool(x)
return x
[docs]@MODELS.register_module()
class TPN(nn.Module):
"""TPN neck.
This module is proposed in `Temporal Pyramid Network for Action Recognition
<https://arxiv.org/pdf/2004.03548.pdf>`_
Args:
in_channels (Tuple[int]): Channel numbers of input features tuple.
out_channels (int): Channel number of output feature.
spatial_modulation_cfg (dict or ConfigDict, optional): Config for
spatial modulation layers. Required keys are ``in_channels`` and
``out_channels``. Defaults to None.
temporal_modulation_cfg (dict or ConfigDict, optional): Config for
temporal modulation layers. Defaults to None.
upsample_cfg (dict or ConfigDict, optional): Config for upsample
layers. The keys are same as that in :class:``nn.Upsample``.
Defaults to None.
downsample_cfg (dict or ConfigDict, optional): Config for downsample
layers. Defaults to None.
level_fusion_cfg (dict or ConfigDict, optional): Config for level
fusion layers.
Required keys are ``in_channels``, ``mid_channels``,
``out_channels``. Defaults to None.
aux_head_cfg (dict or ConfigDict, optional): Config for aux head
layers. Required keys are ``out_channels``. Defaults to None.
flow_type (str): Flow type to combine the features. Options are
``cascade`` and ``parallel``. Defaults to ``cascade``.
"""
def __init__(self,
in_channels: Tuple[int],
out_channels: int,
spatial_modulation_cfg: OptConfigType = None,
temporal_modulation_cfg: OptConfigType = None,
upsample_cfg: OptConfigType = None,
downsample_cfg: OptConfigType = None,
level_fusion_cfg: OptConfigType = None,
aux_head_cfg: OptConfigType = None,
flow_type: str = 'cascade') -> None:
super().__init__()
assert isinstance(in_channels, tuple)
assert isinstance(out_channels, int)
self.in_channels = in_channels
self.out_channels = out_channels
self.num_tpn_stages = len(in_channels)
assert spatial_modulation_cfg is None or isinstance(
spatial_modulation_cfg, dict)
assert temporal_modulation_cfg is None or isinstance(
temporal_modulation_cfg, dict)
assert upsample_cfg is None or isinstance(upsample_cfg, dict)
assert downsample_cfg is None or isinstance(downsample_cfg, dict)
assert aux_head_cfg is None or isinstance(aux_head_cfg, dict)
assert level_fusion_cfg is None or isinstance(level_fusion_cfg, dict)
if flow_type not in ['cascade', 'parallel']:
raise ValueError(
f"flow type in TPN should be 'cascade' or 'parallel', "
f'but got {flow_type} instead.')
self.flow_type = flow_type
self.temporal_modulation_ops = nn.ModuleList()
self.upsample_ops = nn.ModuleList()
self.downsample_ops = nn.ModuleList()
self.level_fusion_1 = LevelFusion(**level_fusion_cfg)
self.spatial_modulation = SpatialModulation(**spatial_modulation_cfg)
for i in range(self.num_tpn_stages):
if temporal_modulation_cfg is not None:
downsample_scale = temporal_modulation_cfg[
'downsample_scales'][i]
temporal_modulation = TemporalModulation(
in_channels[-1], out_channels, downsample_scale)
self.temporal_modulation_ops.append(temporal_modulation)
if i < self.num_tpn_stages - 1:
if upsample_cfg is not None:
upsample = nn.Upsample(**upsample_cfg)
self.upsample_ops.append(upsample)
if downsample_cfg is not None:
downsample = DownSample(out_channels, out_channels,
**downsample_cfg)
self.downsample_ops.append(downsample)
out_dims = level_fusion_cfg['out_channels']
# two pyramids
self.level_fusion_2 = LevelFusion(**level_fusion_cfg)
self.pyramid_fusion = ConvModule(
out_dims * 2,
2048,
1,
stride=1,
padding=0,
bias=False,
conv_cfg=dict(type='Conv3d'),
norm_cfg=dict(type='BN3d', requires_grad=True))
if aux_head_cfg is not None:
self.aux_head = AuxHead(self.in_channels[-2], **aux_head_cfg)
else:
self.aux_head = None
[docs] def init_weights(self) -> None:
"""Default init_weights for conv(msra) and norm in ConvModule."""
for m in self.modules():
if isinstance(m, nn.Conv3d):
xavier_init(m, distribution='uniform')
if isinstance(m, nn.BatchNorm3d):
constant_init(m, 1)
if self.aux_head is not None:
self.aux_head.init_weights()
[docs] def forward(self,
x: Tuple[torch.Tensor],
data_samples: Optional[SampleList] = None) -> tuple:
"""Defines the computation performed at every call."""
loss_aux = dict()
# Calculate auxiliary loss if `self.aux_head`
# and `data_samples` are not None.
if self.aux_head is not None and data_samples is not None:
loss_aux = self.aux_head.loss(x[-2], data_samples)
# Spatial Modulation
spatial_modulation_outs = self.spatial_modulation(x)
# Temporal Modulation
temporal_modulation_outs = []
for i, temporal_modulation in enumerate(self.temporal_modulation_ops):
temporal_modulation_outs.append(
temporal_modulation(spatial_modulation_outs[i]))
outs = [out.clone() for out in temporal_modulation_outs]
if len(self.upsample_ops) != 0:
for i in range(self.num_tpn_stages - 1, 0, -1):
outs[i - 1] = outs[i - 1] + self.upsample_ops[i - 1](outs[i])
# Get top-down outs
top_down_outs = self.level_fusion_1(outs)
# Build bottom-up flow using downsample operation
if self.flow_type == 'parallel':
outs = [out.clone() for out in temporal_modulation_outs]
if len(self.downsample_ops) != 0:
for i in range(self.num_tpn_stages - 1):
outs[i + 1] = outs[i + 1] + self.downsample_ops[i](outs[i])
# Get bottom-up outs
botton_up_outs = self.level_fusion_2(outs)
# fuse two pyramid outs
outs = self.pyramid_fusion(
torch.cat([top_down_outs, botton_up_outs], 1))
return outs, loss_aux