Shortcuts

Source code for mmaction.models.backbones.x3d

# Copyright (c) OpenMMLab. All rights reserved.
import math

import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule, Swish, build_activation_layer
from mmengine.logging import MMLogger
from mmengine.model.weight_init import constant_init, kaiming_init
from mmengine.runner import load_checkpoint
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm

from mmaction.registry import MODELS


class SEModule(nn.Module):

    def __init__(self, channels, reduction):
        super().__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        self.bottleneck = self._round_width(channels, reduction)
        self.fc1 = nn.Conv3d(
            channels, self.bottleneck, kernel_size=1, padding=0)
        self.relu = nn.ReLU()
        self.fc2 = nn.Conv3d(
            self.bottleneck, channels, kernel_size=1, padding=0)
        self.sigmoid = nn.Sigmoid()

    @staticmethod
    def _round_width(width, multiplier, min_width=8, divisor=8):
        """Round width of filters based on width multiplier."""
        width *= multiplier
        min_width = min_width or divisor
        width_out = max(min_width,
                        int(width + divisor / 2) // divisor * divisor)
        if width_out < 0.9 * width:
            width_out += divisor
        return int(width_out)

    def forward(self, x):
        """Defines the computation performed at every call.

        Args:
            x (Tensor): The input data.

        Returns:
            Tensor: The output of the module.
        """
        module_input = x
        x = self.avg_pool(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.sigmoid(x)
        return module_input * x


class BlockX3D(nn.Module):
    """BlockX3D 3d building block for X3D.

    Args:
        inplanes (int): Number of channels for the input in first conv3d layer.
        planes (int): Number of channels produced by some norm/conv3d layers.
        outplanes (int): Number of channels produced by final the conv3d layer.
        spatial_stride (int): Spatial stride in the conv3d layer. Default: 1.
        downsample (nn.Module | None): Downsample layer. Default: None.
        se_ratio (float | None): The reduction ratio of squeeze and excitation
            unit. If set as None, it means not using SE unit. Default: None.
        use_swish (bool): Whether to use swish as the activation function
            before and after the 3x3x3 conv. Default: True.
        conv_cfg (dict): Config dict for convolution layer.
            Default: ``dict(type='Conv3d')``.
        norm_cfg (dict): Config for norm layers. required keys are ``type``,
            Default: ``dict(type='BN3d')``.
        act_cfg (dict): Config dict for activation layer.
            Default: ``dict(type='ReLU')``.
        with_cp (bool): Use checkpoint or not. Using checkpoint will save some
            memory while slowing down the training speed. Default: False.
    """

    def __init__(self,
                 inplanes,
                 planes,
                 outplanes,
                 spatial_stride=1,
                 downsample=None,
                 se_ratio=None,
                 use_swish=True,
                 conv_cfg=dict(type='Conv3d'),
                 norm_cfg=dict(type='BN3d'),
                 act_cfg=dict(type='ReLU'),
                 with_cp=False):
        super().__init__()

        self.inplanes = inplanes
        self.planes = planes
        self.outplanes = outplanes
        self.spatial_stride = spatial_stride
        self.downsample = downsample
        self.se_ratio = se_ratio
        self.use_swish = use_swish
        self.conv_cfg = conv_cfg
        self.norm_cfg = norm_cfg
        self.act_cfg = act_cfg
        self.act_cfg_swish = dict(type='Swish')
        self.with_cp = with_cp

        self.conv1 = ConvModule(
            in_channels=inplanes,
            out_channels=planes,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=self.act_cfg)
        # Here we use the channel-wise conv
        self.conv2 = ConvModule(
            in_channels=planes,
            out_channels=planes,
            kernel_size=3,
            stride=(1, self.spatial_stride, self.spatial_stride),
            padding=1,
            groups=planes,
            bias=False,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=None)

        self.swish = Swish()

        self.conv3 = ConvModule(
            in_channels=planes,
            out_channels=outplanes,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False,
            conv_cfg=self.conv_cfg,
            norm_cfg=self.norm_cfg,
            act_cfg=None)

        if self.se_ratio is not None:
            self.se_module = SEModule(planes, self.se_ratio)

        self.relu = build_activation_layer(self.act_cfg)

    def forward(self, x):
        """Defines the computation performed at every call."""

        def _inner_forward(x):
            """Forward wrapper for utilizing checkpoint."""
            identity = x

            out = self.conv1(x)
            out = self.conv2(out)
            if self.se_ratio is not None:
                out = self.se_module(out)

            out = self.swish(out)

            out = self.conv3(out)

            if self.downsample is not None:
                identity = self.downsample(x)

            out = out + identity
            return out

        if self.with_cp and x.requires_grad:
            out = cp.checkpoint(_inner_forward, x)
        else:
            out = _inner_forward(x)
        out = self.relu(out)
        return out


# We do not support initialize with 2D pretrain weight for X3D
[docs]@MODELS.register_module() class X3D(nn.Module): """X3D backbone. https://arxiv.org/pdf/2004.04730.pdf. Args: gamma_w (float): Global channel width expansion factor. Default: 1. gamma_b (float): Bottleneck channel width expansion factor. Default: 1. gamma_d (float): Network depth expansion factor. Default: 1. pretrained (str | None): Name of pretrained model. Default: None. in_channels (int): Channel num of input features. Default: 3. num_stages (int): Resnet stages. Default: 4. spatial_strides (Sequence[int]): Spatial strides of residual blocks of each stage. Default: ``(1, 2, 2, 2)``. frozen_stages (int): Stages to be frozen (all param fixed). If set to -1, it means not freezing any parameters. Default: -1. se_style (str): The style of inserting SE modules into BlockX3D, 'half' denotes insert into half of the blocks, while 'all' denotes insert into all blocks. Default: 'half'. se_ratio (float | None): The reduction ratio of squeeze and excitation unit. If set as None, it means not using SE unit. Default: 1 / 16. use_swish (bool): Whether to use swish as the activation function before and after the 3x3x3 conv. Default: True. conv_cfg (dict): Config for conv layers. required keys are ``type`` Default: ``dict(type='Conv3d')``. norm_cfg (dict): Config for norm layers. required keys are ``type`` and ``requires_grad``. Default: ``dict(type='BN3d', requires_grad=True)``. act_cfg (dict): Config dict for activation layer. Default: ``dict(type='ReLU', inplace=True)``. norm_eval (bool): Whether to set BN layers to eval mode, namely, freeze running stats (mean and var). Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. zero_init_residual (bool): Whether to use zero initialization for residual block, Default: True. kwargs (dict, optional): Key arguments for "make_res_layer". """ def __init__(self, gamma_w=1.0, gamma_b=1.0, gamma_d=1.0, pretrained=None, in_channels=3, num_stages=4, spatial_strides=(2, 2, 2, 2), frozen_stages=-1, se_style='half', se_ratio=1 / 16, use_swish=True, conv_cfg=dict(type='Conv3d'), norm_cfg=dict(type='BN3d', requires_grad=True), act_cfg=dict(type='ReLU', inplace=True), norm_eval=False, with_cp=False, zero_init_residual=True, **kwargs): super().__init__() self.gamma_w = gamma_w self.gamma_b = gamma_b self.gamma_d = gamma_d self.pretrained = pretrained self.in_channels = in_channels # Hard coded, can be changed by gamma_w self.base_channels = 24 self.stage_blocks = [1, 2, 5, 3] # apply parameters gamma_w and gamma_d self.base_channels = self._round_width(self.base_channels, self.gamma_w) self.stage_blocks = [ self._round_repeats(x, self.gamma_d) for x in self.stage_blocks ] self.num_stages = num_stages assert 1 <= num_stages <= 4 self.spatial_strides = spatial_strides assert len(spatial_strides) == num_stages self.frozen_stages = frozen_stages self.se_style = se_style assert self.se_style in ['all', 'half'] self.se_ratio = se_ratio assert (self.se_ratio is None) or (self.se_ratio > 0) self.use_swish = use_swish self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.act_cfg = act_cfg self.norm_eval = norm_eval self.with_cp = with_cp self.zero_init_residual = zero_init_residual self.block = BlockX3D self.stage_blocks = self.stage_blocks[:num_stages] self.layer_inplanes = self.base_channels self._make_stem_layer() self.res_layers = [] for i, num_blocks in enumerate(self.stage_blocks): spatial_stride = spatial_strides[i] inplanes = self.base_channels * 2**i planes = int(inplanes * self.gamma_b) res_layer = self.make_res_layer( self.block, self.layer_inplanes, inplanes, planes, num_blocks, spatial_stride=spatial_stride, se_style=self.se_style, se_ratio=self.se_ratio, use_swish=self.use_swish, norm_cfg=self.norm_cfg, conv_cfg=self.conv_cfg, act_cfg=self.act_cfg, with_cp=with_cp, **kwargs) self.layer_inplanes = inplanes layer_name = f'layer{i + 1}' self.add_module(layer_name, res_layer) self.res_layers.append(layer_name) self.feat_dim = self.base_channels * 2**(len(self.stage_blocks) - 1) self.conv5 = ConvModule( self.feat_dim, int(self.feat_dim * self.gamma_b), kernel_size=1, stride=1, padding=0, bias=False, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) self.feat_dim = int(self.feat_dim * self.gamma_b) @staticmethod def _round_width(width, multiplier, min_depth=8, divisor=8): """Round width of filters based on width multiplier.""" if not multiplier: return width width *= multiplier min_depth = min_depth or divisor new_filters = max(min_depth, int(width + divisor / 2) // divisor * divisor) if new_filters < 0.9 * width: new_filters += divisor return int(new_filters) @staticmethod def _round_repeats(repeats, multiplier): """Round number of layers based on depth multiplier.""" if not multiplier: return repeats return int(math.ceil(multiplier * repeats)) # the module is parameterized with gamma_b # no temporal_stride
[docs] def make_res_layer(self, block, layer_inplanes, inplanes, planes, blocks, spatial_stride=1, se_style='half', se_ratio=None, use_swish=True, norm_cfg=None, act_cfg=None, conv_cfg=None, with_cp=False, **kwargs): """Build residual layer for ResNet3D. Args: block (nn.Module): Residual module to be built. layer_inplanes (int): Number of channels for the input feature of the res layer. inplanes (int): Number of channels for the input feature in each block, which equals to base_channels * gamma_w. planes (int): Number of channels for the output feature in each block, which equals to base_channel * gamma_w * gamma_b. blocks (int): Number of residual blocks. spatial_stride (int): Spatial strides in residual and conv layers. Default: 1. se_style (str): The style of inserting SE modules into BlockX3D, 'half' denotes insert into half of the blocks, while 'all' denotes insert into all blocks. Default: 'half'. se_ratio (float | None): The reduction ratio of squeeze and excitation unit. If set as None, it means not using SE unit. Default: None. use_swish (bool): Whether to use swish as the activation function before and after the 3x3x3 conv. Default: True. conv_cfg (dict | None): Config for norm layers. Default: None. norm_cfg (dict | None): Config for norm layers. Default: None. act_cfg (dict | None): Config for activate layers. Default: None. with_cp (bool | None): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. Returns: nn.Module: A residual layer for the given config. """ downsample = None if spatial_stride != 1 or layer_inplanes != inplanes: downsample = ConvModule( layer_inplanes, inplanes, kernel_size=1, stride=(1, spatial_stride, spatial_stride), padding=0, bias=False, conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=None) use_se = [False] * blocks if self.se_style == 'all': use_se = [True] * blocks elif self.se_style == 'half': use_se = [i % 2 == 0 for i in range(blocks)] else: raise NotImplementedError layers = [] layers.append( block( layer_inplanes, planes, inplanes, spatial_stride=spatial_stride, downsample=downsample, se_ratio=se_ratio if use_se[0] else None, use_swish=use_swish, norm_cfg=norm_cfg, conv_cfg=conv_cfg, act_cfg=act_cfg, with_cp=with_cp, **kwargs)) for i in range(1, blocks): layers.append( block( inplanes, planes, inplanes, spatial_stride=1, se_ratio=se_ratio if use_se[i] else None, use_swish=use_swish, norm_cfg=norm_cfg, conv_cfg=conv_cfg, act_cfg=act_cfg, with_cp=with_cp, **kwargs)) return nn.Sequential(*layers)
def _make_stem_layer(self): """Construct the stem layers consists of a conv+norm+act module and a pooling layer.""" self.conv1_s = ConvModule( self.in_channels, self.base_channels, kernel_size=(1, 3, 3), stride=(1, 2, 2), padding=(0, 1, 1), bias=False, conv_cfg=self.conv_cfg, norm_cfg=None, act_cfg=None) self.conv1_t = ConvModule( self.base_channels, self.base_channels, kernel_size=(5, 1, 1), stride=(1, 1, 1), padding=(2, 0, 0), groups=self.base_channels, bias=False, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, act_cfg=self.act_cfg) def _freeze_stages(self): """Prevent all the parameters from being optimized before ``self.frozen_stages``.""" if self.frozen_stages >= 0: self.conv1_s.eval() self.conv1_t.eval() for param in self.conv1_s.parameters(): param.requires_grad = False for param in self.conv1_t.parameters(): param.requires_grad = False for i in range(1, self.frozen_stages + 1): m = getattr(self, f'layer{i}') m.eval() for param in m.parameters(): param.requires_grad = False
[docs] def init_weights(self): """Initiate the parameters either from existing checkpoint or from scratch.""" if isinstance(self.pretrained, str): logger = MMLogger.get_current_instance() logger.info(f'load model from: {self.pretrained}') load_checkpoint(self, self.pretrained, strict=False, logger=logger) elif self.pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv3d): kaiming_init(m) elif isinstance(m, _BatchNorm): constant_init(m, 1) if self.zero_init_residual: for m in self.modules(): if isinstance(m, BlockX3D): constant_init(m.conv3.bn, 0) else: raise TypeError('pretrained must be a str or None')
[docs] def forward(self, x): """Defines the computation performed at every call. Args: x (torch.Tensor): The input data. Returns: torch.Tensor: The feature of the input samples extracted by the backbone. """ x = self.conv1_s(x) x = self.conv1_t(x) for layer_name in self.res_layers: res_layer = getattr(self, layer_name) x = res_layer(x) x = self.conv5(x) return x
[docs] def train(self, mode=True): """Set the optimization status when training.""" super().train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): if isinstance(m, _BatchNorm): m.eval()