maxmax20160403's picture
INIT
10f957b
raw
history blame
No virus
2.19 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import weight_norm, spectral_norm
class DiscriminatorR(torch.nn.Module):
def __init__(self, hp, resolution):
super(DiscriminatorR, self).__init__()
self.resolution = resolution
self.LRELU_SLOPE = hp.mpd.lReLU_slope
norm_f = weight_norm if hp.mrd.use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(nn.Conv2d(1, 32, (3, 9), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, (3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, (3, 3), padding=(1, 1))),
])
self.conv_post = norm_f(nn.Conv2d(32, 1, (3, 3), padding=(1, 1)))
def forward(self, x):
fmap = []
x = self.spectrogram(x)
x = x.unsqueeze(1)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, self.LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return fmap, x
def spectrogram(self, x):
n_fft, hop_length, win_length = self.resolution
x = F.pad(x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode='reflect')
x = x.squeeze(1)
x = torch.stft(x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=False) #[B, F, TT, 2]
mag = torch.norm(x, p=2, dim =-1) #[B, F, TT]
return mag
class MultiResolutionDiscriminator(torch.nn.Module):
def __init__(self, hp):
super(MultiResolutionDiscriminator, self).__init__()
self.resolutions = eval(hp.mrd.resolutions)
self.discriminators = nn.ModuleList(
[DiscriminatorR(hp, resolution) for resolution in self.resolutions]
)
def forward(self, x):
ret = list()
for disc in self.discriminators:
ret.append(disc(x))
return ret # [(feat, score), (feat, score), (feat, score)]