XzJosh commited on
Commit
48c04e9
·
verified ·
1 Parent(s): 2fe559d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -36
app.py CHANGED
@@ -1,5 +1,13 @@
1
- import os,re
2
- import gradio as gr
 
 
 
 
 
 
 
 
3
 
4
  gpt_path = os.environ.get(
5
  "gpt_path", "models/Taffy/Taffy-e5.ckpt"
@@ -49,7 +57,6 @@ else:
49
  bert_model = bert_model.to(device)
50
 
51
 
52
- # bert_model=bert_model.to(device)
53
  def get_bert_feature(text, word2ph):
54
  with torch.no_grad():
55
  inputs = tokenizer(text, return_tensors="pt")
@@ -63,15 +70,8 @@ def get_bert_feature(text, word2ph):
63
  repeat_feature = res[i].repeat(word2ph[i], 1)
64
  phone_level_feature.append(repeat_feature)
65
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
66
- # if(is_half==True):phone_level_feature=phone_level_feature.half()
67
  return phone_level_feature.T
68
 
69
-
70
- n_semantic = 1024
71
-
72
- dict_s2=torch.load(sovits_path,map_location="cpu")
73
- hps=dict_s2["config"]
74
-
75
  class DictToAttrRecursive(dict):
76
  def __init__(self, input_dict):
77
  super().__init__(input_dict)
@@ -99,12 +99,6 @@ class DictToAttrRecursive(dict):
99
  except KeyError:
100
  raise AttributeError(f"Attribute {item} not found")
101
 
102
-
103
- hps = DictToAttrRecursive(hps)
104
-
105
- hps.model.semantic_frame_rate = "25hz"
106
- dict_s1 = torch.load(gpt_path, map_location="cpu")
107
- config = dict_s1["config"]
108
  ssl_model = cnhubert.get_model()
109
  if is_half == True:
110
  ssl_model = ssl_model.half().to(device)
@@ -123,7 +117,8 @@ def change_sovits_weights(sovits_path):
123
  n_speakers=hps.data.n_speakers,
124
  **hps.model
125
  )
126
- del vq_model.enc_q
 
127
  if is_half == True:
128
  vq_model = vq_model.half().to(device)
129
  else:
@@ -165,10 +160,88 @@ def get_spepc(hps, filename):
165
  return spec
166
 
167
 
168
- dict_language = {"中文": "zh", "英文": "en", "日文": "ja"}
169
-
170
-
171
- def get_tts_wav(selected_text, prompt_text, prompt_language, text, text_language):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
  ref_wav_path = text_to_audio_mappings.get(selected_text, "")
173
  if not ref_wav_path:
174
  print("Audio file not found for the selected text.")
@@ -201,28 +274,37 @@ def get_tts_wav(selected_text, prompt_text, prompt_language, text, text_language
201
  t1 = ttime()
202
  prompt_language = dict_language[prompt_language]
203
  text_language = dict_language[text_language]
204
- phones1, word2ph1, norm_text1 = clean_text(prompt_text, prompt_language)
205
- phones1 = cleaned_text_to_sequence(phones1)
206
- texts = text.split("\n")
 
 
 
 
 
 
 
 
 
207
  audio_opt = []
 
 
 
 
208
 
209
  for text in texts:
210
  # 解决输入目标文本的空行导致报错的问题
211
  if (len(text.strip()) == 0):
212
  continue
213
- phones2, word2ph2, norm_text2 = clean_text(text, text_language)
214
- phones2 = cleaned_text_to_sequence(phones2)
215
- if prompt_language == "zh":
216
- bert1 = get_bert_feature(norm_text1, word2ph1).to(device)
217
  else:
218
- bert1 = torch.zeros(
219
- (1024, len(phones1)),
220
- dtype=torch.float16 if is_half == True else torch.float32,
221
- ).to(device)
222
- if text_language == "zh":
223
- bert2 = get_bert_feature(norm_text2, word2ph2).to(device)
224
  else:
225
- bert2 = torch.zeros((1024, len(phones2))).to(bert1)
226
  bert = torch.cat([bert1, bert2], 1)
227
 
228
  all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
@@ -380,7 +462,7 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
380
  ### <center>⚠️在线端不稳定且生成速度较慢,强烈建议下载模型本地推理!\n
381
  """)
382
  # with gr.Tabs():
383
- # with gr.TabItem(i18n("伴奏人声分离&去混响&去回声")):
384
  with gr.Group():
385
  gr.Markdown(value="*参考音频选择(必选)")
386
  with gr.Row():
 
1
+ import os,re,logging
2
+ logging.getLogger("markdown_it").setLevel(logging.ERROR)
3
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
4
+ logging.getLogger("httpcore").setLevel(logging.ERROR)
5
+ logging.getLogger("httpx").setLevel(logging.ERROR)
6
+ logging.getLogger("asyncio").setLevel(logging.ERROR)
7
+
8
+ logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
9
+ logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
10
+ import pdb
11
 
12
  gpt_path = os.environ.get(
13
  "gpt_path", "models/Taffy/Taffy-e5.ckpt"
 
57
  bert_model = bert_model.to(device)
58
 
59
 
 
60
  def get_bert_feature(text, word2ph):
61
  with torch.no_grad():
62
  inputs = tokenizer(text, return_tensors="pt")
 
70
  repeat_feature = res[i].repeat(word2ph[i], 1)
71
  phone_level_feature.append(repeat_feature)
72
  phone_level_feature = torch.cat(phone_level_feature, dim=0)
 
73
  return phone_level_feature.T
74
 
 
 
 
 
 
 
75
  class DictToAttrRecursive(dict):
76
  def __init__(self, input_dict):
77
  super().__init__(input_dict)
 
99
  except KeyError:
100
  raise AttributeError(f"Attribute {item} not found")
101
 
 
 
 
 
 
 
102
  ssl_model = cnhubert.get_model()
103
  if is_half == True:
104
  ssl_model = ssl_model.half().to(device)
 
117
  n_speakers=hps.data.n_speakers,
118
  **hps.model
119
  )
120
+ if("pretrained"not in sovits_path):
121
+ del vq_model.enc_q
122
  if is_half == True:
123
  vq_model = vq_model.half().to(device)
124
  else:
 
160
  return spec
161
 
162
 
163
+ dict_language={
164
+ ("中文"):"zh",
165
+ ("英文"):"en",
166
+ ("日文"):"ja"
167
+ }
168
+
169
+
170
+ def splite_en_inf(sentence, language):
171
+ pattern = re.compile(r'[a-zA-Z. ]+')
172
+ textlist = []
173
+ langlist = []
174
+ pos = 0
175
+ for match in pattern.finditer(sentence):
176
+ start, end = match.span()
177
+ if start > pos:
178
+ textlist.append(sentence[pos:start])
179
+ langlist.append(language)
180
+ textlist.append(sentence[start:end])
181
+ langlist.append("en")
182
+ pos = end
183
+ if pos < len(sentence):
184
+ textlist.append(sentence[pos:])
185
+ langlist.append(language)
186
+
187
+ return textlist, langlist
188
+
189
+
190
+ def clean_text_inf(text, language):
191
+ phones, word2ph, norm_text = clean_text(text, language)
192
+ phones = cleaned_text_to_sequence(phones)
193
+
194
+ return phones, word2ph, norm_text
195
+ def get_bert_inf(phones, word2ph, norm_text, language):
196
+ if language == "zh":
197
+ bert = get_bert_feature(norm_text, word2ph).to(device)
198
+ else:
199
+ bert = torch.zeros(
200
+ (1024, len(phones)),
201
+ dtype=torch.float16 if is_half == True else torch.float32,
202
+ ).to(device)
203
+
204
+ return bert
205
+
206
+
207
+ def nonen_clean_text_inf(text, language):
208
+ textlist, langlist = splite_en_inf(text, language)
209
+ phones_list = []
210
+ word2ph_list = []
211
+ norm_text_list = []
212
+ for i in range(len(textlist)):
213
+ lang = langlist[i]
214
+ phones, word2ph, norm_text = clean_text_inf(textlist[i], lang)
215
+ phones_list.append(phones)
216
+ if lang == "en" or "ja":
217
+ pass
218
+ else:
219
+ word2ph_list.append(word2ph)
220
+ norm_text_list.append(norm_text)
221
+ print(word2ph_list)
222
+ phones = sum(phones_list, [])
223
+ word2ph = sum(word2ph_list, [])
224
+ norm_text = ' '.join(norm_text_list)
225
+
226
+ return phones, word2ph, norm_text
227
+
228
+
229
+ def nonen_get_bert_inf(text, language):
230
+ textlist, langlist = splite_en_inf(text, language)
231
+ print(textlist)
232
+ print(langlist)
233
+ bert_list = []
234
+ for i in range(len(textlist)):
235
+ text = textlist[i]
236
+ lang = langlist[i]
237
+ phones, word2ph, norm_text = clean_text_inf(text, lang)
238
+ bert = get_bert_inf(phones, word2ph, norm_text, lang)
239
+ bert_list.append(bert)
240
+ bert = torch.cat(bert_list, dim=1)
241
+
242
+ return bert
243
+
244
+ def get_tts_wav(selected_text, prompt_text, prompt_language, text, text_language,how_to_cut=("不切")):
245
  ref_wav_path = text_to_audio_mappings.get(selected_text, "")
246
  if not ref_wav_path:
247
  print("Audio file not found for the selected text.")
 
274
  t1 = ttime()
275
  prompt_language = dict_language[prompt_language]
276
  text_language = dict_language[text_language]
277
+
278
+ if prompt_language == "en":
279
+ phones1, word2ph1, norm_text1 = clean_text_inf(prompt_text, prompt_language)
280
+ else:
281
+ phones1, word2ph1, norm_text1 = nonen_clean_text_inf(prompt_text, prompt_language)
282
+ if(how_to_cut==("凑五句一切")):text=cut1(text)
283
+ elif(how_to_cut==("凑50字一切")):text=cut2(text)
284
+ elif(how_to_cut==("按中文句号。切")):text=cut3(text)
285
+ elif(how_to_cut==("按英文句号.切")):text=cut4(text)
286
+ text = text.replace("\n\n","\n").replace("\n\n","\n").replace("\n\n","\n")
287
+ if(text[-1]not in splits):text+="。"if text_language!="en"else "."
288
+ texts=text.split("\n")
289
  audio_opt = []
290
+ if prompt_language == "en":
291
+ bert1 = get_bert_inf(phones1, word2ph1, norm_text1, prompt_language)
292
+ else:
293
+ bert1 = nonen_get_bert_inf(prompt_text, prompt_language)
294
 
295
  for text in texts:
296
  # 解决输入目标文本的空行导致报错的问题
297
  if (len(text.strip()) == 0):
298
  continue
299
+ if text_language == "en":
300
+ phones2, word2ph2, norm_text2 = clean_text_inf(text, text_language)
 
 
301
  else:
302
+ phones2, word2ph2, norm_text2 = nonen_clean_text_inf(text, text_language)
303
+
304
+ if text_language == "en":
305
+ bert2 = get_bert_inf(phones2, word2ph2, norm_text2, text_language)
 
 
306
  else:
307
+ bert2 = nonen_get_bert_inf(text, text_language)
308
  bert = torch.cat([bert1, bert2], 1)
309
 
310
  all_phoneme_ids = torch.LongTensor(phones1 + phones2).to(device).unsqueeze(0)
 
462
  ### <center>⚠️在线端不稳定且生成速度较慢,强烈建议下载模型本地推理!\n
463
  """)
464
  # with gr.Tabs():
465
+
466
  with gr.Group():
467
  gr.Markdown(value="*参考音频选择(必选)")
468
  with gr.Row():