Source code for dlordinal.losses.ordinal_ecoc_distance_loss

from typing import Optional

import numpy as np
import torch
from torch import Tensor


[docs] class OrdinalECOCDistanceLoss(torch.nn.Module): """Ordinal ECOC distance loss from :footcite:t:`barbero2023error` for use with :class:`dlordinal.wrappers.OBDECOCModel`. Computes the MSE loss between the output of the model (class threshold probabilities) and the ideal output vector for each class. Parameters ---------- num_classes : int Number of classes. weights : Optional[torch.Tensor] Optional weighting for each class. Should be of shape (num_classes,) if provided. Attributes ---------- target_class : torch.Tensor A tensor of shape (num_classes, num_classes-1) containing the ideal output vectors for each class. weights : Optional[torch.Tensor] A tensor of shape (num_classes,) containing the class-specific weights. """ target_class: Tensor weights: Optional[Tensor] def __init__(self, num_classes: int, weights: Optional[Tensor] = None) -> None: super().__init__() target_class = np.ones((num_classes, num_classes - 1), dtype=np.float32) target_class[np.triu_indices(num_classes, 0, num_classes - 1)] = 0.0 target_class = torch.tensor(target_class, dtype=torch.float32) self.register_buffer("target_class", target_class) self.mse = torch.nn.MSELoss(reduction="sum" if weights is None else "none") if weights is not None: self.register_buffer("weights", weights) else: self.weights = None
[docs] def forward(self, input, target): """ Parameters ---------- input : torch.Tensor Predicted probabilities for each class threshold, with shape (batch_size, num_classes - 1). target : torch.Tensor Ground truth labels of shape (batch_size,). The labels are integer class indices in the range [0, num_classes-1]. Returns ------- loss : torch.Tensor A scalar tensor representing the computed loss. If `weights` is None, the loss is computed as the sum of the MSE between `input` and the target vector for each class. If `weights` is provided, the loss is computed as the weighted sum of the per-sample MSE losses. """ target_vector = self.target_class[target] if self.weights is None: return self.mse(input, target_vector) else: weights = self.weights[target] return (self.mse(input, target_vector).sum(dim=1) * weights).sum()