Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from utils.utils import batched_bincount | |
import torch.nn.functional as F | |
class GeneralCrossEntropy(nn.Module): | |
def __init__(self, weight_type: str, beta : float = 0.99, is_sequential: bool = True): | |
super().__init__() | |
self.weight_type = weight_type | |
self.beta = beta | |
if weight_type == "seq_cbce": | |
assert is_sequential == True | |
self.loss_func = SeqCBCrossEntropy(beta=beta) | |
elif weight_type == "cbce": | |
self.loss_func = CBCrossEntropy(beta=beta, is_sequential=is_sequential) | |
elif weight_type == "wce": | |
self.loss_func = WeightedCrossEntropy(is_sequential=is_sequential) | |
elif weight_type == "ce": | |
self.loss_func = CrossEntropy(is_sequential=is_sequential) | |
else: | |
NotImplementedError | |
def forward(self, | |
preds: torch.Tensor, | |
labels: torch.Tensor, | |
pad_mask: torch.Tensor = None): | |
return self.loss_func(preds, labels, pad_mask) | |
class SeqCBCrossEntropy(nn.Module): | |
def __init__(self, beta : float = 0.99): | |
super().__init__() | |
self.beta = beta | |
def forward(self, | |
preds: torch.Tensor, | |
labels: torch.Tensor, | |
pad_mask: torch.Tensor): | |
""" | |
Sequential Class-alanced Cross Entropy Loss (Our proposal) | |
Parameters | |
----------- | |
preds: torch.Tensor [batch_size, max_seq_length, num_classes] | |
labels: torch.Tensor [batch_size, max_seq_length] | |
pad_mask: torch.Tensor [batch_size, max_seq_length] | |
Returns | |
------- | |
loss: torch.Tensor [1] | |
""" | |
seq_length_batch = pad_mask.sum(-1) # [batch_size] | |
seq_length_list = torch.unique(seq_length_batch) # [num_unique_seq_length] | |
batch_size = preds.size(0) | |
loss = 0 | |
for seq_length in seq_length_list: | |
extracted_batch = (seq_length_batch == seq_length) # [batch_size] | |
extracted_preds = preds[extracted_batch] # [num_extracted_batch] | |
extracted_labels = labels[extracted_batch] # [num_extracted_batch] | |
extracted_batch_size = extracted_labels.size(0) | |
bin = batched_bincount(extracted_labels.T, 1, extracted_preds.size(-1)) # [seq_length x num_classes] | |
weight = (1 - self.beta) / (1 - self.beta**bin + 1e-8) | |
for seq_no in range(seq_length.item()): | |
loss += (extracted_batch_size / batch_size) * F.nll_loss(extracted_preds[:, seq_no], extracted_labels[:, seq_no], weight=weight[seq_no]) | |
return loss | |
class CBCrossEntropy(nn.Module): | |
def __init__(self, beta : float = 0.99, is_sequential: bool = True): | |
super().__init__() | |
self.beta = beta | |
self.is_sequential = is_sequential | |
def forward(self, | |
preds: torch.Tensor, | |
labels: torch.Tensor, | |
pad_mask: torch.Tensor = None): | |
if self.is_sequential: | |
mask = pad_mask.view(-1) | |
preds = preds.view(-1, preds.size(-1)) | |
bin = labels.view(-1)[mask].bincount() | |
weight = (1 - self.beta) / (1 - self.beta**bin + 1e-8) | |
loss = F.nll_loss(preds[mask], labels.view(-1)[mask], weight=weight) | |
else: | |
bincount = labels.view(-1).bincount() | |
weight = (1 - self.beta) / (1 - self.beta**bincount + 1e-8) | |
loss = F.nll_loss(preds, labels.squeeze(-1), weight=weight) | |
return loss | |
class WeightedCrossEntropy(nn.Module): | |
def __init__(self, is_sequential: bool = True, norm: str = "min"): | |
super().__init__() | |
self.is_sequential = is_sequential | |
if norm == "min": | |
self.norm = torch.min | |
elif norm == "max": | |
self.norm = torch.max | |
def forward(self, | |
preds: torch.Tensor, | |
labels: torch.Tensor, | |
pad_mask: torch.Tensor = None): | |
if self.is_sequential: | |
mask = pad_mask.view(-1) | |
preds = preds.view(-1, preds.size(-1)) | |
bin = labels.view(-1)[mask].bincount() | |
weight = self.norm(bin) / (bin + 1e-8) | |
loss = F.nll_loss(preds[mask], labels.view(-1)[mask], weight=weight) | |
else: | |
bincount = labels.view(-1).bincount() | |
weight = self.norm(bin) / (bin + 1e-8) | |
loss = F.nll_loss(preds, labels.squeeze(-1), weight=weight) | |
return loss | |
class CrossEntropy(nn.Module): | |
def __init__(self, is_sequential: bool = True): | |
super().__init__() | |
self.is_sequential = is_sequential | |
def forward(self, | |
preds: torch.Tensor, | |
labels: torch.Tensor, | |
pad_mask: torch.Tensor = None): | |
if self.is_sequential: | |
mask = pad_mask.view(-1) | |
preds = preds.view(-1, preds.size(-1)) | |
loss = F.nll_loss(preds[mask], labels.view(-1)[mask]) | |
else: | |
loss = F.nll_loss(preds, labels.squeeze(-1)) | |
return loss |