Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
from src.loss.adversarial import AdversarialLoss | |
################################################################################ | |
# Cross-entropy loss | |
################################################################################ | |
class CELoss(AdversarialLoss): | |
""" | |
Measure cross-entropy between categorical (class) distributions | |
""" | |
def __init__(self, | |
targeted: bool = True, | |
reduction: str = 'none', | |
): | |
super().__init__(targeted, reduction) | |
self.loss = nn.CrossEntropyLoss(reduction='none') | |
def _compute_loss(self, y_pred: torch.Tensor, y_true: torch.Tensor): | |
assert y_pred.device == y_true.device | |
assert y_pred.ndim >= 2 and y_pred.shape[-1] >= 2 | |
if y_true.ndim >= 2: | |
y_true = y_true.argmax(dim=-1) | |
loss = self.loss(y_pred, y_true) | |
if not self.targeted: | |
loss *= -1 | |
return loss | |