tts-vie / TTS /tests /tts_tests /test_overflow.py
Nông Văn Thắng
main
33acd27
import os
import random
import unittest
from copy import deepcopy
import torch
from tests import get_tests_output_path
from TTS.tts.configs.overflow_config import OverflowConfig
from TTS.tts.layers.overflow.common_layers import Encoder, Outputnet, OverflowUtils
from TTS.tts.layers.overflow.decoder import Decoder
from TTS.tts.layers.overflow.neural_hmm import EmissionModel, NeuralHMM, TransitionModel
from TTS.tts.models.overflow import Overflow
from TTS.tts.utils.helpers import sequence_mask
from TTS.utils.audio import AudioProcessor
# pylint: disable=unused-variable
torch.manual_seed(1)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
config_global = OverflowConfig(num_chars=24)
ap = AudioProcessor.init_from_config(config_global)
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
output_path = os.path.join(get_tests_output_path(), "train_outputs")
parameter_path = os.path.join(get_tests_output_path(), "lj_parameters.pt")
torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path)
def _create_inputs(batch_size=8):
max_len_t, max_len_m = random.randint(25, 50), random.randint(50, 80)
input_dummy = torch.randint(0, 24, (batch_size, max_len_t)).long().to(device)
input_lengths = torch.randint(20, max_len_t, (batch_size,)).long().to(device).sort(descending=True)[0]
input_lengths[0] = max_len_t
input_dummy = input_dummy * sequence_mask(input_lengths)
mel_spec = torch.randn(batch_size, max_len_m, config_global.audio["num_mels"]).to(device)
mel_lengths = torch.randint(40, max_len_m, (batch_size,)).long().to(device).sort(descending=True)[0]
mel_lengths[0] = max_len_m
mel_spec = mel_spec * sequence_mask(mel_lengths).unsqueeze(2)
return input_dummy, input_lengths, mel_spec, mel_lengths
def get_model(config=None):
if config is None:
config = config_global
config.mel_statistics_parameter_path = parameter_path
model = Overflow(config)
model = model.to(device)
return model
def reset_all_weights(model):
"""
refs:
- https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6
- https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch
- https://pytorch.org/docs/stable/generated/torch.nn.Module.html
"""
@torch.no_grad()
def weight_reset(m):
# - check if the current module has reset_parameters & if it's callabed called it on m
reset_parameters = getattr(m, "reset_parameters", None)
if callable(reset_parameters):
m.reset_parameters()
# Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
model.apply(fn=weight_reset)
class TestOverflow(unittest.TestCase):
def test_forward(self):
model = get_model()
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs()
outputs = model(input_dummy, input_lengths, mel_spec, mel_lengths)
self.assertEqual(outputs["log_probs"].shape, (input_dummy.shape[0],))
self.assertEqual(model.state_per_phone * max(input_lengths), outputs["alignments"].shape[2])
def test_inference(self):
model = get_model()
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs()
output_dict = model.inference(input_dummy)
self.assertEqual(output_dict["model_outputs"].shape[2], config_global.out_channels)
def test_init_from_config(self):
config = deepcopy(config_global)
config.mel_statistics_parameter_path = parameter_path
config.prenet_dim = 256
model = Overflow.init_from_config(config_global)
self.assertEqual(model.prenet_dim, config.prenet_dim)
class TestOverflowEncoder(unittest.TestCase):
@staticmethod
def get_encoder(state_per_phone):
config = deepcopy(config_global)
config.state_per_phone = state_per_phone
config.num_chars = 24
return Encoder(config.num_chars, config.state_per_phone, config.prenet_dim, config.encoder_n_convolutions).to(
device
)
def test_forward_with_state_per_phone_multiplication(self):
for s_p_p in [1, 2, 3]:
input_dummy, input_lengths, _, _ = _create_inputs()
model = self.get_encoder(s_p_p)
x, x_len = model(input_dummy, input_lengths)
self.assertEqual(x.shape[1], input_dummy.shape[1] * s_p_p)
def test_inference_with_state_per_phone_multiplication(self):
for s_p_p in [1, 2, 3]:
input_dummy, input_lengths, _, _ = _create_inputs()
model = self.get_encoder(s_p_p)
x, x_len = model.inference(input_dummy, input_lengths)
self.assertEqual(x.shape[1], input_dummy.shape[1] * s_p_p)
class TestOverflowUtils(unittest.TestCase):
def test_logsumexp(self):
a = torch.randn(10) # random numbers
self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all())
a = torch.zeros(10) # all zeros
self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all())
a = torch.ones(10) # all ones
self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all())
class TestOverflowDecoder(unittest.TestCase):
@staticmethod
def _get_decoder(num_flow_blocks_dec=None, hidden_channels_dec=None, reset_weights=True):
config = deepcopy(config_global)
config.num_flow_blocks_dec = (
num_flow_blocks_dec if num_flow_blocks_dec is not None else config.num_flow_blocks_dec
)
config.hidden_channels_dec = (
hidden_channels_dec if hidden_channels_dec is not None else config.hidden_channels_dec
)
config.dropout_p_dec = 0.0 # turn off dropout to check invertibility
decoder = Decoder(
config.out_channels,
config.hidden_channels_dec,
config.kernel_size_dec,
config.dilation_rate,
config.num_flow_blocks_dec,
config.num_block_layers,
config.dropout_p_dec,
config.num_splits,
config.num_squeeze,
config.sigmoid_scale,
config.c_in_channels,
).to(device)
if reset_weights:
reset_all_weights(decoder)
return decoder
def test_decoder_forward_backward(self):
for num_flow_blocks_dec in [8, None]:
for hidden_channels_dec in [100, None]:
decoder = self._get_decoder(num_flow_blocks_dec, hidden_channels_dec)
_, _, mel_spec, mel_lengths = _create_inputs()
z, z_len, _ = decoder(mel_spec.transpose(1, 2), mel_lengths)
mel_spec_, mel_lengths_, _ = decoder(z, z_len, reverse=True)
mask = sequence_mask(z_len).unsqueeze(1)
mel_spec = mel_spec[:, : z.shape[2], :].transpose(1, 2) * mask
z = z * mask
self.assertTrue(
torch.isclose(mel_spec, mel_spec_, atol=1e-2).all(),
f"num_flow_blocks_dec={num_flow_blocks_dec}, hidden_channels_dec={hidden_channels_dec}",
)
class TestNeuralHMM(unittest.TestCase):
@staticmethod
def _get_neural_hmm(deterministic_transition=None):
config = deepcopy(config_global)
neural_hmm = NeuralHMM(
config.out_channels,
config.ar_order,
config.deterministic_transition if deterministic_transition is None else deterministic_transition,
config.encoder_in_out_features,
config.prenet_type,
config.prenet_dim,
config.prenet_n_layers,
config.prenet_dropout,
config.prenet_dropout_at_inference,
config.memory_rnn_dim,
config.outputnet_size,
config.flat_start_params,
config.std_floor,
).to(device)
return neural_hmm
@staticmethod
def _get_emission_model():
return EmissionModel().to(device)
@staticmethod
def _get_transition_model():
return TransitionModel().to(device)
@staticmethod
def _get_embedded_input():
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs()
input_dummy = torch.nn.Embedding(config_global.num_chars, config_global.encoder_in_out_features).to(device)(
input_dummy
)
return input_dummy, input_lengths, mel_spec, mel_lengths
def test_neural_hmm_forward(self):
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
neural_hmm = self._get_neural_hmm()
log_prob, log_alpha_scaled, transition_matrix, means = neural_hmm(
input_dummy, input_lengths, mel_spec.transpose(1, 2), mel_lengths
)
self.assertEqual(log_prob.shape, (input_dummy.shape[0],))
self.assertEqual(log_alpha_scaled.shape, transition_matrix.shape)
def test_mask_lengths(self):
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
neural_hmm = self._get_neural_hmm()
log_prob, log_alpha_scaled, transition_matrix, means = neural_hmm(
input_dummy, input_lengths, mel_spec.transpose(1, 2), mel_lengths
)
log_c = torch.randn(mel_spec.shape[0], mel_spec.shape[1], device=device)
log_c, log_alpha_scaled = neural_hmm._mask_lengths( # pylint: disable=protected-access
mel_lengths, log_c, log_alpha_scaled
)
assertions = []
for i in range(mel_spec.shape[0]):
assertions.append(log_c[i, mel_lengths[i] :].sum() == 0.0)
self.assertTrue(all(assertions), "Incorrect masking")
assertions = []
for i in range(mel_spec.shape[0]):
assertions.append(log_alpha_scaled[i, mel_lengths[i] :, : input_lengths[i]].sum() == 0.0)
self.assertTrue(all(assertions), "Incorrect masking")
def test_process_ar_timestep(self):
model = self._get_neural_hmm()
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
h_post_prenet, c_post_prenet = model._init_lstm_states( # pylint: disable=protected-access
input_dummy.shape[0], config_global.memory_rnn_dim, mel_spec
)
h_post_prenet, c_post_prenet = model._process_ar_timestep( # pylint: disable=protected-access
1,
mel_spec,
h_post_prenet,
c_post_prenet,
)
self.assertEqual(h_post_prenet.shape, (input_dummy.shape[0], config_global.memory_rnn_dim))
self.assertEqual(c_post_prenet.shape, (input_dummy.shape[0], config_global.memory_rnn_dim))
def test_add_go_token(self):
model = self._get_neural_hmm()
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
out = model._add_go_token(mel_spec) # pylint: disable=protected-access
self.assertEqual(out.shape, mel_spec.shape)
self.assertTrue((out[:, 1:] == mel_spec[:, :-1]).all(), "Go token not appended properly")
def test_forward_algorithm_variables(self):
model = self._get_neural_hmm()
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
(
log_c,
log_alpha_scaled,
transition_matrix,
_,
) = model._initialize_forward_algorithm_variables( # pylint: disable=protected-access
mel_spec, input_dummy.shape[1] * config_global.state_per_phone
)
self.assertEqual(log_c.shape, (mel_spec.shape[0], mel_spec.shape[1]))
self.assertEqual(
log_alpha_scaled.shape,
(
mel_spec.shape[0],
mel_spec.shape[1],
input_dummy.shape[1] * config_global.state_per_phone,
),
)
self.assertEqual(
transition_matrix.shape,
(mel_spec.shape[0], mel_spec.shape[1], input_dummy.shape[1] * config_global.state_per_phone),
)
def test_get_absorption_state_scaling_factor(self):
model = self._get_neural_hmm()
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
input_lengths = input_lengths * config_global.state_per_phone
(
log_c,
log_alpha_scaled,
transition_matrix,
_,
) = model._initialize_forward_algorithm_variables( # pylint: disable=protected-access
mel_spec, input_dummy.shape[1] * config_global.state_per_phone
)
log_alpha_scaled = torch.rand_like(log_alpha_scaled).clamp(1e-3)
transition_matrix = torch.randn_like(transition_matrix).sigmoid().log()
sum_final_log_c = model.get_absorption_state_scaling_factor(
mel_lengths, log_alpha_scaled, input_lengths, transition_matrix
)
text_mask = ~sequence_mask(input_lengths)
transition_prob_mask = ~model.get_mask_for_last_item(input_lengths, device=input_lengths.device)
outputs = []
for i in range(input_dummy.shape[0]):
last_log_alpha_scaled = log_alpha_scaled[i, mel_lengths[i] - 1].masked_fill(text_mask[i], -float("inf"))
log_last_transition_probability = OverflowUtils.log_clamped(
torch.sigmoid(transition_matrix[i, mel_lengths[i] - 1])
).masked_fill(transition_prob_mask[i], -float("inf"))
outputs.append(last_log_alpha_scaled + log_last_transition_probability)
sum_final_log_c_computed = torch.logsumexp(torch.stack(outputs), dim=1)
self.assertTrue(torch.isclose(sum_final_log_c_computed, sum_final_log_c).all())
def test_inference(self):
model = self._get_neural_hmm()
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
for temp in [0.334, 0.667, 1.0]:
outputs = model.inference(
input_dummy, input_lengths, temp, config_global.max_sampling_time, config_global.duration_threshold
)
self.assertEqual(outputs["hmm_outputs"].shape[-1], outputs["input_parameters"][0][0][0].shape[-1])
self.assertEqual(
outputs["output_parameters"][0][0][0].shape[-1], outputs["input_parameters"][0][0][0].shape[-1]
)
self.assertEqual(len(outputs["alignments"]), input_dummy.shape[0])
def test_emission_model(self):
model = self._get_emission_model()
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
x_t = torch.randn(input_dummy.shape[0], config_global.out_channels).to(device)
means = torch.randn(input_dummy.shape[0], input_dummy.shape[1], config_global.out_channels).to(device)
std = torch.rand_like(means).to(device).clamp_(1e-3) # std should be positive
out = model(x_t, means, std, input_lengths)
self.assertEqual(out.shape, (input_dummy.shape[0], input_dummy.shape[1]))
# testing sampling
for temp in [0, 0.334, 0.667]:
out = model.sample(means, std, 0)
self.assertEqual(out.shape, means.shape)
if temp == 0:
self.assertTrue(torch.isclose(out, means).all())
def test_transition_model(self):
model = self._get_transition_model()
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
prev_t_log_scaled_alph = torch.randn(input_dummy.shape[0], input_lengths.max()).to(device)
transition_vector = torch.randn(input_lengths.max()).to(device)
out = model(prev_t_log_scaled_alph, transition_vector, input_lengths)
self.assertEqual(out.shape, (input_dummy.shape[0], input_lengths.max()))
class TestOverflowOutputNet(unittest.TestCase):
@staticmethod
def _get_outputnet():
config = deepcopy(config_global)
outputnet = Outputnet(
config.encoder_in_out_features,
config.memory_rnn_dim,
config.out_channels,
config.outputnet_size,
config.flat_start_params,
config.std_floor,
).to(device)
return outputnet
@staticmethod
def _get_embedded_input():
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs()
input_dummy = torch.nn.Embedding(config_global.num_chars, config_global.encoder_in_out_features).to(device)(
input_dummy
)
one_timestep_frame = torch.randn(input_dummy.shape[0], config_global.memory_rnn_dim).to(device)
return input_dummy, one_timestep_frame
def test_outputnet_forward_with_flat_start(self):
model = self._get_outputnet()
input_dummy, one_timestep_frame = self._get_embedded_input()
mean, std, transition_vector = model(one_timestep_frame, input_dummy)
self.assertTrue(torch.isclose(mean, torch.tensor(model.flat_start_params["mean"] * 1.0)).all())
self.assertTrue(torch.isclose(std, torch.tensor(model.flat_start_params["std"] * 1.0)).all())
self.assertTrue(
torch.isclose(
transition_vector.sigmoid(), torch.tensor(model.flat_start_params["transition_p"] * 1.0)
).all()
)