File size: 446 Bytes
cb9e677 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
from typing import Optional
import torch
from torch.nn import functional as F
def compute_loss_with_mask(
logits: torch.Tensor, target: torch.Tensor, target_mask: Optional[torch.Tensor]
):
if target_mask is None:
return F.cross_entropy(logits, target, reduction="mean")
mb_loss = F.cross_entropy(logits, target, reduction="none")
mb_loss = torch.sum(mb_loss * target_mask) / torch.sum(target_mask)
return mb_loss
|