video2music / model /loss.py
kjysmu's picture
add files
4e46a55
raw
history blame
1.7 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
# Borrowed from https://github.com/jason9693/MusicTransformer-pytorch/blob/5f183374833ff6b7e17f3a24e3594dedd93a5fe5/custom/criterion.py#L28
class SmoothCrossEntropyLoss(_Loss):
"""
https://arxiv.org/abs/1512.00567
"""
__constants__ = ['label_smoothing', 'vocab_size', 'ignore_index', 'reduction']
def __init__(self, label_smoothing, vocab_size, ignore_index=-100, reduction='mean', is_logits=True):
assert 0.0 <= label_smoothing <= 1.0
super().__init__(reduction=reduction)
self.label_smoothing = label_smoothing
self.vocab_size = vocab_size
self.ignore_index = ignore_index
self.input_is_logits = is_logits
def forward(self, input, target):
"""
Args:
input: [B * T, V]
target: [B * T]
Returns:
cross entropy: [1]
"""
mask = (target == self.ignore_index).unsqueeze(-1)
q = F.one_hot(target.long(), self.vocab_size).type(torch.float32)
u = 1.0 / self.vocab_size
q_prime = (1.0 - self.label_smoothing) * q + self.label_smoothing * u
q_prime = q_prime.masked_fill(mask, 0)
ce = self.cross_entropy_with_logits(q_prime, input)
if self.reduction == 'mean':
lengths = torch.sum(target != self.ignore_index)
return ce.sum() / lengths
elif self.reduction == 'sum':
return ce.sum()
else:
raise NotImplementedError
def cross_entropy_with_logits(self, p, q):
return -torch.sum(p * (q - q.logsumexp(dim=-1, keepdim=True)), dim=-1)