import torch import torchaudio import typing as T class MelspecDiscriminator(torch.nn.Module): """mel spectrogram (frequency domain) discriminator""" def __init__(self) -> None: super().__init__() self.SAMPLE_RATE = 48000 # mel filterbank transform self._melspec = torchaudio.transforms.MelSpectrogram( sample_rate=self.SAMPLE_RATE, n_fft=2048, win_length=int(0.025 * self.SAMPLE_RATE), hop_length=int(0.010 * self.SAMPLE_RATE), n_mels=128, power=1, ) # time-frequency 2D convolutions kernel_sizes = [(7, 7), (4, 4), (4, 4), (4, 4)] strides = [(1, 2), (1, 2), (1, 2), (1, 2)] self._convs = torch.nn.ModuleList( [ torch.nn.Sequential( torch.nn.Conv2d( in_channels=1 if i == 0 else 32, out_channels=64, kernel_size=k, stride=s, padding=(1, 2), bias=False, ), torch.nn.BatchNorm2d(num_features=64), torch.nn.GLU(dim=1), ) for i, (k, s) in enumerate(zip(kernel_sizes, strides)) ] ) # output adversarial projection self._postnet = torch.nn.Conv2d( in_channels=32, out_channels=1, kernel_size=(15, 3), stride=(1, 2), ) def forward(self, x: torch.Tensor) -> T.Tuple[torch.Tensor, T.List[torch.Tensor]]: # apply the log-scale mel spectrogram transform x = torch.log(self._melspec(x) + 1e-5) # compute hidden layers and feature maps f = [] for c in self._convs: x = c(x) f.append(x) # apply the output projection and global average pooling x = self._postnet(x) x = x.mean(dim=[-2, -1]) return [(f, x)]