File size: 3,393 Bytes
ff3c3b5
72d1bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b47d35
72d1bae
0b47d35
72d1bae
 
 
 
 
 
 
 
 
 
 
 
 
 
ff3c3b5
 
72d1bae
 
 
 
 
 
 
 
 
 
 
 
ff3c3b5
72d1bae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import spaces
import gradio as gr
import torch

from transformers import VitsModel, VitsTokenizer, set_seed


title = """
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
    <div
        style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"
    > <h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
        VITS TTS Demo
    </h1> </div>
</div>
 """

description = """
VITS is an end-to-end speech synthesis model that predicts a speech waveform conditional on an input text sequence. It is a conditional variational autoencoder (VAE) comprised of a posterior encoder, decoder, and conditional prior.

This demo showcases the official VITS checkpoints, trained on the [LJSpeech](https://huggingface.co/kakao-enterprise/vits-ljs) and [VCTK](https://huggingface.co/kakao-enterprise/vits-vctk) datasets.
"""

article = "Model by Jaehyeon Kim et al. from Kakao Enterprise. Code and demo by 🤗 Hugging Face."

ljs_model = VitsModel.from_pretrained("kakao-enterprise/vits-ljs")
ljs_tokenizer = VitsTokenizer.from_pretrained("kakao-enterprise/vits-ljs")

vctk_model = VitsModel.from_pretrained("kakao-enterprise/vits-vctk")
vctk_tokenizer = VitsTokenizer.from_pretrained("kakao-enterprise/vits-vctk")

device = "cuda" if torch.cuda.is_available() else "cpu"
ljs_model.to(device)
vctk_model.to(device)


@spaces.GPU
def ljs_forward(text, speaking_rate=1.0):
    inputs = ljs_tokenizer(text, return_tensors="pt")

    ljs_model.speaking_rate = speaking_rate
    set_seed(555)
    with torch.no_grad():
        outputs = ljs_model(**inputs)[0]

    waveform = outputs[0].cpu().float().numpy()
    return gr.make_waveform((22050, waveform))


@spaces.GPU
def vctk_forward(text, speaking_rate=1.0, speaker_id=1):
    inputs = vctk_tokenizer(text, return_tensors="pt")

    vctk_model.speaking_rate = speaking_rate
    set_seed(555)
    with torch.no_grad():
        outputs = vctk_model(**inputs, speaker_id=speaker_id - 1)[0]

    waveform = outputs[0].cpu().float().numpy()
    return gr.make_waveform((22050, waveform))


ljs_inference = gr.Interface(
    fn=ljs_forward,
    inputs=[
        gr.Textbox(
            value="Hey, it's Hugging Face on the phone",
            max_lines=1,
            label="Input text",
        ),
        gr.Slider(
            0.5,
            1.5,
            value=1,
            step=0.1,
            label="Speaking rate",
        ),
    ],
    outputs=gr.Audio(),
)

vctk_inference = gr.Interface(
    fn=vctk_forward,
    inputs=[
        gr.Textbox(
            value="Hey, it's Hugging Face on the phone",
            max_lines=1,
            label="Input text",
        ),
        gr.Slider(
            0.5,
            1.5,
            value=1,
            step=0.1,
            label="Speaking rate",
        ),
        gr.Slider(
            1,
            vctk_model.config.num_speakers,
            value=1,
            step=1,
            label="Speaker id",
            info=f"The VCTK model is trained on {vctk_model.config.num_speakers} speakers. You can prompt the model using one of these speaker ids.",
        ),
    ],
    outputs=gr.Audio(),
)

demo = gr.Blocks()

with demo:
    gr.Markdown(title)
    gr.Markdown(description)
    gr.TabbedInterface([ljs_inference, vctk_inference], ["LJ Speech", "VCTK"])
    gr.Markdown(article)

demo.queue(max_size=10)
demo.launch()