from typing import Optional
import torch
import torch.nn.functional as F
from torch import Tensor, nn
[docs]
class SLACELoss(nn.Module):
"""
Implements the SLACE (Soft Labels Accumulating Cross Entropy) loss from
:footcite:t:`nachmani2025slace`.
Ordinal regression classifies objects to classes with a natural order,
where the severity of prediction errors varies (e.g., classifying 'No
Risk' as 'Critical Risk' is worse than 'High Risk').
SLACE is ordinality-aware loss designed to ensure the model's
output is as close as possible to the correct class, considering the
order of labels.
It provably satisfies two key properties for ordinal losses:
**monotonicity** and **balance sensitivity**.
The mechanism involves generating a smooth, ordinally-weighted target
probability distribution ('softmax_targets') and applying cross-entropy
to an accumulated version of the model's predicted distribution
('accumulating_softmax').
Parameters
----------
alpha : float
Scaling factor controlling the 'smoothness' of the softmax target
distribution. A higher alpha results in a sharper distribution.
num_classes : int
The total number of ordinal classes (C).
weight : Optional[torch.Tensor], default=None
Optional class weights of shape [num_classes] to handle class imbalance.
use_logits : bool, default=True
If True, assumes 'input' contains logits and applies softmax internally.
If False, assumes 'input' is already probabilities.
Attributes
----------
prox_dom : Optional[torch.Tensor]
The precomputed ordinal dominance matrix used for probability accumulation.
Registered as a buffer.
"""
prox_dom: Optional[Tensor]
def __init__(
self,
alpha: float,
num_classes: int,
weight: Optional[torch.Tensor] = None,
use_logits: bool = True,
):
super().__init__()
self.alpha = alpha
self.num_classes = num_classes
self.use_logits = use_logits
if weight is not None:
self.register_buffer("weight", weight.float())
else:
self.weight = None
# Precompute prox_dom
labels = torch.arange(self.num_classes)
h = labels.view(-1, 1, 1)
i = labels.view(1, -1, 1)
j = labels.view(1, 1, -1)
distance_i = torch.abs(i - h)
distance_j = torch.abs(j - h)
self.register_buffer("prox_dom", (distance_j <= distance_i).float())
[docs]
def forward(self, input: Tensor, target: Tensor) -> Tensor:
"""
Calculates the SLACE loss between the model's prediction and the ordinal
target distribution.
Parameters
----------
input : torch.Tensor
The model's output (logits or probabilities) with shape [Batch, num_classes].
target : torch.Tensor
The true ordinal labels with shape [Batch] or [Batch, 1].
Returns
-------
torch.Tensor
The scalar mean value of the SLACE loss.
"""
if self.use_logits:
input = F.softmax(input, dim=1)
# In case target has shape [Batch, 1], flatten to [Batch]
target = target.view(-1)
phi = torch.abs(
torch.arange(self.num_classes, device=input.device).view(1, -1)
- target.double().view(-1, 1)
)
softmax_targets = F.softmax(-self.alpha * phi, dim=1)
# This is the original formulation by the authors of the paper:
# one_hot_target = F.one_hot(target, num_classes=self.num_classes).to(
# input.device
# )
# one_hot_target_comp = 1 - one_hot_target
# mass_weights = (
# one_hot_target * softmax_targets + one_hot_target_comp * softmax_targets
# )
# one_hot_target: x
# softmax_targets: y
# one_hot_target_comp: 1 - x
# mass_weights: x*y + (1-x)*y = y
# Therefore, mass_weights == softmax_targets
mass_weights = softmax_targets
accumulating_softmax = (
torch.matmul(
self.prox_dom[target.long()].double(),
torch.unsqueeze(input, 2).double(),
)
.double()
.squeeze(dim=2)
)
per_sample_loss = -torch.sum(
mass_weights * torch.log(accumulating_softmax + 1e-9), dim=1
) # [Batch]
if self.weight is not None:
sample_weights = self.weight[target].to(input.device) # [Batch]
per_sample_loss = per_sample_loss * sample_weights # [Batch]
return per_sample_loss.mean()