Source code for mmaction.models.losses.base
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
import torch.nn as nn
[docs]class BaseWeightedLoss(nn.Module, metaclass=ABCMeta):
    """Base class for loss.
    All subclass should overwrite the ``_forward()`` method which returns the
    normal loss without loss weights.
    Args:
        loss_weight (float): Factor scalar multiplied on the loss.
            Default: 1.0.
    """
    def __init__(self, loss_weight=1.0):
        super().__init__()
        self.loss_weight = loss_weight
    @abstractmethod
    def _forward(self, *args, **kwargs):
        """Forward function."""
        pass
[docs]    def forward(self, *args, **kwargs):
        """Defines the computation performed at every call.
        Args:
            *args: The positional arguments for the corresponding
                loss.
            **kwargs: The keyword arguments for the corresponding
                loss.
        Returns:
            torch.Tensor: The calculated loss.
        """
        ret = self._forward(*args, **kwargs)
        if isinstance(ret, dict):
            for k in ret:
                if 'loss' in k:
                    ret[k] *= self.loss_weight
        else:
            ret *= self.loss_weight
        return ret