Source code for dlordinal.output_layers.copoc

from typing import Callable

import torch
from torch import Tensor
from torch.nn import Module


[docs] class COPOC(Module): """Implements the Conformal Predictions for OC (COPOC) output layer(s) from :footcite:t:`dey2023conformal`, which enforce unimodality in the output probabilities in a non-parametric way. Parameters ---------- phi: Callable[[Tensor], Tensor] Non-negative transformation function. Default is absolute value function :math:`\\phi(x)=|x|`. psi: Callable[[Tensor], Tensor] Strictly monotonic decreasing bijective function. Default is negative absolute value function :math:`\\psi(x)=-|x|`. Example ------- >>> import torch >>> from dlordinal.output_layers import COPOC >>> inp = torch.randn(10, 5) >>> fc = torch.nn.Linear(5, 5) >>> copoc = COPOC() >>> output = torch.nn.functional.softmax(copoc(fc(inp)),dim=1) >>> print(output) tensor([[0.1898, 0.1901, 0.2568, 0.2196, 0.1436], [0.4538, 0.3191, 0.1412, 0.0529, 0.0330], [0.3371, 0.2554, 0.2151, 0.1047, 0.0876], [0.1859, 0.2073, 0.2658, 0.1889, 0.1520], [0.3306, 0.2195, 0.1982, 0.1303, 0.1214], [0.2132, 0.3768, 0.1590, 0.1278, 0.1232], [0.1531, 0.1544, 0.2094, 0.2451, 0.2381], [0.4986, 0.2240, 0.1689, 0.0590, 0.0495], [0.5838, 0.2201, 0.1289, 0.0507, 0.0166], [0.1639, 0.1969, 0.2100, 0.2347, 0.1946]], grad_fn=<SoftmaxBackward0>) """ def __init__( self, phi: Callable[[Tensor], Tensor] = lambda x: torch.abs(x), psi: Callable[[Tensor], Tensor] = lambda x: -torch.abs(x), ) -> None: super().__init__() self.phi = phi self.psi = psi
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Parameters ---------- x: torch.Tensor Input tensor of shape (batch_size, num_classes). Returns ------- probs : torch.Tensor Logits of the unimodal output layer (batch_size, num_classes). """ # Step 1: Compute η(x) = f(x; θ), which is given by the input tensor. n = x.clone() # η ∈ ℝ^K — raw logits for each class # Step 2: Ensure all values are non-negative: v_k = φ(η_k), # φ = softplus ensures v_k ≥ 0 v_rest = self.phi(n[:, 1:]) v = torch.cat([n[:, :1], v_rest], dim=1) # Step 3: Generate cumulative sum: r_k = r_{k-1} + v_k (with r₁ = v₁) r = torch.cumsum(v, dim=1) # Step 4: Apply symmetric decreasing function: z_k = ψ_E(r_k) = -|r_k| z = self.psi(r) # z ∈ ℝ^K, unimodal due to symmetric log-probability decay # Step 5: To turn logits into unimodal probabilities compute class probabilities: # p̂_k = softmax(z_k). Here, we only return the logits return z