xiaoxuezi's picture
2
875baeb
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from utils.acc import accuracy
class AamSoftmax(nn.Module):
def __init__(self, nOut, nClasses, margin=0.2, scale=30, easy_margin=False, **kwargs):
super(AamSoftmax, self).__init__()
self.test_normalize = True
self.m = margin
self.s = scale
self.in_feats = nOut
self.weight = torch.nn.Parameter(torch.FloatTensor(nClasses, nOut), requires_grad=True)
self.ce = nn.CrossEntropyLoss()
nn.init.xavier_normal_(self.weight, gain=1)
self.easy_margin = easy_margin
self.cos_m = math.cos(self.m)
self.sin_m = math.sin(self.m)
# make the function cos(theta+m) monotonic decreasing while theta in [0°,180°]
self.th = math.cos(math.pi - self.m)
self.mm = math.sin(math.pi - self.m) * self.m
print('Initialised AAMSoftmax margin %.3f scale %.3f'%(self.m,self.s))
def forward(self, x, label=None):
assert x.size()[0] == label.size()[0]
assert x.size()[1] == self.in_feats
# cos(theta)
cosine = F.linear(F.normalize(x), F.normalize(self.weight))
# print("cosine:", cosine.shape)
# cos(theta + m)
sine = torch.sqrt((1.0 - torch.mul(cosine, cosine)).clamp(0, 1))
# phi = cos(ø+m)
phi = cosine * self.cos_m - sine * self.sin_m
# print(self.cos_m)
# print("phi:", phi.shape)
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
else:
phi = torch.where((cosine - self.th) > 0, phi, cosine - self.mm)
#one_hot = torch.zeros(cosine.size(), device='cuda' if torch.cuda.is_available() else 'cpu')
one_hot = torch.zeros_like(cosine)
one_hot.scatter_(1, label.view(-1, 1), 1)
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
output = output * self.s
loss = self.ce(output, label)
prec1 = accuracy(output.detach(), label.detach(), topk=(1,))[0]
return loss, prec1
if __name__ == "__main__":
x = torch.randn(32, 512)
y = torch.randint(1000, size=(32,))
print(x.shape, y.shape)
loss = AamSoftmax(512, 1000)
nloss, prec1 = loss(x, y)
print(nloss, prec1)