Source code for mmaction.models.backbones.resnet_omni
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmengine.model import BaseModel, BaseModule
from mmengine.runner import CheckpointLoader
from mmaction.registry import MODELS
from mmaction.utils import OptConfigType
def batch_norm(inputs: torch.Tensor,
module: nn.modules.batchnorm,
training: Optional[bool] = None) -> torch.Tensor:
"""Applies Batch Normalization for each channel across a batch of data
using params from the given batch normalization module.
Args:
inputs (Tensor): The input data.
module (nn.modules.batchnorm): a batch normalization module. Will use
params from this batch normalization module to do the operation.
training (bool, optional): if true, apply the train mode batch
normalization. Defaults to None and will use the training mode of
the module.
"""
if training is None:
training = module.training
return F.batch_norm(
input=inputs,
running_mean=None if training else module.running_mean,
running_var=None if training else module.running_var,
weight=module.weight,
bias=module.bias,
training=training,
momentum=module.momentum,
eps=module.eps)
class BottleNeck(BaseModule):
"""Building block for Omni-ResNet.
Args:
inplanes (int): Number of channels for the input in first conv layer.
planes (int): Number of channels for the input in second conv layer.
temporal_kernel (int): Temporal kernel in the conv layer. Should be
either 1 or 3. Defaults to 1.
spatial_stride (int): Spatial stride in the conv layer. Defaults to 1.
init_cfg (dict or ConfigDict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
inplanes: int,
planes: int,
temporal_kernel: int = 3,
spatial_stride: int = 1,
init_cfg: OptConfigType = None,
**kwargs) -> None:
super(BottleNeck, self).__init__(init_cfg=init_cfg)
assert temporal_kernel in [1, 3]
self.conv1 = nn.Conv3d(
inplanes,
planes,
kernel_size=(temporal_kernel, 1, 1),
padding=(temporal_kernel // 2, 0, 0),
bias=False)
self.conv2 = nn.Conv3d(
planes,
planes,
stride=(1, spatial_stride, spatial_stride),
kernel_size=(1, 3, 3),
padding=(0, 1, 1),
bias=False)
self.conv3 = nn.Conv3d(planes, planes * 4, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm3d(planes, momentum=0.01)
self.bn2 = nn.BatchNorm3d(planes, momentum=0.01)
self.bn3 = nn.BatchNorm3d(planes * 4, momentum=0.01)
if inplanes != planes * 4 or spatial_stride != 1:
downsample = [
nn.Conv3d(
inplanes,
planes * 4,
kernel_size=1,
stride=(1, spatial_stride, spatial_stride),
bias=False),
nn.BatchNorm3d(planes * 4, momentum=0.01)
]
self.downsample = nn.Sequential(*downsample)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Defines the computation performed at every call.
Accept both 3D (BCTHW for videos) and 2D (BCHW for images) tensors.
"""
if x.ndim == 4:
return self.forward_2d(x)
# Forward call for 3D tensors.
out = self.conv1(x)
out = self.bn1(out).relu_()
out = self.conv2(out)
out = self.bn2(out).relu_()
out = self.conv3(out)
out = self.bn3(out)
if hasattr(self, 'downsample'):
x = self.downsample(x)
return out.add_(x).relu_()
def forward_2d(self, x: torch.Tensor) -> torch.Tensor:
"""Forward call for 2D tensors."""
out = F.conv2d(x, self.conv1.weight.sum(2))
out = batch_norm(out, self.bn1).relu_()
out = F.conv2d(
out,
self.conv2.weight.squeeze(2),
stride=self.conv2.stride[-1],
padding=1)
out = batch_norm(out, self.bn2).relu_()
out = F.conv2d(out, self.conv3.weight.squeeze(2))
out = batch_norm(out, self.bn3)
if hasattr(self, 'downsample'):
x = F.conv2d(
x,
self.downsample[0].weight.squeeze(2),
stride=self.downsample[0].stride[-1])
x = batch_norm(x, self.downsample[1])
return out.add_(x).relu_()
[docs]@MODELS.register_module()
class OmniResNet(BaseModel):
"""Omni-ResNet that accepts both image and video inputs.
Args:
layers (List[int]): number of layers in each residual stages. Defaults
to [3, 4, 6, 3].
pretrain_2d (str, optional): path to the 2D pretraining checkpoints.
Defaults to None.
init_cfg (dict or ConfigDict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
layers: List[int] = [3, 4, 6, 3],
pretrain_2d: Optional[str] = None,
init_cfg: OptConfigType = None) -> None:
super(OmniResNet, self).__init__(init_cfg=init_cfg)
self.inplanes = 64
self.conv1 = nn.Conv3d(
3,
self.inplanes,
kernel_size=(1, 7, 7),
stride=(1, 2, 2),
padding=(0, 3, 3),
bias=False)
self.bn1 = nn.BatchNorm3d(self.inplanes, momentum=0.01)
self.pool3d = nn.MaxPool3d((1, 3, 3), (1, 2, 2), (0, 1, 1))
self.pool2d = nn.MaxPool2d(3, 2, 1)
self.temporal_kernel = 1
self.layer1 = self._make_layer(64, layers[0])
self.layer2 = self._make_layer(128, layers[1], stride=2)
self.temporal_kernel = 3
self.layer3 = self._make_layer(256, layers[2], stride=2)
self.layer4 = self._make_layer(512, layers[3], stride=2)
if pretrain_2d is not None:
self.init_from_2d(pretrain_2d)
def _make_layer(self,
planes: int,
num_blocks: int,
stride: int = 1) -> nn.Module:
layers = [
BottleNeck(
self.inplanes,
planes,
spatial_stride=stride,
temporal_kernel=self.temporal_kernel)
]
self.inplanes = planes * 4
for _ in range(1, num_blocks):
layers.append(
BottleNeck(
self.inplanes,
planes,
temporal_kernel=self.temporal_kernel))
return nn.Sequential(*layers)
def init_from_2d(self, pretrain: str) -> None:
param2d = CheckpointLoader.load_checkpoint(
pretrain, map_location='cpu')
param3d = self.state_dict()
for key in param3d:
if key in param2d:
weight = param2d[key]
if weight.ndim == 4:
t = param3d[key].shape[2]
weight = weight.unsqueeze(2)
weight = weight.expand(-1, -1, t, -1, -1)
weight = weight / t
param3d[key] = weight
self.load_state_dict(param3d)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Defines the computation performed at every call.
Accept both 3D (BCTHW for videos) and 2D (BCHW for images) tensors.
"""
if x.ndim == 4:
return self.forward_2d(x)
# Forward call for 3D tensors.
x = self.conv1(x)
x = self.bn1(x).relu_()
x = self.pool3d(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x
[docs] def forward_2d(self, x: torch.Tensor) -> torch.Tensor:
"""Forward call for 2D tensors."""
x = F.conv2d(
x,
self.conv1.weight.squeeze(2),
stride=self.conv1.stride[-1],
padding=self.conv1.padding[-1])
x = batch_norm(x, self.bn1).relu_()
x = self.pool2d(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
return x