| from typing import Optional | |
| from torch import nn, Tensor | |
| import torch | |
| import torch.nn.functional as F | |
| from ._functional import label_smoothed_nll_loss | |
| __all__ = ["SoftCrossEntropyLoss"] | |
| class SoftCrossEntropyLoss(nn.Module): | |
| __constants__ = ["reduction", "ignore_index", "smooth_factor"] | |
| def __init__( | |
| self, | |
| reduction: str = "mean", | |
| smooth_factor: Optional[float] = None, | |
| ignore_index: Optional[int] = -100, | |
| dim: int = 1, | |
| ): | |
| """Drop-in replacement for torch.nn.CrossEntropyLoss with label_smoothing | |
| Args: | |
| smooth_factor: Factor to smooth target (e.g. if smooth_factor=0.1 then [1, 0, 0] -> [0.9, 0.05, 0.05]) | |
| Shape | |
| - **y_pred** - torch.Tensor of shape (N, C, H, W) | |
| - **y_true** - torch.Tensor of shape (N, H, W) | |
| Reference | |
| https://github.com/BloodAxe/pytorch-toolbelt | |
| """ | |
| super().__init__() | |
| self.smooth_factor = smooth_factor | |
| self.ignore_index = ignore_index | |
| self.reduction = reduction | |
| self.dim = dim | |
| def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: | |
| log_prob = F.log_softmax(y_pred, dim=self.dim) | |
| return label_smoothed_nll_loss( | |
| log_prob, | |
| y_true, | |
| epsilon=self.smooth_factor, | |
| ignore_index=self.ignore_index, | |
| reduction=self.reduction, | |
| dim=self.dim, | |
| ) | |