GroTTS / app.py
wietsedv's picture
initial version
38a727d
raw history blame
No virus
3.05 kB
import gradio as gr
import time
import urllib.request
from pathlib import Path
import os
import torch
import scipy.io.wavfile
from espnet2.bin.tts_inference import Text2Speech
from espnet2.utils.types import str_or_none
def load_model(model_tag, vocoder_tag):
from espnet_model_zoo.downloader import ModelDownloader
kwargs = {}
# Model
d = ModelDownloader()
kwargs = d.download_and_unpack(model_tag)
# Vocoder
download_dir = Path(os.path.expanduser("~/.cache/parallel_wavegan"))
vocoder_dir = download_dir / vocoder_tag
os.makedirs(vocoder_dir, exist_ok=True)
kwargs["vocoder_config"] = vocoder_dir / "config.yml"
if not kwargs["vocoder_config"].exists():
urllib.request.urlretrieve(f"https://huggingface.co/{vocoder_tag}/resolve/main/config.yml", kwargs["vocoder_config"])
kwargs["vocoder_file"] = vocoder_dir / "checkpoint-50000steps.pkl"
if not kwargs["vocoder_file"].exists():
urllib.request.urlretrieve(f"https://huggingface.co/{vocoder_tag}/resolve/main/checkpoint-50000steps.pkl", kwargs["vocoder_file"])
return Text2Speech(
**kwargs,
device="cpu",
threshold=0.5,
minlenratio=0.0,
maxlenratio=10.0,
use_att_constraint=True,
backward_window=1,
forward_window=4,
)
gos_text2speech = load_model('https://huggingface.co/wietsedv/tacotron2-gronings/resolve/main/tts_ljspeech_finetune_tacotron2.v5_train.loss.ave.zip', 'wietsedv/parallelwavegan-gronings')
nld_text2speech = load_model('https://huggingface.co/wietsedv/tacotron2-dutch/resolve/main/tts_ljspeech_finetune_tacotron2.v5_train.loss.ave.zip', 'wietsedv/parallelwavegan-dutch')
eng_text2speech = Text2Speech.from_pretrained(
model_tag="kan-bayashi/ljspeech_tacotron2",
vocoder_tag="parallel_wavegan/ljspeech_parallel_wavegan.v3",
device="cpu",
threshold=0.5,
minlenratio=0.0,
maxlenratio=10.0,
use_att_constraint=True,
backward_window=1,
forward_window=4,
)
def inference(text,lang):
with torch.no_grad():
if lang == "gronings":
wav = gos_text2speech(text)["wav"]
scipy.io.wavfile.write("out.wav", gos_text2speech.fs , wav.view(-1).cpu().numpy())
if lang == "dutch":
wav = nld_text2speech(text)["wav"]
scipy.io.wavfile.write("out.wav", nld_text2speech.fs , wav.view(-1).cpu().numpy())
if lang == "english":
wav = eng_text2speech(text)["wav"]
scipy.io.wavfile.write("out.wav", eng_text2speech.fs , wav.view(-1).cpu().numpy())
return "out.wav", "out.wav"
title = "GroTTS"
examples = [
['Ze gingen mit klas noar Waddendiek. Over en deur bragel lopen.', 'gronings']
]
gr.Interface(
inference,
[gr.inputs.Textbox(label="input text", lines=3), gr.inputs.Radio(choices=["gronings", "dutch", "english"], type="value", default="gronings", label="language")],
[gr.outputs.Audio(type="file", label="Output"), gr.outputs.File()],
title=title,
examples=examples
).launch(enable_queue=True, debug=True)