import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.loss import _Loss # Borrowed from https://github.com/jason9693/MusicTransformer-pytorch/blob/5f183374833ff6b7e17f3a24e3594dedd93a5fe5/custom/criterion.py#L28 class SmoothCrossEntropyLoss(_Loss): """ https://arxiv.org/abs/1512.00567 """ __constants__ = ['label_smoothing', 'vocab_size', 'ignore_index', 'reduction'] def __init__(self, label_smoothing, vocab_size, ignore_index=-100, reduction='mean', is_logits=True): assert 0.0 <= label_smoothing <= 1.0 super().__init__(reduction=reduction) self.label_smoothing = label_smoothing self.vocab_size = vocab_size self.ignore_index = ignore_index self.input_is_logits = is_logits def forward(self, input, target): """ Args: input: [B * T, V] target: [B * T] Returns: cross entropy: [1] """ mask = (target == self.ignore_index).unsqueeze(-1) q = F.one_hot(target.long(), self.vocab_size).type(torch.float32) u = 1.0 / self.vocab_size q_prime = (1.0 - self.label_smoothing) * q + self.label_smoothing * u q_prime = q_prime.masked_fill(mask, 0) ce = self.cross_entropy_with_logits(q_prime, input) if self.reduction == 'mean': lengths = torch.sum(target != self.ignore_index) return ce.sum() / lengths elif self.reduction == 'sum': return ce.sum() else: raise NotImplementedError def cross_entropy_with_logits(self, p, q): return -torch.sum(p * (q - q.logsumexp(dim=-1, keepdim=True)), dim=-1)