Shortcuts

Source code for mmaction.models.common.transformer

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from einops import rearrange
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, build_dropout
from mmengine.model import BaseModule
from mmengine.model.weight_init import constant_init
from mmengine.utils import digit_version

from mmaction.registry import MODELS


[docs]@MODELS.register_module() class DividedTemporalAttentionWithNorm(BaseModule): """Temporal Attention in Divided Space Time Attention. Args: embed_dims (int): Dimensions of embedding. num_heads (int): Number of parallel attention heads in TransformerCoder. num_frames (int): Number of frames in the video. attn_drop (float): A Dropout layer on attn_output_weights. Defaults to 0.. proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. Defaults to 0.. dropout_layer (dict): The dropout_layer used when adding the shortcut. Defaults to `dict(type='DropPath', drop_prob=0.1)`. norm_cfg (dict): Config dict for normalization layer. Defaults to `dict(type='LN')`. init_cfg (dict | None): The Config for initialization. Defaults to None. """ def __init__(self, embed_dims, num_heads, num_frames, attn_drop=0., proj_drop=0., dropout_layer=dict(type='DropPath', drop_prob=0.1), norm_cfg=dict(type='LN'), init_cfg=None, **kwargs): super().__init__(init_cfg) self.embed_dims = embed_dims self.num_heads = num_heads self.num_frames = num_frames self.norm = build_norm_layer(norm_cfg, self.embed_dims)[1] if digit_version(torch.__version__) < digit_version('1.9.0'): kwargs.pop('batch_first', None) self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, **kwargs) self.proj_drop = nn.Dropout(proj_drop) self.dropout_layer = build_dropout( dropout_layer) if dropout_layer else nn.Identity() self.temporal_fc = nn.Linear(self.embed_dims, self.embed_dims) self.init_weights()
[docs] def init_weights(self): """Initialize weights.""" constant_init(self.temporal_fc, val=0, bias=0)
[docs] def forward(self, query, key=None, value=None, residual=None, **kwargs): """Defines the computation performed at every call.""" assert residual is None, ( 'Always adding the shortcut in the forward function') init_cls_token = query[:, 0, :].unsqueeze(1) identity = query_t = query[:, 1:, :] # query_t [batch_size, num_patches * num_frames, embed_dims] b, pt, m = query_t.size() p, t = pt // self.num_frames, self.num_frames # res_temporal [batch_size * num_patches, num_frames, embed_dims] query_t = self.norm(query_t.reshape(b * p, t, m)).permute(1, 0, 2) res_temporal = self.attn(query_t, query_t, query_t)[0].permute(1, 0, 2) res_temporal = self.dropout_layer( self.proj_drop(res_temporal.contiguous())) res_temporal = self.temporal_fc(res_temporal) # res_temporal [batch_size, num_patches * num_frames, embed_dims] res_temporal = res_temporal.reshape(b, p * t, m) # ret_value [batch_size, num_patches * num_frames + 1, embed_dims] new_query_t = identity + res_temporal new_query = torch.cat((init_cls_token, new_query_t), 1) return new_query
[docs]@MODELS.register_module() class DividedSpatialAttentionWithNorm(BaseModule): """Spatial Attention in Divided Space Time Attention. Args: embed_dims (int): Dimensions of embedding. num_heads (int): Number of parallel attention heads in TransformerCoder. num_frames (int): Number of frames in the video. attn_drop (float): A Dropout layer on attn_output_weights. Defaults to 0.. proj_drop (float): A Dropout layer after `nn.MultiheadAttention`. Defaults to 0.. dropout_layer (dict): The dropout_layer used when adding the shortcut. Defaults to `dict(type='DropPath', drop_prob=0.1)`. norm_cfg (dict): Config dict for normalization layer. Defaults to `dict(type='LN')`. init_cfg (dict | None): The Config for initialization. Defaults to None. """ def __init__(self, embed_dims, num_heads, num_frames, attn_drop=0., proj_drop=0., dropout_layer=dict(type='DropPath', drop_prob=0.1), norm_cfg=dict(type='LN'), init_cfg=None, **kwargs): super().__init__(init_cfg) self.embed_dims = embed_dims self.num_heads = num_heads self.num_frames = num_frames self.norm = build_norm_layer(norm_cfg, self.embed_dims)[1] if digit_version(torch.__version__) < digit_version('1.9.0'): kwargs.pop('batch_first', None) self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop, **kwargs) self.proj_drop = nn.Dropout(proj_drop) self.dropout_layer = build_dropout( dropout_layer) if dropout_layer else nn.Identity() self.init_weights()
[docs] def init_weights(self): """init DividedSpatialAttentionWithNorm by default.""" pass
[docs] def forward(self, query, key=None, value=None, residual=None, **kwargs): """Defines the computation performed at every call.""" assert residual is None, ( 'Always adding the shortcut in the forward function') identity = query init_cls_token = query[:, 0, :].unsqueeze(1) query_s = query[:, 1:, :] # query_s [batch_size, num_patches * num_frames, embed_dims] b, pt, m = query_s.size() p, t = pt // self.num_frames, self.num_frames # cls_token [batch_size * num_frames, 1, embed_dims] cls_token = init_cls_token.repeat(1, t, 1).reshape(b * t, m).unsqueeze(1) # query_s [batch_size * num_frames, num_patches + 1, embed_dims] query_s = rearrange(query_s, 'b (p t) m -> (b t) p m', p=p, t=t) query_s = torch.cat((cls_token, query_s), 1) # res_spatial [batch_size * num_frames, num_patches + 1, embed_dims] query_s = self.norm(query_s).permute(1, 0, 2) res_spatial = self.attn(query_s, query_s, query_s)[0].permute(1, 0, 2) res_spatial = self.dropout_layer( self.proj_drop(res_spatial.contiguous())) # cls_token [batch_size, 1, embed_dims] cls_token = res_spatial[:, 0, :].reshape(b, t, m) cls_token = torch.mean(cls_token, 1, True) # res_spatial [batch_size * num_frames, num_patches + 1, embed_dims] res_spatial = rearrange( res_spatial[:, 1:, :], '(b t) p m -> b (p t) m', p=p, t=t) res_spatial = torch.cat((cls_token, res_spatial), 1) new_query = identity + res_spatial return new_query
[docs]@MODELS.register_module() class FFNWithNorm(FFN): """FFN with pre normalization layer. FFNWithNorm is implemented to be compatible with `BaseTransformerLayer` when using `DividedTemporalAttentionWithNorm` and `DividedSpatialAttentionWithNorm`. FFNWithNorm has one main difference with FFN: - It apply one normalization layer before forwarding the input data to feed-forward networks. Args: embed_dims (int): Dimensions of embedding. Defaults to 256. feedforward_channels (int): Hidden dimension of FFNs. Defaults to 1024. num_fcs (int, optional): Number of fully-connected layers in FFNs. Defaults to 2. act_cfg (dict): Config for activate layers. Defaults to `dict(type='ReLU')` ffn_drop (float, optional): Probability of an element to be zeroed in FFN. Defaults to 0.. add_residual (bool, optional): Whether to add the residual connection. Defaults to `True`. dropout_layer (dict | None): The dropout_layer used when adding the shortcut. Defaults to None. init_cfg (dict): The Config for initialization. Defaults to None. norm_cfg (dict): Config dict for normalization layer. Defaults to `dict(type='LN')`. """ def __init__(self, *args, norm_cfg=dict(type='LN'), **kwargs): super().__init__(*args, **kwargs) self.norm = build_norm_layer(norm_cfg, self.embed_dims)[1]
[docs] def forward(self, x, residual=None): """Defines the computation performed at every call.""" assert residual is None, ('Cannot apply pre-norm with FFNWithNorm') return super().forward(self.norm(x), x)