""" Taken from ESPNet """ from abc import ABC import torch from Layers.Conformer import Conformer from Layers.DurationPredictor import DurationPredictor from Layers.LengthRegulator import LengthRegulator from Layers.PostNet import PostNet from Layers.VariancePredictor import VariancePredictor from TrainingInterfaces.Text_to_Spectrogram.FastSpeech2.FastSpeech2Loss import FastSpeech2Loss from Utility.SoftDTW.sdtw_cuda_loss import SoftDTW from Utility.utils import initialize from Utility.utils import make_non_pad_mask from Utility.utils import make_pad_mask class FastSpeech2(torch.nn.Module, ABC): """ FastSpeech 2 module. This is a module of FastSpeech 2 described in FastSpeech 2: Fast and High-Quality End-to-End Text to Speech. Instead of quantized pitch and energy, we use token-averaged value introduced in FastPitch: Parallel Text-to-speech with Pitch Prediction. The encoder and decoder are Conformers instead of regular Transformers. https://arxiv.org/abs/2006.04558 https://arxiv.org/abs/2006.06873 https://arxiv.org/pdf/2005.08100 """ def __init__(self, # network structure related idim=66, odim=80, adim=384, aheads=4, elayers=6, eunits=1536, dlayers=6, dunits=1536, postnet_layers=5, postnet_chans=256, postnet_filts=5, positionwise_layer_type="conv1d", positionwise_conv_kernel_size=1, use_scaled_pos_enc=True, use_batch_norm=True, encoder_normalize_before=True, decoder_normalize_before=True, encoder_concat_after=False, decoder_concat_after=False, reduction_factor=1, # encoder / decoder use_macaron_style_in_conformer=True, use_cnn_in_conformer=True, conformer_enc_kernel_size=7, conformer_dec_kernel_size=31, # duration predictor duration_predictor_layers=2, duration_predictor_chans=256, duration_predictor_kernel_size=3, # energy predictor energy_predictor_layers=2, energy_predictor_chans=256, energy_predictor_kernel_size=3, energy_predictor_dropout=0.5, energy_embed_kernel_size=1, energy_embed_dropout=0.0, stop_gradient_from_energy_predictor=False, # pitch predictor pitch_predictor_layers=5, pitch_predictor_chans=256, pitch_predictor_kernel_size=5, pitch_predictor_dropout=0.5, pitch_embed_kernel_size=1, pitch_embed_dropout=0.0, stop_gradient_from_pitch_predictor=True, # training related transformer_enc_dropout_rate=0.2, transformer_enc_positional_dropout_rate=0.2, transformer_enc_attn_dropout_rate=0.2, transformer_dec_dropout_rate=0.2, transformer_dec_positional_dropout_rate=0.2, transformer_dec_attn_dropout_rate=0.2, duration_predictor_dropout_rate=0.2, postnet_dropout_rate=0.5, init_type="xavier_uniform", init_enc_alpha=1.0, init_dec_alpha=1.0, use_masking=False, use_weighted_masking=True, # additional features use_dtw_loss=False, utt_embed_dim=704, connect_utt_emb_at_encoder_out=True, lang_embs=100): super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.use_dtw_loss = use_dtw_loss self.eos = 1 self.reduction_factor = reduction_factor self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor self.use_scaled_pos_enc = use_scaled_pos_enc self.multilingual_model = lang_embs is not None self.multispeaker_model = utt_embed_dim is not None # define encoder embed = torch.nn.Sequential(torch.nn.Linear(idim, 100), torch.nn.Tanh(), torch.nn.Linear(100, adim)) self.encoder = Conformer(idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=embed, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=False, utt_embed=utt_embed_dim, connect_utt_emb_at_encoder_out=connect_utt_emb_at_encoder_out, lang_embs=lang_embs) # define duration predictor self.duration_predictor = DurationPredictor(idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define pitch predictor self.pitch_predictor = VariancePredictor(idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout) # continuous pitch + FastPitch style avg self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d(in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2), torch.nn.Dropout(pitch_embed_dropout)) # define energy predictor self.energy_predictor = VariancePredictor(idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout) # continuous energy + FastPitch style avg self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d(in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2), torch.nn.Dropout(energy_embed_dropout)) # define length regulator self.length_regulator = LengthRegulator() self.decoder = Conformer(idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size) # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = PostNet(idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate) # initialize parameters self._reset_parameters(init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha) # define criterions self.criterion = FastSpeech2Loss(use_masking=use_masking, use_weighted_masking=use_weighted_masking) self.dtw_criterion = SoftDTW(use_cuda=True, gamma=0.1) def forward(self, text_tensors, text_lengths, gold_speech, speech_lengths, gold_durations, gold_pitch, gold_energy, utterance_embedding, return_mels=False, lang_ids=None): """ Calculate forward propagation. Args: return_mels: whether to return the predicted spectrogram text_tensors (LongTensor): Batch of padded text vectors (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input (B,). gold_speech (Tensor): Batch of padded target features (B, Lmax, odim). speech_lengths (LongTensor): Batch of the lengths of each target (B,). gold_durations (LongTensor): Batch of padded durations (B, Tmax + 1). gold_pitch (Tensor): Batch of padded token-averaged pitch (B, Tmax + 1, 1). gold_energy (Tensor): Batch of padded token-averaged energy (B, Tmax + 1, 1). Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value. """ # Texts include EOS token from the teacher model already in this version # forward propagation before_outs, after_outs, d_outs, p_outs, e_outs = self._forward(text_tensors, text_lengths, gold_speech, speech_lengths, gold_durations, gold_pitch, gold_energy, utterance_embedding=utterance_embedding, is_inference=False, lang_ids=lang_ids) # modify mod part of groundtruth (speaking pace) if self.reduction_factor > 1: speech_lengths = speech_lengths.new([olen - olen % self.reduction_factor for olen in speech_lengths]) # calculate loss l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion(after_outs=after_outs, before_outs=before_outs, d_outs=d_outs, p_outs=p_outs, e_outs=e_outs, ys=gold_speech, ds=gold_durations, ps=gold_pitch, es=gold_energy, ilens=text_lengths, olens=speech_lengths) loss = l1_loss + duration_loss + pitch_loss + energy_loss if self.use_dtw_loss: # print("Regular Loss: {}".format(loss)) dtw_loss = self.dtw_criterion(after_outs, gold_speech).mean() / 2000.0 # division to balance orders of magnitude # print("DTW Loss: {}".format(dtw_loss)) loss = loss + dtw_loss if return_mels: return loss, after_outs return loss def _forward(self, text_tensors, text_lens, gold_speech=None, speech_lens=None, gold_durations=None, gold_pitch=None, gold_energy=None, is_inference=False, alpha=1.0, utterance_embedding=None, lang_ids=None): if not self.multilingual_model: lang_ids = None if not self.multispeaker_model: utterance_embedding = None # forward encoder text_masks = self._source_mask(text_lens) encoded_texts, _ = self.encoder(text_tensors, text_masks, utterance_embedding=utterance_embedding, lang_ids=lang_ids) # (B, Tmax, adim) # forward duration predictor and variance predictors d_masks = make_pad_mask(text_lens, device=text_lens.device) if self.stop_gradient_from_pitch_predictor: pitch_predictions = self.pitch_predictor(encoded_texts.detach(), d_masks.unsqueeze(-1)) else: pitch_predictions = self.pitch_predictor(encoded_texts, d_masks.unsqueeze(-1)) if self.stop_gradient_from_energy_predictor: energy_predictions = self.energy_predictor(encoded_texts.detach(), d_masks.unsqueeze(-1)) else: energy_predictions = self.energy_predictor(encoded_texts, d_masks.unsqueeze(-1)) if is_inference: d_outs = self.duration_predictor.inference(encoded_texts, d_masks) # (B, Tmax) # use prediction in inference p_embs = self.pitch_embed(pitch_predictions.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(energy_predictions.transpose(1, 2)).transpose(1, 2) encoded_texts = encoded_texts + e_embs + p_embs encoded_texts = self.length_regulator(encoded_texts, d_outs, alpha) # (B, Lmax, adim) else: d_outs = self.duration_predictor(encoded_texts, d_masks) # use groundtruth in training p_embs = self.pitch_embed(gold_pitch.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(gold_energy.transpose(1, 2)).transpose(1, 2) encoded_texts = encoded_texts + e_embs + p_embs encoded_texts = self.length_regulator(encoded_texts, gold_durations) # (B, Lmax, adim) # forward decoder if speech_lens is not None and not is_inference: if self.reduction_factor > 1: olens_in = speech_lens.new([olen // self.reduction_factor for olen in speech_lens]) else: olens_in = speech_lens h_masks = self._source_mask(olens_in) else: h_masks = None zs, _ = self.decoder(encoded_texts, h_masks) # (B, Lmax, adim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim) # postnet -> (B, Lmax//r * r, odim) after_outs = before_outs + self.postnet(before_outs.transpose(1, 2)).transpose(1, 2) return before_outs, after_outs, d_outs, pitch_predictions, energy_predictions def batch_inference(self, texts, text_lens, utt_emb): _, after_outs, d_outs, _, _ = self._forward(texts, text_lens, None, is_inference=True, alpha=1.0) return after_outs, d_outs def inference(self, text, speech=None, durations=None, pitch=None, energy=None, alpha=1.0, use_teacher_forcing=False, utterance_embedding=None, return_duration_pitch_energy=False, lang_id=None): """ Generate the sequence of features given the sequences of characters. Args: text (LongTensor): Input sequence of characters (T,). speech (Tensor, optional): Feature sequence to extract style (N, idim). durations (LongTensor, optional): Groundtruth of duration (T + 1,). pitch (Tensor, optional): Groundtruth of token-averaged pitch (T + 1, 1). energy (Tensor, optional): Groundtruth of token-averaged energy (T + 1, 1). alpha (float, optional): Alpha to control the speed. use_teacher_forcing (bool, optional): Whether to use teacher forcing. If true, groundtruth of duration, pitch and energy will be used. return_duration_pitch_energy: whether to return the list of predicted durations for nicer plotting Returns: Tensor: Output sequence of features (L, odim). """ self.eval() x, y = text, speech d, p, e = durations, pitch, energy # setup batch axis ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) xs, ys = x.unsqueeze(0), None if y is not None: ys = y.unsqueeze(0) if lang_id is not None: lang_id = lang_id.unsqueeze(0) if use_teacher_forcing: # use groundtruth of duration, pitch, and energy ds, ps, es = d.unsqueeze(0), p.unsqueeze(0), e.unsqueeze(0) before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(xs, ilens, ys, gold_durations=ds, gold_pitch=ps, gold_energy=es, utterance_embedding=utterance_embedding.unsqueeze(0), lang_ids=lang_id) # (1, L, odim) else: before_outs, after_outs, d_outs, pitch_predictions, energy_predictions = self._forward(xs, ilens, ys, is_inference=True, alpha=alpha, utterance_embedding=utterance_embedding.unsqueeze(0), lang_ids=lang_id) # (1, L, odim) self.train() if return_duration_pitch_energy: return after_outs[0], d_outs[0], pitch_predictions[0], energy_predictions[0] return after_outs[0] def _source_mask(self, ilens): """ Make masks for self-attention. Args: ilens (LongTensor): Batch of lengths (B,). Returns: Tensor: Mask tensor for self-attention. """ x_masks = make_non_pad_mask(ilens, device=ilens.device) return x_masks.unsqueeze(-2) def _reset_parameters(self, init_type, init_enc_alpha, init_dec_alpha): # initialize parameters if init_type != "pytorch": initialize(self, init_type)