File size: 5,509 Bytes
b2458f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f584511
b2458f3
f584511
 
 
b2458f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f584511
 
 
b2458f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import logging
import re

import gradio as gr
import numpy
import torch

import utils
from infer import infer, get_net_g

logging.getLogger("numba").setLevel(logging.WARNING)
logging.getLogger("markdown_it").setLevel(logging.WARNING)
logging.getLogger("urllib3").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)

logging.basicConfig(level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s")

logger = logging.getLogger(__name__)

net_g = None
hps = None

device = "cuda" if torch.cuda.is_available() else "cpu"
model_path = "models/G_1000.pth"
sampling_rate = 22050


def split_sentence(sentence: str):
    if len(sentence) == 0:
        return []

    result = []

    is_english = [i.isascii() for i in sentence]
    is_chinese = [not re.match(r"[a-zA-Z]", i) for i in sentence]

    assert len(is_english) == len(is_chinese) == len(sentence), "bad length"
    assert is_english[0] or is_chinese[0], "bad first char: " + sentence[0]

    current_language = ''
    current_chain = []
    for idx in range(len(sentence)):
        if not is_english[idx]:
            current_language = 'ZH'
            current_chain = is_chinese
            break
        if not is_chinese[idx]:
            current_language = 'EN'
            current_chain = is_english
            break
        pass

    step = 0
    while step < len(sentence):
        try:
            next_step = current_chain.index(False, step)
        except ValueError:
            next_step = len(sentence)
        result.append((sentence[step:next_step], current_language))
        step = next_step
        current_language = 'ZH' if current_language == 'EN' else 'EN'
        current_chain = is_chinese if current_language == 'ZH' else is_english
        pass

    return result


def tts_fn(
        text: str,
        speaker,
        sdp_ratio,
        noise_scale,
        noise_scale_w,
        length_scale,
        language,
):
    language = 'ZH' if language == '普通话' else 'SH'
    sentences = split_sentence(text)

    silence = numpy.zeros(sampling_rate // 2, dtype=numpy.int16)
    audio_data = numpy.array([], dtype=numpy.float32)
    for (sentence, sentence_language) in sentences:
        sub_audio_data = infer(
            sentence,
            sdp_ratio,
            noise_scale,
            noise_scale_w,
            length_scale,
            sid=speaker,
            language=language if sentence_language == "ZH" else sentence_language,
            hps=hps,
            net_g=net_g,
            device=device)
        audio_data = numpy.concatenate((audio_data, sub_audio_data, silence))

    audio_data = audio_data / numpy.abs(audio_data).max()
    audio_data = audio_data * 32767
    audio_data = audio_data.astype(numpy.int16)

    return "Success", (sampling_rate, audio_data)


def main():
    logging.basicConfig(level=logging.DEBUG)

    global hps
    hps = utils.get_hparams_from_file("configs/config.json")

    global net_g
    net_g = get_net_g(model_path=model_path, device=device, hps=hps)

    speaker_ids = hps.data.spk2id
    speakers = list(speaker_ids.keys())
    languages = ["上海话", "普通话"]
    with gr.Blocks() as app:
        with gr.Row():
            with gr.Column():
                gr.Markdown('# Bert-VITS2-Shanghainese')
        with gr.Row():
            with gr.Column():
                text = gr.TextArea(
                    label="输入文本内容",
                    value="\n".join([
                        "站一个制高点看上海,",
                        "Looking at Shanghai from a commanding height,",
                        "上海的弄堂是壮观的景象。",
                        "The alleys in Shanghai are a great sight.",
                        "它是这城市背景一样的东西。",
                        "It is something with the same background as this city."
                    ]),
                )
                sdp_ratio = gr.Slider(minimum=0, maximum=1, value=0.2, step=0.1, label="SDP/DP混合比")
                noise_scale = gr.Slider(minimum=0.1, maximum=2, value=0.6, step=0.1, label="感情")
                noise_scale_w = gr.Slider(minimum=0.1, maximum=2, value=0.8, step=0.1, label="音素长度")
                length_scale = gr.Slider(minimum=0.1, maximum=2, value=1.0, step=0.1, label="语速")
            with gr.Column():
                with gr.Row():
                    with gr.Column():
                        speaker = gr.Dropdown(choices=speakers, value=speakers[0], label="选择说话人")
                    with gr.Column():
                        language = gr.Dropdown(choices=languages, value=languages[0], label="选择语言")
                submit_btn = gr.Button("生成音频", variant="primary")
                text_output = gr.Textbox(label="状态")
                audio_output = gr.Audio(label="音频")
        with gr.Row():
            with gr.Column():
                gr.Markdown('### Thanks <a href="https://github.com/fishaudio/Bert-VITS2" target="_blank">github@Bert-VITS2</a>')
        submit_btn.click(
            tts_fn,
            inputs=[
                text,
                speaker,
                sdp_ratio,
                noise_scale,
                noise_scale_w,
                length_scale,
                language,
            ],
            outputs=[text_output, audio_output],
        )

    app.launch(share=False, server_name="0.0.0.0", server_port=7860)


if __name__ == "__main__":
    main()