import random import numpy as np import torch from TTS.vocoder.configs import WavernnConfig from TTS.vocoder.models.wavernn import Wavernn, WavernnArgs def test_wavernn(): config = WavernnConfig() config.model_args = WavernnArgs( rnn_dims=512, fc_dims=512, mode="mold", mulaw=False, pad=2, use_aux_net=True, use_upsample_net=True, upsample_factors=[4, 8, 8], feat_dims=80, compute_dims=128, res_out_dims=128, num_res_blocks=10, ) config.audio.hop_length = 256 config.audio.sample_rate = 2048 dummy_x = torch.rand((2, 1280)) dummy_m = torch.rand((2, 80, 9)) y_size = random.randrange(20, 60) dummy_y = torch.rand((80, y_size)) # mode: mold model = Wavernn(config) output = model(dummy_x, dummy_m) assert np.all(output.shape == (2, 1280, 30)), output.shape # mode: gauss config.model_args.mode = "gauss" model = Wavernn(config) output = model(dummy_x, dummy_m) assert np.all(output.shape == (2, 1280, 2)), output.shape # mode: quantized config.model_args.mode = 4 model = Wavernn(config) output = model(dummy_x, dummy_m) assert np.all(output.shape == (2, 1280, 2**4)), output.shape output = model.inference(dummy_y, True, 5500, 550) assert np.all(output.shape == (256 * (y_size - 1),))