import gradio as gr import argparse import os import torch import soundfile as sf import numpy as np from models.tts.naturalspeech2.ns2 import NaturalSpeech2 from encodec import EncodecModel from encodec.utils import convert_audio from utils.util import load_config from text import text_to_sequence from text.cmudict import valid_symbols from text.g2p import preprocess_english, read_lexicon import torchaudio def build_codec(device): encodec_model = EncodecModel.encodec_model_24khz() encodec_model = encodec_model.to(device=device) encodec_model.set_target_bandwidth(12.0) return encodec_model def build_model(cfg, device): model = NaturalSpeech2(cfg.model) model.load_state_dict( torch.load( "ckpts/ns2/pytorch_model.bin", map_location="cpu", ) ) model = model.to(device=device) return model def ns2_inference( prmopt_audio_path, text, diffusion_steps=100, ): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') os.environ["WORK_DIR"] = "./" cfg = load_config("egs/tts/NaturalSpeech2/exp_config.json") model = build_model(cfg, device) codec = build_codec(device) ref_wav_path = prmopt_audio_path ref_wav, sr = torchaudio.load(ref_wav_path) ref_wav = convert_audio( ref_wav, sr, codec.sample_rate, codec.channels ) ref_wav = ref_wav.unsqueeze(0).to(device=device) with torch.no_grad(): encoded_frames = codec.encode(ref_wav) ref_code = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) ref_mask = torch.ones(ref_code.shape[0], ref_code.shape[-1]).to(ref_code.device) symbols = valid_symbols + ["sp", "spn", "sil"] + ["", ""] phone2id = {s: i for i, s in enumerate(symbols)} id2phone = {i: s for s, i in phone2id.items()} lexicon = read_lexicon(cfg.preprocess.lexicon_path) phone_seq = preprocess_english(text, lexicon) phone_id = np.array( [ *map( phone2id.get, phone_seq.replace("{", "").replace("}", "").split(), ) ] ) phone_id = torch.from_numpy(phone_id).unsqueeze(0).to(device=device) x0, prior_out = model.inference( ref_code, phone_id, ref_mask, diffusion_steps ) latent_ref = codec.quantizer.vq.decode(ref_code.transpose(0, 1)) rec_wav = codec.decoder(x0) os.makedirs("result", exist_ok=True) sf.write( "result/{}.wav".format(prmopt_audio_path.split("/")[-1][:-4] + "_zero_shot_result"), rec_wav[0, 0].detach().cpu().numpy(), samplerate=24000, ) result_file = "result/{}.wav".format(prmopt_audio_path.split("/")[-1][:-4] + "_zero_shot_result") return result_file demo_inputs = [ gr.Audio( sources=["upload", "microphone"], label="Upload a reference speech you want to clone timbre", type="filepath", ), gr.Textbox( value="Amphion is a toolkit that can speak, make sounds, and sing.", label="Text you want to generate", type="text", ), gr.Slider( 10, 1000, value=200, step=1, label="Diffusion Inference Steps", info="As the step number increases, the synthesis quality will be better while the inference speed will be lower", ), ] demo_outputs = gr.Audio(label="") demo = gr.Interface( fn=ns2_inference, inputs=demo_inputs, outputs=demo_outputs, title="Amphion Zero-Shot TTS NaturalSpeech2" ) if __name__ == "__main__": demo.launch()