import re import logging import torch import torchaudio import random import speechbrain from speechbrain.inference.interfaces import Pretrained from speechbrain.inference.text import GraphemeToPhoneme logger = logging.getLogger(__name__) class TTSInferencing(Pretrained): """ A ready-to-use wrapper for TTS (text -> mel_spec). Arguments --------- hparams Hyperparameters (from HyperPyYAML) """ HPARAMS_NEEDED = ["modules", "input_encoder"] MODULES_NEEDED = ["encoder_prenet", "pos_emb_enc", "decoder_prenet", "pos_emb_dec", "Seq2SeqTransformer", "mel_lin", "stop_lin", "decoder_postnet"] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) lexicon = self.hparams.lexicon lexicon = ["@@"] + lexicon self.input_encoder = self.hparams.input_encoder self.input_encoder.update_from_iterable(lexicon, sequence_input=False) self.input_encoder.add_unk() self.modules = self.hparams.modules self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p") def generate_padded_phonemes(self, texts): """Computes mel-spectrogram for a list of texts Arguments --------- texts: List[str] texts to be converted to spectrogram Returns ------- tensors of output spectrograms """ # Preprocessing required at the inference time for the input text # "label" below contains input text # "phoneme_labels" contain the phoneme sequences corresponding to input text labels phoneme_labels = list() for label in texts: phoneme_label = list() label = self.custom_clean(label).upper() words = label.split() words = [word.strip() for word in words] words_phonemes = self.g2p(words) for i in range(len(words_phonemes)): words_phonemes_seq = words_phonemes[i] for phoneme in words_phonemes_seq: if not phoneme.isspace(): phoneme_label.append(phoneme) phoneme_labels.append(phoneme_label) # encode the phonemes with input text encoder encoded_phonemes = list() for i in range(len(phoneme_labels)): phoneme_label = phoneme_labels[i] encoded_phoneme = torch.LongTensor(self.input_encoder.encode_sequence(phoneme_label)).to(self.device) encoded_phonemes.append(encoded_phoneme) # Right zero-pad all one-hot text sequences to max input length input_lengths, ids_sorted_decreasing = torch.sort( torch.LongTensor([len(x) for x in encoded_phonemes]), dim=0, descending=True ) max_input_len = input_lengths[0] phoneme_padded = torch.LongTensor(len(encoded_phonemes), max_input_len).to(self.device) phoneme_padded.zero_() for seq_idx, seq in enumerate(encoded_phonemes): phoneme_padded[seq_idx, : len(seq)] = seq return phoneme_padded.to(self.device, non_blocking=True).float() def encode_batch(self, texts): """Computes mel-spectrogram for a list of texts Texts must be sorted in decreasing order on their lengths Arguments --------- texts: List[str] texts to be encoded into spectrogram Returns ------- tensors of output spectrograms """ # generate phonemes and padd the input texts encoded_phoneme_padded = self.generate_padded_phonemes(texts) phoneme_prenet_emb = self.modules['encoder_prenet'](encoded_phoneme_padded) # Positional Embeddings phoneme_pos_emb = self.modules['pos_emb_enc'](encoded_phoneme_padded) # Summing up embeddings enc_phoneme_emb = phoneme_prenet_emb.permute(0,2,1) + phoneme_pos_emb enc_phoneme_emb = enc_phoneme_emb.to(self.device) with torch.no_grad(): # generate sequential predictions via transformer decoder start_token = torch.full((80, 1), fill_value= 0) start_token[1] = 2 decoder_input = start_token.repeat(enc_phoneme_emb.size(0), 1, 1) decoder_input = decoder_input.to(self.device, non_blocking=True).float() num_itr = 0 stop_condition = [False] * decoder_input.size(0) max_iter = 100 # while not all(stop_condition) and num_itr < max_iter: while num_itr < max_iter: # Decoder Prenet mel_prenet_emb = self.modules['decoder_prenet'](decoder_input).to(self.device).permute(0,2,1) # Positional Embeddings mel_pos_emb = self.modules['pos_emb_dec'](mel_prenet_emb).to(self.device) # Summing up Embeddings dec_mel_spec = mel_prenet_emb + mel_pos_emb # Getting the target mask to avoid looking ahead tgt_mask = self.hparams.lookahead_mask(dec_mel_spec).to(self.device) # Getting the source mask src_mask = torch.zeros(enc_phoneme_emb.shape[1], enc_phoneme_emb.shape[1]).to(self.device) # Padding masks for source and targets src_key_padding_mask = self.hparams.padding_mask(enc_phoneme_emb, pad_idx = self.hparams.blank_index).to(self.device) tgt_key_padding_mask = self.hparams.padding_mask(dec_mel_spec, pad_idx = self.hparams.blank_index).to(self.device) # Running the Seq2Seq Transformer decoder_outputs = self.modules['Seq2SeqTransformer'](src = enc_phoneme_emb, tgt = dec_mel_spec, src_mask = src_mask, tgt_mask = tgt_mask, src_key_padding_mask = src_key_padding_mask, tgt_key_padding_mask = tgt_key_padding_mask) # Mel Linears mel_linears = self.modules['mel_lin'](decoder_outputs).permute(0,2,1) mel_postnet = self.modules['decoder_postnet'](mel_linears) # mel tensor output mel_pred = mel_linears + mel_postnet # mel tensor output stop_token_pred = self.modules['stop_lin'](decoder_outputs).squeeze(-1) stop_condition_list = self.check_stop_condition(stop_token_pred) # update the values of main stop conditions stop_condition_update = [True if stop_condition_list[i] else stop_condition[i] for i in range(len(stop_condition))] stop_condition = stop_condition_update # Prepare input for the transformer input for next iteration current_output = mel_pred[:, :, -1:] decoder_input=torch.cat([decoder_input,current_output],dim=2) num_itr = num_itr+1 mel_outputs = decoder_input[:, :, 1:] return mel_outputs def encode_text(self, text): """Runs inference for a single text str""" return self.encode_batch([text]) def forward(self, text_list): "Encodes the input texts." return self.encode_batch(text_list) def check_stop_condition(self, stop_token_pred): """ check if stop token / EOS reached or not for mel_specs in the batch """ # Applying sigmoid to perform binary classification sigmoid_output = torch.sigmoid(stop_token_pred) # Checking if the probability is greater than 0.5 stop_results = sigmoid_output > 0.8 stop_output = [all(result) for result in stop_results] return stop_output def custom_clean(self, text): """ Uses custom criteria to clean text. Arguments --------- text : str Input text to be cleaned model_name : str whether to treat punctuations Returns ------- text : str Cleaned text """ _abbreviations = [ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) for x in [ ("mrs", "missus"), ("mr", "mister"), ("dr", "doctor"), ("st", "saint"), ("co", "company"), ("jr", "junior"), ("maj", "major"), ("gen", "general"), ("drs", "doctors"), ("rev", "reverend"), ("lt", "lieutenant"), ("hon", "honorable"), ("sgt", "sergeant"), ("capt", "captain"), ("esq", "esquire"), ("ltd", "limited"), ("col", "colonel"), ("ft", "fort"), ] ] text = re.sub(" +", " ", text) for regex, replacement in _abbreviations: text = re.sub(regex, replacement, text) return text