Shortcuts

Source code for mmaction.models.heads.feature_head

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union

import torch
from torch import Tensor

from mmaction.registry import MODELS
from .base import BaseHead


[docs]@MODELS.register_module() class FeatureHead(BaseHead): """General head for feature extraction. Args: spatial_type (str, optional): Pooling type in spatial dimension. Default: 'avg'. If set to None, means keeping spatial dimension, and for GCN backbone, keeping last two dimension(T, V). temporal_type (str, optional): Pooling type in temporal dimension. Default: 'avg'. If set to None, meanse keeping temporal dimnsion, and for GCN backbone, keeping dimesion M. Please note that the channel order would keep same with the output of backbone, [N, T, C, H, W] for 2D recognizer, and [N, M, C, T, V] for GCN recognizer. backbone_name (str, optional): Backbone name to specifying special operations.Currently supports: `'tsm'`, `'slowfast'`, and `'gcn'`. Defaults to None, means take the input as normal feature. num_segments (int, optional): Number of frame segments for TSM backbone. Defaults to None. kwargs (dict, optional): Any keyword argument to be used to initialize the head. """ def __init__(self, spatial_type: str = 'avg', temporal_type: str = 'avg', backbone_name: Optional[str] = None, num_segments: Optional[str] = None, **kwargs) -> None: super().__init__(None, None, **kwargs) self.temporal_type = temporal_type self.backbone_name = backbone_name self.num_segments = num_segments if spatial_type == 'avg': self.pool2d = torch.mean elif spatial_type == 'max': self.pool2d = torch.max elif spatial_type is None: self.pool2d = lambda x, dim: x else: raise NotImplementedError( f'Unsupported spatial_type {spatial_type}') if temporal_type == 'avg': self.pool1d = torch.mean elif temporal_type == 'max': self.pool1d = torch.max elif temporal_type is None: self.pool1d = lambda x, dim: x else: raise NotImplementedError( f'Unsupported temporal_type {temporal_type}')
[docs] def forward(self, x: Tensor, num_segs: Optional[int] = None, **kwargs) -> Tensor: """Defines the computation performed at every call. Args: x (Tensor): The input data. num_segs (int): For 2D backbone. Number of segments into which a video is divided. Defaults to None. Returns: Tensor: The output features after pooling. """ if isinstance(x, Tensor): n_dims = x.ndim elif isinstance(x, tuple): n_dims = x[0].ndim assert self.backbone_name == 'slowfast', \ 'Only support SlowFast backbone to input tuple' else: raise NotImplementedError(f'Unsupported feature type: {type(x)}') # For 2D backbone with spatial dimension if n_dims == 4: assert num_segs is not None if self.backbone_name == 'tsm': assert self.num_segments is not None, \ 'Please Specify num_segments for TSM' num_segs = self.num_segments # [N, T, channels, H, W] x = x.view((-1, num_segs) + x.shape[1:]) feat = self.pool1d(self.pool2d(x, dim=[-2, -1]), dim=1) elif n_dims == 5: if self.backbone_name == 'slowfast': x_slow, x_fast = x assert self.temporal_type is not None, \ 'slowfast backbone has to pool temporal dimension' x_fast = self.pool1d(self.pool2d(x_fast, dim=[-2, -1]), dim=2) x_slow = self.pool1d(self.pool2d(x_slow, dim=[-2, -1]), dim=2) feat = torch.cat((x_slow, x_fast), dim=1) # For GCN-based backbone elif self.backbone_name == 'gcn': # N, M, C, T, V feat = self.pool1d(self.pool2d(x, dim=[-2, -1]), dim=1) # For 3D backbone with spatial dimension else: # [N, channels, T, H, W] feat = self.pool1d(self.pool2d(x, dim=[-2, -1]), dim=2) # For backbone output feature without spatial and temporal dimension elif n_dims == 2: # [N, channels] feat = x return feat
[docs] def predict_by_feat(self, feats: Union[Tensor, Tuple[Tensor]], data_samples) -> Tensor: """Integrate multi-view features into one tensor. Args: feats (torch.Tensor | tuple[torch.Tensor]): Features from upstream network. data_samples (list[:obj:`ActionDataSample`]): The batch data samples. Returns: Tensor: The integrated multi-view features. """ num_segs = feats.shape[0] // len(data_samples) feats = self.average_clip(feats, num_segs=num_segs) return feats