video-dubbing / TTS /tests /vocoder_tests /test_vocoder_wavernn.py
artificialguybr's picture
Upload 650 files
45ee559
raw
history blame
No virus
1.39 kB
import random
import numpy as np
import torch
from TTS.vocoder.configs import WavernnConfig
from TTS.vocoder.models.wavernn import Wavernn, WavernnArgs
def test_wavernn():
config = WavernnConfig()
config.model_args = WavernnArgs(
rnn_dims=512,
fc_dims=512,
mode="mold",
mulaw=False,
pad=2,
use_aux_net=True,
use_upsample_net=True,
upsample_factors=[4, 8, 8],
feat_dims=80,
compute_dims=128,
res_out_dims=128,
num_res_blocks=10,
)
config.audio.hop_length = 256
config.audio.sample_rate = 2048
dummy_x = torch.rand((2, 1280))
dummy_m = torch.rand((2, 80, 9))
y_size = random.randrange(20, 60)
dummy_y = torch.rand((80, y_size))
# mode: mold
model = Wavernn(config)
output = model(dummy_x, dummy_m)
assert np.all(output.shape == (2, 1280, 30)), output.shape
# mode: gauss
config.model_args.mode = "gauss"
model = Wavernn(config)
output = model(dummy_x, dummy_m)
assert np.all(output.shape == (2, 1280, 2)), output.shape
# mode: quantized
config.model_args.mode = 4
model = Wavernn(config)
output = model(dummy_x, dummy_m)
assert np.all(output.shape == (2, 1280, 2**4)), output.shape
output = model.inference(dummy_y, True, 5500, 550)
assert np.all(output.shape == (256 * (y_size - 1),))