import torch
from torch.nn import Module
[docs]
class StickBreakingLayer(Module):
"""Base class to implement the stick breaking layer from :footcite:t:`liu2020unimodal`.
Parameters
----------
input_shape: int
Input shape, which refers to the number of neurons in the last fully connected layer
num_classes: int
Number of classes
"""
def __init__(self, input_shape: int, num_classes: int) -> None:
super().__init__()
self.fcn1 = torch.nn.Linear(input_shape, num_classes)
self.fcn2 = torch.nn.Sigmoid()
[docs]
def get_stick_logits(self, x: torch.Tensor):
"""
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
logits : torch.Tensor
Logits of the stick breaking layer
"""
# Clamps all elements in input into the range [ min, max ]. Letting min_value
# and max_value be min and max, respectively
x = torch.clamp(x, 0.1, 0.9)
comp = 1.0 - x
# cumprod is the cumulative product of the elements of the input tensor in
# the given dimension dim.
cumprod = torch.cumprod(comp, axis=1)
logits = torch.log(x * cumprod)
return logits
[docs]
def forward(self, x) -> torch.Tensor:
"""
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
logits : torch.Tensor
Logits of the stick breaking layer
"""
x = self.fcn1(x)
x = self.fcn2(x)
logits = self.get_stick_logits(x)
return logits