File size: 4,900 Bytes
59b7eeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import copy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from common.audio import stft
from torch.nn.utils import weight_norm, spectral_norm
from torch.nn import Conv1d
from einops import rearrange

class SpecDiscriminator(nn.Module):
    def __init__(self,
                 stft_params=None,
                 in_channels=1,
                 out_channels=1,
                 kernel_sizes=(7, 3),
                 channels=32,
                 max_downsample_channels=512,
                 downsample_scales=(2, 2, 2),
                 use_weight_norm=True,
                 ):
        super().__init__()

        if stft_params is None:
            stft_params = {
                'fft_sizes': [1024, 2048, 512],
                'hop_sizes': [120, 240, 50],
                'win_lengths': [600, 1200, 240],
                'window': 'hann_window'
            }

        self.stft_params = stft_params
        
        self.model = nn.ModuleDict()
        for i in range(len(stft_params['fft_sizes'])):
            self.model["disc_" + str(i)] = NLayerSpecDiscriminator(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_sizes=kernel_sizes,
                channels=channels,
                max_downsample_channels=max_downsample_channels,
                downsample_scales=downsample_scales,
            )

        if use_weight_norm:
            self.apply_weight_norm()
        self.reset_parameters()

    def forward(self, x):
        results = []
        i = 0
        x = x.squeeze(1)
        for _, disc in self.model.items():
            spec = stft(x, self.stft_params['fft_sizes'][i], self.stft_params['hop_sizes'][i],
                        self.stft_params['win_lengths'][i],
                        window=getattr(torch, self.stft_params['window'])(self.stft_params['win_lengths'][i])) # [B, T, F]
            spec = spec.transpose(1, 2).unsqueeze(1) # [B, 1, F, T]
            results.append(disc(spec))
            i += 1
        return results

    def remove_weight_norm(self):
        def _remove_weight_norm(m):
            try:
                torch.nn.utils.remove_weight_norm(m)
            except ValueError:  # this module didn't have weight norm
                return
        self.apply(_remove_weight_norm)

    def apply_weight_norm(self):
        def _apply_weight_norm(m):
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                torch.nn.utils.weight_norm(m)
        self.apply(_apply_weight_norm)

    def reset_parameters(self):
        def _reset_parameters(m):
            if isinstance(m, nn.Conv1d) or isinstance(m, nn.ConvTranspose1d) or isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
                m.weight.data.normal_(0.0, 0.02)
        self.apply(_reset_parameters)


class NLayerSpecDiscriminator(nn.Module):
    def __init__(self,
                 in_channels=1,
                 out_channels=1,
                 kernel_sizes=(5, 3),
                 channels=32,
                 max_downsample_channels=512,
                 downsample_scales=(2, 2, 2)):
        super().__init__()

        # check kernel size is valid
        assert kernel_sizes[0] % 2 == 1
        assert kernel_sizes[1] % 2 == 1

        model = nn.ModuleDict()

        model["layer_0"] = nn.Sequential(
            nn.Conv2d(in_channels, channels,
                      kernel_size=kernel_sizes[0],
                      stride=2,
                      padding=kernel_sizes[0] // 2),
            nn.LeakyReLU(0.2, True),
        )

        in_chs = channels
        for i, downsample_scale in enumerate(downsample_scales):
            out_chs = min(in_chs * downsample_scale, max_downsample_channels)

            model[f"layer_{i + 1}"] = nn.Sequential(
                nn.Conv2d(
                    in_chs,
                    out_chs,
                    kernel_size=downsample_scale * 2 + 1,
                    stride=downsample_scale,
                    padding=downsample_scale,
                ),
                nn.LeakyReLU(0.2, True),
            )
            in_chs = out_chs

        out_chs = min(in_chs * 2, max_downsample_channels)
        model[f"layer_{len(downsample_scales) + 1}"] = nn.Sequential(
            nn.Conv2d(in_chs, out_chs, kernel_size=kernel_sizes[1],
                      padding=kernel_sizes[1] // 2),
            nn.LeakyReLU(0.2, True),
        )

        model[f"layer_{len(downsample_scales) + 2}"] = nn.Conv2d(
            out_chs, out_channels, kernel_size=kernel_sizes[1],
            padding=kernel_sizes[1] // 2)

        self.model = model

    def forward(self, x):
        results = []
        for _, layer in self.model.items():
            x = layer(x)
            results.append(x)
        return results