|
import os
|
|
import random
|
|
import unittest
|
|
from copy import deepcopy
|
|
|
|
import torch
|
|
|
|
from tests import get_tests_output_path
|
|
from TTS.tts.configs.overflow_config import OverflowConfig
|
|
from TTS.tts.layers.overflow.common_layers import Encoder, Outputnet, OverflowUtils
|
|
from TTS.tts.layers.overflow.decoder import Decoder
|
|
from TTS.tts.layers.overflow.neural_hmm import EmissionModel, NeuralHMM, TransitionModel
|
|
from TTS.tts.models.overflow import Overflow
|
|
from TTS.tts.utils.helpers import sequence_mask
|
|
from TTS.utils.audio import AudioProcessor
|
|
|
|
|
|
|
|
torch.manual_seed(1)
|
|
use_cuda = torch.cuda.is_available()
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
config_global = OverflowConfig(num_chars=24)
|
|
ap = AudioProcessor.init_from_config(config_global)
|
|
|
|
config_path = os.path.join(get_tests_output_path(), "test_model_config.json")
|
|
output_path = os.path.join(get_tests_output_path(), "train_outputs")
|
|
parameter_path = os.path.join(get_tests_output_path(), "lj_parameters.pt")
|
|
|
|
torch.save({"mean": -5.5138, "std": 2.0636, "init_transition_prob": 0.3212}, parameter_path)
|
|
|
|
|
|
def _create_inputs(batch_size=8):
|
|
max_len_t, max_len_m = random.randint(25, 50), random.randint(50, 80)
|
|
input_dummy = torch.randint(0, 24, (batch_size, max_len_t)).long().to(device)
|
|
input_lengths = torch.randint(20, max_len_t, (batch_size,)).long().to(device).sort(descending=True)[0]
|
|
input_lengths[0] = max_len_t
|
|
input_dummy = input_dummy * sequence_mask(input_lengths)
|
|
mel_spec = torch.randn(batch_size, max_len_m, config_global.audio["num_mels"]).to(device)
|
|
mel_lengths = torch.randint(40, max_len_m, (batch_size,)).long().to(device).sort(descending=True)[0]
|
|
mel_lengths[0] = max_len_m
|
|
mel_spec = mel_spec * sequence_mask(mel_lengths).unsqueeze(2)
|
|
return input_dummy, input_lengths, mel_spec, mel_lengths
|
|
|
|
|
|
def get_model(config=None):
|
|
if config is None:
|
|
config = config_global
|
|
config.mel_statistics_parameter_path = parameter_path
|
|
model = Overflow(config)
|
|
model = model.to(device)
|
|
return model
|
|
|
|
|
|
def reset_all_weights(model):
|
|
"""
|
|
refs:
|
|
- https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6
|
|
- https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch
|
|
- https://pytorch.org/docs/stable/generated/torch.nn.Module.html
|
|
"""
|
|
|
|
@torch.no_grad()
|
|
def weight_reset(m):
|
|
|
|
reset_parameters = getattr(m, "reset_parameters", None)
|
|
if callable(reset_parameters):
|
|
m.reset_parameters()
|
|
|
|
|
|
model.apply(fn=weight_reset)
|
|
|
|
|
|
class TestOverflow(unittest.TestCase):
|
|
def test_forward(self):
|
|
model = get_model()
|
|
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs()
|
|
outputs = model(input_dummy, input_lengths, mel_spec, mel_lengths)
|
|
self.assertEqual(outputs["log_probs"].shape, (input_dummy.shape[0],))
|
|
self.assertEqual(model.state_per_phone * max(input_lengths), outputs["alignments"].shape[2])
|
|
|
|
def test_inference(self):
|
|
model = get_model()
|
|
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs()
|
|
output_dict = model.inference(input_dummy)
|
|
self.assertEqual(output_dict["model_outputs"].shape[2], config_global.out_channels)
|
|
|
|
def test_init_from_config(self):
|
|
config = deepcopy(config_global)
|
|
config.mel_statistics_parameter_path = parameter_path
|
|
config.prenet_dim = 256
|
|
model = Overflow.init_from_config(config_global)
|
|
self.assertEqual(model.prenet_dim, config.prenet_dim)
|
|
|
|
|
|
class TestOverflowEncoder(unittest.TestCase):
|
|
@staticmethod
|
|
def get_encoder(state_per_phone):
|
|
config = deepcopy(config_global)
|
|
config.state_per_phone = state_per_phone
|
|
config.num_chars = 24
|
|
return Encoder(config.num_chars, config.state_per_phone, config.prenet_dim, config.encoder_n_convolutions).to(
|
|
device
|
|
)
|
|
|
|
def test_forward_with_state_per_phone_multiplication(self):
|
|
for s_p_p in [1, 2, 3]:
|
|
input_dummy, input_lengths, _, _ = _create_inputs()
|
|
model = self.get_encoder(s_p_p)
|
|
x, x_len = model(input_dummy, input_lengths)
|
|
self.assertEqual(x.shape[1], input_dummy.shape[1] * s_p_p)
|
|
|
|
def test_inference_with_state_per_phone_multiplication(self):
|
|
for s_p_p in [1, 2, 3]:
|
|
input_dummy, input_lengths, _, _ = _create_inputs()
|
|
model = self.get_encoder(s_p_p)
|
|
x, x_len = model.inference(input_dummy, input_lengths)
|
|
self.assertEqual(x.shape[1], input_dummy.shape[1] * s_p_p)
|
|
|
|
|
|
class TestOverflowUtils(unittest.TestCase):
|
|
def test_logsumexp(self):
|
|
a = torch.randn(10)
|
|
self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all())
|
|
|
|
a = torch.zeros(10)
|
|
self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all())
|
|
|
|
a = torch.ones(10)
|
|
self.assertTrue(torch.eq(torch.logsumexp(a, dim=0), OverflowUtils.logsumexp(a, dim=0)).all())
|
|
|
|
|
|
class TestOverflowDecoder(unittest.TestCase):
|
|
@staticmethod
|
|
def _get_decoder(num_flow_blocks_dec=None, hidden_channels_dec=None, reset_weights=True):
|
|
config = deepcopy(config_global)
|
|
config.num_flow_blocks_dec = (
|
|
num_flow_blocks_dec if num_flow_blocks_dec is not None else config.num_flow_blocks_dec
|
|
)
|
|
config.hidden_channels_dec = (
|
|
hidden_channels_dec if hidden_channels_dec is not None else config.hidden_channels_dec
|
|
)
|
|
config.dropout_p_dec = 0.0
|
|
decoder = Decoder(
|
|
config.out_channels,
|
|
config.hidden_channels_dec,
|
|
config.kernel_size_dec,
|
|
config.dilation_rate,
|
|
config.num_flow_blocks_dec,
|
|
config.num_block_layers,
|
|
config.dropout_p_dec,
|
|
config.num_splits,
|
|
config.num_squeeze,
|
|
config.sigmoid_scale,
|
|
config.c_in_channels,
|
|
).to(device)
|
|
if reset_weights:
|
|
reset_all_weights(decoder)
|
|
return decoder
|
|
|
|
def test_decoder_forward_backward(self):
|
|
for num_flow_blocks_dec in [8, None]:
|
|
for hidden_channels_dec in [100, None]:
|
|
decoder = self._get_decoder(num_flow_blocks_dec, hidden_channels_dec)
|
|
_, _, mel_spec, mel_lengths = _create_inputs()
|
|
z, z_len, _ = decoder(mel_spec.transpose(1, 2), mel_lengths)
|
|
mel_spec_, mel_lengths_, _ = decoder(z, z_len, reverse=True)
|
|
mask = sequence_mask(z_len).unsqueeze(1)
|
|
mel_spec = mel_spec[:, : z.shape[2], :].transpose(1, 2) * mask
|
|
z = z * mask
|
|
self.assertTrue(
|
|
torch.isclose(mel_spec, mel_spec_, atol=1e-2).all(),
|
|
f"num_flow_blocks_dec={num_flow_blocks_dec}, hidden_channels_dec={hidden_channels_dec}",
|
|
)
|
|
|
|
|
|
class TestNeuralHMM(unittest.TestCase):
|
|
@staticmethod
|
|
def _get_neural_hmm(deterministic_transition=None):
|
|
config = deepcopy(config_global)
|
|
neural_hmm = NeuralHMM(
|
|
config.out_channels,
|
|
config.ar_order,
|
|
config.deterministic_transition if deterministic_transition is None else deterministic_transition,
|
|
config.encoder_in_out_features,
|
|
config.prenet_type,
|
|
config.prenet_dim,
|
|
config.prenet_n_layers,
|
|
config.prenet_dropout,
|
|
config.prenet_dropout_at_inference,
|
|
config.memory_rnn_dim,
|
|
config.outputnet_size,
|
|
config.flat_start_params,
|
|
config.std_floor,
|
|
).to(device)
|
|
return neural_hmm
|
|
|
|
@staticmethod
|
|
def _get_emission_model():
|
|
return EmissionModel().to(device)
|
|
|
|
@staticmethod
|
|
def _get_transition_model():
|
|
return TransitionModel().to(device)
|
|
|
|
@staticmethod
|
|
def _get_embedded_input():
|
|
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs()
|
|
input_dummy = torch.nn.Embedding(config_global.num_chars, config_global.encoder_in_out_features).to(device)(
|
|
input_dummy
|
|
)
|
|
return input_dummy, input_lengths, mel_spec, mel_lengths
|
|
|
|
def test_neural_hmm_forward(self):
|
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
|
|
neural_hmm = self._get_neural_hmm()
|
|
log_prob, log_alpha_scaled, transition_matrix, means = neural_hmm(
|
|
input_dummy, input_lengths, mel_spec.transpose(1, 2), mel_lengths
|
|
)
|
|
self.assertEqual(log_prob.shape, (input_dummy.shape[0],))
|
|
self.assertEqual(log_alpha_scaled.shape, transition_matrix.shape)
|
|
|
|
def test_mask_lengths(self):
|
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
|
|
neural_hmm = self._get_neural_hmm()
|
|
log_prob, log_alpha_scaled, transition_matrix, means = neural_hmm(
|
|
input_dummy, input_lengths, mel_spec.transpose(1, 2), mel_lengths
|
|
)
|
|
log_c = torch.randn(mel_spec.shape[0], mel_spec.shape[1], device=device)
|
|
log_c, log_alpha_scaled = neural_hmm._mask_lengths(
|
|
mel_lengths, log_c, log_alpha_scaled
|
|
)
|
|
assertions = []
|
|
for i in range(mel_spec.shape[0]):
|
|
assertions.append(log_c[i, mel_lengths[i] :].sum() == 0.0)
|
|
self.assertTrue(all(assertions), "Incorrect masking")
|
|
assertions = []
|
|
for i in range(mel_spec.shape[0]):
|
|
assertions.append(log_alpha_scaled[i, mel_lengths[i] :, : input_lengths[i]].sum() == 0.0)
|
|
self.assertTrue(all(assertions), "Incorrect masking")
|
|
|
|
def test_process_ar_timestep(self):
|
|
model = self._get_neural_hmm()
|
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
|
|
|
|
h_post_prenet, c_post_prenet = model._init_lstm_states(
|
|
input_dummy.shape[0], config_global.memory_rnn_dim, mel_spec
|
|
)
|
|
h_post_prenet, c_post_prenet = model._process_ar_timestep(
|
|
1,
|
|
mel_spec,
|
|
h_post_prenet,
|
|
c_post_prenet,
|
|
)
|
|
|
|
self.assertEqual(h_post_prenet.shape, (input_dummy.shape[0], config_global.memory_rnn_dim))
|
|
self.assertEqual(c_post_prenet.shape, (input_dummy.shape[0], config_global.memory_rnn_dim))
|
|
|
|
def test_add_go_token(self):
|
|
model = self._get_neural_hmm()
|
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
|
|
|
|
out = model._add_go_token(mel_spec)
|
|
self.assertEqual(out.shape, mel_spec.shape)
|
|
self.assertTrue((out[:, 1:] == mel_spec[:, :-1]).all(), "Go token not appended properly")
|
|
|
|
def test_forward_algorithm_variables(self):
|
|
model = self._get_neural_hmm()
|
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
|
|
|
|
(
|
|
log_c,
|
|
log_alpha_scaled,
|
|
transition_matrix,
|
|
_,
|
|
) = model._initialize_forward_algorithm_variables(
|
|
mel_spec, input_dummy.shape[1] * config_global.state_per_phone
|
|
)
|
|
|
|
self.assertEqual(log_c.shape, (mel_spec.shape[0], mel_spec.shape[1]))
|
|
self.assertEqual(
|
|
log_alpha_scaled.shape,
|
|
(
|
|
mel_spec.shape[0],
|
|
mel_spec.shape[1],
|
|
input_dummy.shape[1] * config_global.state_per_phone,
|
|
),
|
|
)
|
|
self.assertEqual(
|
|
transition_matrix.shape,
|
|
(mel_spec.shape[0], mel_spec.shape[1], input_dummy.shape[1] * config_global.state_per_phone),
|
|
)
|
|
|
|
def test_get_absorption_state_scaling_factor(self):
|
|
model = self._get_neural_hmm()
|
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
|
|
input_lengths = input_lengths * config_global.state_per_phone
|
|
(
|
|
log_c,
|
|
log_alpha_scaled,
|
|
transition_matrix,
|
|
_,
|
|
) = model._initialize_forward_algorithm_variables(
|
|
mel_spec, input_dummy.shape[1] * config_global.state_per_phone
|
|
)
|
|
log_alpha_scaled = torch.rand_like(log_alpha_scaled).clamp(1e-3)
|
|
transition_matrix = torch.randn_like(transition_matrix).sigmoid().log()
|
|
sum_final_log_c = model.get_absorption_state_scaling_factor(
|
|
mel_lengths, log_alpha_scaled, input_lengths, transition_matrix
|
|
)
|
|
|
|
text_mask = ~sequence_mask(input_lengths)
|
|
transition_prob_mask = ~model.get_mask_for_last_item(input_lengths, device=input_lengths.device)
|
|
|
|
outputs = []
|
|
|
|
for i in range(input_dummy.shape[0]):
|
|
last_log_alpha_scaled = log_alpha_scaled[i, mel_lengths[i] - 1].masked_fill(text_mask[i], -float("inf"))
|
|
log_last_transition_probability = OverflowUtils.log_clamped(
|
|
torch.sigmoid(transition_matrix[i, mel_lengths[i] - 1])
|
|
).masked_fill(transition_prob_mask[i], -float("inf"))
|
|
outputs.append(last_log_alpha_scaled + log_last_transition_probability)
|
|
|
|
sum_final_log_c_computed = torch.logsumexp(torch.stack(outputs), dim=1)
|
|
|
|
self.assertTrue(torch.isclose(sum_final_log_c_computed, sum_final_log_c).all())
|
|
|
|
def test_inference(self):
|
|
model = self._get_neural_hmm()
|
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
|
|
for temp in [0.334, 0.667, 1.0]:
|
|
outputs = model.inference(
|
|
input_dummy, input_lengths, temp, config_global.max_sampling_time, config_global.duration_threshold
|
|
)
|
|
self.assertEqual(outputs["hmm_outputs"].shape[-1], outputs["input_parameters"][0][0][0].shape[-1])
|
|
self.assertEqual(
|
|
outputs["output_parameters"][0][0][0].shape[-1], outputs["input_parameters"][0][0][0].shape[-1]
|
|
)
|
|
self.assertEqual(len(outputs["alignments"]), input_dummy.shape[0])
|
|
|
|
def test_emission_model(self):
|
|
model = self._get_emission_model()
|
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
|
|
x_t = torch.randn(input_dummy.shape[0], config_global.out_channels).to(device)
|
|
means = torch.randn(input_dummy.shape[0], input_dummy.shape[1], config_global.out_channels).to(device)
|
|
std = torch.rand_like(means).to(device).clamp_(1e-3)
|
|
out = model(x_t, means, std, input_lengths)
|
|
self.assertEqual(out.shape, (input_dummy.shape[0], input_dummy.shape[1]))
|
|
|
|
|
|
for temp in [0, 0.334, 0.667]:
|
|
out = model.sample(means, std, 0)
|
|
self.assertEqual(out.shape, means.shape)
|
|
if temp == 0:
|
|
self.assertTrue(torch.isclose(out, means).all())
|
|
|
|
def test_transition_model(self):
|
|
model = self._get_transition_model()
|
|
input_dummy, input_lengths, mel_spec, mel_lengths = self._get_embedded_input()
|
|
prev_t_log_scaled_alph = torch.randn(input_dummy.shape[0], input_lengths.max()).to(device)
|
|
transition_vector = torch.randn(input_lengths.max()).to(device)
|
|
out = model(prev_t_log_scaled_alph, transition_vector, input_lengths)
|
|
self.assertEqual(out.shape, (input_dummy.shape[0], input_lengths.max()))
|
|
|
|
|
|
class TestOverflowOutputNet(unittest.TestCase):
|
|
@staticmethod
|
|
def _get_outputnet():
|
|
config = deepcopy(config_global)
|
|
outputnet = Outputnet(
|
|
config.encoder_in_out_features,
|
|
config.memory_rnn_dim,
|
|
config.out_channels,
|
|
config.outputnet_size,
|
|
config.flat_start_params,
|
|
config.std_floor,
|
|
).to(device)
|
|
return outputnet
|
|
|
|
@staticmethod
|
|
def _get_embedded_input():
|
|
input_dummy, input_lengths, mel_spec, mel_lengths = _create_inputs()
|
|
input_dummy = torch.nn.Embedding(config_global.num_chars, config_global.encoder_in_out_features).to(device)(
|
|
input_dummy
|
|
)
|
|
one_timestep_frame = torch.randn(input_dummy.shape[0], config_global.memory_rnn_dim).to(device)
|
|
return input_dummy, one_timestep_frame
|
|
|
|
def test_outputnet_forward_with_flat_start(self):
|
|
model = self._get_outputnet()
|
|
input_dummy, one_timestep_frame = self._get_embedded_input()
|
|
mean, std, transition_vector = model(one_timestep_frame, input_dummy)
|
|
self.assertTrue(torch.isclose(mean, torch.tensor(model.flat_start_params["mean"] * 1.0)).all())
|
|
self.assertTrue(torch.isclose(std, torch.tensor(model.flat_start_params["std"] * 1.0)).all())
|
|
self.assertTrue(
|
|
torch.isclose(
|
|
transition_vector.sigmoid(), torch.tensor(model.flat_start_params["transition_p"] * 1.0)
|
|
).all()
|
|
)
|
|
|