XzJosh commited on
Commit
9d3080c
1 Parent(s): 9fa3841

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +363 -0
app.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ gpt_path = os.environ.get(
4
+ "gpt_path", "models/Carol/Carol-e15.ckpt"
5
+ )
6
+ sovits_path = os.environ.get("sovits_path", "models/Carol/Carol_e40_s2160.pth")
7
+ cnhubert_base_path = os.environ.get(
8
+ "cnhubert_base_path", "pretrained_models/chinese-hubert-base"
9
+ )
10
+ bert_path = os.environ.get(
11
+ "bert_path", "pretrained_models/chinese-roberta-wwm-ext-large"
12
+ )
13
+ infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
14
+ infer_ttswebui = int(infer_ttswebui)
15
+ if "_CUDA_VISIBLE_DEVICES" in os.environ:
16
+ os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
17
+ is_half = eval(os.environ.get("is_half", "True"))
18
+ import gradio as gr
19
+ from transformers import AutoModelForMaskedLM, AutoTokenizer
20
+ import numpy as np
21
+ import librosa,torch
22
+ from feature_extractor import cnhubert
23
+ cnhubert.cnhubert_base_path=cnhubert_base_path
24
+
25
+ from module.models import SynthesizerTrn
26
+ from AR.models.t2s_lightning_module import Text2SemanticLightningModule
27
+ from text import cleaned_text_to_sequence
28
+ from text.cleaner import clean_text
29
+ from time import time as ttime
30
+ from module.mel_processing import spectrogram_torch
31
+ from my_utils import load_audio
32
+
33
+ device = "cuda"
34
+ tokenizer = AutoTokenizer.from_pretrained(bert_path)
35
+ bert_model = AutoModelForMaskedLM.from_pretrained(bert_path)
36
+ if is_half == True:
37
+ bert_model = bert_model.half().to(device)
38
+ else:
39
+ bert_model = bert_model.to(device)
40
+
41
+
42
+ # bert_model=bert_model.to(device)
43
+ def get_bert_feature(text, word2ph):
44
+ with torch.no_grad():
45
+ inputs = tokenizer(text, return_tensors="pt")
46
+ for i in inputs:
47
+ inputs[i] = inputs[i].to(device) #####输入是long不用管精度问题,精度随bert_model
48
+ res = bert_model(**inputs, output_hidden_states=True)
49
+ res = torch.cat(res["hidden_states"][-3:-2], -1)[0].cpu()[1:-1]
50
+ assert len(word2ph) == len(text)
51
+ phone_level_feature = []
52
+ for i in range(len(word2ph)):
53
+ repeat_feature = res[i].repeat(word2ph[i], 1)
54
+ phone_level_feature.append(repeat_feature)
55
+ phone_level_feature = torch.cat(phone_level_feature, dim=0)
56
+ # if(is_half==True):phone_level_feature=phone_level_feature.half()
57
+ return phone_level_feature.T
58
+
59
+
60
+ n_semantic = 1024
61
+
62
+ dict_s2=torch.load(sovits_path,map_location="cpu")
63
+ hps=dict_s2["config"]
64
+
65
+ class DictToAttrRecursive(dict):
66
+ def __init__(self, input_dict):
67
+ super().__init__(input_dict)
68
+ for key, value in input_dict.items():
69
+ if isinstance(value, dict):
70
+ value = DictToAttrRecursive(value)
71
+ self[key] = value
72
+ setattr(self, key, value)
73
+
74
+ def __getattr__(self, item):
75
+ try:
76
+ return self[item]
77
+ except KeyError:
78
+ raise AttributeError(f"Attribute {item} not found")
79
+
80
+ def __setattr__(self, key, value):
81
+ if isinstance(value, dict):
82
+ value = DictToAttrRecursive(value)
83
+ super(DictToAttrRecursive, self).__setitem__(key, value)
84
+ super().__setattr__(key, value)
85
+
86
+ def __delattr__(self, item):
87
+ try:
88
+ del self[item]
89
+ except KeyError:
90
+ raise AttributeError(f"Attribute {item} not found")
91
+
92
+
93
+ hps = DictToAttrRecursive(hps)
94
+
95
+ hps.model.semantic_frame_rate = "25hz"
96
+ dict_s1 = torch.load(gpt_path, map_location="cpu")
97
+ config = dict_s1["config"]
98
+ ssl_model = cnhubert.get_model()
99
+ if is_half == True:
100
+ ssl_model = ssl_model.half().to(device)
101
+ else:
102
+ ssl_model = ssl_model.to(device)
103
+
104
+ vq_model = SynthesizerTrn(
105
+ hps.data.filter_length // 2 + 1,
106
+ hps.train.segment_size // hps.data.hop_length,
107
+ n_speakers=hps.data.n_speakers,
108
+ **hps.model
109
+ )
110
+ if is_half == True:
111
+ vq_model = vq_model.half().to(device)
112
+ else:
113
+ vq_model = vq_model.to(device)
114
+ vq_model.eval()
115
+ print(vq_model.load_state_dict(dict_s2["weight"], strict=False))
116
+ hz = 50
117
+ max_sec = config["data"]["max_sec"]
118
+ # t2s_model = Text2SemanticLightningModule.load_from_checkpoint(checkpoint_path=gpt_path, config=config, map_location="cpu")#########todo
119
+ t2s_model = Text2SemanticLightningModule(config, "ojbk", is_train=False)
120
+ t2s_model.load_state_dict(dict_s1["weight"])
121
+ if is_half == True:
122
+ t2s_model = t2s_model.half()
123
+ t2s_model = t2s_model.to(device)
124
+ t2s_model.eval()
125
+ total = sum([param.nelement() for param in t2s_model.parameters()])
126
+ print("Number of parameter: %.2fM" % (total / 1e6))
127
+
128
+
129
+ def get_spepc(hps, filename):
130
+ audio = load_audio(filename, int(hps.data.sampling_rate))
131
+ audio = torch.FloatTensor(audio)
132
+ audio_norm = audio
133
+ audio_norm = audio_norm.unsqueeze(0)
134
+ spec = spectrogram_torch(
135
+ audio_norm,
136
+ hps.data.filter_length,
137
+ hps.data.sampling_rate,
138
+ hps.data.hop_length,
139
+ hps.data.win_length,
140
+ center=False,
141
+ )
142
+ return spec
143
+
144
+
145
+ dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
146
+
147
+
148
+ def get_tts_wav(ref_wav_path, prompt_text, prompt_language, text, text_language):
149
+ t0 = ttime()
150
+ prompt_text = prompt_text.strip("\n")
151
+ prompt_language, text = prompt_language, text.strip("\n")
152
+ with torch.no_grad():
153
+ wav16k, sr = librosa.load(ref_wav_path, sr=16000) # 派蒙
154
+ wav16k = torch.from_numpy(wav16k)
155
+ if is_half == True:
156
+ wav16k = wav16k.half().to(device)
157
+ else:
158
+ wav16k = wav16k.to(device)
159
+ ssl_content = ssl_model.model(wav16k.unsqueeze(0))[
160
+ "last_hidden_state"
161
+ ].transpose(
162
+ 1, 2
163
+ ) # .float()
164
+ codes = vq_model.extract_latent(ssl_content)
165
+ prompt_semantic = codes[0, 0]
166
+ t1 = ttime()
167
+ prompt_language = dict_language[prompt_language]
168
+ text_language = dict_language[text_language]
169
+ phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
170
+ phones1 = cleaned_text_to_sequence(phones1)
171
+ texts = text.split("\n")
172
+ audio_opt = []
173
+ zero_wav = np.zeros(
174
+ int(hps.data.sampling_rate * 0.3),
175
+ dtype=np.float16 if is_half == True else np.float32,
176
+ )
177
+ for text in texts:
178
+ # 解决输入目标文本的空行导致报错的问题
179
+ if (len(text.strip()) == 0):
180
+ continue
181
+ phones2, word2ph2, norm_text2 = clean_text(text, text_language)
182
+ phones2 = cleaned_text_to_sequence(phones2)
183
+ if prompt_language == "zh":
184
+ bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
185
+ else:
186
+ bert1 = torch.zeros(
187
+ (1024, len(phones1)),
188
+ dtype=torch.float16 if is_half == True else torch.float32,
189
+ ).to(device)
190
+ if text_language == "zh":
191
+ bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
192
+ else:
193
+ bert2 = torch.zeros((1024, len(phones2))).to(bert1)
194
+ bert = torch.cat([bert1, bert2], 1)
195
+
196
+ all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
197
+ bert = bert.to(device).unsqueeze(0)
198
+ all_phoneme_len = torch.tensor([all_phoneme_ids.shape[-1]]).to(device)
199
+ prompt = prompt_semantic.unsqueeze(0).to(device)
200
+ t2 = ttime()
201
+ with torch.no_grad():
202
+ # pred_semantic = t2s_model.model.infer(
203
+ pred_semantic, idx = t2s_model.model.infer_panel(
204
+ all_phoneme_ids,
205
+ all_phoneme_len,
206
+ prompt,
207
+ bert,
208
+ # prompt_phone_len=ph_offset,
209
+ top_k=config["inference"]["top_k"],
210
+ early_stop_num=hz * max_sec,
211
+ )
212
+ t3 = ttime()
213
+ # print(pred_semantic.shape,idx)
214
+ pred_semantic = pred_semantic[:, -idx:].unsqueeze(
215
+ 0
216
+ ) # .unsqueeze(0)#mq要多unsqueeze一次
217
+ refer = get_spepc(hps, ref_wav_path) # .to(device)
218
+ if is_half == True:
219
+ refer = refer.half().to(device)
220
+ else:
221
+ refer = refer.to(device)
222
+ # audio = vq_model.decode(pred_semantic, all_phoneme_ids, refer).detach().cpu().numpy()[0, 0]
223
+ audio = (
224
+ vq_model.decode(
225
+ pred_semantic, torch.LongTensor(phones2).to(device).unsqueeze(0), refer
226
+ )
227
+ .detach()
228
+ .cpu()
229
+ .numpy()[0, 0]
230
+ ) ###试试重建不带上prompt部分
231
+ audio_opt.append(audio)
232
+ audio_opt.append(zero_wav)
233
+ t4 = ttime()
234
+ print("%.3f\t%.3f\t%.3f\t%.3f" % (t1 - t0, t2 - t1, t3 - t2, t4 - t3))
235
+ yield hps.data.sampling_rate, (np.concatenate(audio_opt, 0) * 32768).astype(
236
+ np.int16
237
+ )
238
+
239
+
240
+ splits = {
241
+ ",",
242
+ "。",
243
+ "?",
244
+ "!",
245
+ ",",
246
+ ".",
247
+ "?",
248
+ "!",
249
+ "~",
250
+ ":",
251
+ ":",
252
+ "—",
253
+ "…",
254
+ } # 不考虑省略号
255
+
256
+
257
+ def split(todo_text):
258
+ todo_text = todo_text.replace("……", "。").replace("——", ",")
259
+ if todo_text[-1] not in splits:
260
+ todo_text += "。"
261
+ i_split_head = i_split_tail = 0
262
+ len_text = len(todo_text)
263
+ todo_texts = []
264
+ while 1:
265
+ if i_split_head >= len_text:
266
+ break # 结尾一定有标点,所以直接跳出即可,最后一段在上次已加入
267
+ if todo_text[i_split_head] in splits:
268
+ i_split_head += 1
269
+ todo_texts.append(todo_text[i_split_tail:i_split_head])
270
+ i_split_tail = i_split_head
271
+ else:
272
+ i_split_head += 1
273
+ return todo_texts
274
+
275
+
276
+ def cut1(inp):
277
+ inp = inp.strip("\n")
278
+ inps = split(inp)
279
+ split_idx = list(range(0, len(inps), 5))
280
+ split_idx[-1] = None
281
+ if len(split_idx) > 1:
282
+ opts = []
283
+ for idx in range(len(split_idx) - 1):
284
+ opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
285
+ else:
286
+ opts = [inp]
287
+ return "\n".join(opts)
288
+
289
+
290
+ def cut2(inp):
291
+ inp = inp.strip("\n")
292
+ inps = split(inp)
293
+ if len(inps) < 2:
294
+ return [inp]
295
+ opts = []
296
+ summ = 0
297
+ tmp_str = ""
298
+ for i in range(len(inps)):
299
+ summ += len(inps[i])
300
+ tmp_str += inps[i]
301
+ if summ > 50:
302
+ summ = 0
303
+ opts.append(tmp_str)
304
+ tmp_str = ""
305
+ if tmp_str != "":
306
+ opts.append(tmp_str)
307
+ if len(opts[-1]) < 50: ##如果最后一个太短了,和前一个合一起
308
+ opts[-2] = opts[-2] + opts[-1]
309
+ opts = opts[:-1]
310
+ return "\n".join(opts)
311
+
312
+
313
+ def cut3(inp):
314
+ inp = inp.strip("\n")
315
+ return "\n".join(["%s。" % item for item in inp.strip("。").split("。")])
316
+
317
+
318
+ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
319
+ gr.Markdown(
320
+ value="本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. <br>如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录<b>LICENSE</b>."
321
+ )
322
+ # with gr.Tabs():
323
+ # with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
324
+ with gr.Group():
325
+ gr.Markdown(value="*请上传并填写参考信息")
326
+ with gr.Row():
327
+ inp_ref = gr.Audio(label="请上传参考音频", type="filepath")
328
+ prompt_text = gr.Textbox(label="参考音频的文本", value="")
329
+ prompt_language = gr.Dropdown(
330
+ label="参考音频的语种", choices=["中文", "英文", "日文"], value="中文"
331
+ )
332
+ gr.Markdown(value="*请填写需要合成的目标文本")
333
+ with gr.Row():
334
+ text = gr.Textbox(label="需要合成的文本", value="")
335
+ text_language = gr.Dropdown(
336
+ label="需要合成的语种", choices=["中文", "英文", "日文"], value="中文"
337
+ )
338
+ inference_button = gr.Button("合成语音", variant="primary")
339
+ output = gr.Audio(label="输出的语音")
340
+ inference_button.click(
341
+ get_tts_wav,
342
+ [inp_ref, prompt_text, prompt_language, text, text_language],
343
+ [output],
344
+ )
345
+
346
+ gr.Markdown(value="文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。")
347
+ with gr.Row():
348
+ text_inp = gr.Textbox(label="需要合成的切分前文本", value="")
349
+ button1 = gr.Button("凑五句一切", variant="primary")
350
+ button2 = gr.Button("凑50字一切", variant="primary")
351
+ button3 = gr.Button("按中文句号。切", variant="primary")
352
+ text_opt = gr.Textbox(label="切分后文本", value="")
353
+ button1.click(cut1, [text_inp], [text_opt])
354
+ button2.click(cut2, [text_inp], [text_opt])
355
+ button3.click(cut3, [text_inp], [text_opt])
356
+ gr.Markdown(value="后续将支持混合语种编码文本输入。")
357
+
358
+ app.queue(concurrency_count=511, max_size=1022).launch(
359
+ server_name="0.0.0.0",
360
+ inbrowser=True,
361
+ server_port=infer_ttswebui,
362
+ quiet=True,
363
+ )