from torch import nn from TTS.vocoder.models.melgan_discriminator import MelganDiscriminator class MelganMultiscaleDiscriminator(nn.Module): def __init__( self, in_channels=1, out_channels=1, num_scales=3, kernel_sizes=(5, 3), base_channels=16, max_channels=1024, downsample_factors=(4, 4, 4), pooling_kernel_size=4, pooling_stride=2, pooling_padding=2, groups_denominator=4, ): super().__init__() self.discriminators = nn.ModuleList( [ MelganDiscriminator( in_channels=in_channels, out_channels=out_channels, kernel_sizes=kernel_sizes, base_channels=base_channels, max_channels=max_channels, downsample_factors=downsample_factors, groups_denominator=groups_denominator, ) for _ in range(num_scales) ] ) self.pooling = nn.AvgPool1d( kernel_size=pooling_kernel_size, stride=pooling_stride, padding=pooling_padding, count_include_pad=False ) def forward(self, x): scores = [] feats = [] for disc in self.discriminators: score, feat = disc(x) scores.append(score) feats.append(feat) x = self.pooling(x) return scores, feats