w11wo's picture
fixed inference fn
b17c40c
raw
history blame
3.28 kB
import os
import sys
import json
from subprocess import call
import torch
import gradio as gr
from scipy.io.wavfile import write
from huggingface_hub import hf_hub_url, cached_download
import nltk
from nltk.tokenize import word_tokenize
nltk.download("punkt")
AUTH_TOKEN = os.environ["HF_TOKEN"]
# download models
url = hf_hub_url(repo_id="bookbot/grad-tts-en-ft-mixed", filename="grad_1000.pt")
grad_tts_model_path = cached_download(url, use_auth_token=AUTH_TOKEN)
torch.hub.download_url_to_file(
"https://github.com/AK391/Speech-Backbones/releases/download/v1/hifigan.pt",
"hifigan.pt",
)
# build MAS
current = os.getcwd()
os.chdir(current + "/Grad-TTS/model/monotonic_align")
call("python setup.py build_ext --inplace", shell=True)
os.chdir("../../../")
sys.path.append("Grad-TTS/")
import params
from model import GradTTS
from text import text_to_sequence, cmudict
from text.symbols import symbols
from utils import intersperse
sys.path.append("Grad-TTS/hifi-gan/")
from env import AttrDict
from models import Generator as HiFiGAN
SPEAKERS = 247
# load models
generator = GradTTS(
len(symbols) + 1,
SPEAKERS,
params.spk_emb_dim,
params.n_enc_channels,
params.filter_channels,
params.filter_channels_dp,
params.n_heads,
params.n_enc_layers,
params.enc_kernel,
params.enc_dropout,
params.window_size,
params.n_feats,
params.dec_dim,
params.beta_min,
params.beta_max,
pe_scale=1000,
)
generator.load_state_dict(
torch.load(grad_tts_model_path, map_location=lambda loc, storage: loc)
)
_ = generator.eval()
cmu = cmudict.CMUDict("./Grad-TTS/resources/cmu_dictionary_id_en")
with open("./Grad-TTS/checkpts/hifigan-config.json") as f:
h = AttrDict(json.load(f))
hifigan = HiFiGAN(h)
hifigan.load_state_dict(
torch.load("./hifigan.pt", map_location=lambda loc, storage: loc)["generator"]
)
_ = hifigan.eval()
hifigan.remove_weight_norm()
def inference(text, n_timesteps):
text = " ".join(word_tokenize(text))
x = torch.LongTensor(
intersperse(text_to_sequence(text, dictionary=cmu), len(symbols))
)[None]
x_lengths = torch.LongTensor([x.shape[-1]])
y_enc, y_dec, attn = generator.forward(
x,
x_lengths,
n_timesteps=n_timesteps,
temperature=1.5,
stoc=False,
spk=torch.LongTensor([0]) if len(SPEAKERS) > 1 else None,
length_scale=1.0,
)
with torch.no_grad():
audio = hifigan.forward(y_dec).cpu().squeeze().clamp(-1, 1).detach().numpy()
write("out.wav", 22050, audio)
return "./out.wav"
inputs = [
gr.inputs.Textbox(lines=5, label="Input Text"),
gr.inputs.Slider(minimum=0, maximum=100, step=10, label="Timesteps"),
]
outputs = gr.outputs.Audio(type="file", label="Output Audio")
title = "Bookbot Grad-TTS Weildan Demo 🐨"
utterances = [
"Selamat pagi! Selamat datang di Jakarta!",
"Kak, harga nasi gorengnya berapa ya?",
"Bapak bilang, Malik hebat. Bisa bersih bersih seperti Bapak."
"Here are the match lineups for the Colombia Haiti match.",
]
timesteps = [(i * 10) + 50 for i in range(len(utterances))]
examples = [list(l) for l in zip(utterances, timesteps)]
gr.Interface(inference, inputs, outputs, title=title, examples=examples).launch()