Shortcuts

Source code for mmaction.models.losses.nll_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn.functional as F

from mmaction.registry import MODELS
from .base import BaseWeightedLoss


[docs]@MODELS.register_module() class NLLLoss(BaseWeightedLoss): """NLL Loss. It will calculate NLL loss given cls_score and label. """ def _forward(self, cls_score, label, **kwargs): """Forward function. Args: cls_score (torch.Tensor): The class score. label (torch.Tensor): The ground truth label. kwargs: Any keyword argument to be used to calculate nll loss. Returns: torch.Tensor: The returned nll loss. """ loss_cls = F.nll_loss(cls_score, label, **kwargs) return loss_cls