route-explainer / models /loss_functions.py
daisuke.kikuta
first commit
719d0db
raw
history blame
5.08 kB
import torch
import torch.nn as nn
from utils.utils import batched_bincount
import torch.nn.functional as F
class GeneralCrossEntropy(nn.Module):
def __init__(self, weight_type: str, beta : float = 0.99, is_sequential: bool = True):
super().__init__()
self.weight_type = weight_type
self.beta = beta
if weight_type == "seq_cbce":
assert is_sequential == True
self.loss_func = SeqCBCrossEntropy(beta=beta)
elif weight_type == "cbce":
self.loss_func = CBCrossEntropy(beta=beta, is_sequential=is_sequential)
elif weight_type == "wce":
self.loss_func = WeightedCrossEntropy(is_sequential=is_sequential)
elif weight_type == "ce":
self.loss_func = CrossEntropy(is_sequential=is_sequential)
else:
NotImplementedError
def forward(self,
preds: torch.Tensor,
labels: torch.Tensor,
pad_mask: torch.Tensor = None):
return self.loss_func(preds, labels, pad_mask)
class SeqCBCrossEntropy(nn.Module):
def __init__(self, beta : float = 0.99):
super().__init__()
self.beta = beta
def forward(self,
preds: torch.Tensor,
labels: torch.Tensor,
pad_mask: torch.Tensor):
"""
Sequential Class-alanced Cross Entropy Loss (Our proposal)
Parameters
-----------
preds: torch.Tensor [batch_size, max_seq_length, num_classes]
labels: torch.Tensor [batch_size, max_seq_length]
pad_mask: torch.Tensor [batch_size, max_seq_length]
Returns
-------
loss: torch.Tensor [1]
"""
seq_length_batch = pad_mask.sum(-1) # [batch_size]
seq_length_list = torch.unique(seq_length_batch) # [num_unique_seq_length]
batch_size = preds.size(0)
loss = 0
for seq_length in seq_length_list:
extracted_batch = (seq_length_batch == seq_length) # [batch_size]
extracted_preds = preds[extracted_batch] # [num_extracted_batch]
extracted_labels = labels[extracted_batch] # [num_extracted_batch]
extracted_batch_size = extracted_labels.size(0)
bin = batched_bincount(extracted_labels.T, 1, extracted_preds.size(-1)) # [seq_length x num_classes]
weight = (1 - self.beta) / (1 - self.beta**bin + 1e-8)
for seq_no in range(seq_length.item()):
loss += (extracted_batch_size / batch_size) * F.nll_loss(extracted_preds[:, seq_no], extracted_labels[:, seq_no], weight=weight[seq_no])
return loss
class CBCrossEntropy(nn.Module):
def __init__(self, beta : float = 0.99, is_sequential: bool = True):
super().__init__()
self.beta = beta
self.is_sequential = is_sequential
def forward(self,
preds: torch.Tensor,
labels: torch.Tensor,
pad_mask: torch.Tensor = None):
if self.is_sequential:
mask = pad_mask.view(-1)
preds = preds.view(-1, preds.size(-1))
bin = labels.view(-1)[mask].bincount()
weight = (1 - self.beta) / (1 - self.beta**bin + 1e-8)
loss = F.nll_loss(preds[mask], labels.view(-1)[mask], weight=weight)
else:
bincount = labels.view(-1).bincount()
weight = (1 - self.beta) / (1 - self.beta**bincount + 1e-8)
loss = F.nll_loss(preds, labels.squeeze(-1), weight=weight)
return loss
class WeightedCrossEntropy(nn.Module):
def __init__(self, is_sequential: bool = True, norm: str = "min"):
super().__init__()
self.is_sequential = is_sequential
if norm == "min":
self.norm = torch.min
elif norm == "max":
self.norm = torch.max
def forward(self,
preds: torch.Tensor,
labels: torch.Tensor,
pad_mask: torch.Tensor = None):
if self.is_sequential:
mask = pad_mask.view(-1)
preds = preds.view(-1, preds.size(-1))
bin = labels.view(-1)[mask].bincount()
weight = self.norm(bin) / (bin + 1e-8)
loss = F.nll_loss(preds[mask], labels.view(-1)[mask], weight=weight)
else:
bincount = labels.view(-1).bincount()
weight = self.norm(bin) / (bin + 1e-8)
loss = F.nll_loss(preds, labels.squeeze(-1), weight=weight)
return loss
class CrossEntropy(nn.Module):
def __init__(self, is_sequential: bool = True):
super().__init__()
self.is_sequential = is_sequential
def forward(self,
preds: torch.Tensor,
labels: torch.Tensor,
pad_mask: torch.Tensor = None):
if self.is_sequential:
mask = pad_mask.view(-1)
preds = preds.view(-1, preds.size(-1))
loss = F.nll_loss(preds[mask], labels.view(-1)[mask])
else:
loss = F.nll_loss(preds, labels.squeeze(-1))
return loss