Source code for mmaction.models.backbones.rgbposeconv3d
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Union
import torch
import torch.nn as nn
from mmengine.logging import MMLogger, print_log
from mmengine.model import BaseModule
from mmengine.model.weight_init import constant_init, kaiming_init
from mmengine.runner.checkpoint import load_checkpoint
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
from mmaction.registry import MODELS
from .resnet3d_slowfast import ResNet3dPathway
[docs]@MODELS.register_module()
class RGBPoseConv3D(BaseModule):
"""RGBPoseConv3D backbone.
Args:
pretrained (str): The file path to a pretrained model.
Defaults to None.
speed_ratio (int): Speed ratio indicating the ratio between time
dimension of the fast and slow pathway, corresponding to the
:math:`\\alpha` in the paper. Defaults to 4.
channel_ratio (int): Reduce the channel number of fast pathway
by ``channel_ratio``, corresponding to :math:`\\beta` in the paper.
Defaults to 4.
rgb_detach (bool): Whether to detach the gradients from the pose path.
Defaults to False.
pose_detach (bool): Whether to detach the gradients from the rgb path.
Defaults to False.
rgb_drop_path (float): The drop rate for dropping the features from
the pose path. Defaults to 0.
pose_drop_path (float): The drop rate for dropping the features from
the rgb path. Defaults to 0.
rgb_pathway (dict): Configuration of rgb branch. Defaults to
``dict(num_stages=4, lateral=True, lateral_infl=1,
lateral_activate=(0, 0, 1, 1), fusion_kernel=7, base_channels=64,
conv1_kernel=(1, 7, 7), inflate=(0, 0, 1, 1), with_pool2=False)``.
pose_pathway (dict): Configuration of pose branch. Defaults to
``dict(num_stages=3, stage_blocks=(4, 6, 3), lateral=True,
lateral_inv=True, lateral_infl=16, lateral_activate=(0, 1, 1),
fusion_kernel=7, in_channels=17, base_channels=32,
out_indices=(2, ), conv1_kernel=(1, 7, 7), conv1_stride_s=1,
conv1_stride_t=1, pool1_stride_s=1, pool1_stride_t=1,
inflate=(0, 1, 1), spatial_strides=(2, 2, 2),
temporal_strides=(1, 1, 1), with_pool2=False)``.
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(self,
pretrained: Optional[str] = None,
speed_ratio: int = 4,
channel_ratio: int = 4,
rgb_detach: bool = False,
pose_detach: bool = False,
rgb_drop_path: float = 0,
pose_drop_path: float = 0,
rgb_pathway: Dict = dict(
num_stages=4,
lateral=True,
lateral_infl=1,
lateral_activate=(0, 0, 1, 1),
fusion_kernel=7,
base_channels=64,
conv1_kernel=(1, 7, 7),
inflate=(0, 0, 1, 1),
with_pool2=False),
pose_pathway: Dict = dict(
num_stages=3,
stage_blocks=(4, 6, 3),
lateral=True,
lateral_inv=True,
lateral_infl=16,
lateral_activate=(0, 1, 1),
fusion_kernel=7,
in_channels=17,
base_channels=32,
out_indices=(2, ),
conv1_kernel=(1, 7, 7),
conv1_stride_s=1,
conv1_stride_t=1,
pool1_stride_s=1,
pool1_stride_t=1,
inflate=(0, 1, 1),
spatial_strides=(2, 2, 2),
temporal_strides=(1, 1, 1),
dilations=(1, 1, 1),
with_pool2=False),
init_cfg: Optional[Union[Dict, List[Dict]]] = None) -> None:
super().__init__(init_cfg=init_cfg)
self.pretrained = pretrained
self.speed_ratio = speed_ratio
self.channel_ratio = channel_ratio
if rgb_pathway['lateral']:
rgb_pathway['speed_ratio'] = speed_ratio
rgb_pathway['channel_ratio'] = channel_ratio
if pose_pathway['lateral']:
pose_pathway['speed_ratio'] = speed_ratio
pose_pathway['channel_ratio'] = channel_ratio
self.rgb_path = ResNet3dPathway(**rgb_pathway)
self.pose_path = ResNet3dPathway(**pose_pathway)
self.rgb_detach = rgb_detach
self.pose_detach = pose_detach
assert 0 <= rgb_drop_path <= 1
assert 0 <= pose_drop_path <= 1
self.rgb_drop_path = rgb_drop_path
self.pose_drop_path = pose_drop_path
[docs] def init_weights(self) -> None:
"""Initiate the parameters either from existing checkpoint or from
scratch."""
for m in self.modules():
if isinstance(m, nn.Conv3d):
kaiming_init(m)
elif isinstance(m, _BatchNorm):
constant_init(m, 1)
if isinstance(self.pretrained, str):
logger = MMLogger.get_current_instance()
msg = f'load model from: {self.pretrained}'
print_log(msg, logger=logger)
load_checkpoint(self, self.pretrained, strict=True, logger=logger)
elif self.pretrained is None:
# Init two branch separately.
self.rgb_path.init_weights()
self.pose_path.init_weights()
else:
raise TypeError('pretrained must be a str or None')
[docs] def forward(self, imgs: torch.Tensor, heatmap_imgs: torch.Tensor) -> tuple:
"""Defines the computation performed at every call.
Args:
imgs (torch.Tensor): The input data.
heatmap_imgs (torch.Tensor): The input data.
Returns:
tuple[torch.Tensor]: The feature of the input
samples extracted by the backbone.
"""
if self.training:
rgb_drop_path = torch.rand(1) < self.rgb_drop_path
pose_drop_path = torch.rand(1) < self.pose_drop_path
else:
rgb_drop_path, pose_drop_path = False, False
# We assume base_channel for RGB and Pose are 64 and 32.
x_rgb = self.rgb_path.conv1(imgs)
x_rgb = self.rgb_path.maxpool(x_rgb)
# N x 64 x 8 x 56 x 56
x_pose = self.pose_path.conv1(heatmap_imgs)
x_pose = self.pose_path.maxpool(x_pose)
x_rgb = self.rgb_path.layer1(x_rgb)
x_rgb = self.rgb_path.layer2(x_rgb)
x_pose = self.pose_path.layer1(x_pose)
if hasattr(self.rgb_path, 'layer2_lateral'):
feat = x_pose.detach() if self.rgb_detach else x_pose
x_pose_lateral = self.rgb_path.layer2_lateral(feat)
if rgb_drop_path:
x_pose_lateral = x_pose_lateral.new_zeros(x_pose_lateral.shape)
if hasattr(self.pose_path, 'layer1_lateral'):
feat = x_rgb.detach() if self.pose_detach else x_rgb
x_rgb_lateral = self.pose_path.layer1_lateral(feat)
if pose_drop_path:
x_rgb_lateral = x_rgb_lateral.new_zeros(x_rgb_lateral.shape)
if hasattr(self.rgb_path, 'layer2_lateral'):
x_rgb = torch.cat((x_rgb, x_pose_lateral), dim=1)
if hasattr(self.pose_path, 'layer1_lateral'):
x_pose = torch.cat((x_pose, x_rgb_lateral), dim=1)
x_rgb = self.rgb_path.layer3(x_rgb)
x_pose = self.pose_path.layer2(x_pose)
if hasattr(self.rgb_path, 'layer3_lateral'):
feat = x_pose.detach() if self.rgb_detach else x_pose
x_pose_lateral = self.rgb_path.layer3_lateral(feat)
if rgb_drop_path:
x_pose_lateral = x_pose_lateral.new_zeros(x_pose_lateral.shape)
if hasattr(self.pose_path, 'layer2_lateral'):
feat = x_rgb.detach() if self.pose_detach else x_rgb
x_rgb_lateral = self.pose_path.layer2_lateral(feat)
if pose_drop_path:
x_rgb_lateral = x_rgb_lateral.new_zeros(x_rgb_lateral.shape)
if hasattr(self.rgb_path, 'layer3_lateral'):
x_rgb = torch.cat((x_rgb, x_pose_lateral), dim=1)
if hasattr(self.pose_path, 'layer2_lateral'):
x_pose = torch.cat((x_pose, x_rgb_lateral), dim=1)
x_rgb = self.rgb_path.layer4(x_rgb)
x_pose = self.pose_path.layer3(x_pose)
return x_rgb, x_pose