Shortcuts

Source code for mmaction.models.backbones.resnet2plus1d

# Copyright (c) OpenMMLab. All rights reserved.
from mmaction.registry import MODELS
from mmaction.utils import get_str_type
from .resnet3d import ResNet3d


[docs]@MODELS.register_module() class ResNet2Plus1d(ResNet3d): """ResNet (2+1)d backbone. This model is proposed in `A Closer Look at Spatiotemporal Convolutions for Action Recognition <https://arxiv.org/abs/1711.11248>`_ """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) assert self.pretrained2d is False assert get_str_type(self.conv_cfg['type']) == 'Conv2plus1d' def _freeze_stages(self): """Prevent all the parameters from being optimized before ``self.frozen_stages``.""" if self.frozen_stages >= 0: self.conv1.eval() for param in self.conv1.parameters(): param.requires_grad = False for i in range(1, self.frozen_stages + 1): m = getattr(self, f'layer{i}') m.eval() for param in m.parameters(): param.requires_grad = False
[docs] def forward(self, x): """Defines the computation performed at every call. Args: x (torch.Tensor): The input data. Returns: torch.Tensor: The feature of the input samples extracted by the backbone. """ x = self.conv1(x) x = self.maxpool(x) for layer_name in self.res_layers: res_layer = getattr(self, layer_name) # no pool2 in R(2+1)d x = res_layer(x) return x