Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# @Time : 2022/2/17 6:05 下午 | |
# @Author : JianingWang | |
# @File : loss | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
class FocalLoss(nn.Module): | |
"""Multi-class Focal loss implementation""" | |
def __init__(self, gamma=2, weight=None, ignore_index=-100): | |
super(FocalLoss, self).__init__() | |
self.gamma = gamma | |
self.weight = weight | |
self.ignore_index = ignore_index | |
def forward(self, input, target): | |
""" | |
input: [N, C] | |
target: [N, ] | |
""" | |
logpt = F.log_softmax(input, dim=1) | |
pt = torch.exp(logpt) | |
logpt = (1 - pt) ** self.gamma * logpt | |
loss = F.nll_loss(logpt, target, self.weight, ignore_index=self.ignore_index) | |
return loss | |