Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn.utils import weight_norm, spectral_norm | |
class DiscriminatorR(torch.nn.Module): | |
def __init__(self, hp, resolution): | |
super(DiscriminatorR, self).__init__() | |
self.resolution = resolution | |
self.LRELU_SLOPE = hp.mpd.lReLU_slope | |
norm_f = weight_norm if hp.mrd.use_spectral_norm == False else spectral_norm | |
self.convs = nn.ModuleList([ | |
norm_f(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))), | |
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), | |
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), | |
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))), | |
norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))), | |
]) | |
self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1))) | |
def forward(self, x): | |
fmap = [] | |
x = self.spectrogram(x) | |
x = x.unsqueeze(1) | |
for l in self.convs: | |
x = l(x) | |
x = F.leaky_relu(x, self.LRELU_SLOPE) | |
fmap.append(x) | |
x = self.conv_post(x) | |
fmap.append(x) | |
x = torch.flatten(x, 1, -1) | |
return fmap, x | |
def spectrogram(self, x): | |
n_fft, hop_length, win_length = self.resolution | |
x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect') | |
x = x.squeeze(1) | |
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=False) #[B, F, TT, 2] | |
mag = torch.norm(x, p=2, dim =-1) #[B, F, TT] | |
return mag | |
class MultiResolutionDiscriminator(torch.nn.Module): | |
def __init__(self, hp): | |
super(MultiResolutionDiscriminator, self).__init__() | |
self.resolutions = eval(hp.mrd.resolutions) | |
self.discriminators = nn.ModuleList( | |
[DiscriminatorR(hp, resolution) for resolution in self.resolutions] | |
) | |
def forward(self, x): | |
ret = list() | |
for disc in self.discriminators: | |
ret.append(disc(x)) | |
return ret # [(feat, score), (feat, score), (feat, score)] | |