import torch from torch import nn def get_loss(name): if name == "cosface": return CosFace() elif name == "arcface": return ArcFace() else: raise ValueError() class CosFace(nn.Module): def __init__(self, s=64.0, m=0.40): super(CosFace, self).__init__() self.s = s self.m = m def forward(self, cosine, label): index = torch.where(label != -1)[0] m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) m_hot.scatter_(1, label[index, None], self.m) cosine[index] -= m_hot ret = cosine * self.s return ret class ArcFace(nn.Module): def __init__(self, s=64.0, m=0.5): super(ArcFace, self).__init__() self.s = s self.m = m def forward(self, cosine: torch.Tensor, label): index = torch.where(label != -1)[0] m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device) m_hot.scatter_(1, label[index, None], self.m) cosine.acos_() cosine[index] += m_hot cosine.cos_().mul_(self.s) return cosine