Shortcuts

Source code for mmaction.models.backbones.resnet_omni

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModel, BaseModule
from mmengine.runner import CheckpointLoader

from mmaction.registry import MODELS
from mmaction.utils import OptConfigType


def batch_norm(inputs: torch.Tensor,
               module: nn.modules.batchnorm,
               training: Optional[bool] = None) -> torch.Tensor:
    """Applies Batch Normalization for each channel across a batch of data
    using params from the given batch normalization module.

    Args:
        inputs (Tensor): The input data.
        module (nn.modules.batchnorm): a batch normalization module. Will use
            params from this batch normalization module to do the operation.
        training (bool, optional): if true, apply the train mode batch
            normalization. Defaults to None and will use the training mode of
            the module.
    """
    if training is None:
        training = module.training
    return F.batch_norm(
        input=inputs,
        running_mean=None if training else module.running_mean,
        running_var=None if training else module.running_var,
        weight=module.weight,
        bias=module.bias,
        training=training,
        momentum=module.momentum,
        eps=module.eps)


class BottleNeck(BaseModule):
    """Building block for Omni-ResNet.

    Args:
        inplanes (int): Number of channels for the input in first conv layer.
        planes (int): Number of channels for the input in second conv layer.
        temporal_kernel (int): Temporal kernel in the conv layer. Should be
            either 1 or 3. Defaults to 1.
        spatial_stride (int): Spatial stride in the conv layer. Defaults to 1.
        init_cfg (dict or ConfigDict, optional): The Config for initialization.
            Defaults to None.
    """

    def __init__(self,
                 inplanes: int,
                 planes: int,
                 temporal_kernel: int = 3,
                 spatial_stride: int = 1,
                 init_cfg: OptConfigType = None,
                 **kwargs) -> None:
        super(BottleNeck, self).__init__(init_cfg=init_cfg)
        assert temporal_kernel in [1, 3]

        self.conv1 = nn.Conv3d(
            inplanes,
            planes,
            kernel_size=(temporal_kernel, 1, 1),
            padding=(temporal_kernel // 2, 0, 0),
            bias=False)
        self.conv2 = nn.Conv3d(
            planes,
            planes,
            stride=(1, spatial_stride, spatial_stride),
            kernel_size=(1, 3, 3),
            padding=(0, 1, 1),
            bias=False)

        self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)

        self.bn1 = nn.BatchNorm3d(planes, momentum=0.01)
        self.bn2 = nn.BatchNorm3d(planes, momentum=0.01)
        self.bn3 = nn.BatchNorm3d(planes * 4, momentum=0.01)

        if inplanes != planes * 4 or spatial_stride != 1:
            downsample = [
                nn.Conv3d(
                    inplanes,
                    planes * 4,
                    kernel_size=1,
                    stride=(1, spatial_stride, spatial_stride),
                    bias=False),
                nn.BatchNorm3d(planes * 4, momentum=0.01)
            ]
            self.downsample = nn.Sequential(*downsample)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Defines the computation performed at every call.

        Accept both 3D (BCTHW for videos) and 2D (BCHW for images) tensors.
        """
        if x.ndim == 4:
            return self.forward_2d(x)

        # Forward call for 3D tensors.
        out = self.conv1(x)
        out = self.bn1(out).relu_()

        out = self.conv2(out)
        out = self.bn2(out).relu_()

        out = self.conv3(out)
        out = self.bn3(out)

        if hasattr(self, 'downsample'):
            x = self.downsample(x)

        return out.add_(x).relu_()

    def forward_2d(self, x: torch.Tensor) -> torch.Tensor:
        """Forward call for 2D tensors."""
        out = F.conv2d(x, self.conv1.weight.sum(2))
        out = batch_norm(out, self.bn1).relu_()

        out = F.conv2d(
            out,
            self.conv2.weight.squeeze(2),
            stride=self.conv2.stride[-1],
            padding=1)
        out = batch_norm(out, self.bn2).relu_()

        out = F.conv2d(out, self.conv3.weight.squeeze(2))
        out = batch_norm(out, self.bn3)

        if hasattr(self, 'downsample'):
            x = F.conv2d(
                x,
                self.downsample[0].weight.squeeze(2),
                stride=self.downsample[0].stride[-1])
            x = batch_norm(x, self.downsample[1])

        return out.add_(x).relu_()


[docs]@MODELS.register_module() class OmniResNet(BaseModel): """Omni-ResNet that accepts both image and video inputs. Args: layers (List[int]): number of layers in each residual stages. Defaults to [3, 4, 6, 3]. pretrain_2d (str, optional): path to the 2D pretraining checkpoints. Defaults to None. init_cfg (dict or ConfigDict, optional): The Config for initialization. Defaults to None. """ def __init__(self, layers: List[int] = [3, 4, 6, 3], pretrain_2d: Optional[str] = None, init_cfg: OptConfigType = None) -> None: super(OmniResNet, self).__init__(init_cfg=init_cfg) self.inplanes = 64 self.conv1 = nn.Conv3d( 3, self.inplanes, kernel_size=(1, 7, 7), stride=(1, 2, 2), padding=(0, 3, 3), bias=False) self.bn1 = nn.BatchNorm3d(self.inplanes, momentum=0.01) self.pool3d = nn.MaxPool3d((1, 3, 3), (1, 2, 2), (0, 1, 1)) self.pool2d = nn.MaxPool2d(3, 2, 1) self.temporal_kernel = 1 self.layer1 = self._make_layer(64, layers[0]) self.layer2 = self._make_layer(128, layers[1], stride=2) self.temporal_kernel = 3 self.layer3 = self._make_layer(256, layers[2], stride=2) self.layer4 = self._make_layer(512, layers[3], stride=2) if pretrain_2d is not None: self.init_from_2d(pretrain_2d) def _make_layer(self, planes: int, num_blocks: int, stride: int = 1) -> nn.Module: layers = [ BottleNeck( self.inplanes, planes, spatial_stride=stride, temporal_kernel=self.temporal_kernel) ] self.inplanes = planes * 4 for _ in range(1, num_blocks): layers.append( BottleNeck( self.inplanes, planes, temporal_kernel=self.temporal_kernel)) return nn.Sequential(*layers) def init_from_2d(self, pretrain: str) -> None: param2d = CheckpointLoader.load_checkpoint( pretrain, map_location='cpu') param3d = self.state_dict() for key in param3d: if key in param2d: weight = param2d[key] if weight.ndim == 4: t = param3d[key].shape[2] weight = weight.unsqueeze(2) weight = weight.expand(-1, -1, t, -1, -1) weight = weight / t param3d[key] = weight self.load_state_dict(param3d)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Defines the computation performed at every call. Accept both 3D (BCTHW for videos) and 2D (BCHW for images) tensors. """ if x.ndim == 4: return self.forward_2d(x) # Forward call for 3D tensors. x = self.conv1(x) x = self.bn1(x).relu_() x = self.pool3d(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x
[docs] def forward_2d(self, x: torch.Tensor) -> torch.Tensor: """Forward call for 2D tensors.""" x = F.conv2d( x, self.conv1.weight.squeeze(2), stride=self.conv1.stride[-1], padding=self.conv1.padding[-1]) x = batch_norm(x, self.bn1).relu_() x = self.pool2d(x) x = self.layer1(x) x = self.layer2(x) x = self.layer3(x) x = self.layer4(x) return x