lafi23333 commited on
Commit
222a360
1 Parent(s): 794efa0

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +485 -0
app.py ADDED
@@ -0,0 +1,485 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E402
2
+ import os
3
+ import logging
4
+ import re_matching
5
+ from tools.sentence import split_by_language
6
+
7
+ logging.getLogger("numba").setLevel(logging.WARNING)
8
+ logging.getLogger("markdown_it").setLevel(logging.WARNING)
9
+ logging.getLogger("urllib3").setLevel(logging.WARNING)
10
+ logging.getLogger("matplotlib").setLevel(logging.WARNING)
11
+
12
+ logging.basicConfig(
13
+ level=logging.INFO, format="| %(name)s | %(levelname)s | %(message)s"
14
+ )
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ import torch
19
+ import utils
20
+ from infer import infer, latest_version, get_net_g, infer_multilang
21
+ import gradio as gr
22
+ import webbrowser
23
+ import numpy as np
24
+ from config import config
25
+ from tools.translate import translate
26
+ import librosa
27
+
28
+ net_g = None
29
+
30
+ device = config.webui_config.device
31
+ if device == "mps":
32
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
33
+
34
+
35
+ def generate_audio(
36
+ slices,
37
+ sdp_ratio,
38
+ noise_scale,
39
+ noise_scale_w,
40
+ length_scale,
41
+ speaker,
42
+ language,
43
+ reference_audio,
44
+ emotion,
45
+ skip_start=False,
46
+ skip_end=False,
47
+ ):
48
+ audio_list = []
49
+ # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
50
+ with torch.no_grad():
51
+ for idx, piece in enumerate(slices):
52
+ skip_start = (idx != 0) and skip_start
53
+ skip_end = (idx != len(slices) - 1) and skip_end
54
+ audio = infer(
55
+ piece,
56
+ reference_audio=reference_audio,
57
+ emotion=emotion,
58
+ sdp_ratio=sdp_ratio,
59
+ noise_scale=noise_scale,
60
+ noise_scale_w=noise_scale_w,
61
+ length_scale=length_scale,
62
+ sid=speaker,
63
+ language=language,
64
+ hps=hps,
65
+ net_g=net_g,
66
+ device=device,
67
+ skip_start=skip_start,
68
+ skip_end=skip_end,
69
+ )
70
+ audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
71
+ audio_list.append(audio16bit)
72
+ # audio_list.append(silence) # 将静音添加到列表中
73
+ return audio_list
74
+
75
+
76
+ def generate_audio_multilang(
77
+ slices,
78
+ sdp_ratio,
79
+ noise_scale,
80
+ noise_scale_w,
81
+ length_scale,
82
+ speaker,
83
+ language,
84
+ reference_audio,
85
+ emotion,
86
+ skip_start=False,
87
+ skip_end=False,
88
+ ):
89
+ audio_list = []
90
+ # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
91
+ with torch.no_grad():
92
+ for idx, piece in enumerate(slices):
93
+ skip_start = (idx != 0) and skip_start
94
+ skip_end = (idx != len(slices) - 1) and skip_end
95
+ audio = infer_multilang(
96
+ piece,
97
+ reference_audio=reference_audio,
98
+ emotion=emotion,
99
+ sdp_ratio=sdp_ratio,
100
+ noise_scale=noise_scale,
101
+ noise_scale_w=noise_scale_w,
102
+ length_scale=length_scale,
103
+ sid=speaker,
104
+ language=language[idx],
105
+ hps=hps,
106
+ net_g=net_g,
107
+ device=device,
108
+ skip_start=skip_start,
109
+ skip_end=skip_end,
110
+ )
111
+ audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
112
+ audio_list.append(audio16bit)
113
+ # audio_list.append(silence) # 将静音添加到列表中
114
+ return audio_list
115
+
116
+
117
+ def tts_split(
118
+ text: str,
119
+ speaker,
120
+ sdp_ratio,
121
+ noise_scale,
122
+ noise_scale_w,
123
+ length_scale,
124
+ language,
125
+ cut_by_sent,
126
+ interval_between_para,
127
+ interval_between_sent,
128
+ reference_audio,
129
+ emotion,
130
+ ):
131
+ if language == "mix":
132
+ return ("invalid", None)
133
+ while text.find("\n\n") != -1:
134
+ text = text.replace("\n\n", "\n")
135
+ para_list = re_matching.cut_para(text)
136
+ audio_list = []
137
+ if not cut_by_sent:
138
+ for idx, p in enumerate(para_list):
139
+ skip_start = idx != 0
140
+ skip_end = idx != len(para_list) - 1
141
+ audio = infer(
142
+ p,
143
+ reference_audio=reference_audio,
144
+ emotion=emotion,
145
+ sdp_ratio=sdp_ratio,
146
+ noise_scale=noise_scale,
147
+ noise_scale_w=noise_scale_w,
148
+ length_scale=length_scale,
149
+ sid=speaker,
150
+ language=language,
151
+ hps=hps,
152
+ net_g=net_g,
153
+ device=device,
154
+ skip_start=skip_start,
155
+ skip_end=skip_end,
156
+ )
157
+ audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
158
+ audio_list.append(audio16bit)
159
+ silence = np.zeros((int)(44100 * interval_between_para), dtype=np.int16)
160
+ audio_list.append(silence)
161
+ else:
162
+ for idx, p in enumerate(para_list):
163
+ skip_start = idx != 0
164
+ skip_end = idx != len(para_list) - 1
165
+ audio_list_sent = []
166
+ sent_list = re_matching.cut_sent(p)
167
+ for idx, s in enumerate(sent_list):
168
+ skip_start = (idx != 0) and skip_start
169
+ skip_end = (idx != len(sent_list) - 1) and skip_end
170
+ audio = infer(
171
+ s,
172
+ reference_audio=reference_audio,
173
+ emotion=emotion,
174
+ sdp_ratio=sdp_ratio,
175
+ noise_scale=noise_scale,
176
+ noise_scale_w=noise_scale_w,
177
+ length_scale=length_scale,
178
+ sid=speaker,
179
+ language=language,
180
+ hps=hps,
181
+ net_g=net_g,
182
+ device=device,
183
+ skip_start=skip_start,
184
+ skip_end=skip_end,
185
+ )
186
+ audio_list_sent.append(audio)
187
+ silence = np.zeros((int)(44100 * interval_between_sent))
188
+ audio_list_sent.append(silence)
189
+ if (interval_between_para - interval_between_sent) > 0:
190
+ silence = np.zeros(
191
+ (int)(44100 * (interval_between_para - interval_between_sent))
192
+ )
193
+ audio_list_sent.append(silence)
194
+ audio16bit = gr.processing_utils.convert_to_16_bit_wav(
195
+ np.concatenate(audio_list_sent)
196
+ ) # 对完整句子做音量归一
197
+ audio_list.append(audio16bit)
198
+ audio_concat = np.concatenate(audio_list)
199
+ return ("Success", (44100, audio_concat))
200
+
201
+
202
+ def tts_fn(
203
+ text: str,
204
+ speaker,
205
+ sdp_ratio,
206
+ noise_scale,
207
+ noise_scale_w,
208
+ length_scale,
209
+ language,
210
+ reference_audio,
211
+ emotion,
212
+ ):
213
+ audio_list = []
214
+ if language == "mix":
215
+ bool_valid, str_valid = re_matching.validate_text(text)
216
+ if not bool_valid:
217
+ return str_valid, (
218
+ hps.data.sampling_rate,
219
+ np.concatenate([np.zeros(hps.data.sampling_rate // 2)]),
220
+ )
221
+ result = []
222
+ for slice in re_matching.text_matching(text):
223
+ _speaker = slice.pop()
224
+ temp_contant = []
225
+ temp_lang = []
226
+ for lang, content in slice:
227
+ if "|" in content:
228
+ temp = []
229
+ temp_ = []
230
+ for i in content.split("|"):
231
+ if i != "":
232
+ temp.append([i])
233
+ temp_.append([lang])
234
+ else:
235
+ temp.append([])
236
+ temp_.append([])
237
+ temp_contant += temp
238
+ temp_lang += temp_
239
+ else:
240
+ if len(temp_contant) == 0:
241
+ temp_contant.append([])
242
+ temp_lang.append([])
243
+ temp_contant[-1].append(content)
244
+ temp_lang[-1].append(lang)
245
+ for i, j in zip(temp_lang, temp_contant):
246
+ result.append([*zip(i, j), _speaker])
247
+ for i, one in enumerate(result):
248
+ skip_start = i != 0
249
+ skip_end = i != len(result) - 1
250
+ _speaker = one.pop()
251
+ idx = 0
252
+ while idx < len(one):
253
+ text_to_generate = []
254
+ lang_to_generate = []
255
+ while True:
256
+ lang, content = one[idx]
257
+ temp_text = [content]
258
+ if len(text_to_generate) > 0:
259
+ text_to_generate[-1] += [temp_text.pop(0)]
260
+ lang_to_generate[-1] += [lang]
261
+ if len(temp_text) > 0:
262
+ text_to_generate += [[i] for i in temp_text]
263
+ lang_to_generate += [[lang]] * len(temp_text)
264
+ if idx + 1 < len(one):
265
+ idx += 1
266
+ else:
267
+ break
268
+ skip_start = (idx != 0) and skip_start
269
+ skip_end = (idx != len(one) - 1) and skip_end
270
+ print(text_to_generate, lang_to_generate)
271
+ audio_list.extend(
272
+ generate_audio_multilang(
273
+ text_to_generate,
274
+ sdp_ratio,
275
+ noise_scale,
276
+ noise_scale_w,
277
+ length_scale,
278
+ _speaker,
279
+ lang_to_generate,
280
+ reference_audio,
281
+ emotion,
282
+ skip_start,
283
+ skip_end,
284
+ )
285
+ )
286
+ idx += 1
287
+ elif language.lower() == "auto":
288
+ for idx, slice in enumerate(text.split("|")):
289
+ if slice == "":
290
+ continue
291
+ skip_start = idx != 0
292
+ skip_end = idx != len(text.split("|")) - 1
293
+ sentences_list = split_by_language(
294
+ slice, target_languages=["zh", "ja", "en"]
295
+ )
296
+ idx = 0
297
+ while idx < len(sentences_list):
298
+ text_to_generate = []
299
+ lang_to_generate = []
300
+ while True:
301
+ content, lang = sentences_list[idx]
302
+ temp_text = [content]
303
+ lang = lang.upper()
304
+ if lang == "JA":
305
+ lang = "JP"
306
+ if len(text_to_generate) > 0:
307
+ text_to_generate[-1] += [temp_text.pop(0)]
308
+ lang_to_generate[-1] += [lang]
309
+ if len(temp_text) > 0:
310
+ text_to_generate += [[i] for i in temp_text]
311
+ lang_to_generate += [[lang]] * len(temp_text)
312
+ if idx + 1 < len(sentences_list):
313
+ idx += 1
314
+ else:
315
+ break
316
+ skip_start = (idx != 0) and skip_start
317
+ skip_end = (idx != len(sentences_list) - 1) and skip_end
318
+ print(text_to_generate, lang_to_generate)
319
+ audio_list.extend(
320
+ generate_audio_multilang(
321
+ text_to_generate,
322
+ sdp_ratio,
323
+ noise_scale,
324
+ noise_scale_w,
325
+ length_scale,
326
+ speaker,
327
+ lang_to_generate,
328
+ reference_audio,
329
+ emotion,
330
+ skip_start,
331
+ skip_end,
332
+ )
333
+ )
334
+ idx += 1
335
+ else:
336
+ audio_list.extend(
337
+ generate_audio(
338
+ text.split("|"),
339
+ sdp_ratio,
340
+ noise_scale,
341
+ noise_scale_w,
342
+ length_scale,
343
+ speaker,
344
+ language,
345
+ reference_audio,
346
+ emotion,
347
+ )
348
+ )
349
+
350
+ audio_concat = np.concatenate(audio_list)
351
+ return "Success", (hps.data.sampling_rate, audio_concat)
352
+
353
+
354
+ if __name__ == "__main__":
355
+ if config.webui_config.debug:
356
+ logger.info("Enable DEBUG-LEVEL log")
357
+ logging.basicConfig(level=logging.DEBUG)
358
+ hps = utils.get_hparams_from_file(config.webui_config.config_path)
359
+ # 若config.json中未指定版本则默认为最新版本
360
+ version = hps.version if hasattr(hps, "version") else latest_version
361
+ net_g = get_net_g(
362
+ model_path=config.webui_config.model, version=version, device=device, hps=hps
363
+ )
364
+ speaker_ids = hps.data.spk2id
365
+ speakers = list(speaker_ids.keys())
366
+ languages = ["ZH", "JP", "EN", "mix", "auto"]
367
+ with gr.Blocks() as app:
368
+ with gr.Row():
369
+ with gr.Column():
370
+ text = gr.TextArea(
371
+ label="输入文本内容",
372
+ placeholder="""
373
+ 如果你选择语言为\'mix\',必须按照格式输入,否则报错:
374
+ 格式举例(zh是中文,jp是日语,不区分大小写;说话人举例:gongzi):
375
+ [说话人1]<zh>你好,こんにちは! <jp>こんにちは,世界。
376
+ [说话人2]<zh>你好吗?<jp>元気ですか?
377
+ [说话人3]<zh>谢谢。<jp>どういたしまして。
378
+ ...
379
+ 另外,所有的语言选项都可以用'|'分割长段实现分句生成。
380
+ """,
381
+ )
382
+ trans = gr.Button("中翻日", variant="primary")
383
+ slicer = gr.Button("快速切分", variant="primary")
384
+ speaker = gr.Dropdown(
385
+ choices=speakers, value=speakers[0], label="Speaker"
386
+ )
387
+ emotion = gr.Slider(
388
+ minimum=0, maximum=9, value=0, step=1, label="Emotion"
389
+ )
390
+ sdp_ratio = gr.Slider(
391
+ minimum=0, maximum=1, value=0.2, step=0.1, label="SDP Ratio"
392
+ )
393
+ noise_scale = gr.Slider(
394
+ minimum=0.1, maximum=2, value=0.6, step=0.1, label="Noise"
395
+ )
396
+ noise_scale_w = gr.Slider(
397
+ minimum=0.1, maximum=2, value=0.8, step=0.1, label="Noise_W"
398
+ )
399
+ length_scale = gr.Slider(
400
+ minimum=0.1, maximum=2, value=1.0, step=0.1, label="Length"
401
+ )
402
+ language = gr.Dropdown(
403
+ choices=languages, value=languages[0], label="Language"
404
+ )
405
+ btn = gr.Button("生成音频!", variant="primary")
406
+ with gr.Column():
407
+ with gr.Row():
408
+ with gr.Column():
409
+ interval_between_sent = gr.Slider(
410
+ minimum=0,
411
+ maximum=5,
412
+ value=0.2,
413
+ step=0.1,
414
+ label="句间停顿(秒),勾选按句切分才生效",
415
+ )
416
+ interval_between_para = gr.Slider(
417
+ minimum=0,
418
+ maximum=10,
419
+ value=1,
420
+ step=0.1,
421
+ label="段间停顿(秒),需要大于句间停顿才有效",
422
+ )
423
+ opt_cut_by_sent = gr.Checkbox(
424
+ label="按句切分 在按段落切分的基础上再按句子切分文本"
425
+ )
426
+ slicer = gr.Button("切分生成", variant="primary")
427
+ text_output = gr.Textbox(label="状态信息")
428
+ audio_output = gr.Audio(label="输出音频")
429
+ # explain_image = gr.Image(
430
+ # label="参数解释信息",
431
+ # show_label=True,
432
+ # show_share_button=False,
433
+ # show_download_button=False,
434
+ # value=os.path.abspath("./img/参数说明.png"),
435
+ # )
436
+ reference_text = gr.Markdown(value="## 情感参考音频(WAV 格式):用于生成语音的情感参考。")
437
+ reference_audio = gr.Audio(label="情感参考音频(WAV 格式)", type="filepath")
438
+ btn.click(
439
+ tts_fn,
440
+ inputs=[
441
+ text,
442
+ speaker,
443
+ sdp_ratio,
444
+ noise_scale,
445
+ noise_scale_w,
446
+ length_scale,
447
+ language,
448
+ reference_audio,
449
+ emotion,
450
+ ],
451
+ outputs=[text_output, audio_output],
452
+ )
453
+
454
+ trans.click(
455
+ translate,
456
+ inputs=[text],
457
+ outputs=[text],
458
+ )
459
+ slicer.click(
460
+ tts_split,
461
+ inputs=[
462
+ text,
463
+ speaker,
464
+ sdp_ratio,
465
+ noise_scale,
466
+ noise_scale_w,
467
+ length_scale,
468
+ language,
469
+ opt_cut_by_sent,
470
+ interval_between_para,
471
+ interval_between_sent,
472
+ reference_audio,
473
+ emotion,
474
+ ],
475
+ outputs=[text_output, audio_output],
476
+ )
477
+
478
+ reference_audio.upload(
479
+ lambda x: librosa.load(x, 16000)[::-1],
480
+ inputs=[reference_audio],
481
+ outputs=[reference_audio],
482
+ )
483
+ print("推理页面已开启!")
484
+ webbrowser.open(f"http://127.0.0.1:{config.webui_config.port}")
485
+ app.launch(share=config.webui_config.share, server_port=config.webui_config.port)