Source code for dlordinal.losses.wkloss

from typing import Optional

import torch
import torch.nn as nn


[docs] class WKLoss(nn.Module): """ Implements Weighted Kappa Loss, introduced by :footcite:t:`deLaTorre2018kappa` and modified by :footcite:t:`vargas2020clm`. Weighted Kappa is widely used in ordinal classification problems. In its original proposal, the loss values lie in :math:`[-\\infty, \\log 2]`, whereas in the version proposed by :footcite:t:`vargas2020clm` the range is :math:`[0, 2]`. Following the definition of :footcite:t:`vargas2020clm`, the loss is computed as follows: .. math:: \\mathcal{L}(X, \\mathbf{y}) = \\frac{\\sum\\limits_{i=1}^J \\sum\\limits_{j=1}^J \\omega_{i,j} \\sum\\limits_{k=1}^N q_{k,i} ~ p_{y_k,j}} {\\frac{1}{N}\\sum\\limits_{i=1}^J \\sum\\limits_{j=1}^J \\omega_{i,j} \\left( \\sum\\limits_{k=1}^N q_{k,i} \\right) \\left( \\sum\\limits_{k=1}^N p_{y_k, j} \\right)} where :math:`q_{k,j}` denotes the normalised predicted probability, computed as: .. math:: q_{k,j} = \\frac{\\text{P}(\\text{y} = j ~|~ \\mathbf{x}_k)} {\\sum\\limits_{i=1}^J \\text{P}(\\text{y} = i ~|~ \\mathbf{x}_k)}, :math:`p_{y_k,j}` is the :math:`j`-th element of the one-hot encoded true label for sample :math:`k`, and :math:`\\omega` is the penalisation matrix, defined either linearly or quadratically. Its elements are: - Linear: :math:`\\omega_{i,j} = \\frac{|i - j|}{J - 1}` - Quadratic: :math:`\\omega_{i,j} = \\frac{(i - j)^2}{(J - 1)^2}` When considering the original definition of Weighted Kappa, the loss can be defined as follows: .. math:: \\mathcal{L}(X, \\mathbf{y}) = \\log\\left( \\frac{\\sum\\limits_{i=1}^J \\sum\\limits_{j=1}^J \\omega_{i,j} \\sum\\limits_{k=1}^N q_{k,i} ~ p_{y_k,j}} {\\frac{1}{N}\\sum\\limits_{i=1}^J \\sum\\limits_{j=1}^J \\omega_{i,j} \\left( \\sum\\limits_{k=1}^N q_{k,i} \\right) \\left( \\sum\\limits_{k=1}^N p_{y_k, j} \\right)} \\right) The parameter `use_logarithm` can be set to `True` to use this version of the loss. The numerical instability caused by the logarithm is mitigated by adding a small value `epsilon` to the denominator. Parameters ---------- num_classes : int The number of unique classes in your dataset. penalization_type : str, default='quadratic' The penalization method for calculating the Kappa statistics. Valid options are ``['linear', 'quadratic']``. Defaults to 'quadratic'. epsilon : float, default=1e-10 Small value added to the denominator division by zero. weight : Optional[torch.Tensor], default=None Class weights to apply during loss computation. Should be a tensor of size `(num_classes,)`. If `None`, equal weight is given to all classes. use_logits : bool, default=False If `True`, the `input` is treated as logits. If `False`, `input` is treated as probabilities. The behavior of the `input` affects its expected format (logits vs. probabilities). use_logarithm : bool, default=False If `True`, the logarithm of the Weighted Kappa is computed, following the original definition by :footcite:t:`deLaTorre2018kappa`. Example ------- >>> import torch >>> from dlordinal.losses import WKLoss >>> num_classes = 5 >>> input = torch.randn(3, num_classes) # Predicted logits for 3 samples >>> target = torch.randint(0, num_classes, (3,)) # Ground truth class indices >>> loss_fn = WKLoss(num_classes) >>> loss = loss_fn(input, target) >>> print(loss) """ num_classes: int penalization_type: str weight: Optional[torch.Tensor] epsilon: float use_logits: bool def __init__( self, num_classes: int, penalization_type: str = "quadratic", weight: Optional[torch.Tensor] = None, epsilon: Optional[float] = 1e-10, use_logits=False, use_logarithm=False, ): super(WKLoss, self).__init__() self.num_classes = num_classes self.penalization_type = penalization_type self.epsilon = epsilon self.weight = weight self.use_logits = use_logits self.use_logarithm = use_logarithm self.first_forward_ = True def _initialize(self, input, target): # Define error weights matrix repeat_op = ( torch.arange(self.num_classes, device=input.device) .unsqueeze(1) .expand(self.num_classes, self.num_classes) ) if self.penalization_type == "linear": self.weights_ = torch.abs(repeat_op - repeat_op.T) / (self.num_classes - 1) elif self.penalization_type == "quadratic": self.weights_ = torch.square((repeat_op - repeat_op.T)) / ( (self.num_classes - 1) ** 2 ) else: raise ValueError( f"Invalid penalization_type '{self.penalization_type}'." " Expected one of ['linear', 'quadratic']." ) # Apply class weight if self.weight is not None: # Repeat weight num_classes times in columns tiled_weight = self.weight.repeat((self.num_classes, 1)).to(input.device) self.weights_ *= tiled_weight
[docs] def forward(self, input, target): """ Forward pass for the Weighted Kappa loss. This method computes the Weighted Kappa loss between the predicted and true labels. The loss is based on the weighted disagreement between predictions and true labels, normalised by the expected disagreement under independence. Parameters ---------- input : torch.Tensor The model predictions. Shape: ``(batch_size, num_classes)``. If ``use_logits=True``, these should be raw logits (unnormalised scores). If ``use_logits=False``, these should be probabilities (rows summing to 1). target : torch.Tensor Ground truth labels. Shape: - ``(batch_size,)`` if labels are class indices. - ``(batch_size, num_classes)`` if already one-hot encoded. The tensor will be converted to float internally. Returns ------- loss : torch.Tensor A scalar tensor representing the weighted disagreement between predictions and true labels, normalised by the expected disagreement. """ num_classes = self.num_classes # Convert to onehot if integer labels are provided if target.dim() == 1: y = torch.eye(num_classes).to(target.device) target = y[target] target = target.float() if self.first_forward_: if not self.use_logits and not torch.allclose( input.sum(dim=1), torch.tensor(1.0, device=input.device) ): raise ValueError( "When passing use_logits=False, the input" " should be probabilities, not logits." ) elif self.use_logits and torch.allclose( input.sum(dim=1), torch.tensor(1.0, device=input.device) ): raise ValueError( "When passing use_logits=True, the input" " should be logits, not probabilities." ) self._initialize(input, target) self.first_forward_ = False if self.use_logits: input = torch.nn.functional.softmax(input, dim=1) hist_rater_a = torch.sum(input, 0) hist_rater_b = torch.sum(target, 0) conf_mat = torch.matmul(input.T, target) bsize = input.size(0) nom = torch.sum(self.weights_ * conf_mat) expected_probs = torch.matmul( torch.reshape(hist_rater_a, [num_classes, 1]), torch.reshape(hist_rater_b, [1, num_classes]), ) denom = torch.sum(self.weights_ * expected_probs / bsize) ret = nom / (denom + self.epsilon) if self.use_logarithm: return torch.log(ret + self.epsilon) return ret