import torch.nn as nn import torch.nn.functional as F import torch class FocalLoss(nn.Module): def __init__(self, alpha=1, gamma=0, size_average=True, ignore_index=255): super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.ignore_index = ignore_index self.size_average = size_average def forward(self, inputs, targets): ce_loss = F.cross_entropy( inputs, targets, reduction='none', ignore_index=self.ignore_index) pt = torch.exp(-ce_loss) focal_loss = self.alpha * (1-pt)**self.gamma * ce_loss if self.size_average: return focal_loss.mean() else: return focal_loss.sum()