|
from typing import Optional |
|
from torch import nn, Tensor |
|
import torch |
|
import torch.nn.functional as F |
|
from ._functional import label_smoothed_nll_loss |
|
|
|
__all__ = ["SoftCrossEntropyLoss"] |
|
|
|
|
|
class SoftCrossEntropyLoss(nn.Module): |
|
|
|
__constants__ = ["reduction", "ignore_index", "smooth_factor"] |
|
|
|
def __init__( |
|
self, |
|
reduction: str = "mean", |
|
smooth_factor: Optional[float] = None, |
|
ignore_index: Optional[int] = -100, |
|
dim: int = 1, |
|
): |
|
"""Drop-in replacement for torch.nn.CrossEntropyLoss with label_smoothing |
|
|
|
Args: |
|
smooth_factor: Factor to smooth target (e.g. if smooth_factor=0.1 then [1, 0, 0] -> [0.9, 0.05, 0.05]) |
|
|
|
Shape |
|
- **y_pred** - torch.Tensor of shape (N, C, H, W) |
|
- **y_true** - torch.Tensor of shape (N, H, W) |
|
|
|
Reference |
|
https://github.com/BloodAxe/pytorch-toolbelt |
|
""" |
|
super().__init__() |
|
self.smooth_factor = smooth_factor |
|
self.ignore_index = ignore_index |
|
self.reduction = reduction |
|
self.dim = dim |
|
|
|
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: |
|
log_prob = F.log_softmax(y_pred, dim=self.dim) |
|
return label_smoothed_nll_loss( |
|
log_prob, |
|
y_true, |
|
epsilon=self.smooth_factor, |
|
ignore_index=self.ignore_index, |
|
reduction=self.reduction, |
|
dim=self.dim, |
|
) |
|
|