import torch from torch.nn import Linear from torch.nn import Sequential from torch.nn import Tanh from Layers.Conformer import Conformer from Layers.DurationPredictor import DurationPredictor from Layers.LengthRegulator import LengthRegulator from Layers.PostNet import PostNet from Layers.VariancePredictor import VariancePredictor from Preprocessing.articulatory_features import get_feature_to_index_lookup from TrainingInterfaces.Text_to_Spectrogram.ToucanTTS.Glow import Glow from Utility.utils import make_non_pad_mask class ToucanTTS(torch.nn.Module): def __init__(self, # network structure related input_feature_dimensions=62, output_spectrogram_channels=80, attention_dimension=192, attention_heads=4, positionwise_conv_kernel_size=1, use_scaled_positional_encoding=True, use_macaron_style_in_conformer=True, use_cnn_in_conformer=True, # encoder encoder_layers=6, encoder_units=1536, encoder_normalize_before=True, encoder_concat_after=False, conformer_encoder_kernel_size=7, transformer_enc_dropout_rate=0.2, transformer_enc_positional_dropout_rate=0.2, transformer_enc_attn_dropout_rate=0.2, # decoder decoder_layers=6, decoder_units=1536, decoder_concat_after=False, conformer_decoder_kernel_size=31, decoder_normalize_before=True, transformer_dec_dropout_rate=0.2, transformer_dec_positional_dropout_rate=0.2, transformer_dec_attn_dropout_rate=0.2, # duration predictor duration_predictor_layers=3, duration_predictor_chans=256, duration_predictor_kernel_size=3, duration_predictor_dropout_rate=0.2, # pitch predictor pitch_predictor_layers=7, # 5 in espnet pitch_predictor_chans=256, pitch_predictor_kernel_size=5, pitch_predictor_dropout=0.5, pitch_embed_kernel_size=1, pitch_embed_dropout=0.0, # energy predictor energy_predictor_layers=2, energy_predictor_chans=256, energy_predictor_kernel_size=3, energy_predictor_dropout=0.5, energy_embed_kernel_size=1, energy_embed_dropout=0.0, # additional features utt_embed_dim=64, detach_postflow=True, lang_embs=8000, weights=None): super().__init__() self.input_feature_dimensions = input_feature_dimensions self.output_spectrogram_channels = output_spectrogram_channels self.attention_dimension = attention_dimension self.detach_postflow = detach_postflow self.use_scaled_pos_enc = use_scaled_positional_encoding self.multilingual_model = lang_embs is not None self.multispeaker_model = utt_embed_dim is not None articulatory_feature_embedding = Sequential(Linear(input_feature_dimensions, 100), Tanh(), Linear(100, attention_dimension)) self.encoder = Conformer(idim=input_feature_dimensions, attention_dim=attention_dimension, attention_heads=attention_heads, linear_units=encoder_units, num_blocks=encoder_layers, input_layer=articulatory_feature_embedding, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_encoder_kernel_size, zero_triu=False, utt_embed=utt_embed_dim, lang_embs=lang_embs, use_output_norm=True) self.duration_predictor = DurationPredictor(idim=attention_dimension, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, utt_embed_dim=utt_embed_dim) self.pitch_predictor = VariancePredictor(idim=attention_dimension, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout, utt_embed_dim=utt_embed_dim) self.energy_predictor = VariancePredictor(idim=attention_dimension, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout, utt_embed_dim=utt_embed_dim) self.pitch_embed = Sequential(torch.nn.Conv1d(in_channels=1, out_channels=attention_dimension, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2), torch.nn.Dropout(pitch_embed_dropout)) self.energy_embed = Sequential(torch.nn.Conv1d(in_channels=1, out_channels=attention_dimension, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2), torch.nn.Dropout(energy_embed_dropout)) self.length_regulator = LengthRegulator() self.decoder = Conformer(idim=0, attention_dim=attention_dimension, attention_heads=attention_heads, linear_units=decoder_units, num_blocks=decoder_layers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_decoder_kernel_size, use_output_norm=False) self.feat_out = Linear(attention_dimension, output_spectrogram_channels) self.conv_postnet = PostNet(idim=0, odim=output_spectrogram_channels, n_layers=5, n_chans=256, n_filts=5, use_batch_norm=True, dropout_rate=0.5) self.post_flow = Glow( in_channels=output_spectrogram_channels, hidden_channels=192, # post_glow_hidden kernel_size=5, # post_glow_kernel_size dilation_rate=1, n_blocks=18, # post_glow_n_blocks (original 12 in paper) n_layers=4, # post_glow_n_block_layers (original 3 in paper) n_split=4, n_sqz=2, text_condition_channels=attention_dimension, share_cond_layers=False, # post_share_cond_layers share_wn_layers=4, sigmoid_scale=False, condition_integration_projection=torch.nn.Conv1d(output_spectrogram_channels + attention_dimension, attention_dimension, 5, padding=2) ) self.load_state_dict(weights) self.eval() def _forward(self, text_tensors, text_lengths, gold_durations=None, gold_pitch=None, gold_energy=None, duration_scaling_factor=1.0, utterance_embedding=None, lang_ids=None, pitch_variance_scale=1.0, energy_variance_scale=1.0, pause_duration_scaling_factor=1.0): if not self.multilingual_model: lang_ids = None if not self.multispeaker_model: utterance_embedding = None else: utterance_embedding = torch.nn.functional.normalize(utterance_embedding) # encoding the texts text_masks = make_non_pad_mask(text_lengths, device=text_lengths.device).unsqueeze(-2) encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) # predicting pitch, energy and durations pitch_predictions = self.pitch_predictor(encoded_texts, padding_mask=None, utt_embed=utterance_embedding) if gold_pitch is None else gold_pitch energy_predictions = self.energy_predictor(encoded_texts, padding_mask=None, utt_embed=utterance_embedding) if gold_energy is None else gold_energy predicted_durations = self.duration_predictor.inference(encoded_texts, padding_mask=None, utt_embed=utterance_embedding) if gold_durations is None else gold_durations # modifying the predictions with linguistic knowledge and control parameters for phoneme_index, phoneme_vector in enumerate(text_tensors.squeeze(0)): if phoneme_vector[get_feature_to_index_lookup()["voiced"]] == 0: pitch_predictions[0][phoneme_index] = 0.0 if phoneme_vector[get_feature_to_index_lookup()["phoneme"]] == 0: energy_predictions[0][phoneme_index] = 0.0 if phoneme_vector[get_feature_to_index_lookup()["word-boundary"]] == 1: predicted_durations[0][phoneme_index] = 0 if phoneme_vector[get_feature_to_index_lookup()["silence"]] == 1 and pause_duration_scaling_factor != 1.0: predicted_durations[0][phoneme_index] = torch.round(predicted_durations[0][phoneme_index].float() * pause_duration_scaling_factor).long() if duration_scaling_factor != 1.0: assert duration_scaling_factor > 0 predicted_durations = torch.round(predicted_durations.float() * duration_scaling_factor).long() pitch_predictions = _scale_variance(pitch_predictions, pitch_variance_scale) energy_predictions = _scale_variance(energy_predictions, energy_variance_scale) # enriching the text with pitch and energy info embedded_pitch_curve = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2) embedded_energy_curve = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2) enriched_encoded_texts = encoded_texts + embedded_pitch_curve + embedded_energy_curve # predicting durations for text and upsampling accordingly upsampled_enriched_encoded_texts = self.length_regulator(enriched_encoded_texts, predicted_durations) # decoding spectrogram decoded_speech, _ = self.decoder(upsampled_enriched_encoded_texts, None) decoded_spectrogram = self.feat_out(decoded_speech).view(decoded_speech.size(0), -1, self.output_spectrogram_channels) refined_spectrogram = decoded_spectrogram + self.conv_postnet(decoded_spectrogram.transpose(1, 2)).transpose(1, 2) # refine spectrogram refined_spectrogram = self.post_flow(tgt_mels=None, infer=True, mel_out=refined_spectrogram, encoded_texts=upsampled_enriched_encoded_texts, tgt_nonpadding=None).squeeze() return decoded_spectrogram.squeeze(), refined_spectrogram.squeeze(), predicted_durations.squeeze(), pitch_predictions.squeeze(), energy_predictions.squeeze() @torch.inference_mode() def forward(self, text, durations=None, pitch=None, energy=None, utterance_embedding=None, return_duration_pitch_energy=False, lang_id=None, duration_scaling_factor=1.0, pitch_variance_scale=1.0, energy_variance_scale=1.0, pause_duration_scaling_factor=1.0): """ Generate the sequence of spectrogram frames given the sequence of vectorized phonemes. Args: text: input sequence of vectorized phonemes durations: durations to be used (optional, if not provided, they will be predicted) pitch: token-averaged pitch curve to be used (optional, if not provided, it will be predicted) energy: token-averaged energy curve to be used (optional, if not provided, it will be predicted) return_duration_pitch_energy: whether to return the list of predicted durations for nicer plotting utterance_embedding: embedding of speaker information lang_id: id to be fed into the embedding layer that contains language information duration_scaling_factor: reasonable values are 0.8 < scale < 1.2. 1.0 means no scaling happens, higher values increase durations for the whole utterance, lower values decrease durations for the whole utterance. pitch_variance_scale: reasonable values are 0.6 < scale < 1.4. 1.0 means no scaling happens, higher values increase variance of the pitch curve, lower values decrease variance of the pitch curve. energy_variance_scale: reasonable values are 0.6 < scale < 1.4. 1.0 means no scaling happens, higher values increase variance of the energy curve, lower values decrease variance of the energy curve. pause_duration_scaling_factor: reasonable values are 0.6 < scale < 1.4. scales the durations of pauses on top of the regular duration scaling Returns: mel spectrogram """ # setup batch axis text_length = torch.tensor([text.shape[0]], dtype=torch.long, device=text.device) if durations is not None: durations = durations.unsqueeze(0).to(text.device) if pitch is not None: pitch = pitch.unsqueeze(0).to(text.device) if energy is not None: energy = energy.unsqueeze(0).to(text.device) if lang_id is not None: lang_id = lang_id.unsqueeze(0).to(text.device) before_outs, \ after_outs, \ predicted_durations, \ pitch_predictions, \ energy_predictions = self._forward(text.unsqueeze(0), text_length, gold_durations=durations, gold_pitch=pitch, gold_energy=energy, utterance_embedding=utterance_embedding.unsqueeze(0) if utterance_embedding is not None else None, lang_ids=lang_id, duration_scaling_factor=duration_scaling_factor, pitch_variance_scale=pitch_variance_scale, energy_variance_scale=energy_variance_scale, pause_duration_scaling_factor=pause_duration_scaling_factor) if return_duration_pitch_energy: return after_outs, predicted_durations, pitch_predictions, energy_predictions return after_outs def store_inverse_all(self): def remove_weight_norm(m): try: if hasattr(m, 'store_inverse'): m.store_inverse() torch.nn.utils.remove_weight_norm(m) except ValueError: # this module didn't have weight norm return self.apply(remove_weight_norm) def _scale_variance(sequence, scale): if scale == 1.0: return sequence average = sequence[0][sequence[0] != 0.0].mean() sequence = sequence - average # center sequence around 0 sequence = sequence * scale # scale the variance sequence = sequence + average # move center back to original with changed variance for sequence_index in range(len(sequence[0])): if sequence[0][sequence_index] < 0.0: sequence[0][sequence_index] = 0.0 return sequence