|
from typing import Optional |
|
|
|
import torch |
|
from torch.nn.modules.loss import _WeightedLoss |
|
|
|
|
|
class MultiLabelNCELoss(_WeightedLoss): |
|
__constants__ = ["reduction"] |
|
|
|
def __init__( |
|
self, |
|
weight: Optional[torch.Tensor] = None, |
|
size_average=None, |
|
reduction: Optional[str] = "mean", |
|
) -> None: |
|
super(MultiLabelNCELoss, self).__init__(weight, size_average, None, reduction) |
|
|
|
def forward( |
|
self, input: torch.Tensor, target: torch.Tensor, ignore_index: int = -100 |
|
) -> torch.Tensor: |
|
gold_scores = input.masked_fill(~(target.bool()), 0) |
|
gold_scores_sum = gold_scores.sum(-1) |
|
neg_logits = input.masked_fill(target.bool(), float("-inf")) |
|
neg_log_sum_exp = torch.logsumexp(neg_logits, -1, keepdim=True) |
|
norm_term = ( |
|
torch.logaddexp(input, neg_log_sum_exp) |
|
.masked_fill(~(target.bool()), 0) |
|
.sum(-1) |
|
) |
|
gold_log_probs = gold_scores_sum - norm_term |
|
loss = -gold_log_probs.sum() |
|
if self.reduction == "mean": |
|
loss /= input.size(0) |
|
return loss |
|
|