|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
class RDMCrossEntropyLoss(nn.CrossEntropyLoss): |
|
def __init__(self, ignore_index): |
|
self.ignore_index = ignore_index |
|
|
|
def forward(self, |
|
scores: torch.Tensor, |
|
target: torch.Tensor, |
|
label_mask, |
|
weights, |
|
) -> torch.Tensor: |
|
""" |
|
Computes the RDM-derived loss (weighted cross-entropy). |
|
""" |
|
|
|
sample_size = target.ne(self.ignore_index).float().sum() |
|
|
|
lprobs = F.log_softmax(scores, dim=-1) |
|
|
|
loss = lprobs * weights |
|
fullseq_loss = loss.sum() / sample_size |
|
|
|
|
|
|
|
label_mask = label_mask.float() |
|
sample_size = label_mask.sum() |
|
loss = (loss * label_mask).sum() / sample_size |
|
|
|
ppl = torch.exp(loss) |
|
|
|
logging_output = { |
|
'ppl': ppl.data, |
|
'fullseq_loss': fullseq_loss.data, |
|
'weight_diff_loss': loss.data |
|
} |
|
|
|
return logging_output |