Source code for mmaction.models.losses.hvu_loss
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn.functional as F
from mmengine.device import get_device
from mmaction.registry import MODELS
from .base import BaseWeightedLoss
[docs]@MODELS.register_module()
class HVULoss(BaseWeightedLoss):
"""Calculate the BCELoss for HVU.
Args:
categories (tuple[str]): Names of tag categories, tags are organized in
this order. Default: ['action', 'attribute', 'concept', 'event',
'object', 'scene'].
category_nums (tuple[int]): Number of tags for each category. Default:
(739, 117, 291, 69, 1678, 248).
category_loss_weights (tuple[float]): Loss weights of categories, it
applies only if `loss_type == 'individual'`. The loss weights will
be normalized so that the sum equals to 1, so that you can give any
positive number as loss weight. Default: (1, 1, 1, 1, 1, 1).
loss_type (str): The loss type we calculate, we can either calculate
the BCELoss for all tags, or calculate the BCELoss for tags in each
category. Choices are 'individual' or 'all'. Default: 'all'.
with_mask (bool): Since some tag categories are missing for some video
clips. If `with_mask == True`, we will not calculate loss for these
missing categories. Otherwise, these missing categories are treated
as negative samples.
reduction (str): Reduction way. Choices are 'mean' or 'sum'. Default:
'mean'.
loss_weight (float): The loss weight. Default: 1.0.
"""
def __init__(self,
categories=('action', 'attribute', 'concept', 'event',
'object', 'scene'),
category_nums=(739, 117, 291, 69, 1678, 248),
category_loss_weights=(1, 1, 1, 1, 1, 1),
loss_type='all',
with_mask=False,
reduction='mean',
loss_weight=1.0):
super().__init__(loss_weight)
self.categories = categories
self.category_nums = category_nums
self.category_loss_weights = category_loss_weights
assert len(self.category_nums) == len(self.category_loss_weights)
for category_loss_weight in self.category_loss_weights:
assert category_loss_weight >= 0
self.loss_type = loss_type
self.with_mask = with_mask
self.reduction = reduction
self.category_startidx = [0]
for i in range(len(self.category_nums) - 1):
self.category_startidx.append(self.category_startidx[-1] +
self.category_nums[i])
assert self.loss_type in ['individual', 'all']
assert self.reduction in ['mean', 'sum']
def _forward(self, cls_score, label, mask, category_mask):
"""Forward function.
Args:
cls_score (torch.Tensor): The class score.
label (torch.Tensor): The ground truth label.
mask (torch.Tensor): The mask of tags. 0 indicates that the
category of this tag is missing in the label of the video.
category_mask (torch.Tensor): The category mask. For each sample,
it's a tensor with length `len(self.categories)`, denotes that
if the category is labeled for this video.
Returns:
torch.Tensor: The returned CrossEntropy loss.
"""
if self.loss_type == 'all':
loss_cls = F.binary_cross_entropy_with_logits(
cls_score, label, reduction='none')
if self.with_mask:
w_loss_cls = mask * loss_cls
w_loss_cls = torch.sum(w_loss_cls, dim=1)
if self.reduction == 'mean':
w_loss_cls = w_loss_cls / torch.sum(mask, dim=1)
w_loss_cls = torch.mean(w_loss_cls)
return dict(loss_cls=w_loss_cls)
if self.reduction == 'sum':
loss_cls = torch.sum(loss_cls, dim=-1)
return dict(loss_cls=torch.mean(loss_cls))
if self.loss_type == 'individual':
losses = {}
loss_weights = {}
for name, num, start_idx in zip(self.categories,
self.category_nums,
self.category_startidx):
category_score = cls_score[:, start_idx:start_idx + num]
category_label = label[:, start_idx:start_idx + num]
category_loss = F.binary_cross_entropy_with_logits(
category_score, category_label, reduction='none')
if self.reduction == 'mean':
category_loss = torch.mean(category_loss, dim=1)
elif self.reduction == 'sum':
category_loss = torch.sum(category_loss, dim=1)
idx = self.categories.index(name)
if self.with_mask:
category_mask_i = category_mask[:, idx].reshape(-1)
# there should be at least one sample which contains tags
# in this category
if torch.sum(category_mask_i) < 0.5:
losses[f'{name}_LOSS'] = torch.tensor(
.0, device=get_device())
loss_weights[f'{name}_LOSS'] = .0
continue
category_loss = torch.sum(category_loss * category_mask_i)
category_loss = category_loss / torch.sum(category_mask_i)
else:
category_loss = torch.mean(category_loss)
# We name the loss of each category as 'LOSS', since we only
# want to monitor them, not backward them. We will also provide
# the loss used for backward in the losses dictionary
losses[f'{name}_LOSS'] = category_loss
loss_weights[f'{name}_LOSS'] = self.category_loss_weights[idx]
loss_weight_sum = sum(loss_weights.values())
loss_weights = {
k: v / loss_weight_sum
for k, v in loss_weights.items()
}
loss_cls = sum([losses[k] * loss_weights[k] for k in losses])
losses['loss_cls'] = loss_cls
# We also trace the loss weights
losses.update({
k + '_weight': torch.tensor(v).to(losses[k].device)
for k, v in loss_weights.items()
})
# Note that the loss weights are just for reference.
return losses
else:
raise ValueError("loss_type should be 'all' or 'individual', "
f'but got {self.loss_type}')