IE101TW / loss /label_smoothing.py
DeepLearning101's picture
Upload 6 files
fdc4786
raw
history blame
842 Bytes
import torch.nn as nn
import torch.nn.functional as F
class LabelSmoothingCrossEntropy(nn.Module):
def __init__(self, eps=0.1, reduction="mean",ignore_index=-100):
super(LabelSmoothingCrossEntropy, self).__init__()
self.eps = eps
self.reduction = reduction
self.ignore_index = ignore_index
def forward(self, output, target):
c = output.size()[-1]
log_preds = F.log_softmax(output, dim=-1)
if self.reduction=="sum":
loss = -log_preds.sum()
else:
loss = -log_preds.sum(dim=-1)
if self.reduction=="mean":
loss = loss.mean()
return loss*self.eps/c + (1-self.eps) * F.nll_loss(log_preds, target, reduction=self.reduction,
ignore_index=self.ignore_index)