Dropout methodologies
- class dlordinal.dropout.HybridDropout(p: float = 0.5, beta: float = 0.1)[source]
Implements a hybrid dropout methodology by Bérchez-Moreno et al.[1] which mix a standard dropout with an ordinal dropout. The ordinal dropout is based on the correlation between the activation values of the neuron and the target labels of the dataset.
To use this module, you must wrap your model with the
HybridDropoutContainermodule- Parameters:
p (float) – Probability of an element to be zeroed. Default: 0.5
beta (float) – Weight of the ordinal dropout. Default: 0.1
batch_targets (torch.Tensor) – Targets of the batch. Default: None
- Raises:
ValueError – If
pis not a probability.
- forward(x)[source]
Forward pass of the HybridDropout module just during training. The module calculates the correlation between the activation values of the neuron and the target labels of the dataset. Then, it calculates the ordinal probabilities and the mask for the dropout.
- Parameters:
x (torch.Tensor) – Input tensor
- Raises:
ValueError – If there are NaN values in the tensor.
ValueError – If the batch targets have not been set.
- Returns:
Output tensor
- Return type:
torch.Tensor
- class dlordinal.dropout.HybridDropoutContainer(model)[source]
Container for the
HybridDropoutmodule. This container is used to set the targets of the batch in the HybridDropout module.- Parameters:
model (torch.nn.Module) – Model to be wrapped.
- forward(x)[source]
Forward pass of the
HybridDropoutContainermodule.- Parameters:
x (torch.Tensor) – Input tensor
- Returns:
Output tensor
- Return type:
torch.Tensor
- set_targets(targets)[source]
Set the targets of the batch in the
HybridDropoutmodule.- Parameters:
targets (torch.Tensor) – Targets of the batch. Must be a 1D tensor of shape (
batch_size,) containing integer class indices for each sample in the batch. One-hot or soft label tensors of shape (batch_size,num_classes) are not supported.
Example
>>> from dlordinal.dropout import HybridDropoutContainer >>> from torchvision.models import resnet18 >>> model = resnet18(weights='IMAGENET1K_V1') >>> model.fc = nn.Sequential( nn.Linear(model.fc.in_features, 256), HybridDropout(), nn.Linear(256, num_classes), ) >>> batch_targets = torch.tensor([...]) >>> model = HybridDropoutContainer(model) >>> model.set_targets(targets)