# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os import time import numpy as np from tqdm import tqdm import torch import json from models.tts.base.tts_inferece import TTSInference from models.tts.vits.vits_dataset import VITSTestDataset, VITSTestCollator from models.tts.vits.vits import SynthesizerTrn from processors.phone_extractor import phoneExtractor from text.text_token_collation import phoneIDCollation class VitsInference(TTSInference): def __init__(self, args=None, cfg=None): TTSInference.__init__(self, args, cfg) def _build_model(self): net_g = SynthesizerTrn( self.cfg.model.text_token_num, self.cfg.preprocess.n_fft // 2 + 1, self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size, **self.cfg.model, ) return net_g def _build_test_dataset(sefl): return VITSTestDataset, VITSTestCollator def build_save_dir(self, dataset, speaker): save_dir = os.path.join( self.args.output_dir, "tts_am_step-{}_{}".format(self.am_restore_step, self.args.mode), ) if dataset is not None: save_dir = os.path.join(save_dir, "data_{}".format(dataset)) if speaker != -1: save_dir = os.path.join( save_dir, "spk_{}".format(speaker), ) os.makedirs(save_dir, exist_ok=True) print("Saving to ", save_dir) return save_dir def inference_for_batches( self, noise_scale=0.667, noise_scale_w=0.8, length_scale=1 ): ###### Construct test_batch ###### n_batch = len(self.test_dataloader) now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(time.time())) print( "Model eval time: {}, batch_size = {}, n_batch = {}".format( now, self.test_batch_size, n_batch ) ) self.model.eval() ###### Inference for each batch ###### pred_res = [] with torch.no_grad(): for i, batch_data in enumerate( self.test_dataloader if n_batch == 1 else tqdm(self.test_dataloader) ): spk_id = None if ( self.cfg.preprocess.use_spkid and self.cfg.train.multi_speaker_training ): spk_id = batch_data["spk_id"] outputs = self.model.infer( batch_data["phone_seq"], batch_data["phone_len"], spk_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, ) audios = outputs["y_hat"] masks = outputs["mask"] for idx in range(audios.size(0)): audio = audios[idx, 0, :].data.cpu().float() mask = masks[idx, :, :] audio_length = ( mask.sum([0, 1]).long() * self.cfg.preprocess.hop_size ) audio_length = audio_length.cpu().numpy() audio = audio[:audio_length] pred_res.append(audio) return pred_res def inference_for_single_utterance( self, noise_scale=0.667, noise_scale_w=0.8, length_scale=1 ): text = self.args.text # get phone symbol file phone_symbol_file = None if self.cfg.preprocess.phone_extractor != "lexicon": phone_symbol_file = os.path.join( self.exp_dir, self.cfg.preprocess.symbols_dict ) assert os.path.exists(phone_symbol_file) # convert text to phone sequence phone_extractor = phoneExtractor(self.cfg) phone_seq = phone_extractor.extract_phone(text) # phone_seq: list # convert phone sequence to phone id sequence phon_id_collator = phoneIDCollation( self.cfg, symbols_dict_file=phone_symbol_file ) phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq) # convert phone sequence to phone id sequence phone_id_seq = np.array(phone_id_seq) phone_id_seq = torch.from_numpy(phone_id_seq) # get speaker id if multi-speaker training and use speaker id speaker_id = None if self.cfg.preprocess.use_spkid and self.cfg.train.multi_speaker_training: spk2id_file = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) with open(spk2id_file, "r") as f: spk2id = json.load(f) speaker_id = spk2id[self.args.speaker_name] speaker_id = torch.from_numpy(np.array([speaker_id], dtype=np.int32)) with torch.no_grad(): x_tst = phone_id_seq.to(self.device).unsqueeze(0) x_tst_lengths = torch.LongTensor([phone_id_seq.size(0)]).to(self.device) if speaker_id is not None: speaker_id = speaker_id.to(self.device) outputs = self.model.infer( x_tst, x_tst_lengths, sid=speaker_id, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale, ) audio = outputs["y_hat"][0, 0].data.cpu().float().numpy() return audio