File size: 3,408 Bytes
53f8a32
 
 
 
 
 
5b5d4c0
 
53f8a32
 
 
 
 
9846d74
53f8a32
 
 
98923fe
9846d74
 
 
5339f1e
9846d74
 
 
53f8a32
 
 
 
 
 
 
 
 
 
 
80ee0e5
53f8a32
 
 
0f11bd1
 
53f8a32
 
fbe7d93
53f8a32
 
 
fbe7d93
babf22d
37e87fa
53f8a32
 
 
 
fbe7d93
53f8a32
 
 
 
 
 
 
 
ef55711
53f8a32
 
 
2d5fa2d
53f8a32
2d5fa2d
53f8a32
 
 
 
fbe7d93
53f8a32
fbe7d93
53f8a32
 
 
 
 
 
 
 
 
 
 
 
 
6ebe60a
 
53f8a32
e0cceab
53f8a32
 
 
4263bcd
53f8a32
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""
Copyright 2022 Balacoon

TTS interactive demo
"""

import os
import glob
import logging
from typing import cast

import gradio as gr
from balacoon_tts import TTS
from huggingface_hub import hf_hub_download, list_repo_files

# global tts module, initialized from a model selected
tts = None
model_repo_dir = "data"
for name in list_repo_files(repo_id="balacoon/tts"):
    hf_hub_download(
        repo_id="balacoon/tts",
        filename=name,
        local_dir=model_repo_dir,
    )


def main():
    logging.basicConfig(level=logging.INFO)

    with gr.Blocks() as demo:
        gr.Markdown(
            """
            <h1 align="center">Balacoon🦝 Text-to-Speech</h1>

            1. Write an utterance to generate,
            2. Select the model to synthesize with
            3. Select speaker
            4. Hit "Generate" and listen to the result!

            You can learn more about models available
            [here](https://huggingface.co/balacoon/tts).
            Visit [Balacoon website](https://balacoon.com/) for more info.
            """
        )
        with gr.Row(variant="panel"):
            text = gr.Textbox(label="Text", placeholder="Type something here...")

        with gr.Row():
            with gr.Column(variant="panel"):
                repo_files = os.listdir(model_repo_dir)
                model_files = [x for x in repo_files if x.endswith("_cpu.addon")]
                model_name = gr.Dropdown(
                    label="Model",
                    choices=model_files,
                )
            with gr.Column(variant="panel"):
                speaker = gr.Dropdown(label="Speaker", choices=[])

            def set_model(model_name_str: str):
                """
                gets value from `model_name`, loads model,
                re-initializes tts object, gets list of
                speakers that model supports and set them to `speaker`
                """
                model_path = os.path.join(model_repo_dir, model_name_str)
                global tts
                tts = TTS(model_path)
                speakers = tts.get_speakers()
                value = speakers[-1]
                return gr.Dropdown.update(
                    choices=speakers, value=value, visible=True
                )

            model_name.change(set_model, inputs=model_name, outputs=speaker)

        with gr.Row(variant="panel"):
            generate = gr.Button("Generate")
        with gr.Row(variant="panel"):
            audio = gr.Audio()

        def synthesize_audio(text_str: str, speaker_str: str = ""):
            """
            gets utterance to synthesize from `text` Textbox
            and speaker name from `speaker` dropdown list.
            speaker name might be empty for single-speaker models.
            Synthesizes the waveform and updates `audio` with it.
            """
            if not text_str:
                logging.info("text or speaker are not provided")
                return None
            global tts
            if len(text_str) > 1024:
                text_str = text_str[:1024]
            samples = cast(TTS, tts).synthesize(text_str, speaker_str)
            return gr.Audio.update(value=(cast(TTS, tts).get_sampling_rate(), samples))

        generate.click(synthesize_audio, inputs=[text, speaker], outputs=audio)

    demo.queue(concurrency_count=1).launch()


if __name__ == "__main__":
    main()