olm-chat-7b / open_lm /losses.py
henhenhahi111112's picture
Upload folder using huggingface_hub
af6e330 verified
raw
history blame contribute delete
681 Bytes
import torch
from torch import Tensor
from torch.nn import CrossEntropyLoss
class CrossEntropyLossWithZLoss(CrossEntropyLoss):
def __init__(
self,
eps: float = 1e-4,
weight: Tensor = None,
size_average=None,
ignore_index: int = -100,
reduce=None,
reduction: str = "mean",
label_smoothing: float = 0,
) -> None:
super().__init__(weight, size_average, ignore_index, reduce, reduction, label_smoothing)
self.eps = eps
def forward(self, input: Tensor, target: Tensor) -> Tensor:
return super().forward(input, target) + self.eps * torch.square(torch.logsumexp(input, dim=-1)).mean()