File size: 1,945 Bytes
3aa4060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
# Adapted from https://github.com/ubisoft/ubisoft-laforge-daft-exprt Apache License Version 2.0
# Unsupervised Domain Adaptation by Backpropagation

import torch
import torch.nn as nn

from torch.autograd import Function
from torch.nn.utils import weight_norm


class GradientReversalFunction(Function):
    @staticmethod
    def forward(ctx, x, lambda_):
        ctx.lambda_ = lambda_
        return x.clone()

    @staticmethod
    def backward(ctx, grads):
        lambda_ = ctx.lambda_
        lambda_ = grads.new_tensor(lambda_)
        dx = -lambda_ * grads
        return dx, None


class GradientReversal(torch.nn.Module):
    ''' Gradient Reversal Layer
            Y. Ganin, V. Lempitsky,
            "Unsupervised Domain Adaptation by Backpropagation",
            in ICML, 2015.
        Forward pass is the identity function
        In the backward pass, upstream gradients are multiplied by -lambda (i.e. gradient are reversed)
    '''

    def __init__(self, lambda_reversal=1):
        super(GradientReversal, self).__init__()
        self.lambda_ = lambda_reversal

    def forward(self, x):
        return GradientReversalFunction.apply(x, self.lambda_)


class SpeakerClassifier(nn.Module):

    def __init__(self, idim, odim):
        super(SpeakerClassifier, self).__init__()
        self.classifier = nn.Sequential(
            GradientReversal(lambda_reversal=1),
            weight_norm(nn.Conv1d(idim, 1024, kernel_size=5, padding=2)),
            nn.ReLU(),
            weight_norm(nn.Conv1d(1024, 1024, kernel_size=5, padding=2)),
            nn.ReLU(),
            weight_norm(nn.Conv1d(1024, odim, kernel_size=5, padding=2))
        )

    def forward(self, x):
        ''' Forward function of Speaker Classifier:
            x = (B, idim, len)
        '''
        # pass through classifier
        outputs = self.classifier(x)  # (B, nb_speakers)
        outputs = torch.mean(outputs, dim=-1)
        return outputs