Clemspace's picture
Initial model upload
cb9e677
from typing import Optional
import torch
from torch.nn import functional as F
def compute_loss_with_mask(
logits: torch.Tensor, target: torch.Tensor, target_mask: Optional[torch.Tensor]
):
if target_mask is None:
return F.cross_entropy(logits, target, reduction="mean")
mb_loss = F.cross_entropy(logits, target, reduction="none")
mb_loss = torch.sum(mb_loss * target_mask) / torch.sum(target_mask)
return mb_loss