import torch import torch.nn as nn import lossfunction.aamsoftmax as aamsoftmax import lossfunction.angleproto as angleproto class AamSoftmaxProto(nn.Module): def __init__(self, nOut, nClasses, margin, scale): super(AamSoftmaxProto, self).__init__() self.test_normalize = True self.aamsoftmax = aamsoftmax.AamSoftmax(nOut, nClasses, margin, scale) self.angleproto = angleproto.AngleProto() print('Initialised AamSoftmaxPrototypical Loss') def forward(self, x, label=None): assert x.size()[1] == 2 nlossS, prec1 = self.aamsoftmax(x.reshape(-1, x.size()[-1]), label.repeat_interleave(2)) nlossP, _ = self.angleproto(x, None) # print("lossP:", nlossP, "nlossS:", nlossS) # lossP:0.6678 nlossS:13.6913 return nlossS + nlossP, prec1