File size: 2,233 Bytes
a5f8a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
from torch import nn
from torch.nn import functional as F


class CrossEntropyLossWithLabelSmoothing(nn.Module):
    r"""
    PyTorch :class:`~torch.nn.CrossEntropyLoss` with label smoothing. Quoting
    documentation from original PyTorch module:

    It is useful when training a classification problem with ``C`` classes.
    The ``inputs`` is expected to contain raw, unnormalized scores for each class.

    ``inputs`` has to be a Tensor of size either ``(N, C)``. This criterion
    expects a class index in the range ``[0, C - 1]`` as the ``targets`` for each
    value of a 1D tensor of size ``minibatch``; if ``ignore_index`` is specified,
    this criterion also accepts this class index (this index may not necessarily be
    in the class range).

    Parameters
    ----------
    smoothing: float, optional (default = 0.1)
        Label smoothing value. It sets target weights as ``(1 - smoothing)``
        and all other weights as ``smoothing / (C - 1)``. Setting this to
        zero will default to vanilla cross entropy loss.
    """

    def __init__(self, smoothing: float = 0.0, ignore_index: int = -100):
        super().__init__()
        self.smoothing = smoothing
        self.ignore_index = ignore_index

    def forward(self, inputs, targets):

        if self.smoothing == 0.0:
            # Use PyTorch cross entropy when smoothing is 0. This is slightly
            # faster than what we are doing manually below.
            return F.cross_entropy(
                inputs, targets, ignore_index=self.ignore_index, reduction="mean"
            )

        # Remove entries matching ``ignore_index``.
        if self.ignore_index >= 0:
            _targets = targets[targets != self.ignore_index]
            _inputs = inputs[targets != self.ignore_index]

        # shape: (batch_size, num_classes)
        logprobs = F.log_softmax(_inputs, dim=-1)

        # shape: (batch_size, num_classes)
        weights = (
            torch.ones_like(_inputs) * self.smoothing / (_inputs.size(-1) - 1.0)
        )
        weights.scatter_(-1, _targets.unsqueeze(-1), (1.0 - self.smoothing))

        # shape: (batch_size, )
        loss = (- weights * logprobs).sum(dim=-1)
        return loss.mean()