Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class CrossEntropyLossWithLabelSmoothing(nn.Module): | |
r""" | |
PyTorch :class:`~torch.nn.CrossEntropyLoss` with label smoothing. Quoting | |
documentation from original PyTorch module: | |
It is useful when training a classification problem with ``C`` classes. | |
The ``inputs`` is expected to contain raw, unnormalized scores for each class. | |
``inputs`` has to be a Tensor of size either ``(N, C)``. This criterion | |
expects a class index in the range ``[0, C - 1]`` as the ``targets`` for each | |
value of a 1D tensor of size ``minibatch``; if ``ignore_index`` is specified, | |
this criterion also accepts this class index (this index may not necessarily be | |
in the class range). | |
Parameters | |
---------- | |
smoothing: float, optional (default = 0.1) | |
Label smoothing value. It sets target weights as ``(1 - smoothing)`` | |
and all other weights as ``smoothing / (C - 1)``. Setting this to | |
zero will default to vanilla cross entropy loss. | |
""" | |
def __init__(self, smoothing: float = 0.0, ignore_index: int = -100): | |
super().__init__() | |
self.smoothing = smoothing | |
self.ignore_index = ignore_index | |
def forward(self, inputs, targets): | |
if self.smoothing == 0.0: | |
# Use PyTorch cross entropy when smoothing is 0. This is slightly | |
# faster than what we are doing manually below. | |
return F.cross_entropy( | |
inputs, targets, ignore_index=self.ignore_index, reduction="mean" | |
) | |
# Remove entries matching ``ignore_index``. | |
if self.ignore_index >= 0: | |
_targets = targets[targets != self.ignore_index] | |
_inputs = inputs[targets != self.ignore_index] | |
# shape: (batch_size, num_classes) | |
logprobs = F.log_softmax(_inputs, dim=-1) | |
# shape: (batch_size, num_classes) | |
weights = ( | |
torch.ones_like(_inputs) * self.smoothing / (_inputs.size(-1) - 1.0) | |
) | |
weights.scatter_(-1, _targets.unsqueeze(-1), (1.0 - self.smoothing)) | |
# shape: (batch_size, ) | |
loss = (- weights * logprobs).sum(dim=-1) | |
return loss.mean() | |