Source code for mmaction.models.losses.ssn_loss
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmaction.registry import MODELS
from .ohem_hinge_loss import OHEMHingeLoss
[docs]@MODELS.register_module()
class SSNLoss(nn.Module):
[docs] @staticmethod
def activity_loss(activity_score, labels, activity_indexer):
"""Activity Loss.
It will calculate activity loss given activity_score and label.
Args:
activity_score (torch.Tensor): Predicted activity score.
labels (torch.Tensor): Groundtruth class label.
activity_indexer (torch.Tensor): Index slices of proposals.
Returns:
torch.Tensor: Returned cross entropy loss.
"""
pred = activity_score[activity_indexer, :]
gt = labels[activity_indexer]
return F.cross_entropy(pred, gt)
[docs] @staticmethod
def completeness_loss(completeness_score,
labels,
completeness_indexer,
positive_per_video,
incomplete_per_video,
ohem_ratio=0.17):
"""Completeness Loss.
It will calculate completeness loss given completeness_score and label.
Args:
completeness_score (torch.Tensor): Predicted completeness score.
labels (torch.Tensor): Groundtruth class label.
completeness_indexer (torch.Tensor): Index slices of positive and
incomplete proposals.
positive_per_video (int): Number of positive proposals sampled
per video.
incomplete_per_video (int): Number of incomplete proposals sampled
pre video.
ohem_ratio (float): Ratio of online hard example mining.
Default: 0.17.
Returns:
torch.Tensor: Returned class-wise completeness loss.
"""
pred = completeness_score[completeness_indexer, :]
gt = labels[completeness_indexer]
pred_dim = pred.size(1)
pred = pred.view(-1, positive_per_video + incomplete_per_video,
pred_dim)
gt = gt.view(-1, positive_per_video + incomplete_per_video)
# yapf:disable
positive_pred = pred[:, :positive_per_video, :].contiguous().view(-1, pred_dim) # noqa:E501
incomplete_pred = pred[:, positive_per_video:, :].contiguous().view(-1, pred_dim) # noqa:E501
# yapf:enable
positive_loss = OHEMHingeLoss.apply(
positive_pred, gt[:, :positive_per_video].contiguous().view(-1), 1,
1.0, positive_per_video)
incomplete_loss = OHEMHingeLoss.apply(
incomplete_pred, gt[:, positive_per_video:].contiguous().view(-1),
-1, ohem_ratio, incomplete_per_video)
num_positives = positive_pred.size(0)
num_incompletes = int(incomplete_pred.size(0) * ohem_ratio)
return ((positive_loss + incomplete_loss) /
float(num_positives + num_incompletes))
[docs] @staticmethod
def classwise_regression_loss(bbox_pred, labels, bbox_targets,
regression_indexer):
"""Classwise Regression Loss.
It will calculate classwise_regression loss given
class_reg_pred and targets.
Args:
bbox_pred (torch.Tensor): Predicted interval center and span
of positive proposals.
labels (torch.Tensor): Groundtruth class label.
bbox_targets (torch.Tensor): Groundtruth center and span
of positive proposals.
regression_indexer (torch.Tensor): Index slices of
positive proposals.
Returns:
torch.Tensor: Returned class-wise regression loss.
"""
pred = bbox_pred[regression_indexer, :, :]
gt = labels[regression_indexer]
reg_target = bbox_targets[regression_indexer, :]
class_idx = gt.data - 1
classwise_pred = pred[:, class_idx, :]
classwise_reg_pred = torch.cat(
(torch.diag(classwise_pred[:, :, 0]).view(
-1, 1), torch.diag(classwise_pred[:, :, 1]).view(-1, 1)),
dim=1)
loss = F.smooth_l1_loss(
classwise_reg_pred.view(-1), reg_target.view(-1)) * 2
return loss
[docs] def forward(self, activity_score, completeness_score, bbox_pred,
proposal_type, labels, bbox_targets, train_cfg):
"""Calculate Boundary Matching Network Loss.
Args:
activity_score (torch.Tensor): Predicted activity score.
completeness_score (torch.Tensor): Predicted completeness score.
bbox_pred (torch.Tensor): Predicted interval center and span
of positive proposals.
proposal_type (torch.Tensor): Type index slices of proposals.
labels (torch.Tensor): Groundtruth class label.
bbox_targets (torch.Tensor): Groundtruth center and span
of positive proposals.
train_cfg (dict): Config for training.
Returns:
dict([torch.Tensor, torch.Tensor, torch.Tensor]):
(loss_activity, loss_completeness, loss_reg).
Loss_activity is the activity loss, loss_completeness is
the class-wise completeness loss,
loss_reg is the class-wise regression loss.
"""
self.sampler = train_cfg.ssn.sampler
self.loss_weight = train_cfg.ssn.loss_weight
losses = dict()
proposal_type = proposal_type.view(-1)
labels = labels.view(-1)
activity_indexer = ((proposal_type == 0) +
(proposal_type == 2)).nonzero().squeeze(1)
completeness_indexer = ((proposal_type == 0) +
(proposal_type == 1)).nonzero().squeeze(1)
total_ratio = (
self.sampler.positive_ratio + self.sampler.background_ratio +
self.sampler.incomplete_ratio)
positive_per_video = int(self.sampler.num_per_video *
(self.sampler.positive_ratio / total_ratio))
background_per_video = int(
self.sampler.num_per_video *
(self.sampler.background_ratio / total_ratio))
incomplete_per_video = (
self.sampler.num_per_video - positive_per_video -
background_per_video)
losses['loss_activity'] = self.activity_loss(activity_score, labels,
activity_indexer)
losses['loss_completeness'] = self.completeness_loss(
completeness_score,
labels,
completeness_indexer,
positive_per_video,
incomplete_per_video,
ohem_ratio=positive_per_video / incomplete_per_video)
losses['loss_completeness'] *= self.loss_weight.comp_loss_weight
if bbox_pred is not None:
regression_indexer = (proposal_type == 0).nonzero().squeeze(1)
bbox_targets = bbox_targets.view(-1, 2)
losses['loss_reg'] = self.classwise_regression_loss(
bbox_pred, labels, bbox_targets, regression_indexer)
losses['loss_reg'] *= self.loss_weight.reg_loss_weight
return losses