duynhm's picture
Initial commit
be2715b
raw
history blame contribute delete
No virus
1.48 kB
from typing import Optional
from torch import nn, Tensor
import torch
import torch.nn.functional as F
from ._functional import label_smoothed_nll_loss
__all__ = ["SoftCrossEntropyLoss"]
class SoftCrossEntropyLoss(nn.Module):
__constants__ = ["reduction", "ignore_index", "smooth_factor"]
def __init__(
self,
reduction: str = "mean",
smooth_factor: Optional[float] = None,
ignore_index: Optional[int] = -100,
dim: int = 1,
):
"""Drop-in replacement for torch.nn.CrossEntropyLoss with label_smoothing
Args:
smooth_factor: Factor to smooth target (e.g. if smooth_factor=0.1 then [1, 0, 0] -> [0.9, 0.05, 0.05])
Shape
- **y_pred** - torch.Tensor of shape (N, C, H, W)
- **y_true** - torch.Tensor of shape (N, H, W)
Reference
https://github.com/BloodAxe/pytorch-toolbelt
"""
super().__init__()
self.smooth_factor = smooth_factor
self.ignore_index = ignore_index
self.reduction = reduction
self.dim = dim
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
log_prob = F.log_softmax(y_pred, dim=self.dim)
return label_smoothed_nll_loss(
log_prob,
y_true,
epsilon=self.smooth_factor,
ignore_index=self.ignore_index,
reduction=self.reduction,
dim=self.dim,
)