Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import lossfunction.aamsoftmax as aamsoftmax | |
| import lossfunction.angleproto as angleproto | |
| class AamSoftmaxProto(nn.Module): | |
| def __init__(self, nOut, nClasses, margin, scale): | |
| super(AamSoftmaxProto, self).__init__() | |
| self.test_normalize = True | |
| self.aamsoftmax = aamsoftmax.AamSoftmax(nOut, nClasses, margin, scale) | |
| self.angleproto = angleproto.AngleProto() | |
| print('Initialised AamSoftmaxPrototypical Loss') | |
| def forward(self, x, label=None): | |
| assert x.size()[1] == 2 | |
| nlossS, prec1 = self.aamsoftmax(x.reshape(-1, x.size()[-1]), label.repeat_interleave(2)) | |
| nlossP, _ = self.angleproto(x, None) | |
| # print("lossP:", nlossP, "nlossS:", nlossS) | |
| # lossP:0.6678 nlossS:13.6913 | |
| return nlossS + nlossP, prec1 |