from typing import Literal
import torch
from torch.nn import Module
from dlordinal.output_layers.utils import stable_sigmoid
[docs]
class CLM(Module):
"""
Implementation of the cumulative link models from :footcite:t:`vargas2020clm` as a
torch layer. Different link functions can be used, including logit, probit
and cloglog.
Parameters
----------
num_classes : int
The number of classes.
link_function : str
The link function to use. Can be ``'logit'``, ``'probit'`` or ``'cloglog'``.
min_distance : float, default=0.0
The minimum distance between thresholds
Attributes
----------
num_classes : int
The number of classes.
link_function : str
The link function to use. Can be ``'logit'``, ``'probit'`` or ``'cloglog'``.
min_distance : float
The minimum distance between thresholds
dist_ : torch.distributions.Normal
The normal (0,1) distribution used to compute the probit link function.
thresholds_b_ : torch.nn.Parameter
The torch parameter for the first threshold.
thresholds_a_ : torch.nn.Parameter
The torch parameter for the alphas of the thresholds.
Example
---------
>>> import torch
>>> from dlordinal.output_layers import CLM
>>> inp = torch.randn(10, 5)
>>> fc = torch.nn.Linear(5, 1)
>>> clm = CLM(5, "logit")
>>> output = clm(fc(inp))
>>> print(output)
tensor([[0.7944, 0.1187, 0.0531, 0.0211, 0.0127],
[0.4017, 0.2443, 0.1862, 0.0987, 0.0690],
[0.4619, 0.2381, 0.1638, 0.0814, 0.0548],
[0.4636, 0.2378, 0.1632, 0.0809, 0.0545],
[0.4330, 0.2419, 0.1746, 0.0893, 0.0612],
[0.5006, 0.2309, 0.1495, 0.0716, 0.0473],
[0.6011, 0.2027, 0.1138, 0.0504, 0.0320],
[0.5995, 0.2032, 0.1144, 0.0507, 0.0322],
[0.4014, 0.2443, 0.1863, 0.0988, 0.0691],
[0.6922, 0.1672, 0.0838, 0.0351, 0.0217]], grad_fn=<CopySlices>)
"""
def __init__(
self,
num_classes: int,
link_function: Literal["logit", "probit", "cloglog"],
min_distance: int = 0.0,
**kwargs,
):
super().__init__()
self.num_classes = num_classes
self.link_function = link_function
self.min_distance = min_distance
self.dist_ = torch.distributions.Normal(0, 1)
self.thresholds_b_ = torch.nn.Parameter(
data=torch.Tensor([0]), requires_grad=True
)
self.thresholds_a_ = torch.nn.Parameter(
data=torch.Tensor([1.0 for _ in range(self.num_classes - 2)]),
requires_grad=True,
)
def _convert_thresholds(self, b, a, min_distance):
a = a**2
a = a + min_distance
thresholds_param = torch.cat((b, a), dim=0)
th = torch.sum(
torch.tril(
torch.ones(
(self.num_classes - 1, self.num_classes - 1), device=a.device
),
diagonal=0,
)
* torch.reshape(
torch.tile(thresholds_param, (self.num_classes - 1,)).to(a.device),
shape=(self.num_classes - 1, self.num_classes - 1),
),
dim=(1,),
)
return th
def _compute_z3(self, projected: torch.Tensor, thresholds: torch.Tensor):
m = projected.shape[0]
a = torch.reshape(torch.tile(thresholds, (m,)), shape=(m, -1))
b = torch.transpose(
torch.reshape(
torch.tile(projected, (self.num_classes - 1,)), shape=(-1, m)
),
0,
1,
)
z3 = a - b
return z3
def _apply_link_function(self, z3):
if self.link_function == "probit":
a3T = self.dist_.cdf(z3)
elif self.link_function == "cloglog":
a3T = 1 - torch.exp(-torch.exp(z3))
else: # 'logit'
a3T = stable_sigmoid(z3)
return a3T
def _clm(self, projected: torch.Tensor, thresholds: torch.Tensor):
projected = torch.reshape(projected, shape=(-1,))
m = projected.shape[0]
z3 = self._compute_z3(projected, thresholds)
a3T = self._apply_link_function(z3)
ones = torch.ones((m, 1), device=projected.device)
a3 = torch.cat((a3T, ones), dim=1)
a3[:, 1:] = a3[:, 1:] - a3[:, 0:-1]
return a3
[docs]
def forward(self, x):
"""
Parameters
----------
x : torch.Tensor
The input tensor.
Returns
-------
output: Tensor
The output tensor.
"""
thresholds = self._convert_thresholds(
self.thresholds_b_, self.thresholds_a_, self.min_distance
)
return self._clm(x, thresholds)