Spaces:
Paused
Paused
| from typing import List, Tuple | |
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from tqdm.auto import tqdm | |
| from TTS.tts.layers.tacotron.common_layers import Linear | |
| from TTS.tts.layers.tacotron.tacotron2 import ConvBNBlock | |
| class Encoder(nn.Module): | |
| r"""Neural HMM Encoder | |
| Same as Tacotron 2 encoder but increases the input length by states per phone | |
| Args: | |
| num_chars (int): Number of characters in the input. | |
| state_per_phone (int): Number of states per phone. | |
| in_out_channels (int): number of input and output channels. | |
| n_convolutions (int): number of convolutional layers. | |
| """ | |
| def __init__(self, num_chars, state_per_phone, in_out_channels=512, n_convolutions=3): | |
| super().__init__() | |
| self.state_per_phone = state_per_phone | |
| self.in_out_channels = in_out_channels | |
| self.emb = nn.Embedding(num_chars, in_out_channels) | |
| self.convolutions = nn.ModuleList() | |
| for _ in range(n_convolutions): | |
| self.convolutions.append(ConvBNBlock(in_out_channels, in_out_channels, 5, "relu")) | |
| self.lstm = nn.LSTM( | |
| in_out_channels, | |
| int(in_out_channels / 2) * state_per_phone, | |
| num_layers=1, | |
| batch_first=True, | |
| bias=True, | |
| bidirectional=True, | |
| ) | |
| self.rnn_state = None | |
| def forward(self, x: torch.FloatTensor, x_len: torch.LongTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]: | |
| """Forward pass to the encoder. | |
| Args: | |
| x (torch.FloatTensor): input text indices. | |
| - shape: :math:`(b, T_{in})` | |
| x_len (torch.LongTensor): input text lengths. | |
| - shape: :math:`(b,)` | |
| Returns: | |
| Tuple[torch.FloatTensor, torch.LongTensor]: encoder outputs and output lengths. | |
| -shape: :math:`((b, T_{in} * states_per_phone, in_out_channels), (b,))` | |
| """ | |
| b, T = x.shape | |
| o = self.emb(x).transpose(1, 2) | |
| for layer in self.convolutions: | |
| o = layer(o) | |
| o = o.transpose(1, 2) | |
| o = nn.utils.rnn.pack_padded_sequence(o, x_len.cpu(), batch_first=True) | |
| self.lstm.flatten_parameters() | |
| o, _ = self.lstm(o) | |
| o, _ = nn.utils.rnn.pad_packed_sequence(o, batch_first=True) | |
| o = o.reshape(b, T * self.state_per_phone, self.in_out_channels) | |
| x_len = x_len * self.state_per_phone | |
| return o, x_len | |
| def inference(self, x, x_len): | |
| """Inference to the encoder. | |
| Args: | |
| x (torch.FloatTensor): input text indices. | |
| - shape: :math:`(b, T_{in})` | |
| x_len (torch.LongTensor): input text lengths. | |
| - shape: :math:`(b,)` | |
| Returns: | |
| Tuple[torch.FloatTensor, torch.LongTensor]: encoder outputs and output lengths. | |
| -shape: :math:`((b, T_{in} * states_per_phone, in_out_channels), (b,))` | |
| """ | |
| b, T = x.shape | |
| o = self.emb(x).transpose(1, 2) | |
| for layer in self.convolutions: | |
| o = layer(o) | |
| o = o.transpose(1, 2) | |
| # self.lstm.flatten_parameters() | |
| o, _ = self.lstm(o) | |
| o = o.reshape(b, T * self.state_per_phone, self.in_out_channels) | |
| x_len = x_len * self.state_per_phone | |
| return o, x_len | |
| class ParameterModel(nn.Module): | |
| r"""Main neural network of the outputnet | |
| Note: Do not put dropout layers here, the model will not converge. | |
| Args: | |
| outputnet_size (List[int]): the architecture of the parameter model | |
| input_size (int): size of input for the first layer | |
| output_size (int): size of output i.e size of the feature dim | |
| frame_channels (int): feature dim to set the flat start bias | |
| flat_start_params (dict): flat start parameters to set the bias | |
| """ | |
| def __init__( | |
| self, | |
| outputnet_size: List[int], | |
| input_size: int, | |
| output_size: int, | |
| frame_channels: int, | |
| flat_start_params: dict, | |
| ): | |
| super().__init__() | |
| self.frame_channels = frame_channels | |
| self.layers = nn.ModuleList( | |
| [Linear(inp, out) for inp, out in zip([input_size] + outputnet_size[:-1], outputnet_size)] | |
| ) | |
| self.last_layer = nn.Linear(outputnet_size[-1], output_size) | |
| self.flat_start_output_layer( | |
| flat_start_params["mean"], flat_start_params["std"], flat_start_params["transition_p"] | |
| ) | |
| def flat_start_output_layer(self, mean, std, transition_p): | |
| self.last_layer.weight.data.zero_() | |
| self.last_layer.bias.data[0 : self.frame_channels] = mean | |
| self.last_layer.bias.data[self.frame_channels : 2 * self.frame_channels] = OverflowUtils.inverse_softplus(std) | |
| self.last_layer.bias.data[2 * self.frame_channels :] = OverflowUtils.inverse_sigmod(transition_p) | |
| def forward(self, x): | |
| for layer in self.layers: | |
| x = F.relu(layer(x)) | |
| x = self.last_layer(x) | |
| return x | |
| class Outputnet(nn.Module): | |
| r""" | |
| This network takes current state and previous observed values as input | |
| and returns its parameters, mean, standard deviation and probability | |
| of transition to the next state | |
| """ | |
| def __init__( | |
| self, | |
| encoder_dim: int, | |
| memory_rnn_dim: int, | |
| frame_channels: int, | |
| outputnet_size: List[int], | |
| flat_start_params: dict, | |
| std_floor: float = 1e-2, | |
| ): | |
| super().__init__() | |
| self.frame_channels = frame_channels | |
| self.flat_start_params = flat_start_params | |
| self.std_floor = std_floor | |
| input_size = memory_rnn_dim + encoder_dim | |
| output_size = 2 * frame_channels + 1 | |
| self.parametermodel = ParameterModel( | |
| outputnet_size=outputnet_size, | |
| input_size=input_size, | |
| output_size=output_size, | |
| flat_start_params=flat_start_params, | |
| frame_channels=frame_channels, | |
| ) | |
| def forward(self, ar_mels, inputs): | |
| r"""Inputs observation and returns the means, stds and transition probability for the current state | |
| Args: | |
| ar_mel_inputs (torch.FloatTensor): shape (batch, prenet_dim) | |
| states (torch.FloatTensor): (batch, hidden_states, hidden_state_dim) | |
| Returns: | |
| means: means for the emission observation for each feature | |
| - shape: (B, hidden_states, feature_size) | |
| stds: standard deviations for the emission observation for each feature | |
| - shape: (batch, hidden_states, feature_size) | |
| transition_vectors: transition vector for the current hidden state | |
| - shape: (batch, hidden_states) | |
| """ | |
| batch_size, prenet_dim = ar_mels.shape[0], ar_mels.shape[1] | |
| N = inputs.shape[1] | |
| ar_mels = ar_mels.unsqueeze(1).expand(batch_size, N, prenet_dim) | |
| ar_mels = torch.cat((ar_mels, inputs), dim=2) | |
| ar_mels = self.parametermodel(ar_mels) | |
| mean, std, transition_vector = ( | |
| ar_mels[:, :, 0 : self.frame_channels], | |
| ar_mels[:, :, self.frame_channels : 2 * self.frame_channels], | |
| ar_mels[:, :, 2 * self.frame_channels :].squeeze(2), | |
| ) | |
| std = F.softplus(std) | |
| std = self._floor_std(std) | |
| return mean, std, transition_vector | |
| def _floor_std(self, std): | |
| r""" | |
| It clamps the standard deviation to not to go below some level | |
| This removes the problem when the model tries to cheat for higher likelihoods by converting | |
| one of the gaussians to a point mass. | |
| Args: | |
| std (float Tensor): tensor containing the standard deviation to be | |
| """ | |
| original_tensor = std.clone().detach() | |
| std = torch.clamp(std, min=self.std_floor) | |
| if torch.any(original_tensor != std): | |
| print( | |
| "[*] Standard deviation was floored! The model is preventing overfitting, nothing serious to worry about" | |
| ) | |
| return std | |
| class OverflowUtils: | |
| def get_data_parameters_for_flat_start( | |
| data_loader: torch.utils.data.DataLoader, out_channels: int, states_per_phone: int | |
| ): | |
| """Generates data parameters for flat starting the HMM. | |
| Args: | |
| data_loader (torch.utils.data.Dataloader): _description_ | |
| out_channels (int): mel spectrogram channels | |
| states_per_phone (_type_): HMM states per phone | |
| """ | |
| # State related information for transition_p | |
| total_state_len = 0 | |
| total_mel_len = 0 | |
| # Useful for data mean an std | |
| total_mel_sum = 0 | |
| total_mel_sq_sum = 0 | |
| for batch in tqdm(data_loader, leave=False): | |
| text_lengths = batch["token_id_lengths"] | |
| mels = batch["mel"] | |
| mel_lengths = batch["mel_lengths"] | |
| total_state_len += torch.sum(text_lengths) | |
| total_mel_len += torch.sum(mel_lengths) | |
| total_mel_sum += torch.sum(mels) | |
| total_mel_sq_sum += torch.sum(torch.pow(mels, 2)) | |
| data_mean = total_mel_sum / (total_mel_len * out_channels) | |
| data_std = torch.sqrt((total_mel_sq_sum / (total_mel_len * out_channels)) - torch.pow(data_mean, 2)) | |
| average_num_states = total_state_len / len(data_loader.dataset) | |
| average_mel_len = total_mel_len / len(data_loader.dataset) | |
| average_duration_each_state = average_mel_len / average_num_states | |
| init_transition_prob = 1 / average_duration_each_state | |
| return data_mean, data_std, (init_transition_prob * states_per_phone) | |
| def update_flat_start_transition(model, transition_p): | |
| model.neural_hmm.output_net.parametermodel.flat_start_output_layer(0.0, 1.0, transition_p) | |
| def log_clamped(x, eps=1e-04): | |
| """ | |
| Avoids the log(0) problem | |
| Args: | |
| x (torch.tensor): input tensor | |
| eps (float, optional): lower bound. Defaults to 1e-04. | |
| Returns: | |
| torch.tensor: :math:`log(x)` | |
| """ | |
| clamped_x = torch.clamp(x, min=eps) | |
| return torch.log(clamped_x) | |
| def inverse_sigmod(x): | |
| r""" | |
| Inverse of the sigmoid function | |
| """ | |
| if not torch.is_tensor(x): | |
| x = torch.tensor(x) | |
| return OverflowUtils.log_clamped(x / (1.0 - x)) | |
| def inverse_softplus(x): | |
| r""" | |
| Inverse of the softplus function | |
| """ | |
| if not torch.is_tensor(x): | |
| x = torch.tensor(x) | |
| return OverflowUtils.log_clamped(torch.exp(x) - 1.0) | |
| def logsumexp(x, dim): | |
| r""" | |
| Differentiable LogSumExp: Does not creates nan gradients | |
| when all the inputs are -inf yeilds 0 gradients. | |
| Args: | |
| x : torch.Tensor - The input tensor | |
| dim: int - The dimension on which the log sum exp has to be applied | |
| """ | |
| m, _ = x.max(dim=dim) | |
| mask = m == -float("inf") | |
| s = (x - m.masked_fill_(mask, 0).unsqueeze(dim=dim)).exp().sum(dim=dim) | |
| return s.masked_fill_(mask, 1).log() + m.masked_fill_(mask, -float("inf")) | |
| def double_pad(list_of_different_shape_tensors): | |
| r""" | |
| Pads the list of tensors in 2 dimensions | |
| """ | |
| second_dim_lens = [len(a) for a in [i[0] for i in list_of_different_shape_tensors]] | |
| second_dim_max = max(second_dim_lens) | |
| padded_x = [F.pad(x, (0, second_dim_max - len(x[0]))) for x in list_of_different_shape_tensors] | |
| return nn.utils.rnn.pad_sequence(padded_x, batch_first=True) | |