File size: 975 Bytes
875baeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#! /usr/bin/python
# -*- encoding: utf-8 -*-

import torch
import torch.nn as nn
import lossfunction.softmax as softmax
import lossfunction.angleproto as angleproto

class SoftmaxProto(nn.Module):

    def __init__(self, nOut, nClasses):
        super(SoftmaxProto, self).__init__()

        self.test_normalize = True

        self.softmax = softmax.Softmax(nOut, nClasses)
        self.angleproto = angleproto.AngleProto()

        print('Initialised SoftmaxPrototypical Loss')

    def forward(self, x, label=None):

        if x.size()[1] != 2:
            # 2是nPerSpeaker
            x = x.reshape(-1, 2, x.size()[-1]).squeeze(1)

        assert x.size()[1] == 2

        nlossS, prec1 = self.softmax(x.reshape(-1, x.size()[-1]), label.repeat_interleave(2))

        nlossP, _ = self.angleproto(x, None)
        # print("lossP:", nlossP, "nlossS:", nlossS)
        # lossP:0.6678 nlossS:13.6913

        # return nlossS + nlossP, prec1
        return nlossS + nlossP