Plachta commited on
Commit
405566d
1 Parent(s): cf8f15a

Added unconditional generation

Browse files
Files changed (1) hide show
  1. app.py +43 -37
app.py CHANGED
@@ -180,29 +180,49 @@ def make_npz_prompt(name, uploaded_audio, recorded_audio, transcript_content):
180
  def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
181
  if len(text) > 150:
182
  return "Rejected, Text too long (should be less than 150 characters)", None
183
- audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt
184
- sr, wav_pr = audio_prompt
185
- if len(wav_pr) / sr > 15:
186
- return "Rejected, Audio too long (should be less than 15 seconds)", None
187
- if not isinstance(wav_pr, torch.FloatTensor):
188
- wav_pr = torch.FloatTensor(wav_pr)
189
- if wav_pr.abs().max() > 1:
190
- wav_pr /= wav_pr.abs().max()
191
- if wav_pr.size(-1) == 2:
192
- wav_pr = wav_pr[:, 0]
193
- if wav_pr.ndim == 1:
194
- wav_pr = wav_pr.unsqueeze(0)
195
- assert wav_pr.ndim and wav_pr.size(0) == 1
196
-
197
- if transcript_content == "":
198
- lang_pr, text_pr = transcribe_one(wav_pr, sr)
199
- lang_token = lang2token[lang_pr]
200
- text_pr = lang_token + text_pr + lang_token
201
  else:
202
- lang_pr = langid.classify(str(transcript_content))[0]
203
- text_pr = transcript_content.replace("\n", "")
204
- lang_token = lang2token[lang_pr]
205
- text_pr = lang_token + text_pr + lang_token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
 
207
  if language == 'auto-detect':
208
  lang_token = lang2token[langid.classify(text)[0]]
@@ -212,13 +232,6 @@ def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt,
212
  text = text.replace("\n", "")
213
  text = lang_token + text + lang_token
214
 
215
- if lang_pr not in ['ja', 'zh', 'en']:
216
- return f"Reference audio must be a speech of one of model-supported languages, got {lang_pr} instead", None
217
-
218
- # tokenize audio
219
- encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
220
- audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
221
-
222
  # tokenize text
223
  logging.info(f"synthesize text: {text}")
224
  phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
@@ -228,14 +241,7 @@ def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt,
228
  ]
229
  )
230
 
231
- enroll_x_lens = None
232
- if text_pr:
233
- text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
234
- text_prompts, enroll_x_lens = text_collater(
235
- [
236
- text_prompts
237
- ]
238
- )
239
  text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
240
  text_tokens_lens += enroll_x_lens
241
  lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]
 
180
  def infer_from_audio(text, language, accent, audio_prompt, record_audio_prompt, transcript_content):
181
  if len(text) > 150:
182
  return "Rejected, Text too long (should be less than 150 characters)", None
183
+ if audio_prompt is None and record_audio_prompt is None:
184
+ audio_prompts = torch.zeros([1, 0, NUM_QUANTIZERS]).type(torch.int32).to(device)
185
+ text_prompts = torch.zeros([1, 0]).type(torch.int32)
186
+ lang_pr = language if language != 'mix' else 'en'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  else:
188
+ audio_prompt = audio_prompt if audio_prompt is not None else record_audio_prompt
189
+ sr, wav_pr = audio_prompt
190
+ if len(wav_pr) / sr > 15:
191
+ return "Rejected, Audio too long (should be less than 15 seconds)", None
192
+ if not isinstance(wav_pr, torch.FloatTensor):
193
+ wav_pr = torch.FloatTensor(wav_pr)
194
+ if wav_pr.abs().max() > 1:
195
+ wav_pr /= wav_pr.abs().max()
196
+ if wav_pr.size(-1) == 2:
197
+ wav_pr = wav_pr[:, 0]
198
+ if wav_pr.ndim == 1:
199
+ wav_pr = wav_pr.unsqueeze(0)
200
+ assert wav_pr.ndim and wav_pr.size(0) == 1
201
+
202
+ if transcript_content == "":
203
+ lang_pr, text_pr = transcribe_one(wav_pr, sr)
204
+ lang_token = lang2token[lang_pr]
205
+ text_pr = lang_token + text_pr + lang_token
206
+ else:
207
+ lang_pr = langid.classify(str(transcript_content))[0]
208
+ text_pr = transcript_content.replace("\n", "")
209
+ if lang_pr not in ['ja', 'zh', 'en']:
210
+ return f"Reference audio must be a speech of one of model-supported languages, got {lang_pr} instead", None
211
+ lang_token = lang2token[lang_pr]
212
+ text_pr = lang_token + text_pr + lang_token
213
+
214
+ # tokenize audio
215
+ encoded_frames = tokenize_audio(audio_tokenizer, (wav_pr, sr))
216
+ audio_prompts = encoded_frames[0][0].transpose(2, 1).to(device)
217
+
218
+ enroll_x_lens = None
219
+ if text_pr:
220
+ text_prompts, _ = text_tokenizer.tokenize(text=f"{text_pr}".strip())
221
+ text_prompts, enroll_x_lens = text_collater(
222
+ [
223
+ text_prompts
224
+ ]
225
+ )
226
 
227
  if language == 'auto-detect':
228
  lang_token = lang2token[langid.classify(text)[0]]
 
232
  text = text.replace("\n", "")
233
  text = lang_token + text + lang_token
234
 
 
 
 
 
 
 
 
235
  # tokenize text
236
  logging.info(f"synthesize text: {text}")
237
  phone_tokens, langs = text_tokenizer.tokenize(text=f"_{text}".strip())
 
241
  ]
242
  )
243
 
244
+
 
 
 
 
 
 
 
245
  text_tokens = torch.cat([text_prompts, text_tokens], dim=-1)
246
  text_tokens_lens += enroll_x_lens
247
  lang = lang if accent == "no-accent" else token2lang[langdropdown2token[accent]]