video-dubbing / TTS /tests /vocoder_tests /test_vocoder_parallel_wavegan_discriminator.py
artificialguybr's picture
Upload 650 files
45ee559
raw
history blame
No virus
1.26 kB
import numpy as np
import torch
from TTS.vocoder.models.parallel_wavegan_discriminator import (
ParallelWaveganDiscriminator,
ResidualParallelWaveganDiscriminator,
)
def test_pwgan_disciminator():
model = ParallelWaveganDiscriminator(
in_channels=1,
out_channels=1,
kernel_size=3,
num_layers=10,
conv_channels=64,
dilation_factor=1,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
bias=True,
)
dummy_x = torch.rand((4, 1, 64 * 256))
output = model(dummy_x)
assert np.all(output.shape == (4, 1, 64 * 256))
model.remove_weight_norm()
def test_redisual_pwgan_disciminator():
model = ResidualParallelWaveganDiscriminator(
in_channels=1,
out_channels=1,
kernel_size=3,
num_layers=30,
stacks=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
dropout=0.0,
bias=True,
nonlinear_activation="LeakyReLU",
nonlinear_activation_params={"negative_slope": 0.2},
)
dummy_x = torch.rand((4, 1, 64 * 256))
output = model(dummy_x)
assert np.all(output.shape == (4, 1, 64 * 256))
model.remove_weight_norm()