
Source code for mmaction.models.losses.base

# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod

import torch.nn as nn

[docs]class BaseWeightedLoss(nn.Module, metaclass=ABCMeta): """Base class for loss. All subclass should overwrite the ``_forward()`` method which returns the normal loss without loss weights. Args: loss_weight (float): Factor scalar multiplied on the loss. Default: 1.0. """ def __init__(self, loss_weight=1.0): super().__init__() self.loss_weight = loss_weight @abstractmethod def _forward(self, *args, **kwargs): """Forward function.""" pass
[docs] def forward(self, *args, **kwargs): """Defines the computation performed at every call. Args: *args: The positional arguments for the corresponding loss. **kwargs: The keyword arguments for the corresponding loss. Returns: torch.Tensor: The calculated loss. """ ret = self._forward(*args, **kwargs) if isinstance(ret, dict): for k in ret: if 'loss' in k: ret[k] *= self.loss_weight else: ret *= self.loss_weight return ret