# coding=utf-8 # Copyright 2021 The IDEA Authors. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from torch.nn import functional as F class FocalLoss(torch.nn.Module): """Multi-class Focal loss implementation""" def __init__(self, gamma=2, weight=None, ignore_index=-100): super(FocalLoss, self).__init__() self.gamma = gamma self.weight = weight self.ignore_index = ignore_index def forward(self, input, target): """ input: [N, C] target: [N, ] """ logpt = F.log_softmax(input, dim=1) pt = torch.exp(logpt) logpt = (1-pt)**self.gamma * logpt loss = F.nll_loss(logpt, target, self.weight, ignore_index=self.ignore_index) return loss # 交叉熵平滑滤波 防止过拟合 class LabelSmoothingCorrectionCrossEntropy(torch.nn.Module): def __init__(self, eps=0.1, reduction='mean', ignore_index=-100): super(LabelSmoothingCorrectionCrossEntropy, self).__init__() self.eps = eps self.reduction = reduction self.ignore_index = ignore_index def forward(self, output, target): c = output.size()[-1] log_preds = F.log_softmax(output, dim=-1) if self.reduction == 'sum': loss = -log_preds.sum() else: loss = -log_preds.sum(dim=-1) if self.reduction == 'mean': loss = loss.mean() # task specific labels_hat = torch.argmax(output, dim=1) lt_sum = labels_hat + target abs_lt_sub = abs(labels_hat - target) correction_loss = 0 for i in range(c): if lt_sum[i] == 0: pass elif lt_sum[i] == 1: if abs_lt_sub[i] == 1: pass else: correction_loss -= self.eps*(0.5945275813408382) else: correction_loss += self.eps*(1/0.32447699714575207) correction_loss /= c # print(correction_loss) return loss*self.eps/c + (1-self.eps) * \ F.nll_loss(log_preds, target, reduction=self.reduction, ignore_index=self.ignore_index) + correction_loss