Source code for mmaction.models.losses.ohem_hinge_loss
# Copyright (c) OpenMMLab. All rights reserved.
import torch
[docs]class OHEMHingeLoss(torch.autograd.Function):
"""This class is the core implementation for the completeness loss in
paper.
It compute class-wise hinge loss and performs online hard example mining
(OHEM).
"""
[docs] @staticmethod
def forward(ctx, pred, labels, is_positive, ohem_ratio, group_size):
"""Calculate OHEM hinge loss.
Args:
pred (torch.Tensor): Predicted completeness score.
labels (torch.Tensor): Groundtruth class label.
is_positive (int): Set to 1 when proposals are positive and
set to -1 when proposals are incomplete.
ohem_ratio (float): Ratio of hard examples.
group_size (int): Number of proposals sampled per video.
Returns:
torch.Tensor: Returned class-wise hinge loss.
"""
num_samples = pred.size(0)
if num_samples != len(labels):
raise ValueError(f'Number of samples should be equal to that '
f'of labels, but got {num_samples} samples and '
f'{len(labels)} labels.')
losses = torch.zeros(num_samples, device=pred.device)
slopes = torch.zeros(num_samples, device=pred.device)
for i in range(num_samples):
losses[i] = max(0, 1 - is_positive * pred[i, labels[i] - 1])
slopes[i] = -is_positive if losses[i] != 0 else 0
losses = losses.view(-1, group_size).contiguous()
sorted_losses, indices = torch.sort(losses, dim=1, descending=True)
keep_length = int(group_size * ohem_ratio)
loss = torch.zeros(1, device=pred.device)
for i in range(losses.size(0)):
loss += sorted_losses[i, :keep_length].sum()
ctx.loss_index = indices[:, :keep_length]
ctx.labels = labels
ctx.slopes = slopes
ctx.shape = pred.size()
ctx.group_size = group_size
ctx.num_groups = losses.size(0)
return loss
[docs] @staticmethod
def backward(ctx, grad_output):
"""Defines a formula for differentiating the operation with backward
mode automatic differentiation."""
labels = ctx.labels
slopes = ctx.slopes
grad_in = torch.zeros(ctx.shape, device=ctx.slopes.device)
for group in range(ctx.num_groups):
for idx in ctx.loss_index[group]:
loc = idx + group * ctx.group_size
grad_in[loc, labels[loc] - 1] = (
slopes[loc] * grad_output.data[0])
return torch.autograd.Variable(grad_in), None, None, None, None