File size: 2,030 Bytes
9791162
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
import torchaudio
import typing as T


class MelspecDiscriminator(torch.nn.Module):
    """mel spectrogram (frequency domain) discriminator"""

    def __init__(self) -> None:
        super().__init__()
        self.SAMPLE_RATE = 48000
        # mel filterbank transform
        self._melspec = torchaudio.transforms.MelSpectrogram(
            sample_rate=self.SAMPLE_RATE,
            n_fft=2048,
            win_length=int(0.025 * self.SAMPLE_RATE),
            hop_length=int(0.010 * self.SAMPLE_RATE),
            n_mels=128,
            power=1,
        )

        # time-frequency 2D convolutions
        kernel_sizes = [(7, 7), (4, 4), (4, 4), (4, 4)]
        strides = [(1, 2), (1, 2), (1, 2), (1, 2)]
        self._convs = torch.nn.ModuleList(
            [
                torch.nn.Sequential(
                    torch.nn.Conv2d(
                        in_channels=1 if i == 0 else 32,
                        out_channels=64,
                        kernel_size=k,
                        stride=s,
                        padding=(1, 2),
                        bias=False,
                    ),
                    torch.nn.BatchNorm2d(num_features=64),
                    torch.nn.GLU(dim=1),
                )
                for i, (k, s) in enumerate(zip(kernel_sizes, strides))
            ]
        )

        # output adversarial projection
        self._postnet = torch.nn.Conv2d(
            in_channels=32,
            out_channels=1,
            kernel_size=(15, 3),
            stride=(1, 2),
        )

    def forward(self, x: torch.Tensor) -> T.Tuple[torch.Tensor, T.List[torch.Tensor]]:
        # apply the log-scale mel spectrogram transform
        x = torch.log(self._melspec(x) + 1e-5)

        # compute hidden layers and feature maps
        f = []
        for c in self._convs:
            x = c(x)
            f.append(x)

        # apply the output projection and global average pooling
        x = self._postnet(x)
        x = x.mean(dim=[-2, -1])

        return [(f, x)]