File size: 1,687 Bytes
07423df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import logging
from typing import Any, KeysView
from torch import nn
__all__ = ["Losses"]
logger = logging.getLogger(__name__)
class TokenAveragedCrossEntropyLoss(nn.Module):
def __init__(self, cfg: Any):
super().__init__()
self.cfg = cfg
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
return self.loss_fn(shift_logits, shift_labels)
class SampleAveragedCrossEntropyLoss(nn.Module):
def __init__(self, cfg: Any):
super().__init__()
self.cfg = cfg
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss = 0
for i in range(labels.shape[0]):
loss += self.loss_fn(shift_logits[i], shift_labels[i])
loss /= labels.shape[0]
return loss
class Losses:
"""Losses factory."""
_losses = {
"TokenAveragedCrossEntropy": TokenAveragedCrossEntropyLoss,
"SampleAveragedCrossEntropy": SampleAveragedCrossEntropyLoss,
}
@classmethod
def names(cls) -> KeysView:
return cls._losses.keys()
@classmethod
def get(cls, name: str) -> Any:
"""Access to Losses.
Args:
name: losses name
Returns:
A class to build the Losses
"""
return cls._losses.get(name, TokenAveragedCrossEntropyLoss)
|