Source code for dlordinal.losses.beta_loss

from typing import Dict, List, Optional, Union

import torch
from deprecated.sphinx import deprecated
from torch import Tensor
from torch.nn import CrossEntropyLoss, Module

from dlordinal.soft_labelling import get_beta_soft_labels

from .custom_targets_loss import CustomTargetsLoss


[docs] class BetaLoss(CustomTargetsLoss): """ Beta-regularized loss, as proposed in :footcite:t:`vargas2022unimodal`. This loss function applies a regularization term based on the Beta distribution to penalize the distance between predicted and true class distributions. It extends the `CustomTargetsLoss` by incorporating a Beta distribution for soft labelling. Parameters ---------- base_loss : torch.nn.Module The base loss function. It must accept `y_true` as a probability distribution (e.g., soft labels or one-hot encoded labels). The base loss is applied between the predicted logits (`y_pred`) and the adjusted target labels (`y_true`). num_classes : int Number of classes. params_set : str or dict[int, list], default="standard" The set of parameters of the beta distributions employed to generate the soft labels. It can be one of the keys in the ``_beta_params_sets`` dictionary. Alternatively, it can be a dictionary with the same structure as the items of the ``_beta_params_sets`` dictionary. The keys of the dictionary must be the number of classes and the values must be a list of lists with the parameters of the beta distributions for each class. The list for each class must have three parameters :math:`[p,q,a]` where :math:`p` and :math:`q` are the shape parameters of the beta distribution and :math:`a` is the scaling parameter. Example: ``{3: [[1, 4, 1], [4, 4, 1], [4, 1, 1]]}`` for three classes with the parameters :math:`[1,4,1]`, :math:`[4,4,1]` and :math:`[4,1,1]` for each class respectively. eta : float, default=1.0 A regularization parameter that controls the balance between the base loss and the regularization term. A value of 0 means no regularization, while a value of 1 means the Beta regularization term fully influences the target labels. Example ------- >>> import torch >>> from dlordinal.losses import BetaLoss >>> from torch.nn import CrossEntropyLoss >>> num_classes = 5 >>> base_loss = CrossEntropyLoss() >>> loss = BetaLoss(base_loss, num_classes) >>> input = torch.randn(3, num_classes) >>> target = torch.randint(0, num_classes, (3,)) >>> output = loss(input, target) """ def __init__( self, base_loss: Module, num_classes: int, params_set: Union[str, Dict[int, List]] = "standard", eta: float = 1.0, ): # Precompute class probabilities for each label cls_probs = torch.tensor(get_beta_soft_labels(num_classes, params_set)).float() super().__init__( base_loss=base_loss, cls_probs=cls_probs, eta=eta, ) forward = CustomTargetsLoss.forward
# TODO: remove in 3.0.0
[docs] @deprecated( version="2.4.0", reason="Use BetaLoss instead with CrossEntropyLoss as base_loss. Will be removed in 3.0.0.", category=DeprecationWarning, ) class BetaCrossEntropyLoss(BetaLoss): def __init__( self, num_classes: int, params_set: str = "standard", eta: float = 1.0, weight: Optional[Tensor] = None, size_average=None, ignore_index: int = -100, reduce=None, reduction: str = "mean", ): base_loss = CrossEntropyLoss( weight=weight, size_average=size_average, ignore_index=ignore_index, reduce=reduce, reduction=reduction, ) super().__init__( base_loss=base_loss, num_classes=num_classes, params_set=params_set, eta=eta, )