speaker_verification / lossfunction /aamsoftmaxproto.py
xiaoxuezi's picture
2
875baeb
raw
history blame contribute delete
834 Bytes
import torch
import torch.nn as nn
import lossfunction.aamsoftmax as aamsoftmax
import lossfunction.angleproto as angleproto
class AamSoftmaxProto(nn.Module):
def __init__(self, nOut, nClasses, margin, scale):
super(AamSoftmaxProto, self).__init__()
self.test_normalize = True
self.aamsoftmax = aamsoftmax.AamSoftmax(nOut, nClasses, margin, scale)
self.angleproto = angleproto.AngleProto()
print('Initialised AamSoftmaxPrototypical Loss')
def forward(self, x, label=None):
assert x.size()[1] == 2
nlossS, prec1 = self.aamsoftmax(x.reshape(-1, x.size()[-1]), label.repeat_interleave(2))
nlossP, _ = self.angleproto(x, None)
# print("lossP:", nlossP, "nlossS:", nlossS)
# lossP:0.6678 nlossS:13.6913
return nlossS + nlossP, prec1