NaturalSpeech2 / app.py
yuancwang
add app
9893813
raw
history blame
No virus
3.58 kB
import gradio as gr
import argparse
import os
import torch
import soundfile as sf
import numpy as np
from models.tts.naturalspeech2.ns2 import NaturalSpeech2
from encodec import EncodecModel
from encodec.utils import convert_audio
from utils.util import load_config
from text import text_to_sequence
from text.cmudict import valid_symbols
from text.g2p import preprocess_english, read_lexicon
import torchaudio
def build_codec(device):
encodec_model = EncodecModel.encodec_model_24khz()
encodec_model = encodec_model.to(device=device)
encodec_model.set_target_bandwidth(12.0)
return encodec_model
def build_model(cfg, device):
model = NaturalSpeech2(cfg.model)
model.load_state_dict(
torch.load(
"ckpts/ns2/pytorch_model.bin",
map_location="cpu",
)
)
model = model.to(device=device)
return model
def ns2_inference(
prmopt_audio_path,
text,
diffusion_steps=100,
):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
os.environ["WORK_DIR"] = "./"
cfg = load_config("egs/tts/NaturalSpeech2/exp_config.json")
model = build_model(cfg, device)
codec = build_codec(device)
ref_wav_path = prmopt_audio_path
ref_wav, sr = torchaudio.load(ref_wav_path)
ref_wav = convert_audio(
ref_wav, sr, codec.sample_rate, codec.channels
)
ref_wav = ref_wav.unsqueeze(0).to(device=device)
with torch.no_grad():
encoded_frames = codec.encode(ref_wav)
ref_code = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)
ref_mask = torch.ones(ref_code.shape[0], ref_code.shape[-1]).to(ref_code.device)
symbols = valid_symbols + ["sp", "spn", "sil"] + ["<s>", "</s>"]
phone2id = {s: i for i, s in enumerate(symbols)}
id2phone = {i: s for s, i in phone2id.items()}
lexicon = read_lexicon(cfg.preprocess.lexicon_path)
phone_seq = preprocess_english(text, lexicon)
phone_id = np.array(
[
*map(
phone2id.get,
phone_seq.replace("{", "").replace("}", "").split(),
)
]
)
phone_id = torch.from_numpy(phone_id).unsqueeze(0).to(device=device)
x0, prior_out = model.inference(
ref_code, phone_id, ref_mask, diffusion_steps
)
latent_ref = codec.quantizer.vq.decode(ref_code.transpose(0, 1))
rec_wav = codec.decoder(x0)
os.makedirs("result", exist_ok=True)
sf.write(
"result/{}.wav".format(prmopt_audio_path.split("/")[-1][:-4] + "_zero_shot_result"),
rec_wav[0, 0].detach().cpu().numpy(),
samplerate=24000,
)
result_file = "result/{}.wav".format(prmopt_audio_path.split("/")[-1][:-4] + "_zero_shot_result")
return result_file
demo_inputs = [
gr.Audio(
sources=["upload", "microphone"],
label="Upload a reference speech you want to clone timbre",
type="filepath",
),
gr.Textbox(
value="Amphion is a toolkit that can speak, make sounds, and sing.",
label="Text you want to generate",
type="text",
),
gr.Slider(
10,
1000,
value=200,
step=1,
label="Diffusion Inference Steps",
info="As the step number increases, the synthesis quality will be better while the inference speed will be lower",
),
]
demo_outputs = gr.Audio(label="")
demo = gr.Interface(
fn=ns2_inference,
inputs=demo_inputs,
outputs=demo_outputs,
title="Amphion Zero-Shot TTS NaturalSpeech2"
)
if __name__ == "__main__":
demo.launch()