Spaces:
Runtime error
Runtime error
| #! /usr/bin/python | |
| # -*- encoding: utf-8 -*- | |
| import torch | |
| import torch.nn as nn | |
| import lossfunction.softmax as softmax | |
| import lossfunction.angleproto as angleproto | |
| class SoftmaxProto(nn.Module): | |
| def __init__(self, nOut, nClasses): | |
| super(SoftmaxProto, self).__init__() | |
| self.test_normalize = True | |
| self.softmax = softmax.Softmax(nOut, nClasses) | |
| self.angleproto = angleproto.AngleProto() | |
| print('Initialised SoftmaxPrototypical Loss') | |
| def forward(self, x, label=None): | |
| if x.size()[1] != 2: | |
| # 2是nPerSpeaker | |
| x = x.reshape(-1, 2, x.size()[-1]).squeeze(1) | |
| assert x.size()[1] == 2 | |
| nlossS, prec1 = self.softmax(x.reshape(-1, x.size()[-1]), label.repeat_interleave(2)) | |
| nlossP, _ = self.angleproto(x, None) | |
| # print("lossP:", nlossP, "nlossS:", nlossS) | |
| # lossP:0.6678 nlossS:13.6913 | |
| # return nlossS + nlossP, prec1 | |
| return nlossS + nlossP | |