virtex-redcaps / virtex /modules /label_smoothing.py
kdexd's picture
Black + isort, remove unused virtx files.
8d0e872
raw history blame
No virus
2.23 kB
import torch
from torch import nn
from torch.nn import functional as F
class CrossEntropyLossWithLabelSmoothing(nn.Module):
r"""
PyTorch :class:`~torch.nn.CrossEntropyLoss` with label smoothing. Quoting
documentation from original PyTorch module:
It is useful when training a classification problem with ``C`` classes.
The ``inputs`` is expected to contain raw, unnormalized scores for each class.
``inputs`` has to be a Tensor of size either ``(N, C)``. This criterion
expects a class index in the range ``[0, C - 1]`` as the ``targets`` for each
value of a 1D tensor of size ``minibatch``; if ``ignore_index`` is specified,
this criterion also accepts this class index (this index may not necessarily be
in the class range).
Parameters
----------
smoothing: float, optional (default = 0.1)
Label smoothing value. It sets target weights as ``(1 - smoothing)``
and all other weights as ``smoothing / (C - 1)``. Setting this to
zero will default to vanilla cross entropy loss.
"""
def __init__(self, smoothing: float = 0.0, ignore_index: int = -100):
super().__init__()
self.smoothing = smoothing
self.ignore_index = ignore_index
def forward(self, inputs, targets):
if self.smoothing == 0.0:
# Use PyTorch cross entropy when smoothing is 0. This is slightly
# faster than what we are doing manually below.
return F.cross_entropy(
inputs, targets, ignore_index=self.ignore_index, reduction="mean"
)
# Remove entries matching ``ignore_index``.
if self.ignore_index >= 0:
_targets = targets[targets != self.ignore_index]
_inputs = inputs[targets != self.ignore_index]
# shape: (batch_size, num_classes)
logprobs = F.log_softmax(_inputs, dim=-1)
# shape: (batch_size, num_classes)
weights = (
torch.ones_like(_inputs) * self.smoothing / (_inputs.size(-1) - 1.0)
)
weights.scatter_(-1, _targets.unsqueeze(-1), (1.0 - self.smoothing))
# shape: (batch_size, )
loss = (- weights * logprobs).sum(dim=-1)
return loss.mean()