import torch import torch.nn as nn import torch.nn.functional as F class LabelSmoothingCrossEntropy(nn.Module): """ NLL loss with label smoothing. """ def __init__(self, smoothing=0.1): """ Constructor for the LabelSmoothing module. :param smoothing: label smoothing factor """ super(LabelSmoothingCrossEntropy, self).__init__() assert smoothing < 1.0 self.smoothing = smoothing self.confidence = 1. - smoothing def forward(self, x, target): logprobs = F.log_softmax(x, dim=-1) nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) nll_loss = nll_loss.squeeze(1) smooth_loss = -logprobs.mean(dim=-1) loss = self.confidence * nll_loss + self.smoothing * smooth_loss return loss.mean() class SoftTargetCrossEntropy(nn.Module): def __init__(self): super(SoftTargetCrossEntropy, self).__init__() def forward(self, x, target): loss = torch.sum(-target * F.log_softmax(x, dim=-1), dim=-1) return loss.mean()