Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# Copyright 2021 Tomoki Hayashi | |
# MIT License (https://opensource.org/licenses/MIT) | |
"""Test code for StyleMelGAN modules.""" | |
import logging | |
import numpy as np | |
import pytest | |
import torch | |
from parallel_wavegan.losses import DiscriminatorAdversarialLoss | |
from parallel_wavegan.losses import GeneratorAdversarialLoss | |
from parallel_wavegan.losses import MultiResolutionSTFTLoss | |
from parallel_wavegan.models import StyleMelGANDiscriminator | |
from parallel_wavegan.models import StyleMelGANGenerator | |
from test_parallel_wavegan import make_mutli_reso_stft_loss_args | |
logging.basicConfig( | |
level=logging.DEBUG, | |
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
) | |
def make_style_melgan_generator_args(**kwargs): | |
defaults = dict( | |
in_channels=128, | |
aux_channels=80, | |
channels=64, | |
out_channels=1, | |
kernel_size=9, | |
dilation=2, | |
bias=True, | |
noise_upsample_scales=[11, 2, 2, 2], | |
noise_upsample_activation="LeakyReLU", | |
noise_upsample_activation_params={"negative_slope": 0.2}, | |
upsample_scales=[2, 2, 2, 2, 2, 2, 2, 2, 1], | |
upsample_mode="nearest", | |
gated_function="softmax", | |
use_weight_norm=True, | |
) | |
defaults.update(kwargs) | |
return defaults | |
def make_style_melgan_discriminator_args(**kwargs): | |
defaults = dict( | |
repeats=2, | |
window_sizes=[512, 1024, 2048, 4096], | |
pqmf_params=[ | |
[1, None, None, None], | |
[2, 62, 0.26700, 9.0], | |
[4, 62, 0.14200, 9.0], | |
[8, 62, 0.07949, 9.0], | |
], | |
discriminator_params={ | |
"out_channels": 1, | |
"kernel_sizes": [5, 3], | |
"channels": 16, | |
"max_downsample_channels": 32, | |
"bias": True, | |
"downsample_scales": [4, 4, 4, 1], | |
"nonlinear_activation": "LeakyReLU", | |
"nonlinear_activation_params": {"negative_slope": 0.2}, | |
"pad": "ReflectionPad1d", | |
"pad_params": {}, | |
}, | |
use_weight_norm=True, | |
) | |
defaults.update(kwargs) | |
return defaults | |
def test_style_melgan_discriminator(dict_d): | |
batch_size = 4 | |
batch_length = 2 ** 14 | |
args_d = make_style_melgan_discriminator_args(**dict_d) | |
y = torch.randn(batch_size, 1, batch_length) | |
model_d = StyleMelGANDiscriminator(**args_d) | |
gen_adv_criterion = GeneratorAdversarialLoss() | |
outs = model_d(y) | |
gen_adv_criterion(outs) | |
def test_style_melgan_generator(dict_g): | |
args_g = make_style_melgan_generator_args(**dict_g) | |
batch_size = 4 | |
batch_length = np.prod(args_g["noise_upsample_scales"]) * np.prod( | |
args_g["upsample_scales"] | |
) | |
z = torch.randn(batch_size, args_g["in_channels"], 1) | |
c = torch.randn( | |
batch_size, | |
args_g["aux_channels"], | |
batch_length // np.prod(args_g["upsample_scales"]), | |
) | |
model_g = StyleMelGANGenerator(**args_g) | |
model_g(c, z) | |
# inference | |
c = torch.randn( | |
512, | |
args_g["aux_channels"], | |
) | |
y = model_g.inference(c) | |
print(y.shape) | |
def test_style_melgan_trainable(dict_g, dict_d, dict_loss, loss_type): | |
# setup | |
args_g = make_style_melgan_generator_args(**dict_g) | |
args_d = make_style_melgan_discriminator_args(**dict_d) | |
args_loss = make_mutli_reso_stft_loss_args(**dict_loss) | |
batch_size = 4 | |
batch_length = np.prod(args_g["noise_upsample_scales"]) * np.prod( | |
args_g["upsample_scales"] | |
) | |
y = torch.randn(batch_size, 1, batch_length) | |
c = torch.randn( | |
batch_size, | |
args_g["aux_channels"], | |
batch_length // np.prod(args_g["upsample_scales"]), | |
) | |
model_g = StyleMelGANGenerator(**args_g) | |
model_d = StyleMelGANDiscriminator(**args_d) | |
aux_criterion = MultiResolutionSTFTLoss(**args_loss) | |
gen_adv_criterion = GeneratorAdversarialLoss(loss_type=loss_type) | |
dis_adv_criterion = DiscriminatorAdversarialLoss(loss_type=loss_type) | |
optimizer_g = torch.optim.Adam(model_g.parameters()) | |
optimizer_d = torch.optim.Adam(model_d.parameters()) | |
# check generator trainable | |
y_hat = model_g(c) | |
p_hat = model_d(y_hat) | |
adv_loss = gen_adv_criterion(p_hat) | |
sc_loss, mag_loss = aux_criterion(y_hat, y) | |
aux_loss = sc_loss + mag_loss | |
loss_g = adv_loss + aux_loss | |
optimizer_g.zero_grad() | |
loss_g.backward() | |
optimizer_g.step() | |
# check discriminator trainable | |
p = model_d(y) | |
p_hat = model_d(y_hat.detach()) | |
real_loss, fake_loss = dis_adv_criterion(p_hat, p) | |
loss_d = real_loss + fake_loss | |
optimizer_d.zero_grad() | |
loss_d.backward() | |
optimizer_d.step() | |