from typing import Optional | |
import torch | |
from torch.nn import functional as F | |
def compute_loss_with_mask( | |
logits: torch.Tensor, target: torch.Tensor, target_mask: Optional[torch.Tensor] | |
): | |
if target_mask is None: | |
return F.cross_entropy(logits, target, reduction="mean") | |
mb_loss = F.cross_entropy(logits, target, reduction="none") | |
mb_loss = torch.sum(mb_loss * target_mask) / torch.sum(target_mask) | |
return mb_loss | |