xiaoxuezi's picture
2
875baeb
raw
history blame contribute delete
975 Bytes
#! /usr/bin/python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
import lossfunction.softmax as softmax
import lossfunction.angleproto as angleproto
class SoftmaxProto(nn.Module):
def __init__(self, nOut, nClasses):
super(SoftmaxProto, self).__init__()
self.test_normalize = True
self.softmax = softmax.Softmax(nOut, nClasses)
self.angleproto = angleproto.AngleProto()
print('Initialised SoftmaxPrototypical Loss')
def forward(self, x, label=None):
if x.size()[1] != 2:
# 2是nPerSpeaker
x = x.reshape(-1, 2, x.size()[-1]).squeeze(1)
assert x.size()[1] == 2
nlossS, prec1 = self.softmax(x.reshape(-1, x.size()[-1]), label.repeat_interleave(2))
nlossP, _ = self.angleproto(x, None)
# print("lossP:", nlossP, "nlossS:", nlossS)
# lossP:0.6678 nlossS:13.6913
# return nlossS + nlossP, prec1
return nlossS + nlossP