import torch
import torch.nn as nn
[docs]
class HybridDropoutContainer(nn.Module):
"""Container for the ``HybridDropout`` module. This container is used to set the
targets of the batch in the HybridDropout module.
Parameters
----------
model : torch.nn.Module
Model to be wrapped.
"""
def __init__(self, model):
super(HybridDropoutContainer, self).__init__()
self.model = model
[docs]
def forward(self, x):
"""Forward pass of the ``HybridDropoutContainer`` module.
Parameters
----------
x : torch.Tensor
Input tensor
Returns
-------
torch.Tensor
Output tensor
"""
return self.model(x)
[docs]
def set_targets(self, targets):
"""
Set the targets of the batch in the ``HybridDropout`` module.
Parameters
----------
targets : torch.Tensor
Targets of the batch. Must be a 1D tensor of shape (``batch_size``,)
containing integer class indices for each sample in the batch.
One-hot or soft label tensors of shape (``batch_size``, ``num_classes``) are
not supported.
Example
-------
>>> from dlordinal.dropout import HybridDropoutContainer
>>> from torchvision.models import resnet18
>>> model = resnet18(weights='IMAGENET1K_V1')
>>> model.fc = nn.Sequential(
nn.Linear(model.fc.in_features, 256),
HybridDropout(),
nn.Linear(256, num_classes),
)
>>> batch_targets = torch.tensor([...])
>>> model = HybridDropoutContainer(model)
>>> model.set_targets(targets)
"""
for module in self.model.modules():
if isinstance(module, HybridDropout):
module.batch_targets = targets
[docs]
class HybridDropout(nn.Module):
"""Implements a hybrid dropout methodology by :footcite:t:`berchez2024fusion` which
mix a standard dropout with an ordinal dropout. The ordinal dropout is based on the
correlation between the activation values of the neuron and the target labels
of the dataset.
To use this module, you must wrap your model with the ``HybridDropoutContainer``
module
Parameters
----------
p : float
Probability of an element to be zeroed. Default: 0.5
beta : float
Weight of the ordinal dropout. Default: 0.1
batch_targets : torch.Tensor
Targets of the batch. Default: None
Raises
------
ValueError
If ``p`` is not a probability.
"""
def __init__(self, p: float = 0.5, beta: float = 0.1):
super(HybridDropout, self).__init__()
self.p = p
self.beta = beta
if self.p < 0 or self.p > 1:
raise ValueError("p must be a probability")
[docs]
def forward(self, x):
"""Forward pass of the HybridDropout module just during training. The module
calculates the correlation between the activation values of the neuron and the
target labels of the dataset. Then, it calculates the ordinal probabilities
and the mask for the dropout.
Parameters
----------
x : torch.Tensor
Input tensor
Raises
------
ValueError
If there are NaN values in the tensor.
ValueError
If the batch targets have not been set.
Returns
-------
torch.Tensor
Output tensor
"""
if self.training:
if torch.isnan(x).any():
raise ValueError("Nan values in the tensor")
if hasattr(self, "batch_targets"):
targets = self.batch_targets
else:
raise ValueError(
"Batch targets have not been set. Use"
" HybridDropoutContainer.set_targets() to set the targets."
)
if targets.ndim != 1:
raise ValueError(
"Targets must be a 1D tensor"
f" but got {targets.ndim}D tensor"
"If you are using one-hot encoding or soft labels, (shape [batch_size, num_classes]),"
" please convert them to class indices (shape [batch_size])"
)
targets = torch.reshape(targets, (1, targets.shape[0]))
correlation_list = []
# Pearson's correlation calculation for each neuron
for neuron in range(0, x.shape[1]):
patterns = x[:, neuron]
patterns = torch.reshape(patterns, (1, x.shape[0]))
concat = torch.cat((patterns, targets), 0)
corr = torch.corrcoef(concat)
correlation_list.append(float(corr[0, 1].cpu().detach()))
correlations = torch.Tensor(correlation_list)
# Scale of the correlation matrix
correlations = 1 + correlations
correlations = correlations / 2
correlations = torch.nan_to_num(correlations)
# Get ordinal probabilities
ordinal_prob = 1 - correlations
# Mask creation: the first summand is the one related to ordinal dropout and
# the second summand is the standard dropout.
probabilities = (self.beta * ordinal_prob) + ((1 - self.beta) * self.p)
mask = torch.empty(x.size()[1]).uniform_(0, 1) >= probabilities
mask = mask.to(x.device)
# Normalisation
no_zeros = int(torch.count_nonzero(mask))
total_neurons = mask.shape[0]
# Proportion of maintained neurons (keep_prob)
keep_prob = no_zeros / total_neurons
# Scaling factor
scaling_factor = 1.0 / (keep_prob + 1e-9)
mask = torch.reshape(mask, (1, mask.shape[0]))
mask = mask.repeat(x.shape[0], 1)
return x.mul(mask) * scaling_factor
else:
return x