video-dubbing / TTS /tests /vocoder_tests /test_vocoder_parallel_wavegan_generator.py
artificialguybr's picture
Upload 650 files
45ee559
raw
history blame
No virus
773 Bytes
import numpy as np
import torch
from TTS.vocoder.models.parallel_wavegan_generator import ParallelWaveganGenerator
def test_pwgan_generator():
model = ParallelWaveganGenerator(
in_channels=1,
out_channels=1,
kernel_size=3,
num_res_blocks=30,
stacks=3,
res_channels=64,
gate_channels=128,
skip_channels=64,
aux_channels=80,
dropout=0.0,
bias=True,
use_weight_norm=True,
upsample_factors=[4, 4, 4, 4],
)
dummy_c = torch.rand((2, 80, 5))
output = model(dummy_c)
assert np.all(output.shape == (2, 1, 5 * 256)), output.shape
model.remove_weight_norm()
output = model.inference(dummy_c)
assert np.all(output.shape == (2, 1, (5 + 4) * 256))