Spaces:
Build error
Build error
import torch | |
from torch import nn | |
from torch.nn import KLDivLoss | |
from torch.nn import LogSoftmax | |
class LabelSmoothingLoss(nn.Module): | |
def __init__(self, label_smoothing=0.0, unreliable_label=None, ignore_index=-100): | |
""" | |
If label_smoothing == 0.0, it is equivalent to xentropy | |
""" | |
assert 0.0 <= label_smoothing <= 1.0 | |
super(LabelSmoothingLoss, self).__init__() | |
self.ignore_index = ignore_index | |
self.label_smoothing = label_smoothing | |
self.loss_fn = KLDivLoss(reduction='batchmean') | |
self.unreliable_label = unreliable_label | |
self.max_gap = 100. | |
self.log_softmax = LogSoftmax(1) | |
def forward(self, output, target): | |
""" | |
output: logits | |
target: labels | |
""" | |
vocab_size = output.shape[1] | |
mask = (target != self.ignore_index) | |
output, target = output[mask], target[mask] | |
output = self.log_softmax(output) | |
def get_smooth_prob(ls): | |
smoothing_value = ls / (vocab_size - 1) | |
prob = output.new_full((target.size(0), vocab_size), smoothing_value) | |
prob.scatter_(1, target.unsqueeze(1), 1 - ls) | |
return prob | |
if self.unreliable_label is not None: | |
smoothed_prob = get_smooth_prob(self.label_smoothing) | |
hard_prob = get_smooth_prob(0.0) | |
unreliable_mask = (target == self.unreliable_label).to(torch.float) | |
model_prob = ((smoothed_prob.T * unreliable_mask) + (hard_prob.T * (1 - unreliable_mask))).T | |
else: | |
model_prob = get_smooth_prob(self.label_smoothing) | |
loss = self.loss_fn(output, model_prob) | |
return loss | |