Source code for mmaction.models.localizers.tcanet
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmengine.model import BaseModel
from torch import Tensor, nn
from mmaction.registry import MODELS
from mmaction.utils import OptConfigType
from .utils import (batch_iou, bbox_se_transform_batch, bbox_se_transform_inv,
                    bbox_xw_transform_batch, bbox_xw_transform_inv,
                    post_processing)
class LGTE(BaseModel):
    """Local-Global Temporal Encoder (LGTE)
    Args:
        input_dim (int): Input feature dimension.
        dropout (float): the dropout rate for the residual branch of
            self-attention and ffn.
        temporal_dim (int): Total frames selected for each video.
            Defaults to 100.
        window_size (int): the window size for Local Temporal Encoder.
            Defaults to 9.
        init_cfg (dict or ConfigDict, optional): The Config for
            initialization. Defaults to None.
    """
    def __init__(self,
                 input_dim: int,
                 dropout: float,
                 temporal_dim: int = 100,
                 window_size: int = 9,
                 num_heads: int = 8,
                 init_cfg: OptConfigType = None,
                 **kwargs) -> None:
        super(LGTE, self).__init__(init_cfg)
        self.atten = MultiheadAttention(
            embed_dims=input_dim,
            num_heads=num_heads,
            proj_drop=dropout,
            attn_drop=0.1)
        self.ffn = FFN(
            embed_dims=input_dim, feedforward_channels=256, ffn_drop=dropout)
        norm_cfg = dict(type='LN', eps=1e-6)
        self.norm1 = build_norm_layer(norm_cfg, input_dim)[1]
        self.norm2 = build_norm_layer(norm_cfg, input_dim)[1]
        mask = self._mask_matrix(num_heads, temporal_dim, window_size)
        self.register_buffer('mask', mask)
    def forward(self, x: Tensor) -> Tensor:
        """Forward call for LGTE.
        Args:
            x (torch.Tensor): The input tensor with shape (B, C, L)
        """
        x = x.permute(2, 0, 1)
        mask = self.mask.repeat(x.size(1), 1, 1, 1)
        L = x.shape[0]
        x = self.atten(x, attn_mask=mask.reshape(-1, L, L))
        x = self.norm1(x)
        x = self.ffn(x)
        x = self.norm2(x)
        x = x.permute(1, 2, 0)
        return x
    @staticmethod
    def _mask_matrix(num_heads: int, temporal_dim: int,
                     window_size: int) -> Tensor:
        mask = torch.zeros(num_heads, temporal_dim, temporal_dim)
        index = torch.arange(temporal_dim)
        for i in range(num_heads // 2):
            for j in range(temporal_dim):
                ignored = (index - j).abs() > window_size / 2
                mask[i, j] = ignored
        return mask.unsqueeze(0).bool()
def StartEndRegressor(sample_num: int, feat_dim: int) -> nn.Module:
    """Start and End Regressor in the Temporal Boundary Regressor.
    Args:
        sample_num (int): number of samples for the start & end.
        feat_dim (int): feature dimension.
    Returns:
        A pytorch module that works as the start and end regressor. The input
        of the module should have a shape of (B, feat_dim * 2, sample_num).
    """
    hidden_dim = 128
    regressor = nn.Sequential(
        nn.Conv1d(
            feat_dim * 2,
            hidden_dim * 2,
            kernel_size=3,
            padding=1,
            groups=8,
            stride=2), nn.ReLU(inplace=True),
        nn.Conv1d(
            hidden_dim * 2,
            hidden_dim * 2,
            kernel_size=3,
            padding=1,
            groups=8,
            stride=2), nn.ReLU(inplace=True),
        nn.Conv1d(hidden_dim * 2, 2, kernel_size=sample_num // 4, groups=2),
        nn.Flatten())
    return regressor
def CenterWidthRegressor(temporal_len: int, feat_dim: int) -> nn.Module:
    """Center Width in the Temporal Boundary Regressor.
    Args:
        temporal_len (int): temporal dimension of the inputs.
        feat_dim (int): feature dimension.
    Returns:
        A pytorch module that works as the start and end regressor. The input
        of the module should have a shape of (B, feat_dim, temporal_len).
    """
    hidden_dim = 512
    regressor = nn.Sequential(
        nn.Conv1d(
            feat_dim, hidden_dim, kernel_size=3, padding=1, groups=4,
            stride=2), nn.ReLU(inplace=True),
        nn.Conv1d(
            hidden_dim,
            hidden_dim,
            kernel_size=3,
            padding=1,
            groups=4,
            stride=2), nn.ReLU(inplace=True),
        nn.Conv1d(
            hidden_dim, hidden_dim, kernel_size=temporal_len // 4, groups=4),
        nn.ReLU(inplace=True), nn.Conv1d(hidden_dim, 3, kernel_size=1))
    return regressor
class TemporalTransform:
    """Temporal Transform to sample temporal features."""
    def __init__(self, prop_boundary_ratio: float, action_sample_num: int,
                 se_sample_num: int, temporal_interval: int):
        super(TemporalTransform, self).__init__()
        self.temporal_interval = temporal_interval
        self.prop_boundary_ratio = prop_boundary_ratio
        self.action_sample_num = action_sample_num
        self.se_sample_num = se_sample_num
    def __call__(self, segments: Tensor, features: Tensor) -> List[Tensor]:
        s_len = segments[:, 1] - segments[:, 0]
        starts_segments = [
            segments[:, 0] - self.prop_boundary_ratio * s_len, segments[:, 0]
        ]
        starts_segments = torch.stack(starts_segments, dim=1)
        ends_segments = [
            segments[:, 1], segments[:, 1] + self.prop_boundary_ratio * s_len
        ]
        ends_segments = torch.stack(ends_segments, dim=1)
        starts_feature = self._sample_one_temporal(starts_segments,
                                                   self.se_sample_num,
                                                   features)
        ends_feature = self._sample_one_temporal(ends_segments,
                                                 self.se_sample_num, features)
        actions_feature = self._sample_one_temporal(segments,
                                                    self.action_sample_num,
                                                    features)
        return starts_feature, actions_feature, ends_feature
    def _sample_one_temporal(self, segments: Tensor, out_len: int,
                             features: Tensor) -> Tensor:
        segments = segments.clamp(0, 1) * 2 - 1
        theta = segments.new_zeros((features.size(0), 2, 3))
        theta[:, 1, 1] = 1.0
        theta[:, 0, 0] = (segments[:, 1] - segments[:, 0]) / 2.0
        theta[:, 0, 2] = (segments[:, 1] + segments[:, 0]) / 2.0
        size = torch.Size((*features.shape[:2], 1, out_len))
        grid = F.affine_grid(theta, size)
        stn_feature = F.grid_sample(features.unsqueeze(2), grid)
        stn_feature = stn_feature.view(*features.shape[:2], out_len)
        return stn_feature
class TBR(BaseModel):
    """Temporal Boundary Regressor (TBR)"""
    def __init__(self,
                 se_sample_num: int,
                 action_sample_num: int,
                 temporal_dim: int,
                 prop_boundary_ratio: float = 0.5,
                 init_cfg: OptConfigType = None,
                 **kwargs) -> None:
        super(TBR, self).__init__(init_cfg)
        hidden_dim = 512
        self.reg1se = StartEndRegressor(se_sample_num, hidden_dim)
        temporal_len = se_sample_num * 2 + action_sample_num
        self.reg1xw = CenterWidthRegressor(temporal_len, hidden_dim)
        self.ttn = TemporalTransform(prop_boundary_ratio, action_sample_num,
                                     se_sample_num, temporal_dim)
    def forward(self, proposals: Tensor, features: Tensor, gt_boxes: Tensor,
                iou_thres: float, training: bool) -> tuple:
        proposals1 = proposals[:, :2]
        starts_feat1, actions_feat1, ends_feat1 = self.ttn(
            proposals1, features)
        reg1se = self.reg1se(torch.cat([starts_feat1, ends_feat1], dim=1))
        features1xw = torch.cat([starts_feat1, actions_feat1, ends_feat1],
                                dim=2)
        reg1xw = self.reg1xw(features1xw).squeeze(2)
        preds_iou1 = reg1xw[:, 2].sigmoid()
        reg1xw = reg1xw[:, :2]
        if training:
            proposals2xw = bbox_xw_transform_inv(proposals1, reg1xw, 0.1, 0.2)
            proposals2se = bbox_se_transform_inv(proposals1, reg1se, 1.0)
            iou1 = batch_iou(proposals1, gt_boxes)
            targets1se = bbox_se_transform_batch(proposals1, gt_boxes)
            targets1xw = bbox_xw_transform_batch(proposals1, gt_boxes)
            rloss1se = self.regress_loss(reg1se, targets1se, iou1, iou_thres)
            rloss1xw = self.regress_loss(reg1xw, targets1xw, iou1, iou_thres)
            rloss1 = rloss1se + rloss1xw
            iloss1 = self.iou_loss(preds_iou1, iou1, iou_thres=iou_thres)
        else:
            proposals2xw = bbox_xw_transform_inv(proposals1, reg1xw, 0.1, 0.2)
            proposals2se = bbox_se_transform_inv(proposals1, reg1se, 0.2)
            rloss1 = iloss1 = 0
        proposals2 = (proposals2se + proposals2xw) / 2.0
        proposals2 = torch.clamp(proposals2, min=0.)
        return preds_iou1, proposals2, rloss1, iloss1
    def regress_loss(self, regression, targets, iou_with_gt, iou_thres):
        weight = (iou_with_gt >= iou_thres).float().unsqueeze(1)
        reg_loss = F.smooth_l1_loss(regression, targets, reduction='none')
        if weight.sum() > 0:
            reg_loss = (weight * reg_loss).sum() / weight.sum()
        else:
            reg_loss = (weight * reg_loss).sum()
        return reg_loss
    def iou_loss(self, preds_iou, match_iou, iou_thres):
        preds_iou = preds_iou.view(-1)
        u_hmask = (match_iou > iou_thres).float()
        u_mmask = ((match_iou <= iou_thres) & (match_iou > 0.3)).float()
        u_lmask = (match_iou <= 0.3).float()
        num_h, num_m, num_l = u_hmask.sum(), u_mmask.sum(), u_lmask.sum()
        bs, device = u_hmask.size()[0], u_hmask.device
        r_m = min(num_h / num_m, 1)
        u_smmask = torch.rand(bs, device=device) * u_mmask
        u_smmask = (u_smmask > (1. - r_m)).float()
        r_l = min(num_h / num_l, 1)
        u_slmask = torch.rand(bs, device=device) * u_lmask
        u_slmask = (u_slmask > (1. - r_l)).float()
        iou_weights = u_hmask + u_smmask + u_slmask
        iou_loss = F.smooth_l1_loss(preds_iou, match_iou, reduction='none')
        if iou_weights.sum() > 0:
            iou_loss = (iou_loss * iou_weights).sum() / iou_weights.sum()
        else:
            iou_loss = (iou_loss * iou_weights).sum()
        return iou_loss
[docs]@MODELS.register_module()
class TCANet(BaseModel):
    """Temporal Context Aggregation Network.
    Please refer `Temporal Context Aggregation Network for Temporal Action
    Proposal Refinement <https://arxiv.org/abs/2103.13141>`_.
    Code Reference:
    https://github.com/qinzhi-0110/Temporal-Context-Aggregation-Network-Pytorch
    """
    def __init__(self,
                 feat_dim: int = 2304,
                 se_sample_num: int = 32,
                 action_sample_num: int = 64,
                 temporal_dim: int = 100,
                 window_size: int = 9,
                 lgte_num: int = 2,
                 soft_nms_alpha: float = 0.4,
                 soft_nms_low_threshold: float = 0.0,
                 soft_nms_high_threshold: float = 0.0,
                 post_process_top_k: int = 100,
                 feature_extraction_interval: int = 16,
                 init_cfg: OptConfigType = None,
                 **kwargs) -> None:
        super(TCANet, self).__init__(init_cfg)
        self.soft_nms_alpha = soft_nms_alpha
        self.soft_nms_low_threshold = soft_nms_low_threshold
        self.soft_nms_high_threshold = soft_nms_high_threshold
        self.feature_extraction_interval = feature_extraction_interval
        self.post_process_top_k = post_process_top_k
        hidden_dim = 512
        self.x_1d_b_f = nn.Sequential(
            nn.Conv1d(
                feat_dim, hidden_dim, kernel_size=3, padding=1, groups=4),
            nn.ReLU(inplace=True),
            nn.Conv1d(
                hidden_dim, hidden_dim, kernel_size=3, padding=1, groups=4),
            nn.ReLU(inplace=True),
        )
        for i in 1, 2, 3:
            tbr = TBR(
                se_sample_num=se_sample_num,
                action_sample_num=action_sample_num,
                temporal_dim=temporal_dim,
                init_cfg=init_cfg,
                **kwargs)
            setattr(self, f'tbr{i}', tbr)
        self.lgtes = nn.ModuleList([
            LGTE(
                input_dim=hidden_dim,
                dropout=0.1,
                temporal_dim=temporal_dim,
                window_size=window_size,
                init_cfg=init_cfg,
                **kwargs) for i in range(lgte_num)
        ])
[docs]    def forward(self, inputs, data_samples, mode, **kwargs):
        """The unified entry for a forward process in both training and test.
        The method should accept three modes:
        - ``tensor``: Forward the whole network and return tensor or tuple of
        tensor without any post-processing, same as a common nn.Module.
        - ``predict``: Forward and return the predictions, which are fully
        processed to a list of :obj:`ActionDataSample`.
        - ``loss``: Forward and return a dict of losses according to the given
        inputs and data samples.
        Note that this method doesn't handle neither back propagation nor
        optimizer updating, which are done in the :meth:`train_step`.
        Args:
            inputs (Tensor): The input tensor with shape
                (N, C, ...) in general.
            data_samples (List[:obj:`ActionDataSample`], optional): The
                annotation data of every samples. Defaults to None.
            mode (str): Return what kind of value. Defaults to ``tensor``.
        Returns:
            The return type depends on ``mode``.
            - If ``mode="tensor"``, return a tensor or a tuple of tensor.
            - If ``mode="predict"``, return a list of ``ActionDataSample``.
            - If ``mode="loss"``, return a dict of tensor.
        """
        if not isinstance(input, Tensor):
            inputs = torch.stack(inputs)
        if mode == 'tensor':
            return self._forward(inputs, **kwargs)
        if mode == 'predict':
            return self.predict(inputs, data_samples, **kwargs)
        elif mode == 'loss':
            return self.loss(inputs, data_samples, **kwargs)
        else:
            raise RuntimeError(f'Invalid mode "{mode}". '
                               'Only supports loss, predict and tensor mode')
    def _forward(self, x):
        """Define the computation performed at every call.
        Args:
            x (torch.Tensor): The input data.
        Returns:
            torch.Tensor: The output of the module.
        """
        x = self.x_1d_b_f(x)
        for layer in self.lgtes:
            x = layer(x)
        return x
    def loss(self, batch_inputs, batch_data_samples, **kwargs):
        features = self._forward(batch_inputs)
        proposals_ = [
            sample.proposals['proposals'] for sample in batch_data_samples
        ]
        batch_size = len(proposals_)
        proposals_num = max([_.shape[0] for _ in proposals_])
        proposals = torch.zeros((batch_size, proposals_num, 3),
                                device=features.device)
        for i, proposal in enumerate(proposals_):
            proposals[i, :proposal.shape[0]] = proposal
        gt_boxes_ = [
            sample.gt_instances['gt_bbox'] for sample in batch_data_samples
        ]
        gt_boxes = torch.zeros((batch_size, proposals_num, 2),
                               device=features.device)
        for i, gt_box in enumerate(gt_boxes_):
            L = gt_box.shape[0]
            if L <= proposals_num:
                gt_boxes[i, :L] = gt_box
            else:
                random_index = torch.randperm(L)[:proposals_num]
                gt_boxes[i] = gt_box[random_index]
        for i in range(batch_size):
            proposals[i, :, 2] = i
        proposals = proposals.view(batch_size * proposals_num, 3)
        proposals_select = proposals[:, 0:2].sum(dim=1) > 0
        proposals = proposals[proposals_select, :]
        features = features[proposals[:, 2].long()]
        gt_boxes = gt_boxes.view(batch_size * proposals_num, 2)
        gt_boxes = gt_boxes[proposals_select, :]
        _, proposals1, rloss1, iloss1 = self.tbr1(proposals, features,
                                                  gt_boxes, 0.5, True)
        _, proposals2, rloss2, iloss2 = self.tbr2(proposals1, features,
                                                  gt_boxes, 0.6, True)
        _, _, rloss3, iloss3 = self.tbr3(proposals2, features, gt_boxes, 0.7,
                                         True)
        loss_dict = dict(
            rloss1=rloss1,
            rloss2=rloss2,
            rloss3=rloss3,
            iloss1=iloss1,
            iloss2=iloss2,
            iloss3=iloss3)
        return loss_dict
    def predict(self, batch_inputs, batch_data_samples, **kwargs):
        features = self._forward(batch_inputs)
        proposals_ = [
            sample.proposals['proposals'] for sample in batch_data_samples
        ]
        batch_size = len(proposals_)
        proposals_num = max([_.shape[0] for _ in proposals_])
        proposals = torch.zeros((batch_size, proposals_num, 3),
                                device=features.device)
        for i, proposal in enumerate(proposals_):
            proposals[i, :proposal.shape[0]] = proposal
        scores = proposals[:, :, 2]
        for i in range(batch_size):
            proposals[i, :, 2] = i
        proposals = proposals.view(batch_size * proposals_num, 3)
        proposals_select = proposals[:, 0:2].sum(dim=1) > 0
        proposals = proposals[proposals_select, :]
        scores = scores.view(-1)[proposals_select]
        features = features[proposals[:, 2].long()]
        preds_iou1, proposals1 = self.tbr1(proposals, features, None, 0.5,
                                           False)[:2]
        preds_iou2, proposals2 = self.tbr2(proposals1, features, None, 0.6,
                                           False)[:2]
        preds_iou3, proposals3 = self.tbr3(proposals2, features, None, 0.7,
                                           False)[:2]
        all_proposals = []
        # all_proposals = [proposals]
        all_proposals += [
            torch.cat([proposals1, (scores * preds_iou1).view(-1, 1)], dim=1)
        ]
        all_proposals += [
            torch.cat([proposals2, (scores * preds_iou2).view(-1, 1)], dim=1)
        ]
        all_proposals += [
            torch.cat([proposals3, (scores * preds_iou3).view(-1, 1)], dim=1)
        ]
        all_proposals = torch.cat(all_proposals, dim=0).cpu().numpy()
        video_info = batch_data_samples[0].metainfo
        proposal_list = post_processing(all_proposals, video_info,
                                        self.soft_nms_alpha,
                                        self.soft_nms_low_threshold,
                                        self.soft_nms_high_threshold,
                                        self.post_process_top_k,
                                        self.feature_extraction_interval)
        output = [
            dict(
                video_name=video_info['video_name'],
                proposal_list=proposal_list)
        ]
        return output