raoyonghui commited on
Commit
f7428c0
·
1 Parent(s): a8db66d

init whisper model when inference

Browse files
Files changed (1) hide show
  1. app.py +19 -9
app.py CHANGED
@@ -18,17 +18,23 @@ from utils.util import load_config
18
  from models.tts.maskgct.g2p.g2p_generation import g2p, chn_eng_g2p
19
 
20
  from transformers import SeamlessM4TFeatureExtractor
 
21
 
22
- import whisper
23
 
24
  processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
25
-
26
  device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
27
- whisper_model = whisper.load_model("turbo")
 
 
 
 
28
 
29
  def detect_speech_language(speech_file):
 
 
 
 
30
  # load audio and pad/trim it to fit 30 seconds
31
- whisper_model = whisper.load_model("turbo")
32
  audio = whisper.load_audio(speech_file)
33
  audio = whisper.pad_or_trim(audio)
34
 
@@ -46,6 +52,10 @@ def get_prompt_text(speech_16k, language):
46
  shot_prompt_text = ""
47
  short_prompt_end_ts = 0.0
48
 
 
 
 
 
49
  asr_result = whisper_model.transcribe(speech_16k, language=language)
50
  full_prompt_text = asr_result["text"] # whisper asr result
51
  #text = asr_result["segments"][0]["text"] # whisperx asr result
@@ -301,7 +311,6 @@ def load_models():
301
  def maskgct_inference(
302
  prompt_speech_path,
303
  target_text,
304
- target_language="en",
305
  target_len=None,
306
  n_timesteps=25,
307
  cfg=2.5,
@@ -320,6 +329,8 @@ def maskgct_inference(
320
  # use the first 4+ seconds wav as the prompt in case the prompt wav is too long
321
  speech = speech[0: int(shot_prompt_end_ts * 24000)]
322
  speech_16k = speech_16k[0: int(shot_prompt_end_ts*16000)]
 
 
323
  combine_semantic_code, _ = text2semantic(
324
  device,
325
  speech_16k,
@@ -351,19 +362,19 @@ def inference(
351
  target_text,
352
  target_len,
353
  n_timesteps,
354
- target_language,
355
  ):
356
- save_path = "./output/output.wav"
 
357
  os.makedirs("./output", exist_ok=True)
358
  recovered_audio = maskgct_inference(
359
  prompt_wav,
360
  target_text,
361
- target_language,
362
  target_len=target_len,
363
  n_timesteps=int(n_timesteps),
364
  device=device,
365
  )
366
  sf.write(save_path, recovered_audio, 24000)
 
367
  return save_path
368
 
369
  # Load models once
@@ -394,7 +405,6 @@ iface = gr.Interface(
394
  gr.Slider(
395
  label="Number of Timesteps", minimum=15, maximum=100, value=25, step=1
396
  ),
397
- gr.Dropdown(label="Target Language", choices=language_list, value="en"),
398
  ],
399
  outputs=gr.Audio(label="Generated Audio"),
400
  title="MaskGCT TTS Demo",
 
18
  from models.tts.maskgct.g2p.g2p_generation import g2p, chn_eng_g2p
19
 
20
  from transformers import SeamlessM4TFeatureExtractor
21
+ import py3langid as langid
22
 
 
23
 
24
  processor = SeamlessM4TFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0")
 
25
  device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
26
+ whisper_model = None
27
+ output_file_name_idx = 0
28
+
29
+ def detect_text_language(text):
30
+ return langid.classify(text)[0]
31
 
32
  def detect_speech_language(speech_file):
33
+ import whisper
34
+ global whisper_model
35
+ if whisper_model == None:
36
+ whisper_model = whisper.load_model("turbo")
37
  # load audio and pad/trim it to fit 30 seconds
 
38
  audio = whisper.load_audio(speech_file)
39
  audio = whisper.pad_or_trim(audio)
40
 
 
52
  shot_prompt_text = ""
53
  short_prompt_end_ts = 0.0
54
 
55
+ import whisper
56
+ global whisper_model
57
+ if whisper_model == None:
58
+ whisper_model = whisper.load_model("turbo")
59
  asr_result = whisper_model.transcribe(speech_16k, language=language)
60
  full_prompt_text = asr_result["text"] # whisper asr result
61
  #text = asr_result["segments"][0]["text"] # whisperx asr result
 
311
  def maskgct_inference(
312
  prompt_speech_path,
313
  target_text,
 
314
  target_len=None,
315
  n_timesteps=25,
316
  cfg=2.5,
 
329
  # use the first 4+ seconds wav as the prompt in case the prompt wav is too long
330
  speech = speech[0: int(shot_prompt_end_ts * 24000)]
331
  speech_16k = speech_16k[0: int(shot_prompt_end_ts*16000)]
332
+
333
+ target_language = detect_text_language(target_text)
334
  combine_semantic_code, _ = text2semantic(
335
  device,
336
  speech_16k,
 
362
  target_text,
363
  target_len,
364
  n_timesteps,
 
365
  ):
366
+ global output_file_name_idx
367
+ save_path = f"./output/output_{output_file_name_idx}.wav"
368
  os.makedirs("./output", exist_ok=True)
369
  recovered_audio = maskgct_inference(
370
  prompt_wav,
371
  target_text,
 
372
  target_len=target_len,
373
  n_timesteps=int(n_timesteps),
374
  device=device,
375
  )
376
  sf.write(save_path, recovered_audio, 24000)
377
+ output_file_name_idx = (output_file_name_idx + 1) % 10
378
  return save_path
379
 
380
  # Load models once
 
405
  gr.Slider(
406
  label="Number of Timesteps", minimum=15, maximum=100, value=25, step=1
407
  ),
 
408
  ],
409
  outputs=gr.Audio(label="Generated Audio"),
410
  title="MaskGCT TTS Demo",