Shortcuts

Source code for mmaction.models.losses.ohem_hinge_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch


[docs]class OHEMHingeLoss(torch.autograd.Function): """This class is the core implementation for the completeness loss in paper. It compute class-wise hinge loss and performs online hard example mining (OHEM). """
[docs] @staticmethod def forward(ctx, pred, labels, is_positive, ohem_ratio, group_size): """Calculate OHEM hinge loss. Args: pred (torch.Tensor): Predicted completeness score. labels (torch.Tensor): Groundtruth class label. is_positive (int): Set to 1 when proposals are positive and set to -1 when proposals are incomplete. ohem_ratio (float): Ratio of hard examples. group_size (int): Number of proposals sampled per video. Returns: torch.Tensor: Returned class-wise hinge loss. """ num_samples = pred.size(0) if num_samples != len(labels): raise ValueError(f'Number of samples should be equal to that ' f'of labels, but got {num_samples} samples and ' f'{len(labels)} labels.') losses = torch.zeros(num_samples, device=pred.device) slopes = torch.zeros(num_samples, device=pred.device) for i in range(num_samples): losses[i] = max(0, 1 - is_positive * pred[i, labels[i] - 1]) slopes[i] = -is_positive if losses[i] != 0 else 0 losses = losses.view(-1, group_size).contiguous() sorted_losses, indices = torch.sort(losses, dim=1, descending=True) keep_length = int(group_size * ohem_ratio) loss = torch.zeros(1, device=pred.device) for i in range(losses.size(0)): loss += sorted_losses[i, :keep_length].sum() ctx.loss_index = indices[:, :keep_length] ctx.labels = labels ctx.slopes = slopes ctx.shape = pred.size() ctx.group_size = group_size ctx.num_groups = losses.size(0) return loss
[docs] @staticmethod def backward(ctx, grad_output): """Defines a formula for differentiating the operation with backward mode automatic differentiation.""" labels = ctx.labels slopes = ctx.slopes grad_in = torch.zeros(ctx.shape, device=ctx.slopes.device) for group in range(ctx.num_groups): for idx in ctx.loss_index[group]: loc = idx + group * ctx.group_size grad_in[loc, labels[loc] - 1] = ( slopes[loc] * grad_output.data[0]) return torch.autograd.Variable(grad_in), None, None, None, None