Shortcuts

Source code for mmaction.models.localizers.tcanet

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import torch
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN, MultiheadAttention
from mmengine.model import BaseModel
from torch import Tensor, nn

from mmaction.registry import MODELS
from mmaction.utils import OptConfigType
from .utils import (batch_iou, bbox_se_transform_batch, bbox_se_transform_inv,
                    bbox_xw_transform_batch, bbox_xw_transform_inv,
                    post_processing)


class LGTE(BaseModel):
    """Local-Global Temporal Encoder (LGTE)

    Args:
        input_dim (int): Input feature dimension.
        dropout (float): the dropout rate for the residual branch of
            self-attention and ffn.
        temporal_dim (int): Total frames selected for each video.
            Defaults to 100.
        window_size (int): the window size for Local Temporal Encoder.
            Defaults to 9.
        init_cfg (dict or ConfigDict, optional): The Config for
            initialization. Defaults to None.
    """

    def __init__(self,
                 input_dim: int,
                 dropout: float,
                 temporal_dim: int = 100,
                 window_size: int = 9,
                 num_heads: int = 8,
                 init_cfg: OptConfigType = None,
                 **kwargs) -> None:
        super(LGTE, self).__init__(init_cfg)

        self.atten = MultiheadAttention(
            embed_dims=input_dim,
            num_heads=num_heads,
            proj_drop=dropout,
            attn_drop=0.1)
        self.ffn = FFN(
            embed_dims=input_dim, feedforward_channels=256, ffn_drop=dropout)

        norm_cfg = dict(type='LN', eps=1e-6)
        self.norm1 = build_norm_layer(norm_cfg, input_dim)[1]
        self.norm2 = build_norm_layer(norm_cfg, input_dim)[1]

        mask = self._mask_matrix(num_heads, temporal_dim, window_size)
        self.register_buffer('mask', mask)

    def forward(self, x: Tensor) -> Tensor:
        """Forward call for LGTE.

        Args:
            x (torch.Tensor): The input tensor with shape (B, C, L)
        """
        x = x.permute(2, 0, 1)
        mask = self.mask.repeat(x.size(1), 1, 1, 1)
        L = x.shape[0]
        x = self.atten(x, attn_mask=mask.reshape(-1, L, L))
        x = self.norm1(x)
        x = self.ffn(x)
        x = self.norm2(x)
        x = x.permute(1, 2, 0)
        return x

    @staticmethod
    def _mask_matrix(num_heads: int, temporal_dim: int,
                     window_size: int) -> Tensor:
        mask = torch.zeros(num_heads, temporal_dim, temporal_dim)
        index = torch.arange(temporal_dim)

        for i in range(num_heads // 2):
            for j in range(temporal_dim):
                ignored = (index - j).abs() > window_size / 2
                mask[i, j] = ignored

        return mask.unsqueeze(0).bool()


def StartEndRegressor(sample_num: int, feat_dim: int) -> nn.Module:
    """Start and End Regressor in the Temporal Boundary Regressor.

    Args:
        sample_num (int): number of samples for the start & end.
        feat_dim (int): feature dimension.

    Returns:
        A pytorch module that works as the start and end regressor. The input
        of the module should have a shape of (B, feat_dim * 2, sample_num).
    """
    hidden_dim = 128
    regressor = nn.Sequential(
        nn.Conv1d(
            feat_dim * 2,
            hidden_dim * 2,
            kernel_size=3,
            padding=1,
            groups=8,
            stride=2), nn.ReLU(inplace=True),
        nn.Conv1d(
            hidden_dim * 2,
            hidden_dim * 2,
            kernel_size=3,
            padding=1,
            groups=8,
            stride=2), nn.ReLU(inplace=True),
        nn.Conv1d(hidden_dim * 2, 2, kernel_size=sample_num // 4, groups=2),
        nn.Flatten())
    return regressor


def CenterWidthRegressor(temporal_len: int, feat_dim: int) -> nn.Module:
    """Center Width in the Temporal Boundary Regressor.

    Args:
        temporal_len (int): temporal dimension of the inputs.
        feat_dim (int): feature dimension.

    Returns:
        A pytorch module that works as the start and end regressor. The input
        of the module should have a shape of (B, feat_dim, temporal_len).
    """
    hidden_dim = 512
    regressor = nn.Sequential(
        nn.Conv1d(
            feat_dim, hidden_dim, kernel_size=3, padding=1, groups=4,
            stride=2), nn.ReLU(inplace=True),
        nn.Conv1d(
            hidden_dim,
            hidden_dim,
            kernel_size=3,
            padding=1,
            groups=4,
            stride=2), nn.ReLU(inplace=True),
        nn.Conv1d(
            hidden_dim, hidden_dim, kernel_size=temporal_len // 4, groups=4),
        nn.ReLU(inplace=True), nn.Conv1d(hidden_dim, 3, kernel_size=1))
    return regressor


class TemporalTransform:
    """Temporal Transform to sample temporal features."""

    def __init__(self, prop_boundary_ratio: float, action_sample_num: int,
                 se_sample_num: int, temporal_interval: int):
        super(TemporalTransform, self).__init__()
        self.temporal_interval = temporal_interval
        self.prop_boundary_ratio = prop_boundary_ratio
        self.action_sample_num = action_sample_num
        self.se_sample_num = se_sample_num

    def __call__(self, segments: Tensor, features: Tensor) -> List[Tensor]:
        s_len = segments[:, 1] - segments[:, 0]
        starts_segments = [
            segments[:, 0] - self.prop_boundary_ratio * s_len, segments[:, 0]
        ]
        starts_segments = torch.stack(starts_segments, dim=1)

        ends_segments = [
            segments[:, 1], segments[:, 1] + self.prop_boundary_ratio * s_len
        ]
        ends_segments = torch.stack(ends_segments, dim=1)

        starts_feature = self._sample_one_temporal(starts_segments,
                                                   self.se_sample_num,
                                                   features)
        ends_feature = self._sample_one_temporal(ends_segments,
                                                 self.se_sample_num, features)
        actions_feature = self._sample_one_temporal(segments,
                                                    self.action_sample_num,
                                                    features)
        return starts_feature, actions_feature, ends_feature

    def _sample_one_temporal(self, segments: Tensor, out_len: int,
                             features: Tensor) -> Tensor:
        segments = segments.clamp(0, 1) * 2 - 1
        theta = segments.new_zeros((features.size(0), 2, 3))
        theta[:, 1, 1] = 1.0
        theta[:, 0, 0] = (segments[:, 1] - segments[:, 0]) / 2.0
        theta[:, 0, 2] = (segments[:, 1] + segments[:, 0]) / 2.0

        size = torch.Size((*features.shape[:2], 1, out_len))
        grid = F.affine_grid(theta, size)
        stn_feature = F.grid_sample(features.unsqueeze(2), grid)
        stn_feature = stn_feature.view(*features.shape[:2], out_len)
        return stn_feature


class TBR(BaseModel):
    """Temporal Boundary Regressor (TBR)"""

    def __init__(self,
                 se_sample_num: int,
                 action_sample_num: int,
                 temporal_dim: int,
                 prop_boundary_ratio: float = 0.5,
                 init_cfg: OptConfigType = None,
                 **kwargs) -> None:
        super(TBR, self).__init__(init_cfg)

        hidden_dim = 512

        self.reg1se = StartEndRegressor(se_sample_num, hidden_dim)
        temporal_len = se_sample_num * 2 + action_sample_num
        self.reg1xw = CenterWidthRegressor(temporal_len, hidden_dim)
        self.ttn = TemporalTransform(prop_boundary_ratio, action_sample_num,
                                     se_sample_num, temporal_dim)

    def forward(self, proposals: Tensor, features: Tensor, gt_boxes: Tensor,
                iou_thres: float, training: bool) -> tuple:
        proposals1 = proposals[:, :2]
        starts_feat1, actions_feat1, ends_feat1 = self.ttn(
            proposals1, features)

        reg1se = self.reg1se(torch.cat([starts_feat1, ends_feat1], dim=1))

        features1xw = torch.cat([starts_feat1, actions_feat1, ends_feat1],
                                dim=2)
        reg1xw = self.reg1xw(features1xw).squeeze(2)

        preds_iou1 = reg1xw[:, 2].sigmoid()
        reg1xw = reg1xw[:, :2]

        if training:
            proposals2xw = bbox_xw_transform_inv(proposals1, reg1xw, 0.1, 0.2)
            proposals2se = bbox_se_transform_inv(proposals1, reg1se, 1.0)

            iou1 = batch_iou(proposals1, gt_boxes)
            targets1se = bbox_se_transform_batch(proposals1, gt_boxes)
            targets1xw = bbox_xw_transform_batch(proposals1, gt_boxes)
            rloss1se = self.regress_loss(reg1se, targets1se, iou1, iou_thres)
            rloss1xw = self.regress_loss(reg1xw, targets1xw, iou1, iou_thres)
            rloss1 = rloss1se + rloss1xw
            iloss1 = self.iou_loss(preds_iou1, iou1, iou_thres=iou_thres)
        else:
            proposals2xw = bbox_xw_transform_inv(proposals1, reg1xw, 0.1, 0.2)
            proposals2se = bbox_se_transform_inv(proposals1, reg1se, 0.2)
            rloss1 = iloss1 = 0
        proposals2 = (proposals2se + proposals2xw) / 2.0
        proposals2 = torch.clamp(proposals2, min=0.)
        return preds_iou1, proposals2, rloss1, iloss1

    def regress_loss(self, regression, targets, iou_with_gt, iou_thres):
        weight = (iou_with_gt >= iou_thres).float().unsqueeze(1)
        reg_loss = F.smooth_l1_loss(regression, targets, reduction='none')
        if weight.sum() > 0:
            reg_loss = (weight * reg_loss).sum() / weight.sum()
        else:
            reg_loss = (weight * reg_loss).sum()
        return reg_loss

    def iou_loss(self, preds_iou, match_iou, iou_thres):
        preds_iou = preds_iou.view(-1)
        u_hmask = (match_iou > iou_thres).float()
        u_mmask = ((match_iou <= iou_thres) & (match_iou > 0.3)).float()
        u_lmask = (match_iou <= 0.3).float()

        num_h, num_m, num_l = u_hmask.sum(), u_mmask.sum(), u_lmask.sum()

        bs, device = u_hmask.size()[0], u_hmask.device

        r_m = min(num_h / num_m, 1)
        u_smmask = torch.rand(bs, device=device) * u_mmask
        u_smmask = (u_smmask > (1. - r_m)).float()

        r_l = min(num_h / num_l, 1)
        u_slmask = torch.rand(bs, device=device) * u_lmask
        u_slmask = (u_slmask > (1. - r_l)).float()

        iou_weights = u_hmask + u_smmask + u_slmask
        iou_loss = F.smooth_l1_loss(preds_iou, match_iou, reduction='none')
        if iou_weights.sum() > 0:
            iou_loss = (iou_loss * iou_weights).sum() / iou_weights.sum()
        else:
            iou_loss = (iou_loss * iou_weights).sum()
        return iou_loss


[docs]@MODELS.register_module() class TCANet(BaseModel): """Temporal Context Aggregation Network. Please refer `Temporal Context Aggregation Network for Temporal Action Proposal Refinement <https://arxiv.org/abs/2103.13141>`_. Code Reference: https://github.com/qinzhi-0110/Temporal-Context-Aggregation-Network-Pytorch """ def __init__(self, feat_dim: int = 2304, se_sample_num: int = 32, action_sample_num: int = 64, temporal_dim: int = 100, window_size: int = 9, lgte_num: int = 2, soft_nms_alpha: float = 0.4, soft_nms_low_threshold: float = 0.0, soft_nms_high_threshold: float = 0.0, post_process_top_k: int = 100, feature_extraction_interval: int = 16, init_cfg: OptConfigType = None, **kwargs) -> None: super(TCANet, self).__init__(init_cfg) self.soft_nms_alpha = soft_nms_alpha self.soft_nms_low_threshold = soft_nms_low_threshold self.soft_nms_high_threshold = soft_nms_high_threshold self.feature_extraction_interval = feature_extraction_interval self.post_process_top_k = post_process_top_k hidden_dim = 512 self.x_1d_b_f = nn.Sequential( nn.Conv1d( feat_dim, hidden_dim, kernel_size=3, padding=1, groups=4), nn.ReLU(inplace=True), nn.Conv1d( hidden_dim, hidden_dim, kernel_size=3, padding=1, groups=4), nn.ReLU(inplace=True), ) for i in 1, 2, 3: tbr = TBR( se_sample_num=se_sample_num, action_sample_num=action_sample_num, temporal_dim=temporal_dim, init_cfg=init_cfg, **kwargs) setattr(self, f'tbr{i}', tbr) self.lgtes = nn.ModuleList([ LGTE( input_dim=hidden_dim, dropout=0.1, temporal_dim=temporal_dim, window_size=window_size, init_cfg=init_cfg, **kwargs) for i in range(lgte_num) ])
[docs] def forward(self, inputs, data_samples, mode, **kwargs): """The unified entry for a forward process in both training and test. The method should accept three modes: - ``tensor``: Forward the whole network and return tensor or tuple of tensor without any post-processing, same as a common nn.Module. - ``predict``: Forward and return the predictions, which are fully processed to a list of :obj:`ActionDataSample`. - ``loss``: Forward and return a dict of losses according to the given inputs and data samples. Note that this method doesn't handle neither back propagation nor optimizer updating, which are done in the :meth:`train_step`. Args: inputs (Tensor): The input tensor with shape (N, C, ...) in general. data_samples (List[:obj:`ActionDataSample`], optional): The annotation data of every samples. Defaults to None. mode (str): Return what kind of value. Defaults to ``tensor``. Returns: The return type depends on ``mode``. - If ``mode="tensor"``, return a tensor or a tuple of tensor. - If ``mode="predict"``, return a list of ``ActionDataSample``. - If ``mode="loss"``, return a dict of tensor. """ if not isinstance(input, Tensor): inputs = torch.stack(inputs) if mode == 'tensor': return self._forward(inputs, **kwargs) if mode == 'predict': return self.predict(inputs, data_samples, **kwargs) elif mode == 'loss': return self.loss(inputs, data_samples, **kwargs) else: raise RuntimeError(f'Invalid mode "{mode}". ' 'Only supports loss, predict and tensor mode')
def _forward(self, x): """Define the computation performed at every call. Args: x (torch.Tensor): The input data. Returns: torch.Tensor: The output of the module. """ x = self.x_1d_b_f(x) for layer in self.lgtes: x = layer(x) return x def loss(self, batch_inputs, batch_data_samples, **kwargs): features = self._forward(batch_inputs) proposals_ = [ sample.proposals['proposals'] for sample in batch_data_samples ] batch_size = len(proposals_) proposals_num = max([_.shape[0] for _ in proposals_]) proposals = torch.zeros((batch_size, proposals_num, 3), device=features.device) for i, proposal in enumerate(proposals_): proposals[i, :proposal.shape[0]] = proposal gt_boxes_ = [ sample.gt_instances['gt_bbox'] for sample in batch_data_samples ] gt_boxes = torch.zeros((batch_size, proposals_num, 2), device=features.device) for i, gt_box in enumerate(gt_boxes_): L = gt_box.shape[0] if L <= proposals_num: gt_boxes[i, :L] = gt_box else: random_index = torch.randperm(L)[:proposals_num] gt_boxes[i] = gt_box[random_index] for i in range(batch_size): proposals[i, :, 2] = i proposals = proposals.view(batch_size * proposals_num, 3) proposals_select = proposals[:, 0:2].sum(dim=1) > 0 proposals = proposals[proposals_select, :] features = features[proposals[:, 2].long()] gt_boxes = gt_boxes.view(batch_size * proposals_num, 2) gt_boxes = gt_boxes[proposals_select, :] _, proposals1, rloss1, iloss1 = self.tbr1(proposals, features, gt_boxes, 0.5, True) _, proposals2, rloss2, iloss2 = self.tbr2(proposals1, features, gt_boxes, 0.6, True) _, _, rloss3, iloss3 = self.tbr3(proposals2, features, gt_boxes, 0.7, True) loss_dict = dict( rloss1=rloss1, rloss2=rloss2, rloss3=rloss3, iloss1=iloss1, iloss2=iloss2, iloss3=iloss3) return loss_dict def predict(self, batch_inputs, batch_data_samples, **kwargs): features = self._forward(batch_inputs) proposals_ = [ sample.proposals['proposals'] for sample in batch_data_samples ] batch_size = len(proposals_) proposals_num = max([_.shape[0] for _ in proposals_]) proposals = torch.zeros((batch_size, proposals_num, 3), device=features.device) for i, proposal in enumerate(proposals_): proposals[i, :proposal.shape[0]] = proposal scores = proposals[:, :, 2] for i in range(batch_size): proposals[i, :, 2] = i proposals = proposals.view(batch_size * proposals_num, 3) proposals_select = proposals[:, 0:2].sum(dim=1) > 0 proposals = proposals[proposals_select, :] scores = scores.view(-1)[proposals_select] features = features[proposals[:, 2].long()] preds_iou1, proposals1 = self.tbr1(proposals, features, None, 0.5, False)[:2] preds_iou2, proposals2 = self.tbr2(proposals1, features, None, 0.6, False)[:2] preds_iou3, proposals3 = self.tbr3(proposals2, features, None, 0.7, False)[:2] all_proposals = [] # all_proposals = [proposals] all_proposals += [ torch.cat([proposals1, (scores * preds_iou1).view(-1, 1)], dim=1) ] all_proposals += [ torch.cat([proposals2, (scores * preds_iou2).view(-1, 1)], dim=1) ] all_proposals += [ torch.cat([proposals3, (scores * preds_iou3).view(-1, 1)], dim=1) ] all_proposals = torch.cat(all_proposals, dim=0).cpu().numpy() video_info = batch_data_samples[0].metainfo proposal_list = post_processing(all_proposals, video_info, self.soft_nms_alpha, self.soft_nms_low_threshold, self.soft_nms_high_threshold, self.post_process_top_k, self.feature_extraction_interval) output = [ dict( video_name=video_info['video_name'], proposal_list=proposal_list) ] return output