File size: 1,479 Bytes
9b2107c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from torch import nn

from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator


class MelganMultiscaleDiscriminator(nn.Module):
    def __init__(
        self,
        in_channels=1,
        out_channels=1,
        num_scales=3,
        kernel_sizes=(5, 3),
        base_channels=16,
        max_channels=1024,
        downsample_factors=(4, 4, 4),
        pooling_kernel_size=4,
        pooling_stride=2,
        pooling_padding=2,
        groups_denominator=4,
    ):
        super().__init__()

        self.discriminators = nn.ModuleList(
            [
                MelganDiscriminator(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_sizes=kernel_sizes,
                    base_channels=base_channels,
                    max_channels=max_channels,
                    downsample_factors=downsample_factors,
                    groups_denominator=groups_denominator,
                )
                for _ in range(num_scales)
            ]
        )

        self.pooling = nn.AvgPool1d(
            kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False
        )

    def forward(self, x):
        scores = []
        feats = []
        for disc in self.discriminators:
            score, feat = disc(x)
            scores.append(score)
            feats.append(feat)
            x = self.pooling(x)
        return scores, feats