Source code for dlordinal.output_layers.poisson_layer

from typing import Optional

import torch
from torch.distributions import Poisson


[docs] class PoissonLayer(torch.nn.Module): """ Unimodal output layer for ordinal classification based on the Poisson distribution. Proposed by :footcite:t:`beckham2017unimodal`. Learns the λ parameter of the Poisson distribution from the input features and uses the Poisson distribution to compute the probabilities of each class, ensuring that the output is unimodal and that the probabilities sum to 1. The softplus of the linear layer output is used to ensure that the λ parameter is positive. Additionally, its value is clamped between 1e-8 and 1e4 to prevent numerical issues. The layer includes an optional learnable temperature parameter τ that controls the sharpness of the output distribution. Higher values of τ produce softer distributions, while lower values produce sharper distributions. If learn_tau is set to False, τ is fixed at 1 (no scaling). Parameters ---------- in_features : int Size of the input feature vector (output features from the previous layer). num_classes : int Number of discrete output classes. Defines support of the distribution as {0, ..., num_classes - 1}. learn_tau : bool, default=True If True, the temperature parameter τ is learned as a model parameter. Otherwise, it is stored as a fixed buffer. Attributes ---------- lambda_layer : torch.nn.Linear Linear transformation that maps input features to a scalar rate λ. log_tau : torch.Tensor or torch.nn.Parameter Log-temperature parameter used to control sharpness of the distribution. num_classes : int Number of output classes. learn_tau : bool Whether temperature is learnable. Example ------- >>> import torch >>> from dlordinal.output_layers import PoissonLayer >>> layer = PoissonLayer(in_features=5, num_classes=3, learn_tau=True) >>> input = torch.randn(2, 5) >>> probs = layer(input) >>> print(probs) """ log_tau: Optional[torch.Tensor] def __init__(self, *, in_features: int, num_classes: int, learn_tau: bool = True): super().__init__() self.num_classes = num_classes self.learn_tau = learn_tau tau_init = torch.tensor(1.0).log() if learn_tau: self.register_parameter("log_tau", torch.nn.Parameter(tau_init)) else: self.register_buffer("log_tau", tau_init) self.lambda_layer = torch.nn.Linear(in_features, 1)
[docs] def forward(self, input: torch.Tensor) -> torch.Tensor: """ Compute class probabilities using a Poisson-based discrete distribution. Parameters ---------- input : torch.Tensor, shape (batch_size, in_features) Input feature tensor. Returns ------- torch.Tensor, shape (batch_size, num_classes) Probability distribution over discrete classes. """ # 1. Compute rate λ > 0 lambda_ = torch.nn.functional.softplus(self.lambda_layer(input).squeeze(-1)) lambda_ = lambda_.clamp(min=1e-8, max=1e4) # 2. Compute Poisson log-probabilities for all classes k = torch.arange(self.num_classes, device=input.device, dtype=input.dtype) pois = Poisson(rate=lambda_[:, None]) log_probs = pois.log_prob(k) # 3. Temperature scaling (controls sharpness) tau = torch.exp(self.log_tau) scaled_log_probs = log_probs / tau # 4. Softmax normalisation probs = torch.softmax(scaled_log_probs, dim=1) return probs