import argparse import os import torch import soundfile as sf import numpy as np from models.tts.naturalspeech2.ns2 import NaturalSpeech2 from encodec import EncodecModel from encodec.utils import convert_audio from utils.util import load_config from text import text_to_sequence from text.cmudict import valid_symbols from text.g2p import preprocess_english, read_lexicon import torchaudio class NS2Inference: def __init__(self, args, cfg): self.cfg = cfg self.args = args self.model = self.build_model() self.codec = self.build_codec() self.symbols = valid_symbols + ["sp", "spn", "sil"] + ["", ""] self.phone2id = {s: i for i, s in enumerate(self.symbols)} self.id2phone = {i: s for s, i in self.phone2id.items()} def build_model(self): model = NaturalSpeech2(self.cfg.model) model.load_state_dict( torch.load( os.path.join(self.args.checkpoint_path, "pytorch_model.bin"), map_location="cpu", ) ) model = model.to(self.args.device) return model def build_codec(self): encodec_model = EncodecModel.encodec_model_24khz() encodec_model = encodec_model.to(device=self.args.device) encodec_model.set_target_bandwidth(12.0) return encodec_model def get_ref_code(self): ref_wav_path = self.args.ref_audio ref_wav, sr = torchaudio.load(ref_wav_path) ref_wav = convert_audio( ref_wav, sr, self.codec.sample_rate, self.codec.channels ) ref_wav = ref_wav.unsqueeze(0).to(device=self.args.device) with torch.no_grad(): encoded_frames = self.codec.encode(ref_wav) ref_code = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # print(ref_code.shape) ref_mask = torch.ones(ref_code.shape[0], ref_code.shape[-1]).to(ref_code.device) # print(ref_mask.shape) return ref_code, ref_mask def inference(self): ref_code, ref_mask = self.get_ref_code() lexicon = read_lexicon(self.cfg.preprocess.lexicon_path) phone_seq = preprocess_english(self.args.text, lexicon) print(phone_seq) phone_id = np.array( [ *map( self.phone2id.get, phone_seq.replace("{", "").replace("}", "").split(), ) ] ) phone_id = torch.from_numpy(phone_id).unsqueeze(0).to(device=self.args.device) print(phone_id) x0, prior_out = self.model.inference( ref_code, phone_id, ref_mask, self.args.inference_step ) print(prior_out["dur_pred"]) print(prior_out["dur_pred_round"]) print(torch.sum(prior_out["dur_pred_round"])) latent_ref = self.codec.quantizer.vq.decode(ref_code.transpose(0, 1)) rec_wav = self.codec.decoder(x0) # ref_wav = self.codec.decoder(latent_ref) os.makedirs(self.args.output_dir, exist_ok=True) sf.write( "{}/{}.wav".format( self.args.output_dir, self.args.text.replace(" ", "_", 100) ), rec_wav[0, 0].detach().cpu().numpy(), samplerate=24000, ) def add_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--ref_audio", type=str, default="", help="Reference audio path", ) parser.add_argument( "--device", type=str, default="cuda", ) parser.add_argument( "--inference_step", type=int, default=200, help="Total inference steps for the diffusion model", )