File size: 3,777 Bytes
6faeba1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import torch.nn as nn


def weights_init_D(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
    elif classname.find('BatchNorm') != -1:
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)


class SpectrogramDiscriminator(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.D = DiscriminatorNet()
        self.D.apply(weights_init_D)

    def _generator_feedback(self, data_generated, data_real):
        for p in self.D.parameters():
            p.requires_grad = False  # freeze critic

        score_fake, fmap_fake = self.D(data_generated)
        _, fmap_real = self.D(data_real)

        feature_matching_loss = 0.0
        for feat_fake, feat_real in zip(fmap_fake, fmap_real):
            feature_matching_loss += nn.functional.l1_loss(feat_fake, feat_real.detach())

        discr_loss = nn.functional.mse_loss(input=score_fake, target=torch.ones(score_fake.shape, device=score_fake.device), reduction="mean")

        return feature_matching_loss + discr_loss

    def _discriminator_feature_matching(self, data_generated, data_real):
        for p in self.D.parameters():
            p.requires_grad = True  # unfreeze critic
        self.D.train()

        score_fake, _ = self.D(data_generated)
        score_real, _ = self.D(data_real)

        discr_loss = 0.0
        discr_loss = discr_loss + nn.functional.mse_loss(input=score_fake, target=torch.zeros(score_fake.shape, device=score_fake.device), reduction="mean")
        discr_loss = discr_loss + nn.functional.mse_loss(input=score_real, target=torch.ones(score_real.shape, device=score_real.device), reduction="mean")

        return discr_loss

    def calc_discriminator_loss(self, data_generated, data_real):
        return self._discriminator_feature_matching(data_generated.detach(), data_real)

    def calc_generator_feedback(self, data_generated, data_real):
        return self._generator_feedback(data_generated, data_real)


class DiscriminatorNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.filters = nn.ModuleList([
            nn.utils.weight_norm(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
            nn.utils.weight_norm(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
            nn.utils.weight_norm(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
            nn.utils.weight_norm(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
            nn.utils.weight_norm(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
        ])

        self.out = nn.utils.weight_norm(nn.Conv2d(32, 1, 3, 1, 1))

        self.fc = nn.Linear(900, 1)  # this needs to be changed everytime the window length is changes. It would be nice if this could be done dynamically.

    def forward(self, y):
        feature_maps = list()
        feature_maps.append(y)
        for d in self.filters:
            y = d(y)
            feature_maps.append(y)
            y = nn.functional.leaky_relu(y, 0.1)
        y = self.out(y)
        feature_maps.append(y)
        y = torch.flatten(y, 1, -1)
        y = self.fc(y)

        return y, feature_maps


if __name__ == '__main__':
    d = SpectrogramDiscriminator()
    fake = torch.randn([2, 100, 72])  # [Batch, Sequence Length, Spectrogram Buckets]
    real = torch.randn([2, 100, 72])  # [Batch, Sequence Length, Spectrogram Buckets]

    critic_loss = d.calc_discriminator_loss((fake.unsqueeze(1)), real.unsqueeze(1))
    generator_loss = d.calc_generator_feedback(fake.unsqueeze(1), real.unsqueeze(1))
    print(critic_loss)
    print(generator_loss)