Spaces:
Paused
Paused
| from typing import List | |
| import torch | |
| import torch.distributions as tdist | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from torch.utils.checkpoint import checkpoint | |
| from TTS.tts.layers.overflow.common_layers import Outputnet, OverflowUtils | |
| from TTS.tts.layers.tacotron.common_layers import Prenet | |
| from TTS.tts.utils.helpers import sequence_mask | |
| class NeuralHMM(nn.Module): | |
| """Autoregressive left to right HMM model primarily used in "Neural HMMs are all you need (for high-quality attention-free TTS)" | |
| Paper:: | |
| https://arxiv.org/abs/2108.13320 | |
| Paper abstract:: | |
| Neural sequence-to-sequence TTS has achieved significantly better output quality than statistical speech synthesis using | |
| HMMs. However, neural TTS is generally not probabilistic and uses non-monotonic attention. Attention failures increase | |
| training time and can make synthesis babble incoherently. This paper describes how the old and new paradigms can be | |
| combined to obtain the advantages of both worlds, by replacing attention in neural TTS with an autoregressive left-right | |
| no-skip hidden Markov model defined by a neural network. Based on this proposal, we modify Tacotron 2 to obtain an | |
| HMM-based neural TTS model with monotonic alignment, trained to maximise the full sequence likelihood without | |
| approximation. We also describe how to combine ideas from classical and contemporary TTS for best results. The resulting | |
| example system is smaller and simpler than Tacotron 2, and learns to speak with fewer iterations and less data, whilst | |
| achieving comparable naturalness prior to the post-net. Our approach also allows easy control over speaking rate. | |
| Args: | |
| frame_channels (int): Output dimension to generate. | |
| ar_order (int): Autoregressive order of the model. In ablations of Neural HMM it was found that more autoregression while giving more variation hurts naturalness of the synthesised audio. | |
| deterministic_transition (bool): deterministic duration generation based on duration quantiles as defiend in "S. Ronanki, O. Watts, S. King, and G. E. Henter, “Medianbased generation of synthetic speech durations using a nonparametric approach,” in Proc. SLT, 2016.". Defaults to True. | |
| encoder_dim (int): Channels of encoder input and character embedding tensors. Defaults to 512. | |
| prenet_type (str): `original` or `bn`. `original` sets the default Prenet and `bn` uses Batch Normalization version of the Prenet. | |
| prenet_dim (int): Dimension of the Prenet. | |
| prenet_n_layers (int): Number of layers in the Prenet. | |
| prenet_dropout (float): Dropout probability of the Prenet. | |
| prenet_dropout_at_inference (bool): If True, dropout is applied at inference time. | |
| memory_rnn_dim (int): Size of the memory RNN to process output of prenet. | |
| outputnet_size (List[int]): Size of the output network inside the neural HMM. | |
| flat_start_params (dict): Parameters for the flat start initialization of the neural HMM. | |
| std_floor (float): Floor value for the standard deviation of the neural HMM. Prevents model cheating by putting point mass and getting infinite likelihood at any datapoint. | |
| use_grad_checkpointing (bool, optional): Use gradient checkpointing to save memory. Defaults to True. | |
| """ | |
| def __init__( | |
| self, | |
| frame_channels: int, | |
| ar_order: int, | |
| deterministic_transition: bool, | |
| encoder_dim: int, | |
| prenet_type: str, | |
| prenet_dim: int, | |
| prenet_n_layers: int, | |
| prenet_dropout: float, | |
| prenet_dropout_at_inference: bool, | |
| memory_rnn_dim: int, | |
| outputnet_size: List[int], | |
| flat_start_params: dict, | |
| std_floor: float, | |
| use_grad_checkpointing: bool = True, | |
| ): | |
| super().__init__() | |
| self.frame_channels = frame_channels | |
| self.ar_order = ar_order | |
| self.deterministic_transition = deterministic_transition | |
| self.prenet_dim = prenet_dim | |
| self.memory_rnn_dim = memory_rnn_dim | |
| self.use_grad_checkpointing = use_grad_checkpointing | |
| self.transition_model = TransitionModel() | |
| self.emission_model = EmissionModel() | |
| assert ar_order > 0, f"AR order must be greater than 0 provided {ar_order}" | |
| self.ar_order = ar_order | |
| self.prenet = Prenet( | |
| in_features=frame_channels * ar_order, | |
| prenet_type=prenet_type, | |
| prenet_dropout=prenet_dropout, | |
| dropout_at_inference=prenet_dropout_at_inference, | |
| out_features=[self.prenet_dim for _ in range(prenet_n_layers)], | |
| bias=False, | |
| ) | |
| self.memory_rnn = nn.LSTMCell(input_size=prenet_dim, hidden_size=memory_rnn_dim) | |
| self.output_net = Outputnet( | |
| encoder_dim, memory_rnn_dim, frame_channels, outputnet_size, flat_start_params, std_floor | |
| ) | |
| self.register_buffer("go_tokens", torch.zeros(ar_order, 1)) | |
| def forward(self, inputs, inputs_len, mels, mel_lens): | |
| r"""HMM forward algorithm for training uses logarithmic version of Rabiner (1989) forward algorithm. | |
| Args: | |
| inputs (torch.FloatTensor): Encoder outputs | |
| inputs_len (torch.LongTensor): Encoder output lengths | |
| mels (torch.FloatTensor): Mel inputs | |
| mel_lens (torch.LongTensor): Length of mel inputs | |
| Shapes: | |
| - inputs: (B, T, D_out_enc) | |
| - inputs_len: (B) | |
| - mels: (B, D_mel, T_mel) | |
| - mel_lens: (B) | |
| Returns: | |
| log_prob (torch.FloatTensor): Log probability of the sequence | |
| """ | |
| # Get dimensions of inputs | |
| batch_size, N, _ = inputs.shape | |
| T_max = torch.max(mel_lens) | |
| mels = mels.permute(0, 2, 1) | |
| # Intialize forward algorithm | |
| log_state_priors = self._initialize_log_state_priors(inputs) | |
| log_c, log_alpha_scaled, transition_matrix, means = self._initialize_forward_algorithm_variables(mels, N) | |
| # Initialize autoregression elements | |
| ar_inputs = self._add_go_token(mels) | |
| h_memory, c_memory = self._init_lstm_states(batch_size, self.memory_rnn_dim, mels) | |
| for t in range(T_max): | |
| # Process Autoregression | |
| h_memory, c_memory = self._process_ar_timestep(t, ar_inputs, h_memory, c_memory) | |
| # Get mean, std and transition vector from decoder for this timestep | |
| # Note: Gradient checkpointing currently doesn't works with multiple gpus inside a loop | |
| if self.use_grad_checkpointing and self.training: | |
| mean, std, transition_vector = checkpoint(self.output_net, h_memory, inputs) | |
| else: | |
| mean, std, transition_vector = self.output_net(h_memory, inputs) | |
| if t == 0: | |
| log_alpha_temp = log_state_priors + self.emission_model(mels[:, 0], mean, std, inputs_len) | |
| else: | |
| log_alpha_temp = self.emission_model(mels[:, t], mean, std, inputs_len) + self.transition_model( | |
| log_alpha_scaled[:, t - 1, :], transition_vector, inputs_len | |
| ) | |
| log_c[:, t] = torch.logsumexp(log_alpha_temp, dim=1) | |
| log_alpha_scaled[:, t, :] = log_alpha_temp - log_c[:, t].unsqueeze(1) | |
| transition_matrix[:, t] = transition_vector # needed for absorption state calculation | |
| # Save for plotting | |
| means.append(mean.detach()) | |
| log_c, log_alpha_scaled = self._mask_lengths(mel_lens, log_c, log_alpha_scaled) | |
| sum_final_log_c = self.get_absorption_state_scaling_factor( | |
| mel_lens, log_alpha_scaled, inputs_len, transition_matrix | |
| ) | |
| log_probs = torch.sum(log_c, dim=1) + sum_final_log_c | |
| return log_probs, log_alpha_scaled, transition_matrix, means | |
| def _mask_lengths(mel_lens, log_c, log_alpha_scaled): | |
| """ | |
| Mask the lengths of the forward variables so that the variable lenghts | |
| do not contribute in the loss calculation | |
| Args: | |
| mel_inputs (torch.FloatTensor): (batch, T, frame_channels) | |
| mel_inputs_lengths (torch.IntTensor): (batch) | |
| log_c (torch.FloatTensor): (batch, T) | |
| Returns: | |
| log_c (torch.FloatTensor) : scaled probabilities (batch, T) | |
| log_alpha_scaled (torch.FloatTensor): forward probabilities (batch, T, N) | |
| """ | |
| mask_log_c = sequence_mask(mel_lens) | |
| log_c = log_c * mask_log_c | |
| mask_log_alpha_scaled = mask_log_c.unsqueeze(2) | |
| log_alpha_scaled = log_alpha_scaled * mask_log_alpha_scaled | |
| return log_c, log_alpha_scaled | |
| def _process_ar_timestep( | |
| self, | |
| t, | |
| ar_inputs, | |
| h_memory, | |
| c_memory, | |
| ): | |
| """ | |
| Process autoregression in timestep | |
| 1. At a specific t timestep | |
| 2. Perform data dropout if applied (we did not use it) | |
| 3. Run the autoregressive frame through the prenet (has dropout) | |
| 4. Run the prenet output through the post prenet rnn | |
| Args: | |
| t (int): mel-spec timestep | |
| ar_inputs (torch.FloatTensor): go-token appended mel-spectrograms | |
| - shape: (b, D_out, T_out) | |
| h_post_prenet (torch.FloatTensor): previous timestep rnn hidden state | |
| - shape: (b, memory_rnn_dim) | |
| c_post_prenet (torch.FloatTensor): previous timestep rnn cell state | |
| - shape: (b, memory_rnn_dim) | |
| Returns: | |
| h_post_prenet (torch.FloatTensor): rnn hidden state of the current timestep | |
| c_post_prenet (torch.FloatTensor): rnn cell state of the current timestep | |
| """ | |
| prenet_input = ar_inputs[:, t : t + self.ar_order].flatten(1) | |
| memory_inputs = self.prenet(prenet_input) | |
| h_memory, c_memory = self.memory_rnn(memory_inputs, (h_memory, c_memory)) | |
| return h_memory, c_memory | |
| def _add_go_token(self, mel_inputs): | |
| """Append the go token to create the autoregressive input | |
| Args: | |
| mel_inputs (torch.FloatTensor): (batch_size, T, n_mel_channel) | |
| Returns: | |
| ar_inputs (torch.FloatTensor): (batch_size, T, n_mel_channel) | |
| """ | |
| batch_size, T, _ = mel_inputs.shape | |
| go_tokens = self.go_tokens.unsqueeze(0).expand(batch_size, self.ar_order, self.frame_channels) | |
| ar_inputs = torch.cat((go_tokens, mel_inputs), dim=1)[:, :T] | |
| return ar_inputs | |
| def _initialize_forward_algorithm_variables(mel_inputs, N): | |
| r"""Initialize placeholders for forward algorithm variables, to use a stable | |
| version we will use log_alpha_scaled and the scaling constant | |
| Args: | |
| mel_inputs (torch.FloatTensor): (b, T_max, frame_channels) | |
| N (int): number of states | |
| Returns: | |
| log_c (torch.FloatTensor): Scaling constant (b, T_max) | |
| """ | |
| b, T_max, _ = mel_inputs.shape | |
| log_alpha_scaled = mel_inputs.new_zeros((b, T_max, N)) | |
| log_c = mel_inputs.new_zeros(b, T_max) | |
| transition_matrix = mel_inputs.new_zeros((b, T_max, N)) | |
| # Saving for plotting later, will not have gradient tapes | |
| means = [] | |
| return log_c, log_alpha_scaled, transition_matrix, means | |
| def _init_lstm_states(batch_size, hidden_state_dim, device_tensor): | |
| r""" | |
| Initialize Hidden and Cell states for LSTM Cell | |
| Args: | |
| batch_size (Int): batch size | |
| hidden_state_dim (Int): dimensions of the h and c | |
| device_tensor (torch.FloatTensor): useful for the device and type | |
| Returns: | |
| (torch.FloatTensor): shape (batch_size, hidden_state_dim) | |
| can be hidden state for LSTM | |
| (torch.FloatTensor): shape (batch_size, hidden_state_dim) | |
| can be the cell state for LSTM | |
| """ | |
| return ( | |
| device_tensor.new_zeros(batch_size, hidden_state_dim), | |
| device_tensor.new_zeros(batch_size, hidden_state_dim), | |
| ) | |
| def get_absorption_state_scaling_factor(self, mels_len, log_alpha_scaled, inputs_len, transition_vector): | |
| """Returns the final scaling factor of absorption state | |
| Args: | |
| mels_len (torch.IntTensor): Input size of mels to | |
| get the last timestep of log_alpha_scaled | |
| log_alpha_scaled (torch.FloatTEnsor): State probabilities | |
| text_lengths (torch.IntTensor): length of the states to | |
| mask the values of states lengths | |
| ( | |
| Useful when the batch has very different lengths, | |
| when the length of an observation is less than | |
| the number of max states, then the log alpha after | |
| the state value is filled with -infs. So we mask | |
| those values so that it only consider the states | |
| which are needed for that length | |
| ) | |
| transition_vector (torch.FloatTensor): transtiion vector for each state per timestep | |
| Shapes: | |
| - mels_len: (batch_size) | |
| - log_alpha_scaled: (batch_size, N, T) | |
| - text_lengths: (batch_size) | |
| - transition_vector: (batch_size, N, T) | |
| Returns: | |
| sum_final_log_c (torch.FloatTensor): (batch_size) | |
| """ | |
| N = torch.max(inputs_len) | |
| max_inputs_len = log_alpha_scaled.shape[2] | |
| state_lengths_mask = sequence_mask(inputs_len, max_len=max_inputs_len) | |
| last_log_alpha_scaled_index = ( | |
| (mels_len - 1).unsqueeze(-1).expand(-1, N).unsqueeze(1) | |
| ) # Batch X Hidden State Size | |
| last_log_alpha_scaled = torch.gather(log_alpha_scaled, 1, last_log_alpha_scaled_index).squeeze(1) | |
| last_log_alpha_scaled = last_log_alpha_scaled.masked_fill(~state_lengths_mask, -float("inf")) | |
| last_transition_vector = torch.gather(transition_vector, 1, last_log_alpha_scaled_index).squeeze(1) | |
| last_transition_probability = torch.sigmoid(last_transition_vector) | |
| log_probability_of_transitioning = OverflowUtils.log_clamped(last_transition_probability) | |
| last_transition_probability_index = self.get_mask_for_last_item(inputs_len, inputs_len.device) | |
| log_probability_of_transitioning = log_probability_of_transitioning.masked_fill( | |
| ~last_transition_probability_index, -float("inf") | |
| ) | |
| final_log_c = last_log_alpha_scaled + log_probability_of_transitioning | |
| # If the length of the mel is less than the number of states it will select the -inf values leading to nan gradients | |
| # Ideally, we should clean the dataset otherwise this is a little hack uncomment the line below | |
| final_log_c = final_log_c.clamp(min=torch.finfo(final_log_c.dtype).min) | |
| sum_final_log_c = torch.logsumexp(final_log_c, dim=1) | |
| return sum_final_log_c | |
| def get_mask_for_last_item(lengths, device, out_tensor=None): | |
| """Returns n-1 mask for the last item in the sequence. | |
| Args: | |
| lengths (torch.IntTensor): lengths in a batch | |
| device (str, optional): Defaults to "cpu". | |
| out_tensor (torch.Tensor, optional): uses the memory of a specific tensor. | |
| Defaults to None. | |
| Returns: | |
| - Shape: :math:`(b, max_len)` | |
| """ | |
| max_len = torch.max(lengths).item() | |
| ids = ( | |
| torch.arange(0, max_len, device=device) if out_tensor is None else torch.arange(0, max_len, out=out_tensor) | |
| ) | |
| mask = ids == lengths.unsqueeze(1) - 1 | |
| return mask | |
| def inference( | |
| self, | |
| inputs: torch.FloatTensor, | |
| input_lens: torch.LongTensor, | |
| sampling_temp: float, | |
| max_sampling_time: int, | |
| duration_threshold: float, | |
| ): | |
| """Inference from autoregressive neural HMM | |
| Args: | |
| inputs (torch.FloatTensor): input states | |
| - shape: :math:`(b, T, d)` | |
| input_lens (torch.LongTensor): input state lengths | |
| - shape: :math:`(b)` | |
| sampling_temp (float): sampling temperature | |
| max_sampling_temp (int): max sampling temperature | |
| duration_threshold (float): duration threshold to switch to next state | |
| - Use this to change the spearking rate of the synthesised audio | |
| """ | |
| b = inputs.shape[0] | |
| outputs = { | |
| "hmm_outputs": [], | |
| "hmm_outputs_len": [], | |
| "alignments": [], | |
| "input_parameters": [], | |
| "output_parameters": [], | |
| } | |
| for i in range(b): | |
| neural_hmm_outputs, states_travelled, input_parameters, output_parameters = self.sample( | |
| inputs[i : i + 1], input_lens[i], sampling_temp, max_sampling_time, duration_threshold | |
| ) | |
| outputs["hmm_outputs"].append(neural_hmm_outputs) | |
| outputs["hmm_outputs_len"].append(neural_hmm_outputs.shape[0]) | |
| outputs["alignments"].append(states_travelled) | |
| outputs["input_parameters"].append(input_parameters) | |
| outputs["output_parameters"].append(output_parameters) | |
| outputs["hmm_outputs"] = nn.utils.rnn.pad_sequence(outputs["hmm_outputs"], batch_first=True) | |
| outputs["hmm_outputs_len"] = torch.tensor( | |
| outputs["hmm_outputs_len"], dtype=input_lens.dtype, device=input_lens.device | |
| ) | |
| return outputs | |
| def sample(self, inputs, input_lens, sampling_temp, max_sampling_time, duration_threshold): | |
| """Samples an output from the parameter models | |
| Args: | |
| inputs (torch.FloatTensor): input states | |
| - shape: :math:`(1, T, d)` | |
| input_lens (torch.LongTensor): input state lengths | |
| - shape: :math:`(1)` | |
| sampling_temp (float): sampling temperature | |
| max_sampling_time (int): max sampling time | |
| duration_threshold (float): duration threshold to switch to next state | |
| Returns: | |
| outputs (torch.FloatTensor): Output Observations | |
| - Shape: :math:`(T, output_dim)` | |
| states_travelled (list[int]): Hidden states travelled | |
| - Shape: :math:`(T)` | |
| input_parameters (list[torch.FloatTensor]): Input parameters | |
| output_parameters (list[torch.FloatTensor]): Output parameters | |
| """ | |
| states_travelled, outputs, t = [], [], 0 | |
| # Sample initial state | |
| current_state = 0 | |
| states_travelled.append(current_state) | |
| # Prepare autoregression | |
| prenet_input = self.go_tokens.unsqueeze(0).expand(1, self.ar_order, self.frame_channels) | |
| h_memory, c_memory = self._init_lstm_states(1, self.memory_rnn_dim, prenet_input) | |
| input_parameter_values = [] | |
| output_parameter_values = [] | |
| quantile = 1 | |
| while True: | |
| memory_input = self.prenet(prenet_input.flatten(1).unsqueeze(0)) | |
| # will be 1 while sampling | |
| h_memory, c_memory = self.memory_rnn(memory_input.squeeze(0), (h_memory, c_memory)) | |
| z_t = inputs[:, current_state].unsqueeze(0) # Add fake time dimension | |
| mean, std, transition_vector = self.output_net(h_memory, z_t) | |
| transition_probability = torch.sigmoid(transition_vector.flatten()) | |
| staying_probability = torch.sigmoid(-transition_vector.flatten()) | |
| # Save for plotting | |
| input_parameter_values.append([prenet_input, current_state]) | |
| output_parameter_values.append([mean, std, transition_probability]) | |
| x_t = self.emission_model.sample(mean, std, sampling_temp=sampling_temp) | |
| # Prepare autoregressive input for next iteration | |
| prenet_input = torch.cat((prenet_input, x_t), dim=1)[:, 1:] | |
| outputs.append(x_t.flatten()) | |
| transition_matrix = torch.cat((staying_probability, transition_probability)) | |
| quantile *= staying_probability | |
| if not self.deterministic_transition: | |
| switch = transition_matrix.multinomial(1)[0].item() | |
| else: | |
| switch = quantile < duration_threshold | |
| if switch: | |
| current_state += 1 | |
| quantile = 1 | |
| states_travelled.append(current_state) | |
| if (current_state == input_lens) or (max_sampling_time and t == max_sampling_time - 1): | |
| break | |
| t += 1 | |
| return ( | |
| torch.stack(outputs, dim=0), | |
| F.one_hot(input_lens.new_tensor(states_travelled)), | |
| input_parameter_values, | |
| output_parameter_values, | |
| ) | |
| def _initialize_log_state_priors(text_embeddings): | |
| """Creates the log pi in forward algorithm. | |
| Args: | |
| text_embeddings (torch.FloatTensor): used to create the log pi | |
| on current device | |
| Shapes: | |
| - text_embeddings: (B, T, D_out_enc) | |
| """ | |
| N = text_embeddings.shape[1] | |
| log_state_priors = text_embeddings.new_full([N], -float("inf")) | |
| log_state_priors[0] = 0.0 | |
| return log_state_priors | |
| class TransitionModel(nn.Module): | |
| """Transition Model of the HMM, it represents the probability of transitioning | |
| form current state to all other states""" | |
| def forward(self, log_alpha_scaled, transition_vector, inputs_len): # pylint: disable=no-self-use | |
| r""" | |
| product of the past state with transitional probabilities in log space | |
| Args: | |
| log_alpha_scaled (torch.Tensor): Multiply previous timestep's alphas by | |
| transition matrix (in log domain) | |
| - shape: (batch size, N) | |
| transition_vector (torch.tensor): transition vector for each state | |
| - shape: (N) | |
| inputs_len (int tensor): Lengths of states in a batch | |
| - shape: (batch) | |
| Returns: | |
| out (torch.FloatTensor): log probability of transitioning to each state | |
| """ | |
| transition_p = torch.sigmoid(transition_vector) | |
| staying_p = torch.sigmoid(-transition_vector) | |
| log_staying_probability = OverflowUtils.log_clamped(staying_p) | |
| log_transition_probability = OverflowUtils.log_clamped(transition_p) | |
| staying = log_alpha_scaled + log_staying_probability | |
| leaving = log_alpha_scaled + log_transition_probability | |
| leaving = leaving.roll(1, dims=1) | |
| leaving[:, 0] = -float("inf") | |
| inputs_len_mask = sequence_mask(inputs_len) | |
| out = OverflowUtils.logsumexp(torch.stack((staying, leaving), dim=2), dim=2) | |
| out = out.masked_fill(~inputs_len_mask, -float("inf")) # There are no states to contribute to the loss | |
| return out | |
| class EmissionModel(nn.Module): | |
| """Emission Model of the HMM, it represents the probability of | |
| emitting an observation based on the current state""" | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self.distribution_function: tdist.Distribution = tdist.normal.Normal | |
| def sample(self, means, stds, sampling_temp): | |
| return self.distribution_function(means, stds * sampling_temp).sample() if sampling_temp > 0 else means | |
| def forward(self, x_t, means, stds, state_lengths): | |
| r"""Calculates the log probability of the the given data (x_t) | |
| being observed from states with given means and stds | |
| Args: | |
| x_t (float tensor) : observation at current time step | |
| - shape: (batch, feature_dim) | |
| means (float tensor): means of the distributions of hidden states | |
| - shape: (batch, hidden_state, feature_dim) | |
| stds (float tensor): standard deviations of the distributions of the hidden states | |
| - shape: (batch, hidden_state, feature_dim) | |
| state_lengths (int tensor): Lengths of states in a batch | |
| - shape: (batch) | |
| Returns: | |
| out (float tensor): observation log likelihoods, | |
| expressing the probability of an observation | |
| being generated from a state i | |
| shape: (batch, hidden_state) | |
| """ | |
| emission_dists = self.distribution_function(means, stds) | |
| out = emission_dists.log_prob(x_t.unsqueeze(1)) | |
| state_lengths_mask = sequence_mask(state_lengths).unsqueeze(2) | |
| out = torch.sum(out * state_lengths_mask, dim=2) | |
| return out | |