Source code for dlordinal.losses.sord

from collections import Counter
from typing import Optional

import numpy as np
import torch
import torch.nn.functional as F
from torch import Tensor, nn


[docs] class SORDLoss(nn.Module): """ Implements the SORD (Softmax-based Ordinal Regression Distribution) Loss from :footcite:t:`diaz2019soft`. SORD Loss generates a smooth, ordinally-weighted target distribution ('softmax_targets') and applies standard Cross-Entropy Loss (or KL Divergence) to the model's prediction. The target distribution is based on the distance from the true target and can be further customized using proximity measures. This loss belongs to the family of ordinal losses designed to penalize errors based on the severity of the ordinal distance. Parameters ---------- alpha : float Scaling factor controlling the 'smoothness' of the softmax target distribution. A higher alpha results in a sharper distribution. num_classes : int The total number of ordinal classes (C). train_targets : torch.Tensor The target labels from the training dataset, required to compute class counts and initialize the proximity matrix (prox_mat). prox : bool, default=False If True, enables the use of class-frequency-based proximity matrices (prox_mat) instead of simple L1 distance. ftype : str, default="max" Defines the function used to convert the proximity matrix into the final penalty (phi). Only used if ``prox`` is True. Options include: "max", "norm_max", "log", "norm_log", "division", "norm_division". weight : Optional[torch.Tensor], default=None Optional class weights of shape [num_classes] to handle class imbalance. use_logits : bool, default=True If True, applies F.log_softmax to the input for numerical stability. If False, assumes input is probabilities and applies log(input + 1e-9). Attributes ---------- prox_mat : Optional[torch.Tensor] The precomputed proximity matrix based on training set class frequencies. Used when ``prox`` is True. norm_prox_mat : Optional[torch.Tensor] The L1-normalized version of ``prox_mat``. """ def __init__( self, alpha: float, num_classes: int, train_targets: Tensor, prox: bool = False, ftype: str = "max", weight: Optional[torch.Tensor] = None, use_logits: bool = True, ): super().__init__() self.alpha = alpha self.num_classes = num_classes self.prox = prox self.ftype = ftype self.use_logits = use_logits self.train_targets = train_targets self.weight = weight # Initialize proximity matrix if needed if self.prox: self.class_counts_dict = self._create_classcounts_dict(train_targets) self.prox_mat = create_prox_mat(self.class_counts_dict, inv=False) if not hasattr(self, "prox_mat"): self.register_buffer("prox_mat", self.prox_mat) self.norm_prox_mat = F.normalize(self.prox_mat, p=1, dim=0) if not hasattr(self, "norm_prox_mat"): self.register_buffer("norm_prox_mat", self.norm_prox_mat) def _create_classcounts_dict(self, targets): class_counts = Counter(np.asarray(targets)) class_counts_dict = {i: class_counts.get(i, 0) for i in range(self.num_classes)} return class_counts_dict
[docs] def forward(self, input: Tensor, target: Tensor) -> Tensor: """ Calculates the SORD loss between the model's prediction and the ordinal target distribution. Parameters ---------- input : torch.Tensor The model's output (logits or probabilities) with shape [Batch, C]. target : torch.Tensor The true ordinal labels with shape [Batch]. Returns ------- torch.Tensor The scalar mean value of the SORD loss. """ if self.use_logits: input_logprob = F.log_softmax(input, dim=1) else: input_logprob = torch.log(input + 1e-9) if self.prox: self.prox_mat = self.prox_mat.to(input.device) self.norm_prox_mat = self.norm_prox_mat.to(input.device) if not self.prox: phi = torch.abs( torch.arange(self.num_classes, device=input.device).view(1, -1) - target.double().view(-1, 1) ) else: if self.ftype == "max": phi = torch.max(self.prox_mat) - self.prox_mat[target] elif self.ftype == "norm_max": phi = torch.max(self.norm_prox_mat) - self.norm_prox_mat[target] elif self.ftype == "norm_log": phi = -torch.log(self.norm_prox_mat[target]) elif self.ftype == "log": phi = -torch.log(self.prox_mat[target]) elif self.ftype == "norm_division": phi = 1.0 / (self.norm_prox_mat[target]) elif self.ftype == "division": phi = 1.0 / (self.prox_mat[target]) softmax_targets = F.softmax(-self.alpha * phi, dim=1) per_sample_loss = -torch.sum(softmax_targets * input_logprob, dim=1) # [batch] # Class weight if self.weight is not None: # self.weight: [num_classes] # target: [batch] sample_weights = self.weight[target].to(input.device) # [batch] per_sample_loss = per_sample_loss * sample_weights return per_sample_loss.mean()
def create_prox_mat(dist_dict, inv=True): """ Creates a proximity matrix based on class frequency distributions. This matrix captures how "close" two classes are based on the frequency of classes falling between them in the training data. Parameters ---------- dist_dict : dict A dictionary containing class indices as keys and their counts/frequencies in the training set as values. inv : bool, default=True If True, the matrix values are calculated as the inverse of the logarithm of the normalized cumulative count. If False, the values are calculated as the negative logarithm of the normalized cumulative count (similar to self-entropy, where distance increases with cumulative frequency). Returns ------- torch.Tensor The proximity matrix of shape [num_classes, num_classes]. """ labels = list(dist_dict.keys()) labels.sort() denominator = sum(dist_dict.values()) prox_mat = np.zeros([len(labels), len(labels)]) for label1 in labels: for label2 in labels: label1 = int(label1) label2 = int(label2) minlabel, maxlabel = min(label1, label2), max(label1, label2) numerator = dist_dict[label1] / 2 if minlabel == label1: # Above the diagonal for tmp_label in range(minlabel + 1, maxlabel + 1): numerator += dist_dict[tmp_label] else: # Under the diagonal for tmp_label in range(maxlabel - 1, minlabel - 1, -1): numerator += dist_dict[tmp_label] if inv: prox_mat[label1][label2] = (-np.log(numerator / denominator)) ** -1 else: prox_mat[label1][label2] = -np.log(numerator / denominator) return torch.tensor(prox_mat)