maxmax20160403's picture
final ver
c24b656
raw
history blame contribute delete
No virus
2.03 kB
import torch
import torchaudio
import typing as T
class MelspecDiscriminator(torch.nn.Module):
"""mel spectrogram (frequency domain) discriminator"""
def __init__(self) -> None:
super().__init__()
self.SAMPLE_RATE = 48000
# mel filterbank transform
self._melspec = torchaudio.transforms.MelSpectrogram(
sample_rate=self.SAMPLE_RATE,
n_fft=2048,
win_length=int(0.025 * self.SAMPLE_RATE),
hop_length=int(0.010 * self.SAMPLE_RATE),
n_mels=128,
power=1,
)
# time-frequency 2D convolutions
kernel_sizes = [(7, 7), (4, 4), (4, 4), (4, 4)]
strides = [(1, 2), (1, 2), (1, 2), (1, 2)]
self._convs = torch.nn.ModuleList(
[
torch.nn.Sequential(
torch.nn.Conv2d(
in_channels=1 if i == 0 else 32,
out_channels=64,
kernel_size=k,
stride=s,
padding=(1, 2),
bias=False,
),
torch.nn.BatchNorm2d(num_features=64),
torch.nn.GLU(dim=1),
)
for i, (k, s) in enumerate(zip(kernel_sizes, strides))
]
)
# output adversarial projection
self._postnet = torch.nn.Conv2d(
in_channels=32,
out_channels=1,
kernel_size=(15, 3),
stride=(1, 2),
)
def forward(self, x: torch.Tensor) -> T.Tuple[torch.Tensor, T.List[torch.Tensor]]:
# apply the log-scale mel spectrogram transform
x = torch.log(self._melspec(x) + 1e-5)
# compute hidden layers and feature maps
f = []
for c in self._convs:
x = c(x)
f.append(x)
# apply the output projection and global average pooling
x = self._postnet(x)
x = x.mean(dim=[-2, -1])
return [(f, x)]