XzJosh commited on
Commit
a4f525f
1 Parent(s): 3b37c27

Upload infer.py

Browse files
Files changed (1) hide show
  1. infer.py +25 -8
infer.py CHANGED
@@ -85,22 +85,22 @@ def get_text(text, language_str, hps, device):
85
  for i in range(len(word2ph)):
86
  word2ph[i] = word2ph[i] * 2
87
  word2ph[0] += 1
88
- bert = get_bert(norm_text, word2ph, language_str, device)
89
  del word2ph
90
- assert bert.shape[-1] == len(phone), phone
91
 
92
  if language_str == "ZH":
93
- bert = bert
94
  ja_bert = torch.zeros(1024, len(phone))
95
  en_bert = torch.zeros(1024, len(phone))
96
  elif language_str == "JP":
97
  bert = torch.zeros(1024, len(phone))
98
- ja_bert = bert
99
  en_bert = torch.zeros(1024, len(phone))
100
  elif language_str == "EN":
101
  bert = torch.zeros(1024, len(phone))
102
  ja_bert = torch.zeros(1024, len(phone))
103
- en_bert = bert
104
  else:
105
  raise ValueError("language_str should be ZH, JP or EN")
106
 
@@ -125,6 +125,8 @@ def infer(
125
  hps,
126
  net_g,
127
  device,
 
 
128
  ):
129
  # 支持中日双语版本
130
  inferMap_V2 = {
@@ -172,6 +174,20 @@ def infer(
172
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
173
  text, language, hps, device
174
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  with torch.no_grad():
176
  x_tst = phones.to(device).unsqueeze(0)
177
  tones = tones.to(device).unsqueeze(0)
@@ -201,10 +217,11 @@ def infer(
201
  .float()
202
  .numpy()
203
  )
204
- del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers
205
- torch.cuda.empty_cache()
 
206
  return audio
207
-
208
 
209
  def infer_multilang(
210
  text,
 
85
  for i in range(len(word2ph)):
86
  word2ph[i] = word2ph[i] * 2
87
  word2ph[0] += 1
88
+ bert_ori = get_bert(norm_text, word2ph, language_str, device)
89
  del word2ph
90
+ assert bert_ori.shape[-1] == len(phone), phone
91
 
92
  if language_str == "ZH":
93
+ bert = bert_ori
94
  ja_bert = torch.zeros(1024, len(phone))
95
  en_bert = torch.zeros(1024, len(phone))
96
  elif language_str == "JP":
97
  bert = torch.zeros(1024, len(phone))
98
+ ja_bert = bert_ori
99
  en_bert = torch.zeros(1024, len(phone))
100
  elif language_str == "EN":
101
  bert = torch.zeros(1024, len(phone))
102
  ja_bert = torch.zeros(1024, len(phone))
103
+ en_bert = bert_ori
104
  else:
105
  raise ValueError("language_str should be ZH, JP or EN")
106
 
 
125
  hps,
126
  net_g,
127
  device,
128
+ skip_start=False,
129
+ skip_end=False,
130
  ):
131
  # 支持中日双语版本
132
  inferMap_V2 = {
 
174
  bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
175
  text, language, hps, device
176
  )
177
+ if skip_start:
178
+ phones = phones[1:]
179
+ tones = tones[1:]
180
+ lang_ids = lang_ids[1:]
181
+ bert = bert[:, 1:]
182
+ ja_bert = ja_bert[:, 1:]
183
+ en_bert = en_bert[:, 1:]
184
+ if skip_end:
185
+ phones = phones[:-1]
186
+ tones = tones[:-1]
187
+ lang_ids = lang_ids[:-1]
188
+ bert = bert[:, :-1]
189
+ ja_bert = ja_bert[:, :-1]
190
+ en_bert = en_bert[:, :-1]
191
  with torch.no_grad():
192
  x_tst = phones.to(device).unsqueeze(0)
193
  tones = tones.to(device).unsqueeze(0)
 
217
  .float()
218
  .numpy()
219
  )
220
+ del x_tst, tones, lang_ids, bert, x_tst_lengths, speakers, ja_bert, en_bert
221
+ if torch.cuda.is_available():
222
+ torch.cuda.empty_cache()
223
  return audio
224
+
225
 
226
  def infer_multilang(
227
  text,