Spaces:
Sleeping
Sleeping
| import re | |
| import logging | |
| import torch | |
| import torchaudio | |
| import random | |
| import speechbrain | |
| from speechbrain.inference.interfaces import Pretrained | |
| from speechbrain.inference.text import GraphemeToPhoneme | |
| logger = logging.getLogger(__name__) | |
| class TTSInferencing(Pretrained): | |
| """ | |
| A ready-to-use wrapper for TTS (text -> mel_spec). | |
| Arguments | |
| --------- | |
| hparams | |
| Hyperparameters (from HyperPyYAML) | |
| """ | |
| HPARAMS_NEEDED = ["modules", "input_encoder"] | |
| MODULES_NEEDED = ["encoder_prenet", "pos_emb_enc", | |
| "decoder_prenet", "pos_emb_dec", | |
| "Seq2SeqTransformer", "mel_lin", | |
| "stop_lin", "decoder_postnet"] | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| lexicon = self.hparams.lexicon | |
| lexicon = ["@@"] + lexicon | |
| self.input_encoder = self.hparams.input_encoder | |
| self.input_encoder.update_from_iterable(lexicon, sequence_input=False) | |
| self.input_encoder.add_unk() | |
| self.modules = self.hparams.modules | |
| self.g2p = GraphemeToPhoneme.from_hparams("speechbrain/soundchoice-g2p") | |
| def generate_padded_phonemes(self, texts): | |
| """Computes mel-spectrogram for a list of texts | |
| Arguments | |
| --------- | |
| texts: List[str] | |
| texts to be converted to spectrogram | |
| Returns | |
| ------- | |
| tensors of output spectrograms | |
| """ | |
| # Preprocessing required at the inference time for the input text | |
| # "label" below contains input text | |
| # "phoneme_labels" contain the phoneme sequences corresponding to input text labels | |
| phoneme_labels = list() | |
| for label in texts: | |
| phoneme_label = list() | |
| label = self.custom_clean(label).upper() | |
| words = label.split() | |
| words = [word.strip() for word in words] | |
| words_phonemes = self.g2p(words) | |
| for i in range(len(words_phonemes)): | |
| words_phonemes_seq = words_phonemes[i] | |
| for phoneme in words_phonemes_seq: | |
| if not phoneme.isspace(): | |
| phoneme_label.append(phoneme) | |
| phoneme_labels.append(phoneme_label) | |
| # encode the phonemes with input text encoder | |
| encoded_phonemes = list() | |
| for i in range(len(phoneme_labels)): | |
| phoneme_label = phoneme_labels[i] | |
| encoded_phoneme = torch.LongTensor(self.input_encoder.encode_sequence(phoneme_label)).to(self.device) | |
| encoded_phonemes.append(encoded_phoneme) | |
| # Right zero-pad all one-hot text sequences to max input length | |
| input_lengths, ids_sorted_decreasing = torch.sort( | |
| torch.LongTensor([len(x) for x in encoded_phonemes]), dim=0, descending=True | |
| ) | |
| max_input_len = input_lengths[0] | |
| phoneme_padded = torch.LongTensor(len(encoded_phonemes), max_input_len).to(self.device) | |
| phoneme_padded.zero_() | |
| for seq_idx, seq in enumerate(encoded_phonemes): | |
| phoneme_padded[seq_idx, : len(seq)] = seq | |
| return phoneme_padded.to(self.device, non_blocking=True).float() | |
| def encode_batch(self, texts): | |
| """Computes mel-spectrogram for a list of texts | |
| Texts must be sorted in decreasing order on their lengths | |
| Arguments | |
| --------- | |
| texts: List[str] | |
| texts to be encoded into spectrogram | |
| Returns | |
| ------- | |
| tensors of output spectrograms | |
| """ | |
| # generate phonemes and padd the input texts | |
| encoded_phoneme_padded = self.generate_padded_phonemes(texts) | |
| phoneme_prenet_emb = self.modules['encoder_prenet'](encoded_phoneme_padded) | |
| # Positional Embeddings | |
| phoneme_pos_emb = self.modules['pos_emb_enc'](encoded_phoneme_padded) | |
| # Summing up embeddings | |
| enc_phoneme_emb = phoneme_prenet_emb.permute(0,2,1) + phoneme_pos_emb | |
| enc_phoneme_emb = enc_phoneme_emb.to(self.device) | |
| with torch.no_grad(): | |
| # generate sequential predictions via transformer decoder | |
| start_token = torch.full((80, 1), fill_value= 0) | |
| start_token[1] = 2 | |
| decoder_input = start_token.repeat(enc_phoneme_emb.size(0), 1, 1) | |
| decoder_input = decoder_input.to(self.device, non_blocking=True).float() | |
| num_itr = 0 | |
| stop_condition = [False] * decoder_input.size(0) | |
| max_iter = 100 | |
| # while not all(stop_condition) and num_itr < max_iter: | |
| while num_itr < max_iter: | |
| # Decoder Prenet | |
| mel_prenet_emb = self.modules['decoder_prenet'](decoder_input).to(self.device).permute(0,2,1) | |
| # Positional Embeddings | |
| mel_pos_emb = self.modules['pos_emb_dec'](mel_prenet_emb).to(self.device) | |
| # Summing up Embeddings | |
| dec_mel_spec = mel_prenet_emb + mel_pos_emb | |
| # Getting the target mask to avoid looking ahead | |
| tgt_mask = self.hparams.lookahead_mask(dec_mel_spec).to(self.device) | |
| # Getting the source mask | |
| src_mask = torch.zeros(enc_phoneme_emb.shape[1], enc_phoneme_emb.shape[1]).to(self.device) | |
| # Padding masks for source and targets | |
| src_key_padding_mask = self.hparams.padding_mask(enc_phoneme_emb, pad_idx = self.hparams.blank_index).to(self.device) | |
| tgt_key_padding_mask = self.hparams.padding_mask(dec_mel_spec, pad_idx = self.hparams.blank_index).to(self.device) | |
| # Running the Seq2Seq Transformer | |
| decoder_outputs = self.modules['Seq2SeqTransformer'](src = enc_phoneme_emb, tgt = dec_mel_spec, src_mask = src_mask, tgt_mask = tgt_mask, | |
| src_key_padding_mask = src_key_padding_mask, tgt_key_padding_mask = tgt_key_padding_mask) | |
| # Mel Linears | |
| mel_linears = self.modules['mel_lin'](decoder_outputs).permute(0,2,1) | |
| mel_postnet = self.modules['decoder_postnet'](mel_linears) # mel tensor output | |
| mel_pred = mel_linears + mel_postnet # mel tensor output | |
| stop_token_pred = self.modules['stop_lin'](decoder_outputs).squeeze(-1) | |
| stop_condition_list = self.check_stop_condition(stop_token_pred) | |
| # update the values of main stop conditions | |
| stop_condition_update = [True if stop_condition_list[i] else stop_condition[i] for i in range(len(stop_condition))] | |
| stop_condition = stop_condition_update | |
| # Prepare input for the transformer input for next iteration | |
| current_output = mel_pred[:, :, -1:] | |
| decoder_input=torch.cat([decoder_input,current_output],dim=2) | |
| num_itr = num_itr+1 | |
| mel_outputs = decoder_input[:, :, 1:] | |
| return mel_outputs | |
| def encode_text(self, text): | |
| """Runs inference for a single text str""" | |
| return self.encode_batch([text]) | |
| def forward(self, text_list): | |
| "Encodes the input texts." | |
| return self.encode_batch(text_list) | |
| def check_stop_condition(self, stop_token_pred): | |
| """ | |
| check if stop token / EOS reached or not for mel_specs in the batch | |
| """ | |
| # Applying sigmoid to perform binary classification | |
| sigmoid_output = torch.sigmoid(stop_token_pred) | |
| # Checking if the probability is greater than 0.5 | |
| stop_results = sigmoid_output > 0.8 | |
| stop_output = [all(result) for result in stop_results] | |
| return stop_output | |
| def custom_clean(self, text): | |
| """ | |
| Uses custom criteria to clean text. | |
| Arguments | |
| --------- | |
| text : str | |
| Input text to be cleaned | |
| model_name : str | |
| whether to treat punctuations | |
| Returns | |
| ------- | |
| text : str | |
| Cleaned text | |
| """ | |
| _abbreviations = [ | |
| (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1]) | |
| for x in [ | |
| ("mrs", "missus"), | |
| ("mr", "mister"), | |
| ("dr", "doctor"), | |
| ("st", "saint"), | |
| ("co", "company"), | |
| ("jr", "junior"), | |
| ("maj", "major"), | |
| ("gen", "general"), | |
| ("drs", "doctors"), | |
| ("rev", "reverend"), | |
| ("lt", "lieutenant"), | |
| ("hon", "honorable"), | |
| ("sgt", "sergeant"), | |
| ("capt", "captain"), | |
| ("esq", "esquire"), | |
| ("ltd", "limited"), | |
| ("col", "colonel"), | |
| ("ft", "fort"), | |
| ] | |
| ] | |
| text = re.sub(" +", " ", text) | |
| for regex, replacement in _abbreviations: | |
| text = re.sub(regex, replacement, text) | |
| return text | |