import torch
from torch import nn
[docs]
class CORNLoss(nn.Module):
"""Rank-consistent ordinal regression (CORN) loss from :footcite:t:`shi2023corn`.
See the reference implementation `here <https://github.com/Raschka-research-group/coral-pytorch/blob/313482f86f50b58d8beb9fb54652e943b06745ef/coral_pytorch/losses.py#L87-L153>`__.
Parameters
----------
num_classes : int
The number of classes (J).
Note
----
CORN loss expects the output of your network to be of dimension J-1 because class 0
is predicted implicitly based on the probabilities of subsequent classes.
CORN loss does not support probabilistic targets.
Example
---
>>> import torch
>>> from dlordinal.losses import CORNLoss
>>> NUM_CLASSES = 5
>>> loss_fn = CORNLoss(num_classes=NUM_CLASSES)
>>> y_pred = torch.randn(3, NUM_CLASSES - 1)
>>> y_true = torch.tensor([0, 3, 1])
>>> loss = loss_fn(y_pred, y_true)
>>> print(loss)
"""
def __init__(self, num_classes):
super(CORNLoss, self).__init__()
self.num_classes = num_classes
self.log_sigmoid = torch.nn.LogSigmoid()
[docs]
def forward(self, y_pred, y_true):
"""
Computes the CORN loss between predicted logits and true labels.
Parameters
----------
y_pred : torch.Tensor
A tensor of shape (N, J - 1) containing predicted logits, where N is the batch
size and J is the number of classes. These logits are typically the raw outputs
of a neural network before applying a softmax function.
y_true : torch.Tensor
A tensor of shape (N,) with integer class indices (for categorical targets).
Returns
-------
torch.Tensor
A scalar tensor representing the mean loss over the batch. The result is the
average of the loss values computed for each sample in the batch.
"""
sets = []
for i in range(self.num_classes - 1):
label_mask = y_true > i - 1
label_tensor = (y_true[label_mask] > i).to(torch.int64)
sets.append((label_mask, label_tensor))
num_examples = 0
losses = 0.0
for task_index, s in enumerate(sets):
train_examples = s[0]
train_labels = s[1]
if len(train_labels) < 1:
continue
num_examples += len(train_labels)
pred = y_pred[train_examples, task_index]
loss = -torch.sum(
self.log_sigmoid(pred) * train_labels
+ (self.log_sigmoid(pred) - pred) * (1 - train_labels)
)
losses += loss
return losses / num_examples