Source code for dlordinal.losses.emd_loss

import torch
import torch.nn.functional as F
from torch import nn


[docs] class EMDLoss(nn.Module): """ Computes the squared Earth Mover's Distance (EMD) loss, also known as the Ranked Probability Score (RPS), for ordinal classification tasks. This implementation follows the formulation presented by :footcite:t:`houys16semd`. The squared EMD loss is equivalent to the RPS described in :footcite:t:`epstein1969scoring`. It serves as a proper scoring rule for ordinal outcomes, encouraging probabilistic predictions that are both accurate and calibrated. Errors farther from the true class are penalised more heavily, reflecting the ordinal structure of the target variable. Parameters ---------- num_classes : int The number of ordinal classes (denoted as J). Examples -------- >>> import torch >>> from dlordinal.losses import EMDLoss >>> loss_fn = EMDLoss(num_classes=5) >>> y_pred = torch.randn(8, 5) # Predicted logits >>> y_true = torch.tensor([0, 1, 2, 3, 4, 3, 1, 0]) # Class indices >>> loss = loss_fn(y_pred, y_true) """ def __init__(self, num_classes: int): super(EMDLoss, self).__init__() self.num_classes = num_classes
[docs] def forward(self, y_pred, y_true): """ Computes the squared Earth Mover's Distance (Ranked Probability Score) between predictions and targets. Parameters ---------- y_pred : torch.Tensor The model predictions. Shape: ``(batch_size, num_classes)``. y_true : torch.Tensor Ground truth labels. Shape: - ``(batch_size,)`` if labels are class indices. - ``(batch_size, num_classes)`` if already one-hot encoded. Returns ------- torch.Tensor A scalar tensor representing the mean squared EMD loss over the batch. """ # One-hot encode true labels if integer labels are provided if y_true.dim() == 1: y_true = F.one_hot(y_true, num_classes=self.num_classes) # Convert logits to probabilities y_pred_proba = torch.nn.functional.softmax(y_pred, dim=1) # Compute the CDFs pred_cdf = torch.cumsum(y_pred_proba, dim=1) true_cdf = torch.cumsum(y_true, dim=1) # Compute the squared EMD emd = torch.sum((pred_cdf - true_cdf) ** 2, dim=1) return emd.mean()