Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding: utf-8 -*- | |
# Copyright 2019 Tomoki Hayashi | |
# MIT License (https://opensource.org/licenses/MIT) | |
import logging | |
import numpy as np | |
import pytest | |
import torch | |
from parallel_wavegan.losses import DiscriminatorAdversarialLoss | |
from parallel_wavegan.losses import FeatureMatchLoss | |
from parallel_wavegan.losses import GeneratorAdversarialLoss | |
from parallel_wavegan.losses import MultiResolutionSTFTLoss | |
from parallel_wavegan.models import MelGANGenerator | |
from parallel_wavegan.models import MelGANMultiScaleDiscriminator | |
from parallel_wavegan.models import ParallelWaveGANDiscriminator | |
from parallel_wavegan.models import ResidualParallelWaveGANDiscriminator | |
from parallel_wavegan.optimizers import RAdam | |
from test_parallel_wavegan import make_discriminator_args | |
from test_parallel_wavegan import make_mutli_reso_stft_loss_args | |
from test_parallel_wavegan import make_residual_discriminator_args | |
logging.basicConfig( | |
level=logging.DEBUG, | |
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s", | |
) | |
def make_melgan_generator_args(**kwargs): | |
defaults = dict( | |
in_channels=80, | |
out_channels=1, | |
kernel_size=7, | |
channels=512, | |
bias=True, | |
upsample_scales=[8, 8, 2, 2], | |
stack_kernel_size=3, | |
stacks=3, | |
nonlinear_activation="LeakyReLU", | |
nonlinear_activation_params={"negative_slope": 0.2}, | |
pad="ReflectionPad1d", | |
pad_params={}, | |
use_final_nonlinear_activation=True, | |
use_weight_norm=True, | |
use_causal_conv=False, | |
) | |
defaults.update(kwargs) | |
return defaults | |
def make_melgan_discriminator_args(**kwargs): | |
defaults = dict( | |
in_channels=1, | |
out_channels=1, | |
scales=3, | |
downsample_pooling="AvgPool1d", | |
# follow the official implementation setting | |
downsample_pooling_params={ | |
"kernel_size": 4, | |
"stride": 2, | |
"padding": 1, | |
"count_include_pad": False, | |
}, | |
kernel_sizes=[5, 3], | |
channels=16, | |
max_downsample_channels=1024, | |
bias=True, | |
downsample_scales=[4, 4, 4, 4], | |
nonlinear_activation="LeakyReLU", | |
nonlinear_activation_params={"negative_slope": 0.2}, | |
pad="ReflectionPad1d", | |
pad_params={}, | |
use_weight_norm=True, | |
) | |
defaults.update(kwargs) | |
return defaults | |
def test_melgan_trainable(dict_g, dict_d, dict_loss): | |
# setup | |
batch_size = 4 | |
batch_length = 4096 | |
args_g = make_melgan_generator_args(**dict_g) | |
args_d = make_discriminator_args(**dict_d) | |
args_loss = make_mutli_reso_stft_loss_args(**dict_loss) | |
y = torch.randn(batch_size, 1, batch_length) | |
c = torch.randn( | |
batch_size, | |
args_g["in_channels"], | |
batch_length // np.prod(args_g["upsample_scales"]), | |
) | |
model_g = MelGANGenerator(**args_g) | |
model_d = ParallelWaveGANDiscriminator(**args_d) | |
aux_criterion = MultiResolutionSTFTLoss(**args_loss) | |
gen_adv_criterion = GeneratorAdversarialLoss() | |
dis_adv_criterion = DiscriminatorAdversarialLoss() | |
optimizer_g = RAdam(model_g.parameters()) | |
optimizer_d = RAdam(model_d.parameters()) | |
# check generator trainable | |
y_hat = model_g(c) | |
p_hat = model_d(y_hat) | |
adv_loss = gen_adv_criterion(p_hat) | |
sc_loss, mag_loss = aux_criterion(y_hat, y) | |
aux_loss = sc_loss + mag_loss | |
loss_g = adv_loss + aux_loss | |
optimizer_g.zero_grad() | |
loss_g.backward() | |
optimizer_g.step() | |
# check discriminator trainable | |
p = model_d(y) | |
p_hat = model_d(y_hat.detach()) | |
real_loss, fake_loss = dis_adv_criterion(p_hat, p) | |
loss_d = real_loss + fake_loss | |
optimizer_d.zero_grad() | |
loss_d.backward() | |
optimizer_d.step() | |
def test_melgan_trainable_with_residual_discriminator(dict_g, dict_d, dict_loss): | |
# setup | |
batch_size = 4 | |
batch_length = 4096 | |
args_g = make_melgan_generator_args(**dict_g) | |
args_d = make_residual_discriminator_args(**dict_d) | |
args_loss = make_mutli_reso_stft_loss_args(**dict_loss) | |
y = torch.randn(batch_size, 1, batch_length) | |
c = torch.randn( | |
batch_size, | |
args_g["in_channels"], | |
batch_length // np.prod(args_g["upsample_scales"]), | |
) | |
model_g = MelGANGenerator(**args_g) | |
model_d = ResidualParallelWaveGANDiscriminator(**args_d) | |
aux_criterion = MultiResolutionSTFTLoss(**args_loss) | |
gen_adv_criterion = GeneratorAdversarialLoss() | |
dis_adv_criterion = DiscriminatorAdversarialLoss() | |
optimizer_g = RAdam(model_g.parameters()) | |
optimizer_d = RAdam(model_d.parameters()) | |
# check generator trainable | |
y_hat = model_g(c) | |
p_hat = model_d(y_hat) | |
adv_loss = gen_adv_criterion(p_hat) | |
sc_loss, mag_loss = aux_criterion(y_hat, y) | |
aux_loss = sc_loss + mag_loss | |
loss_g = adv_loss + aux_loss | |
optimizer_g.zero_grad() | |
loss_g.backward() | |
optimizer_g.step() | |
# check discriminator trainable | |
p = model_d(y) | |
p_hat = model_d(y_hat.detach()) | |
real_loss, fake_loss = dis_adv_criterion(p_hat, p) | |
loss_d = real_loss + fake_loss | |
optimizer_d.zero_grad() | |
loss_d.backward() | |
optimizer_d.step() | |
def test_melgan_trainable_with_melgan_discriminator(dict_g, dict_d, dict_loss): | |
# setup | |
batch_size = 4 | |
batch_length = 4096 | |
args_g = make_melgan_generator_args(**dict_g) | |
args_d = make_melgan_discriminator_args(**dict_d) | |
args_loss = make_mutli_reso_stft_loss_args(**dict_loss) | |
y = torch.randn(batch_size, 1, batch_length) | |
c = torch.randn( | |
batch_size, | |
args_g["in_channels"], | |
batch_length // np.prod(args_g["upsample_scales"]), | |
) | |
model_g = MelGANGenerator(**args_g) | |
model_d = MelGANMultiScaleDiscriminator(**args_d) | |
aux_criterion = MultiResolutionSTFTLoss(**args_loss) | |
feat_match_criterion = FeatureMatchLoss() | |
gen_adv_criterion = GeneratorAdversarialLoss() | |
dis_adv_criterion = DiscriminatorAdversarialLoss() | |
optimizer_g = RAdam(model_g.parameters()) | |
optimizer_d = RAdam(model_d.parameters()) | |
# check generator trainable | |
y_hat = model_g(c) | |
p_hat = model_d(y_hat) | |
sc_loss, mag_loss = aux_criterion(y_hat, y) | |
aux_loss = sc_loss + mag_loss | |
adv_loss = gen_adv_criterion(p_hat) | |
with torch.no_grad(): | |
p = model_d(y) | |
fm_loss = feat_match_criterion(p_hat, p) | |
loss_g = adv_loss + aux_loss + fm_loss | |
optimizer_g.zero_grad() | |
loss_g.backward() | |
optimizer_g.step() | |
# check discriminator trainable | |
p = model_d(y) | |
p_hat = model_d(y_hat.detach()) | |
real_loss, fake_loss = dis_adv_criterion(p_hat, p) | |
loss_d = real_loss + fake_loss | |
optimizer_d.zero_grad() | |
loss_d.backward() | |
optimizer_d.step() | |
def test_causal_melgan(dict_g): | |
batch_size = 4 | |
batch_length = 4096 | |
args_g = make_melgan_generator_args(**dict_g) | |
upsampling_factor = np.prod(args_g["upsample_scales"]) | |
c = torch.randn( | |
batch_size, args_g["in_channels"], batch_length // upsampling_factor | |
) | |
model_g = MelGANGenerator(**args_g) | |
c_ = c.clone() | |
c_[..., c.size(-1) // 2 :] = torch.randn(c[..., c.size(-1) // 2 :].shape) | |
try: | |
# check not equal | |
np.testing.assert_array_equal(c.numpy(), c_.numpy()) | |
except AssertionError: | |
pass | |
else: | |
raise AssertionError("Must be different.") | |
# check causality | |
y = model_g(c) | |
y_ = model_g(c_) | |
assert y.size(2) == c.size(2) * upsampling_factor | |
np.testing.assert_array_equal( | |
y[..., : c.size(-1) // 2 * upsampling_factor].detach().cpu().numpy(), | |
y_[..., : c_.size(-1) // 2 * upsampling_factor].detach().cpu().numpy(), | |
) | |