vits / app.py
sanchit-gandhi's picture
Update app.py
ff3c3b5 verified
raw
history blame contribute delete
No virus
3.39 kB
import spaces
import gradio as gr
import torch
from transformers import VitsModel, VitsTokenizer, set_seed
title = """
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<div
style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
> <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
VITS TTS Demo
</h1> </div>
</div>
"""
description = """
VITS is an end-to-end speech synthesis model that predicts a speech waveform conditional on an input text sequence. It is a conditional variational autoencoder (VAE) comprised of a posterior encoder, decoder, and conditional prior.
This demo showcases the official VITS checkpoints, trained on the [LJSpeech](https://huggingface.co/kakao-enterprise/vits-ljs) and [VCTK](https://huggingface.co/kakao-enterprise/vits-vctk) datasets.
"""
article = "Model by Jaehyeon Kim et al. from Kakao Enterprise. Code and demo by 🤗 Hugging Face."
ljs_model = VitsModel.from_pretrained("kakao-enterprise/vits-ljs")
ljs_tokenizer = VitsTokenizer.from_pretrained("kakao-enterprise/vits-ljs")
vctk_model = VitsModel.from_pretrained("kakao-enterprise/vits-vctk")
vctk_tokenizer = VitsTokenizer.from_pretrained("kakao-enterprise/vits-vctk")
device = "cuda" if torch.cuda.is_available() else "cpu"
ljs_model.to(device)
vctk_model.to(device)
@spaces.GPU
def ljs_forward(text, speaking_rate=1.0):
inputs = ljs_tokenizer(text, return_tensors="pt")
ljs_model.speaking_rate = speaking_rate
set_seed(555)
with torch.no_grad():
outputs = ljs_model(**inputs)[0]
waveform = outputs[0].cpu().float().numpy()
return gr.make_waveform((22050, waveform))
@spaces.GPU
def vctk_forward(text, speaking_rate=1.0, speaker_id=1):
inputs = vctk_tokenizer(text, return_tensors="pt")
vctk_model.speaking_rate = speaking_rate
set_seed(555)
with torch.no_grad():
outputs = vctk_model(**inputs, speaker_id=speaker_id - 1)[0]
waveform = outputs[0].cpu().float().numpy()
return gr.make_waveform((22050, waveform))
ljs_inference = gr.Interface(
fn=ljs_forward,
inputs=[
gr.Textbox(
value="Hey, it's Hugging Face on the phone",
max_lines=1,
label="Input text",
),
gr.Slider(
0.5,
1.5,
value=1,
step=0.1,
label="Speaking rate",
),
],
outputs=gr.Audio(),
)
vctk_inference = gr.Interface(
fn=vctk_forward,
inputs=[
gr.Textbox(
value="Hey, it's Hugging Face on the phone",
max_lines=1,
label="Input text",
),
gr.Slider(
0.5,
1.5,
value=1,
step=0.1,
label="Speaking rate",
),
gr.Slider(
1,
vctk_model.config.num_speakers,
value=1,
step=1,
label="Speaker id",
info=f"The VCTK model is trained on {vctk_model.config.num_speakers} speakers. You can prompt the model using one of these speaker ids.",
),
],
outputs=gr.Audio(),
)
demo = gr.Blocks()
with demo:
gr.Markdown(title)
gr.Markdown(description)
gr.TabbedInterface([ljs_inference, vctk_inference], ["LJ Speech", "VCTK"])
gr.Markdown(article)
demo.queue(max_size=10)
demo.launch()