Tacotron2 / app.py
akhaliq's picture
akhaliq HF staff
Create app.py
cc372f2
raw
history blame
1.4 kB
import torch
import torchaudio
import gradio as gr
device="cpu"
bundle = torchaudio.pipelines.TACOTRON2_WAVERNN_PHONE_LJSPEECH
processor = bundle.get_text_processor()
tacotron2 = bundle.get_tacotron2().to(device)
# Workaround to load model mapped on GPU
# https://stackoverflow.com/a/61840832
waveglow = torch.hub.load(
"NVIDIA/DeepLearningExamples:torchhub",
"nvidia_waveglow",
model_math="fp32",
pretrained=False,
)
checkpoint = torch.hub.load_state_dict_from_url(
"https://api.ngc.nvidia.com/v2/models/nvidia/waveglowpyt_fp32/versions/1/files/nvidia_waveglowpyt_fp32_20190306.pth", # noqa: E501
progress=False,
map_location=device,
)
state_dict = {key.replace("module.", ""): value for key, value in checkpoint["state_dict"].items()}
waveglow.load_state_dict(state_dict)
waveglow = waveglow.remove_weightnorm(waveglow)
waveglow = waveglow.to(device)
waveglow.eval()
def inference(text):
with torch.inference_mode():
processed, lengths = processor(text)
processed = processed.to(device)
lengths = lengths.to(device)
spec, _, _ = tacotron2.infer(processed, lengths)
with torch.no_grad():
waveforms = waveglow.infer(spec)
torchaudio.save("_assets/output_waveglow.wav", waveforms[0:1].cpu(), sample_rate=22050)
return "output_waveglow.wav"
gr.Interface(inference,"text",gr.outputs.Audio(type="file")).launch(debug=True)