# Copyright (c) 2023 Amphion. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import os import torch from tqdm import tqdm from collections import OrderedDict from models.tts.base.tts_inferece import TTSInference from models.tts.fastspeech2.fs2_dataset import FS2TestDataset, FS2TestCollator from utils.util import load_config from utils.io import save_audio from models.tts.fastspeech2.fs2 import FastSpeech2 from models.vocoders.vocoder_inference import synthesis from pathlib import Path from processors.phone_extractor import phoneExtractor from text.text_token_collation import phoneIDCollation import numpy as np import json class FastSpeech2Inference(TTSInference): def __init__(self, args, cfg): TTSInference.__init__(self, args, cfg) self.args = args self.cfg = cfg self.infer_type = args.mode def _build_model(self): self.model = FastSpeech2(self.cfg) return self.model def load_model(self, state_dict): raw_dict = state_dict["model"] clean_dict = OrderedDict() for k, v in raw_dict.items(): if k.startswith("module."): clean_dict[k[7:]] = v else: clean_dict[k] = v self.model.load_state_dict(clean_dict) def _build_test_dataset(self): return FS2TestDataset, FS2TestCollator @staticmethod def _parse_vocoder(vocoder_dir): r"""Parse vocoder config""" vocoder_dir = os.path.abspath(vocoder_dir) ckpt_list = [ckpt for ckpt in Path(vocoder_dir).glob("*.pt")] # last step (different from the base *int(x.stem)*) ckpt_list.sort( key=lambda x: int(x.stem.split("_")[-2].split("-")[-1]), reverse=True ) ckpt_path = str(ckpt_list[0]) vocoder_cfg = load_config( os.path.join(vocoder_dir, "args.json"), lowercase=True ) return vocoder_cfg, ckpt_path @torch.inference_mode() def inference_for_batches(self): y_pred = [] for i, batch in tqdm(enumerate(self.test_dataloader)): y_pred, mel_lens, _ = self._inference_each_batch(batch) y_ls = y_pred.chunk(self.test_batch_size) tgt_ls = mel_lens.chunk(self.test_batch_size) j = 0 for it, l in zip(y_ls, tgt_ls): l = l.item() it = it.squeeze(0)[:l].detach().cpu() uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"] torch.save(it, os.path.join(self.args.output_dir, f"{uid}.pt")) j += 1 vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir) res = synthesis( cfg=vocoder_cfg, vocoder_weight_file=vocoder_ckpt, n_samples=None, pred=[ torch.load( os.path.join(self.args.output_dir, "{}.pt".format(item["Uid"])) ).numpy() for item in self.test_dataset.metadata ], ) for it, wav in zip(self.test_dataset.metadata, res): uid = it["Uid"] save_audio( os.path.join(self.args.output_dir, f"{uid}.wav"), wav.numpy(), self.cfg.preprocess.sample_rate, add_silence=True, turn_up=True, ) os.remove(os.path.join(self.args.output_dir, f"{uid}.pt")) def _inference_each_batch(self, batch_data): device = self.accelerator.device control_values = ( self.args.pitch_control, self.args.energy_control, self.args.duration_control, ) for k, v in batch_data.items(): batch_data[k] = v.to(device) pitch_control, energy_control, duration_control = control_values output = self.model( batch_data, p_control=pitch_control, e_control=energy_control, d_control=duration_control, ) pred_res = output["postnet_output"] mel_lens = output["mel_lens"].cpu() return pred_res, mel_lens, 0 def inference_for_single_utterance(self): text = self.args.text control_values = ( self.args.pitch_control, self.args.energy_control, self.args.duration_control, ) pitch_control, energy_control, duration_control = control_values # get phone symbol file phone_symbol_file = None if self.cfg.preprocess.phone_extractor != "lexicon": phone_symbol_file = os.path.join( self.exp_dir, self.cfg.preprocess.symbols_dict ) assert os.path.exists(phone_symbol_file) # convert text to phone sequence phone_extractor = phoneExtractor(self.cfg) phone_seq = phone_extractor.extract_phone(text) # phone_seq: list # convert phone sequence to phone id sequence phon_id_collator = phoneIDCollation( self.cfg, symbols_dict_file=phone_symbol_file ) phone_seq = ["{"] + phone_seq + ["}"] phone_id_seq = phon_id_collator.get_phone_id_sequence(self.cfg, phone_seq) # convert phone sequence to phone id sequence phone_id_seq = np.array(phone_id_seq) phone_id_seq = torch.from_numpy(phone_id_seq) # get speaker id if multi-speaker training and use speaker id speaker_id = None if self.cfg.preprocess.use_spkid and self.cfg.train.multi_speaker_training: spk2id_file = os.path.join(self.exp_dir, self.cfg.preprocess.spk2id) with open(spk2id_file, "r") as f: spk2id = json.load(f) speaker_id = spk2id[self.args.speaker_name] speaker_id = torch.from_numpy(np.array([speaker_id], dtype=np.int32)) else: speaker_id = torch.Tensor(0).view(-1) with torch.no_grad(): x_tst = phone_id_seq.to(self.device).unsqueeze(0) x_tst_lengths = torch.LongTensor([phone_id_seq.size(0)]).to(self.device) if speaker_id is not None: speaker_id = speaker_id.to(self.device) data = {} data["texts"] = x_tst data["text_len"] = x_tst_lengths data["spk_id"] = speaker_id output = self.model( data, p_control=pitch_control, e_control=energy_control, d_control=duration_control, ) pred_res = output["postnet_output"] vocoder_cfg, vocoder_ckpt = self._parse_vocoder(self.args.vocoder_dir) audio = synthesis( cfg=vocoder_cfg, vocoder_weight_file=vocoder_ckpt, n_samples=None, pred=pred_res, ) return audio[0]