File size: 1,156 Bytes
626eca0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from typing import Optional

import torch
from torch.nn.modules.loss import _WeightedLoss


class MultiLabelNCELoss(_WeightedLoss):
    __constants__ = ["reduction"]

    def __init__(
        self,
        weight: Optional[torch.Tensor] = None,
        size_average=None,
        reduction: Optional[str] = "mean",
    ) -> None:
        super(MultiLabelNCELoss, self).__init__(weight, size_average, None, reduction)

    def forward(
        self, input: torch.Tensor, target: torch.Tensor, ignore_index: int = -100
    ) -> torch.Tensor:
        gold_scores = input.masked_fill(~(target.bool()), 0)
        gold_scores_sum = gold_scores.sum(-1)  # B x C
        neg_logits = input.masked_fill(target.bool(), float("-inf"))  # B x C x L
        neg_log_sum_exp = torch.logsumexp(neg_logits, -1, keepdim=True)  # B x C x 1
        norm_term = (
            torch.logaddexp(input, neg_log_sum_exp)
            .masked_fill(~(target.bool()), 0)
            .sum(-1)
        )
        gold_log_probs = gold_scores_sum - norm_term
        loss = -gold_log_probs.sum()
        if self.reduction == "mean":
            loss /= input.size(0)
        return loss