# -*- coding: utf-8 -*- # @Time : 2022/2/17 6:05 下午 # @Author : JianingWang # @File : loss import torch from torch import nn import torch.nn.functional as F class FocalLoss(nn.Module): """Multi-class Focal loss implementation""" def __init__(self, gamma=2, weight=None, ignore_index=-100): super(FocalLoss, self).__init__() self.gamma = gamma self.weight = weight self.ignore_index = ignore_index def forward(self, input, target): """ input: [N, C] target: [N, ] """ logpt = F.log_softmax(input, dim=1) pt = torch.exp(logpt) logpt = (1 - pt) ** self.gamma * logpt loss = F.nll_loss(logpt, target, self.weight, ignore_index=self.ignore_index) return loss