Shortcuts

Source code for mmaction.models.backbones.timesformer

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmcv.cnn.bricks.transformer import build_transformer_layer_sequence
from mmengine import ConfigDict
from mmengine.logging import MMLogger
from mmengine.model.weight_init import kaiming_init, trunc_normal_
from mmengine.runner.checkpoint import _load_checkpoint, load_state_dict
from torch.nn.modules.utils import _pair

from mmaction.registry import MODELS


class PatchEmbed(nn.Module):
    """Image to Patch Embedding.

    Args:
        img_size (int | tuple): Size of input image.
        patch_size (int): Size of one patch.
        in_channels (int): Channel num of input features. Defaults to 3.
        embed_dims (int): Dimensions of embedding. Defaults to 768.
        conv_cfg (dict | None): Config dict for convolution layer. Defaults to
            `dict(type='Conv2d')`.
    """

    def __init__(self,
                 img_size,
                 patch_size,
                 in_channels=3,
                 embed_dims=768,
                 conv_cfg=dict(type='Conv2d')):
        super().__init__()
        self.img_size = _pair(img_size)
        self.patch_size = _pair(patch_size)

        num_patches = (self.img_size[1] // self.patch_size[1]) * (
            self.img_size[0] // self.patch_size[0])
        assert num_patches * self.patch_size[0] * self.patch_size[1] == \
               self.img_size[0] * self.img_size[1], \
               'The image size H*W must be divisible by patch size'
        self.num_patches = num_patches

        # Use conv layer to embed
        self.projection = build_conv_layer(
            conv_cfg,
            in_channels,
            embed_dims,
            kernel_size=patch_size,
            stride=patch_size)

        self.init_weights()

    def init_weights(self):
        """Initialize weights."""
        # Lecun norm from ClassyVision
        kaiming_init(self.projection, mode='fan_in', nonlinearity='linear')

    def forward(self, x):
        """Defines the computation performed at every call.

        Args:
            x (Tensor): The input data.

        Returns:
            Tensor: The output of the module.
        """
        x = rearrange(x, 'b c t h w -> (b t) c h w')
        x = self.projection(x).flatten(2).transpose(1, 2)
        return x


[docs]@MODELS.register_module() class TimeSformer(nn.Module): """TimeSformer. A PyTorch impl of `Is Space-Time Attention All You Need for Video Understanding? <https://arxiv.org/abs/2102.05095>`_ Args: num_frames (int): Number of frames in the video. img_size (int | tuple): Size of input image. patch_size (int): Size of one patch. pretrained (str | None): Name of pretrained model. Default: None. embed_dims (int): Dimensions of embedding. Defaults to 768. num_heads (int): Number of parallel attention heads in TransformerCoder. Defaults to 12. num_transformer_layers (int): Number of transformer layers. Defaults to 12. in_channels (int): Channel num of input features. Defaults to 3. dropout_ratio (float): Probability of dropout layer. Defaults to 0.. transformer_layers (list[obj:`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None): Config of transformerlayer in TransformerCoder. If it is obj:`mmcv.ConfigDict`, it would be repeated `num_transformer_layers` times to a list[obj:`mmcv.ConfigDict`]. Defaults to None. attention_type (str): Type of attentions in TransformerCoder. Choices are 'divided_space_time', 'space_only' and 'joint_space_time'. Defaults to 'divided_space_time'. norm_cfg (dict): Config for norm layers. Defaults to `dict(type='LN', eps=1e-6)`. """ supported_attention_types = [ 'divided_space_time', 'space_only', 'joint_space_time' ] def __init__(self, num_frames, img_size, patch_size, pretrained=None, embed_dims=768, num_heads=12, num_transformer_layers=12, in_channels=3, dropout_ratio=0., transformer_layers=None, attention_type='divided_space_time', norm_cfg=dict(type='LN', eps=1e-6), **kwargs): super().__init__(**kwargs) assert attention_type in self.supported_attention_types, ( f'Unsupported Attention Type {attention_type}!') assert transformer_layers is None or isinstance( transformer_layers, (dict, list)) self.num_frames = num_frames self.pretrained = pretrained self.embed_dims = embed_dims self.num_transformer_layers = num_transformer_layers self.attention_type = attention_type self.patch_embed = PatchEmbed( img_size=img_size, patch_size=patch_size, in_channels=in_channels, embed_dims=embed_dims) num_patches = self.patch_embed.num_patches self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims)) self.pos_embed = nn.Parameter( torch.zeros(1, num_patches + 1, embed_dims)) self.drop_after_pos = nn.Dropout(p=dropout_ratio) if self.attention_type != 'space_only': self.time_embed = nn.Parameter( torch.zeros(1, num_frames, embed_dims)) self.drop_after_time = nn.Dropout(p=dropout_ratio) self.norm = build_norm_layer(norm_cfg, embed_dims)[1] if transformer_layers is None: # stochastic depth decay rule dpr = np.linspace(0, 0.1, num_transformer_layers) if self.attention_type == 'divided_space_time': _transformerlayers_cfg = [ dict( type='BaseTransformerLayer', attn_cfgs=[ dict( type='DividedTemporalAttentionWithNorm', embed_dims=embed_dims, num_heads=num_heads, num_frames=num_frames, dropout_layer=dict( type='DropPath', drop_prob=dpr[i]), norm_cfg=dict(type='LN', eps=1e-6)), dict( type='DividedSpatialAttentionWithNorm', embed_dims=embed_dims, num_heads=num_heads, num_frames=num_frames, dropout_layer=dict( type='DropPath', drop_prob=dpr[i]), norm_cfg=dict(type='LN', eps=1e-6)) ], ffn_cfgs=dict( type='FFNWithNorm', embed_dims=embed_dims, feedforward_channels=embed_dims * 4, num_fcs=2, act_cfg=dict(type='GELU'), dropout_layer=dict( type='DropPath', drop_prob=dpr[i]), norm_cfg=dict(type='LN', eps=1e-6)), operation_order=('self_attn', 'self_attn', 'ffn')) for i in range(num_transformer_layers) ] else: # Sapce Only & Joint Space Time _transformerlayers_cfg = [ dict( type='BaseTransformerLayer', attn_cfgs=[ dict( type='MultiheadAttention', embed_dims=embed_dims, num_heads=num_heads, batch_first=True, dropout_layer=dict( type='DropPath', drop_prob=dpr[i])) ], ffn_cfgs=dict( type='FFN', embed_dims=embed_dims, feedforward_channels=embed_dims * 4, num_fcs=2, act_cfg=dict(type='GELU'), dropout_layer=dict( type='DropPath', drop_prob=dpr[i])), operation_order=('norm', 'self_attn', 'norm', 'ffn'), norm_cfg=dict(type='LN', eps=1e-6), batch_first=True) for i in range(num_transformer_layers) ] transformer_layers = ConfigDict( dict( type='TransformerLayerSequence', transformerlayers=_transformerlayers_cfg, num_layers=num_transformer_layers)) self.transformer_layers = build_transformer_layer_sequence( transformer_layers)
[docs] def init_weights(self, pretrained=None): """Initiate the parameters either from existing checkpoint or from scratch.""" trunc_normal_(self.pos_embed, std=.02) trunc_normal_(self.cls_token, std=.02) if pretrained: self.pretrained = pretrained if isinstance(self.pretrained, str): logger = MMLogger.get_current_instance() logger.info(f'load model from: {self.pretrained}') state_dict = _load_checkpoint(self.pretrained, map_location='cpu') if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] if self.attention_type == 'divided_space_time': # modify the key names of norm layers old_state_dict_keys = list(state_dict.keys()) for old_key in old_state_dict_keys: if 'norms' in old_key: new_key = old_key.replace('norms.0', 'attentions.0.norm') new_key = new_key.replace('norms.1', 'ffns.0.norm') state_dict[new_key] = state_dict.pop(old_key) # copy the parameters of space attention to time attention old_state_dict_keys = list(state_dict.keys()) for old_key in old_state_dict_keys: if 'attentions.0' in old_key: new_key = old_key.replace('attentions.0', 'attentions.1') state_dict[new_key] = state_dict[old_key].clone() load_state_dict(self, state_dict, strict=False, logger=logger)
[docs] def forward(self, x): """Defines the computation performed at every call.""" # x [batch_size * num_frames, num_patches, embed_dims] batches = x.shape[0] x = self.patch_embed(x) # x [batch_size * num_frames, num_patches + 1, embed_dims] cls_tokens = self.cls_token.expand(x.size(0), -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed x = self.drop_after_pos(x) # Add Time Embedding if self.attention_type != 'space_only': # x [batch_size, num_patches * num_frames + 1, embed_dims] cls_tokens = x[:batches, 0, :].unsqueeze(1) x = rearrange(x[:, 1:, :], '(b t) p m -> (b p) t m', b=batches) x = x + self.time_embed x = self.drop_after_time(x) x = rearrange(x, '(b p) t m -> b (p t) m', b=batches) x = torch.cat((cls_tokens, x), dim=1) x = self.transformer_layers(x, None, None) if self.attention_type == 'space_only': # x [batch_size, num_patches + 1, embed_dims] x = x.view(-1, self.num_frames, *x.size()[-2:]) x = torch.mean(x, 1) x = self.norm(x) # Return Class Token return x[:, 0]