Shortcuts

Source code for mmaction.models.losses.cross_entropy_loss

# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional

import numpy as np
import torch
import torch.nn.functional as F

from mmaction.registry import MODELS
from .base import BaseWeightedLoss


[docs]@MODELS.register_module() class CrossEntropyLoss(BaseWeightedLoss): """Cross Entropy Loss. Support two kinds of labels and their corresponding loss type. It's worth mentioning that loss type will be detected by the shape of ``cls_score`` and ``label``. 1) Hard label: This label is an integer array and all of the elements are in the range [0, num_classes - 1]. This label's shape should be ``cls_score``'s shape with the `num_classes` dimension removed. 2) Soft label(probability distribution over classes): This label is a probability distribution and all of the elements are in the range [0, 1]. This label's shape must be the same as ``cls_score``. For now, only 2-dim soft label is supported. Args: loss_weight (float): Factor scalar multiplied on the loss. Defaults to 1.0. class_weight (list[float] | None): Loss weight for each class. If set as None, use the same weight 1 for all classes. Only applies to CrossEntropyLoss and BCELossWithLogits (should not be set when using other losses). Defaults to None. """ def __init__(self, loss_weight: float = 1.0, class_weight: Optional[List[float]] = None) -> None: super().__init__(loss_weight=loss_weight) self.class_weight = None if class_weight is not None: self.class_weight = torch.Tensor(class_weight) def _forward(self, cls_score: torch.Tensor, label: torch.Tensor, **kwargs) -> torch.Tensor: """Forward function. Args: cls_score (torch.Tensor): The class score. label (torch.Tensor): The ground truth label. kwargs: Any keyword argument to be used to calculate CrossEntropy loss. Returns: torch.Tensor: The returned CrossEntropy loss. """ if cls_score.size() == label.size(): # calculate loss for soft label assert cls_score.dim() == 2, 'Only support 2-dim soft label' assert len(kwargs) == 0, \ ('For now, no extra args are supported for soft label, ' f'but get {kwargs}') lsm = F.log_softmax(cls_score, 1) if self.class_weight is not None: self.class_weight = self.class_weight.to(cls_score.device) lsm = lsm * self.class_weight.unsqueeze(0) loss_cls = -(label * lsm).sum(1) # default reduction 'mean' if self.class_weight is not None: # Use weighted average as pytorch CrossEntropyLoss does. # For more information, please visit https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html # noqa loss_cls = loss_cls.sum() / torch.sum( self.class_weight.unsqueeze(0) * label) else: loss_cls = loss_cls.mean() else: # calculate loss for hard label if self.class_weight is not None: assert 'weight' not in kwargs, \ "The key 'weight' already exists." kwargs['weight'] = self.class_weight.to(cls_score.device) loss_cls = F.cross_entropy(cls_score, label, **kwargs) return loss_cls
[docs]@MODELS.register_module() class BCELossWithLogits(BaseWeightedLoss): """Binary Cross Entropy Loss with logits. Args: loss_weight (float): Factor scalar multiplied on the loss. Defaults to 1.0. class_weight (list[float] | None): Loss weight for each class. If set as None, use the same weight 1 for all classes. Only applies to CrossEntropyLoss and BCELossWithLogits (should not be set when using other losses). Defaults to None. """ def __init__(self, loss_weight: float = 1.0, class_weight: Optional[List[float]] = None) -> None: super().__init__(loss_weight=loss_weight) self.class_weight = None if class_weight is not None: self.class_weight = torch.Tensor(class_weight) def _forward(self, cls_score: torch.Tensor, label: torch.Tensor, **kwargs) -> torch.Tensor: """Forward function. Args: cls_score (torch.Tensor): The class score. label (torch.Tensor): The ground truth label. kwargs: Any keyword argument to be used to calculate bce loss with logits. Returns: torch.Tensor: The returned bce loss with logits. """ if self.class_weight is not None: assert 'weight' not in kwargs, "The key 'weight' already exists." kwargs['weight'] = self.class_weight.to(cls_score.device) loss_cls = F.binary_cross_entropy_with_logits(cls_score, label, **kwargs) return loss_cls
[docs]@MODELS.register_module() class CBFocalLoss(BaseWeightedLoss): """Class Balanced Focal Loss. Adapted from https://github.com/abhinanda- punnakkal/BABEL/. This loss is used in the skeleton-based action recognition baseline for BABEL. Args: loss_weight (float): Factor scalar multiplied on the loss. Defaults to 1.0. samples_per_cls (list[int]): The number of samples per class. Defaults to []. beta (float): Hyperparameter that controls the per class loss weight. Defaults to 0.9999. gamma (float): Hyperparameter of the focal loss. Defaults to 2.0. """ def __init__(self, loss_weight: float = 1.0, samples_per_cls: List[int] = [], beta: float = 0.9999, gamma: float = 2.) -> None: super().__init__(loss_weight=loss_weight) self.samples_per_cls = samples_per_cls self.beta = beta self.gamma = gamma effective_num = 1.0 - np.power(beta, samples_per_cls) weights = (1.0 - beta) / np.array(effective_num) weights = weights / np.sum(weights) * len(weights) self.weights = weights self.num_classes = len(weights) def _forward(self, cls_score: torch.Tensor, label: torch.Tensor, **kwargs) -> torch.Tensor: """Forward function. Args: cls_score (torch.Tensor): The class score. label (torch.Tensor): The ground truth label. kwargs: Any keyword argument to be used to calculate bce loss with logits. Returns: torch.Tensor: The returned bce loss with logits. """ weights = torch.tensor(self.weights).float().to(cls_score.device) label_one_hot = F.one_hot(label, self.num_classes).float() weights = weights.unsqueeze(0) weights = weights.repeat(label_one_hot.shape[0], 1) * label_one_hot weights = weights.sum(1) weights = weights.unsqueeze(1) weights = weights.repeat(1, self.num_classes) BCELoss = F.binary_cross_entropy_with_logits( input=cls_score, target=label_one_hot, reduction='none') modulator = 1.0 if self.gamma: modulator = torch.exp(-self.gamma * label_one_hot * cls_score - self.gamma * torch.log(1 + torch.exp(-1.0 * cls_score))) loss = modulator * BCELoss weighted_loss = weights * loss focal_loss = torch.sum(weighted_loss) focal_loss /= torch.sum(label_one_hot) return focal_loss