jvde commited on
Commit
4973571
1 Parent(s): efd1204

add application file

Browse files
Files changed (1) hide show
  1. app.py +371 -0
app.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import IPython.display as ipd
3
+ import os
4
+ import json
5
+ import math
6
+ import torch
7
+ import commons
8
+ import utils
9
+ from models import SynthesizerTrn
10
+ from text.symbols import symbols
11
+ from text import text_to_sequence
12
+ from scipy.io.wavfile import write
13
+ import gradio as gr
14
+ import numpy as np
15
+ from PIL import Image
16
+ import numpy as np
17
+ import os
18
+ from pathlib import Path
19
+
20
+
21
+
22
+
23
+ LANGUAGES = ['EN','CN','JP']
24
+ SPEAKER_ID = 0
25
+ COVER = "models/Yuuka/cover.png"
26
+ speaker_choice = "Yuuka"
27
+ MODEL_ZH_NAME = "早濑优香"
28
+ EXAMPLE_TEXT = "先生。今日も全力であなたをアシストしますね。"
29
+ #USER_INPUT_TEXT = ""
30
+
31
+ CONFIG_PATH = "configs/config2.json"
32
+ MODEL_PATH = "models/parappa/path.pth"
33
+
34
+ hps = utils.get_hparams_from_file(CONFIG_PATH)
35
+ net_g = SynthesizerTrn(
36
+ len(hps.symbols),
37
+ hps.data.filter_length // 2 + 1,
38
+ hps.train.segment_size // hps.data.hop_length,
39
+ n_speakers=hps.data.n_speakers,
40
+ **hps.model)
41
+
42
+ model = net_g.eval()
43
+ model = utils.load_checkpoint(MODEL_PATH, net_g, None)
44
+
45
+ def load_model():
46
+ global hps,net_g,model
47
+
48
+ hps = utils.get_hparams_from_file(CONFIG_PATH)
49
+ net_g = SynthesizerTrn(
50
+ len(hps.symbols),
51
+ hps.data.filter_length // 2 + 1,
52
+ hps.train.segment_size // hps.data.hop_length,
53
+ n_speakers=hps.data.n_speakers,
54
+ **hps.model)
55
+
56
+ model = net_g.eval()
57
+ model = utils.load_checkpoint(MODEL_PATH, net_g, None)
58
+
59
+ def get_text(text, hps):
60
+ text_norm = text_to_sequence(text, hps.data.text_cleaners)
61
+ if hps.data.add_blank:
62
+ text_norm = commons.intersperse(text_norm, 0)
63
+ text_norm = torch.LongTensor(text_norm)
64
+ return text_norm
65
+
66
+ def tts_fn(text, noise_scale, noise_scale_w, length_scale):
67
+ stn_tst = get_text(text, hps)
68
+ with torch.no_grad():
69
+ x_tst = stn_tst.unsqueeze(0)
70
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)])
71
+ sid = torch.LongTensor([SPEAKER_ID])
72
+ audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=noise_scale, noise_scale_w=noise_scale_w, length_scale=length_scale)[0][0,0].data.cpu().float().numpy()
73
+ return (22050, audio)
74
+
75
+ def add_model_fn(example_text, cover, speakerID, name_en, name_cn, language):
76
+
77
+ # 检查必填字段是否为空
78
+ if not speakerID or not name_en or not language:
79
+ raise gr.Error("Please fill in all required fields!")
80
+ return "Failed to add model"
81
+
82
+ ### 保存上传的文件
83
+
84
+ # 生成文件路径
85
+ model_save_dir = Path("models")
86
+ model_save_dir = model_save_dir / name_en
87
+ img_save_dir = model_save_dir
88
+ model_save_dir.mkdir(parents=True, exist_ok=True)
89
+
90
+ Model_name = name_en + ".pth"
91
+ model_save_dir = model_save_dir / Model_name
92
+
93
+ # 保存上传的图片
94
+ if cover is not None:
95
+ img = np.array(cover)
96
+ img = Image.fromarray(img)
97
+ img.save(os.path.join(img_save_dir, 'cover.png'))
98
+
99
+ #获取用户输入
100
+ new_model = {
101
+ "name_en": name_en,
102
+ "name_zh": name_cn,
103
+ "cover": img_save_dir / "cover.png",
104
+ "sid": speakerID,
105
+ "example": example_text,
106
+ "language": language,
107
+ "type": "single",
108
+ "model_path": model_save_dir
109
+ }
110
+
111
+ #写入json
112
+ with open("models/model_info.json", "r", encoding="utf-8") as f:
113
+ models_info = json.load(f)
114
+
115
+ models_info[name_en] = new_model
116
+ with open("models/model_info.json", "w") as f:
117
+ json.dump(models_info, f, cls=CustomEncoder)
118
+
119
+
120
+ return "Success"
121
+
122
+ def clear_input_text():
123
+ return ""
124
+
125
+ def clear_add_model_info():
126
+ return "",None,"","","",""
127
+
128
+ def get_options():
129
+ with open("models/model_info.json", "r", encoding="utf-8") as f:
130
+ global models_info
131
+ models_info = json.load(f)
132
+
133
+ for i,model_info in models_info.items():
134
+ global name_en
135
+ name_en = model_info['name_en']
136
+
137
+ def reset_options():
138
+ value_model_choice = models_info['Yuuka']['name_en']
139
+ value_speaker_id = models_info['Yuuka']['sid']
140
+ return value_model_choice,value_speaker_id
141
+
142
+ def refresh_options():
143
+ get_options()
144
+ value_model_choice = models_info[speaker_choice]['name_en']
145
+ value_speaker_id = models_info[speaker_choice]['sid']
146
+ return value_model_choice,value_speaker_id
147
+
148
+ def change_dropdown(choice):
149
+ global speaker_choice
150
+ speaker_choice = choice
151
+ global COVER
152
+ COVER = str(models_info[speaker_choice]['cover'])
153
+ global MODEL_PATH
154
+ MODEL_PATH = str(models_info[speaker_choice]['model_path'])
155
+ global MODEL_ZH_NAME
156
+ MODEL_ZH_NAME = str(models_info[speaker_choice]['name_zh'])
157
+ global EXAMPLE_TEXT
158
+ EXAMPLE_TEXT = str(models_info[speaker_choice]['example'])
159
+
160
+ speaker_id_change = gr.update(value=str(models_info[speaker_choice]['sid']))
161
+ cover_change = gr.update(value='<div align="center">'
162
+ f'<img style="width:auto;height:512px;" src="file/{COVER}">' if COVER else ""
163
+ f'<a><strong>{speaker_choice}</strong></a>'
164
+ '</div>')
165
+ title_change = gr.update(value=
166
+ '<div align="center">'
167
+ f'<h3><a><strong>{"语音名称: "}{MODEL_ZH_NAME}</strong></a>'
168
+ f'<h3><strong>{"checkpoint: "}{speaker_choice}</strong>'
169
+ '</div>')
170
+
171
+
172
+ lan_change = gr.update(value=str(models_info[speaker_choice]['language']))
173
+
174
+ example_change = gr.update(value=EXAMPLE_TEXT)
175
+
176
+ load_model()
177
+
178
+ return [speaker_id_change,cover_change,title_change,lan_change,example_change]
179
+
180
+ class CustomEncoder(json.JSONEncoder):
181
+ def default(self, obj):
182
+ if isinstance(obj, Path):
183
+ return str(obj)
184
+ return super().default(obj)
185
+
186
+ download_audio_js = """
187
+ () =>{{
188
+ let root = document.querySelector("body > gradio-app");
189
+ if (root.shadowRoot != null)
190
+ root = root.shadowRoot;
191
+ let audio = root.querySelector("#tts-audio-{audio_id}").querySelector("audio");
192
+ let text = root.querySelector("#input-text-{audio_id}").querySelector("textarea");
193
+ if (audio == undefined)
194
+ return;
195
+ text = text.value;
196
+ if (text == undefined)
197
+ text = Math.floor(Math.random()*100000000);
198
+ audio = audio.src;
199
+ let oA = document.createElement("a");
200
+ oA.download = text.substr(0, 20)+'.wav';
201
+ oA.href = audio;
202
+ document.body.appendChild(oA);
203
+ oA.click();
204
+ oA.remove();
205
+ }}
206
+ """
207
+
208
+
209
+
210
+
211
+
212
+
213
+
214
+
215
+ if __name__ == '__main__':
216
+
217
+ with open("models/model_info.json", "r", encoding="utf-8") as f:
218
+ models_info = json.load(f)
219
+
220
+ for i, model_info in models_info.items():
221
+ name_en = model_info['name_en']
222
+
223
+
224
+ theme = gr.themes.Base()
225
+
226
+ with gr.Blocks(theme=theme) as interface:
227
+ with gr.Tab("Text to Speech"):
228
+ with gr.Column():
229
+ cover_markdown = gr.Markdown(
230
+ '<div align="center">'
231
+ f'<img style="width:auto;height:512px;" src="file/{COVER}">' if COVER else ""
232
+ '</div>')
233
+ title_markdown = gr.Markdown(
234
+ '<div align="center">'
235
+ f'<h3><a><strong>{"语音名称: "}{MODEL_ZH_NAME}</strong></a>'
236
+ f'<h3><strong>{"checkpoint: "}{speaker_choice}</strong>'
237
+ '</div>')
238
+
239
+ with gr.Row():
240
+ with gr.Column(scale=4):
241
+ input_text = gr.Textbox(
242
+ label="Input",
243
+ lines=2,
244
+ placeholder="Enter the text you want to process here",
245
+ elem_id=f"input-text-en-{name_en.replace(' ', '')}",
246
+ scale=2
247
+ )
248
+ with gr.Column(scale=1):
249
+ gen_button = gr.Button("Generate", variant="primary")
250
+ clear_input_button = gr.Button("Clear")
251
+
252
+ with gr.Row():
253
+ with gr.Column(scale=2):
254
+ lan = gr.Radio(label="Language", choices=LANGUAGES, value="JP")
255
+ noise_scale = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="Noise Scale (情感变化程度)",
256
+ value=0.6)
257
+ noise_scale_w = gr.Slider(minimum=0.1, maximum=1.0, step=0.1, label="Noise Scale w (发音长度)",
258
+ value=0.668)
259
+ length_scale = gr.Slider(minimum=0.1, maximum=2.0, step=0.1, label="Length Scale (语速)",
260
+ value=1.0)
261
+
262
+ with gr.Column(scale=1):
263
+ example_text_box = gr.Textbox(label="Example:",
264
+ value=EXAMPLE_TEXT)
265
+
266
+ output_audio = gr.Audio(label="Output", elem_id=f"tts-audio-en-{name_en.replace(' ', '')}")
267
+ download_button = gr.Button("Download")
268
+
269
+ # example = gr.Examples(
270
+ # examples = [EXAMPLE_TEXT],
271
+ # inputs=input_text,
272
+ # outputs = output_audio,
273
+ # fn=example_tts_fn,
274
+ # cache_examples=True
275
+ # )
276
+
277
+ gen_button.click(
278
+ tts_fn,
279
+ inputs=[input_text, noise_scale, noise_scale_w, length_scale],
280
+ outputs=output_audio)
281
+ clear_input_button.click(
282
+ clear_input_text,
283
+ outputs=input_text
284
+ )
285
+ download_button.click(None, [], [], _js=download_audio_js.format(audio_id=f"en-{name_en.replace(' ', '')}"))
286
+
287
+ # ------------------------------------------------------------------------------------------------------------------------
288
+ with gr.Tab("AI Singer"):
289
+ input_text_singer = gr.Textbox()
290
+
291
+ # ------------------------------------------------------------------------------------------------------------------------
292
+ with gr.Tab("TTS with ChatGPT"):
293
+ input_text_gpt = gr.Textbox()
294
+
295
+ # ------------------------------------------------------------------------------------------------------------------------
296
+ with gr.Tab("Settings"):
297
+ with gr.Box():
298
+ gr.Markdown("""# Select Model""")
299
+ with gr.Row():
300
+ with gr.Column(scale=5):
301
+ model_choice = gr.Dropdown(label="Model",
302
+ choices=[(model["name_en"]) for name, model in models_info.items()],
303
+ interactive=True,
304
+ value=models_info['Yuuka']['name_en']
305
+ )
306
+ with gr.Column(scale=5):
307
+ speaker_id_choice = gr.Dropdown(label="Speaker ID",
308
+ choices=[(str(model["sid"])) for name, model in
309
+ models_info.items()],
310
+ interactive=True,
311
+ value=str(models_info['Yuuka']['sid'])
312
+ )
313
+
314
+ with gr.Column(scale=1):
315
+ refresh_button = gr.Button("Refresh", variant="primary")
316
+ reset_button = gr.Button("Reset")
317
+
318
+ model_choice.change(fn=change_dropdown, inputs=model_choice,
319
+ outputs=[speaker_id_choice, cover_markdown, title_markdown, lan, example_text_box])
320
+
321
+ refresh_button.click(fn=refresh_options, outputs=[model_choice, speaker_id_choice])
322
+ reset_button.click(reset_options, outputs=[model_choice, speaker_id_choice])
323
+
324
+ with gr.Box():
325
+ gr.Markdown("# Add Model\n"
326
+ "> *为必填选项\n"
327
+ "> 添加完成后将**checkpoints**文件放到对应生成的文件夹中"
328
+ )
329
+
330
+ with gr.Row():
331
+ # file = gr.Files(label = "VITS Model*", file_types=[".pth"])
332
+ example_text = gr.Textbox(label="Example Text",
333
+ lines=16,
334
+ placeholder="Enter the example text here", )
335
+ model_cover = gr.Image(label="Cover")
336
+
337
+ with gr.Column():
338
+ model_speaker_id = gr.Textbox(label="Speaker List*",
339
+ placeholder="Single speaker model default=0")
340
+ model_name_en = gr.Textbox(label="name_en*")
341
+ model_name_cn = gr.Textbox(label="name_cn")
342
+ model_language = gr.Dropdown(label="Language*",
343
+ choices=LANGUAGES,
344
+ interactive=True)
345
+ with gr.Row():
346
+ add_model_button = gr.Button("Add Model", variant="primary")
347
+ clear_add_model_button = gr.Button("Clear")
348
+ with gr.Box():
349
+ with gr.Row():
350
+ message_box = gr.Textbox(label="Message")
351
+
352
+ add_model_button.click(add_model_fn,
353
+ inputs=[example_text, model_cover, model_speaker_id, model_name_en, model_name_cn,
354
+ model_language],
355
+ outputs=message_box
356
+ )
357
+ clear_add_model_button.click(clear_add_model_info,
358
+ outputs=[example_text, model_cover, model_speaker_id, model_name_en,
359
+ model_name_cn, model_language]
360
+ )
361
+
362
+ interface.queue(concurrency_count=1).launch(debug=True)
363
+
364
+
365
+
366
+
367
+
368
+
369
+
370
+
371
+