Source code for mmaction.models.heads.feature_head
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import torch
from torch import Tensor
from mmaction.registry import MODELS
from .base import BaseHead
[docs]@MODELS.register_module()
class FeatureHead(BaseHead):
    """General head for feature extraction.
    Args:
        spatial_type (str, optional): Pooling type in spatial dimension.
            Default: 'avg'. If set to None, means keeping spatial dimension,
            and for GCN backbone, keeping last two dimension(T, V).
        temporal_type (str, optional): Pooling type in temporal dimension.
            Default: 'avg'. If set to None, meanse keeping temporal dimnsion,
            and for GCN backbone, keeping dimesion M. Please note that the
            channel order would keep same with the output of backbone,
            [N, T, C, H, W] for 2D recognizer, and [N, M, C, T, V] for GCN
            recognizer.
        backbone_name (str, optional): Backbone name to specifying special
            operations.Currently supports: `'tsm'`, `'slowfast'`, and `'gcn'`.
            Defaults to None, means take the input as normal feature.
        num_segments (int, optional): Number of frame segments for TSM
            backbone. Defaults to None.
        kwargs (dict, optional): Any keyword argument to be used to initialize
            the head.
    """
    def __init__(self,
                 spatial_type: str = 'avg',
                 temporal_type: str = 'avg',
                 backbone_name: Optional[str] = None,
                 num_segments: Optional[str] = None,
                 **kwargs) -> None:
        super().__init__(None, None, **kwargs)
        self.temporal_type = temporal_type
        self.backbone_name = backbone_name
        self.num_segments = num_segments
        if spatial_type == 'avg':
            self.pool2d = torch.mean
        elif spatial_type == 'max':
            self.pool2d = torch.max
        elif spatial_type is None:
            self.pool2d = lambda x, dim: x
        else:
            raise NotImplementedError(
                f'Unsupported spatial_type {spatial_type}')
        if temporal_type == 'avg':
            self.pool1d = torch.mean
        elif temporal_type == 'max':
            self.pool1d = torch.max
        elif temporal_type is None:
            self.pool1d = lambda x, dim: x
        else:
            raise NotImplementedError(
                f'Unsupported temporal_type {temporal_type}')
[docs]    def forward(self,
                x: Tensor,
                num_segs: Optional[int] = None,
                **kwargs) -> Tensor:
        """Defines the computation performed at every call.
        Args:
            x (Tensor): The input data.
            num_segs (int): For 2D backbone. Number of segments into which
                a video is divided. Defaults to None.
        Returns:
            Tensor: The output features after pooling.
        """
        if isinstance(x, Tensor):
            n_dims = x.ndim
        elif isinstance(x, tuple):
            n_dims = x[0].ndim
            assert self.backbone_name == 'slowfast', \
                'Only support SlowFast backbone to input tuple'
        else:
            raise NotImplementedError(f'Unsupported feature type: {type(x)}')
        # For 2D backbone with spatial dimension
        if n_dims == 4:
            assert num_segs is not None
            if self.backbone_name == 'tsm':
                assert self.num_segments is not None, \
                    'Please Specify num_segments for TSM'
                num_segs = self.num_segments
            # [N, T, channels, H, W]
            x = x.view((-1, num_segs) + x.shape[1:])
            feat = self.pool1d(self.pool2d(x, dim=[-2, -1]), dim=1)
        elif n_dims == 5:
            if self.backbone_name == 'slowfast':
                x_slow, x_fast = x
                assert self.temporal_type is not None, \
                    'slowfast backbone has to pool temporal dimension'
                x_fast = self.pool1d(self.pool2d(x_fast, dim=[-2, -1]), dim=2)
                x_slow = self.pool1d(self.pool2d(x_slow, dim=[-2, -1]), dim=2)
                feat = torch.cat((x_slow, x_fast), dim=1)
            # For GCN-based backbone
            elif self.backbone_name == 'gcn':
                # N, M, C, T, V
                feat = self.pool1d(self.pool2d(x, dim=[-2, -1]), dim=1)
            # For 3D backbone with spatial dimension
            else:
                # [N, channels, T, H, W]
                feat = self.pool1d(self.pool2d(x, dim=[-2, -1]), dim=2)
        # For backbone output feature without spatial and temporal dimension
        elif n_dims == 2:
            # [N, channels]
            feat = x
        return feat
[docs]    def predict_by_feat(self, feats: Union[Tensor, Tuple[Tensor]],
                        data_samples) -> Tensor:
        """Integrate multi-view features into one tensor.
        Args:
            feats (torch.Tensor | tuple[torch.Tensor]): Features from
                upstream network.
            data_samples (list[:obj:`ActionDataSample`]): The batch
                data samples.
        Returns:
            Tensor: The integrated multi-view features.
        """
        num_segs = feats.shape[0] // len(data_samples)
        feats = self.average_clip(feats, num_segs=num_segs)
        return feats