VQMIVC / ParallelWaveGAN /test /test_style_melgan.py
akhaliq3
spaces demo
2b7bf83
raw
history blame
5.15 kB
#!/usr/bin/env python3
# Copyright 2021 Tomoki Hayashi
# MIT License (https://opensource.org/licenses/MIT)
"""Test code for StyleMelGAN modules."""
import logging
import numpy as np
import pytest
import torch
from parallel_wavegan.losses import DiscriminatorAdversarialLoss
from parallel_wavegan.losses import GeneratorAdversarialLoss
from parallel_wavegan.losses import MultiResolutionSTFTLoss
from parallel_wavegan.models import StyleMelGANDiscriminator
from parallel_wavegan.models import StyleMelGANGenerator
from test_parallel_wavegan import make_mutli_reso_stft_loss_args
logging.basicConfig(
level=logging.DEBUG,
format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
)
def make_style_melgan_generator_args(**kwargs):
defaults = dict(
in_channels=128,
aux_channels=80,
channels=64,
out_channels=1,
kernel_size=9,
dilation=2,
bias=True,
noise_upsample_scales=[11, 2, 2, 2],
noise_upsample_activation="LeakyReLU",
noise_upsample_activation_params={"negative_slope": 0.2},
upsample_scales=[2, 2, 2, 2, 2, 2, 2, 2, 1],
upsample_mode="nearest",
gated_function="softmax",
use_weight_norm=True,
)
defaults.update(kwargs)
return defaults
def make_style_melgan_discriminator_args(**kwargs):
defaults = dict(
repeats=2,
window_sizes=[512, 1024, 2048, 4096],
pqmf_params=[
[1, None, None, None],
[2, 62, 0.26700, 9.0],
[4, 62, 0.14200, 9.0],
[8, 62, 0.07949, 9.0],
],
discriminator_params={
"out_channels": 1,
"kernel_sizes": [5, 3],
"channels": 16,
"max_downsample_channels": 32,
"bias": True,
"downsample_scales": [4, 4, 4, 1],
"nonlinear_activation": "LeakyReLU",
"nonlinear_activation_params": {"negative_slope": 0.2},
"pad": "ReflectionPad1d",
"pad_params": {},
},
use_weight_norm=True,
)
defaults.update(kwargs)
return defaults
@pytest.mark.parametrize(
"dict_d",
[
{"repeats": 1},
{"repeats": 4},
],
)
def test_style_melgan_discriminator(dict_d):
batch_size = 4
batch_length = 2 ** 14
args_d = make_style_melgan_discriminator_args(**dict_d)
y = torch.randn(batch_size, 1, batch_length)
model_d = StyleMelGANDiscriminator(**args_d)
gen_adv_criterion = GeneratorAdversarialLoss()
outs = model_d(y)
gen_adv_criterion(outs)
@pytest.mark.parametrize(
"dict_g",
[
{},
{"noise_upsample_scales": [4, 4, 4]},
],
)
def test_style_melgan_generator(dict_g):
args_g = make_style_melgan_generator_args(**dict_g)
batch_size = 4
batch_length = np.prod(args_g["noise_upsample_scales"]) * np.prod(
args_g["upsample_scales"]
)
z = torch.randn(batch_size, args_g["in_channels"], 1)
c = torch.randn(
batch_size,
args_g["aux_channels"],
batch_length // np.prod(args_g["upsample_scales"]),
)
model_g = StyleMelGANGenerator(**args_g)
model_g(c, z)
# inference
c = torch.randn(
512,
args_g["aux_channels"],
)
y = model_g.inference(c)
print(y.shape)
@pytest.mark.parametrize(
"dict_g, dict_d, dict_loss, loss_type",
[
({}, {}, {}, "mse"),
({}, {}, {}, "hinge"),
({"noise_upsample_scales": [4, 4, 4]}, {}, {}, "mse"),
({"gated_function": "sigmoid"}, {}, {}, "mse"),
],
)
def test_style_melgan_trainable(dict_g, dict_d, dict_loss, loss_type):
# setup
args_g = make_style_melgan_generator_args(**dict_g)
args_d = make_style_melgan_discriminator_args(**dict_d)
args_loss = make_mutli_reso_stft_loss_args(**dict_loss)
batch_size = 4
batch_length = np.prod(args_g["noise_upsample_scales"]) * np.prod(
args_g["upsample_scales"]
)
y = torch.randn(batch_size, 1, batch_length)
c = torch.randn(
batch_size,
args_g["aux_channels"],
batch_length // np.prod(args_g["upsample_scales"]),
)
model_g = StyleMelGANGenerator(**args_g)
model_d = StyleMelGANDiscriminator(**args_d)
aux_criterion = MultiResolutionSTFTLoss(**args_loss)
gen_adv_criterion = GeneratorAdversarialLoss(loss_type=loss_type)
dis_adv_criterion = DiscriminatorAdversarialLoss(loss_type=loss_type)
optimizer_g = torch.optim.Adam(model_g.parameters())
optimizer_d = torch.optim.Adam(model_d.parameters())
# check generator trainable
y_hat = model_g(c)
p_hat = model_d(y_hat)
adv_loss = gen_adv_criterion(p_hat)
sc_loss, mag_loss = aux_criterion(y_hat, y)
aux_loss = sc_loss + mag_loss
loss_g = adv_loss + aux_loss
optimizer_g.zero_grad()
loss_g.backward()
optimizer_g.step()
# check discriminator trainable
p = model_d(y)
p_hat = model_d(y_hat.detach())
real_loss, fake_loss = dis_adv_criterion(p_hat, p)
loss_d = real_loss + fake_loss
optimizer_d.zero_grad()
loss_d.backward()
optimizer_d.step()