Approximetal commited on
Commit
4508345
·
verified ·
1 Parent(s): c7f69cf

Update gradio_mix.py

Browse files
Files changed (1) hide show
  1. gradio_mix.py +16 -49
gradio_mix.py CHANGED
@@ -68,15 +68,7 @@ def _pick_device():
68
  return "cuda" if torch.cuda.is_available() else "cpu"
69
 
70
  device = _pick_device()
71
- # For WhisperX ASR:
72
- # - On Spaces we always construct the pipeline lazily inside @spaces.GPU
73
- # functions, so keep the default "cpu" here to avoid touching CUDA in
74
- # the main process.
75
- # - Elsewhere prefer CUDA if available.
76
- if IS_SPACES:
77
- ASR_DEVICE = "cpu"
78
- else:
79
- ASR_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
80
  whisper_model, align_model = None, None
81
  tts_edit_model = None
82
 
@@ -139,31 +131,16 @@ class UVR5:
139
  )
140
 
141
  uvr5_model = Inference(model_data, device)
142
- # On HF Spaces with stateless GPU, we must not initialize CUDA in the
143
- # main process. The heavy UVR5 loading happens lazily inside
144
- # @spaces.GPU functions; this guard is kept only for the CPU path to
145
- # avoid any accidental CUDA init.
146
- if IS_SPACES and device == "cpu":
147
- orig_is_available = torch.cuda.is_available
148
- torch.cuda.is_available = lambda: False
149
- try:
150
- uvr5_model.load_model(model_path, 1)
151
- finally:
152
- torch.cuda.is_available = orig_is_available
153
- else:
154
- uvr5_model.load_model(model_path, 1)
155
 
156
  self.model = uvr5_model
157
  self.device = device
158
  return self.model
159
 
160
  def denoise(self, audio_info):
161
- # Prefer GPU if available; on Spaces this runs inside @spaces.GPU so
162
- # CUDA can be safely initialized here.
163
- device = "cuda" if torch.cuda.is_available() else "cpu"
164
- model = self.load_model(device=device)
165
  input_audio = load_wav(audio_info, sr=44100, channel=2)
166
- output_audio = model.demix_base({0:input_audio.squeeze()}, is_match_mix=False, device=device)
167
  # transform = torchaudio.transforms.Resample(44100, 16000)
168
  # output_audio = transform(output_audio)
169
  return output_audio.squeeze().T.cpu().numpy(), 44100
@@ -450,9 +427,13 @@ class MMSAlignModel:
450
  class WhisperxModel:
451
  def __init__(self, model_name):
452
  # Lazily construct the WhisperX pipeline so that on Spaces we only
453
- # touch CUDA inside @spaces.GPU workers.
454
  self.model_name = model_name
455
- self.model = None
 
 
 
 
456
 
457
  def _ensure_model(self):
458
  if self.model is not None:
@@ -461,19 +442,11 @@ class WhisperxModel:
461
 
462
  prompt = None # "This might be a blend of Simplified Chinese and English speech, do not translate, only transcription be allowed."
463
 
464
- # On Spaces, this will be called from within @spaces.GPU so we can
465
- # safely move the ASR to CUDA if available. Locally we respect the
466
- # ASR_DEVICE hint.
467
- if IS_SPACES:
468
- asr_device = "cuda" if torch.cuda.is_available() else "cpu"
469
- else:
470
- asr_device = ASR_DEVICE
471
-
472
  # Use the lighter Silero VAD backend to avoid pyannote checkpoints
473
  # and their PyTorch 2.6 `weights_only` pickling issues.
474
  self.model = load_model(
475
  self.model_name,
476
- asr_device,
477
  compute_type="float32",
478
  asr_options={
479
  "suppress_numerals": False,
@@ -700,7 +673,7 @@ def get_transcribe_state(segments):
700
  "word_bounds": [f"{word['start']} {word['word']} {word['end']}" for word in segments["words"]]
701
  }
702
 
703
- @spaces.GPU(duration=240)
704
  @torch.no_grad()
705
  @torch.inference_mode()
706
  def transcribe(seed, audio_info):
@@ -719,9 +692,6 @@ def transcribe(seed, audio_info):
719
  state
720
  ]
721
 
722
- @spaces.GPU(duration=240)
723
- @torch.no_grad()
724
- @torch.inference_mode()
725
  def align(transcript, audio_info, state):
726
  lang = state["segments"]["lang"]
727
  # print("realign: ", transcript, state)
@@ -747,9 +717,6 @@ def align(transcript, audio_info, state):
747
  ]
748
 
749
 
750
- @spaces.GPU(duration=240)
751
- @torch.no_grad()
752
- @torch.inference_mode()
753
  def denoise(audio_info):
754
  # Denoiser can be relatively heavy (especially UVR5), so schedule it on
755
  # GPU workers when running on HF Spaces.
@@ -769,7 +736,7 @@ def get_output_audio(audio_tensors, sr):
769
  print("save result:", result.shape)
770
  # wavfile.write(os.path.join(TMP_PATH, "output.wav"), sr, result)
771
  return (int(sr), result)
772
-
773
 
774
  def get_edit_audio_part(audio_info, edit_start, edit_end):
775
  sr, raw_wav = audio_info
@@ -796,7 +763,7 @@ def replace_numbers_with_words(sentence, lang="en"):
796
  return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers
797
 
798
 
799
- @spaces.GPU(duration=240)
800
  @torch.no_grad()
801
  @torch.inference_mode()
802
  def run(seed, nfe_step, speed, cfg_strength, sway_sampling_coef, ref_ratio,
@@ -1069,7 +1036,7 @@ def get_app():
1069
  )
1070
  denoise_model_choice = gr.Radio(label="Denoise Model", scale=2, value="UVR5", choices=["UVR5", "DeepFilterNet"]) # "830M", "330M_TTSEnhanced", "830M_TTSEnhanced"])
1071
  # whisper_backend_choice = gr.Radio(label="Whisper backend", value="", choices=["whisperX", "whisper"])
1072
- whisper_model_choice = gr.Radio(label="Whisper model", scale=3, value="small", choices=["base", "small", "medium", "large"])
1073
  align_model_choice = gr.Radio(label="Forced alignment model", scale=2, value="MMS", choices=["whisperX", "MMS"], visible=False)
1074
 
1075
  with gr.Row():
@@ -1174,7 +1141,7 @@ def get_app():
1174
  with gr.Row():
1175
  nfe_step = gr.Number(
1176
  label="NFE Step",
1177
- value=32,
1178
  precision=0,
1179
  info="Number of function evaluations (sampling steps).",
1180
  )
 
68
  return "cuda" if torch.cuda.is_available() else "cpu"
69
 
70
  device = _pick_device()
71
+
 
 
 
 
 
 
 
 
72
  whisper_model, align_model = None, None
73
  tts_edit_model = None
74
 
 
131
  )
132
 
133
  uvr5_model = Inference(model_data, device)
134
+ uvr5_model.load_model(model_path, 1)
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
  self.model = uvr5_model
137
  self.device = device
138
  return self.model
139
 
140
  def denoise(self, audio_info):
141
+ model = self.load_model(device="cpu")
 
 
 
142
  input_audio = load_wav(audio_info, sr=44100, channel=2)
143
+ output_audio = model.demix_base({0:input_audio.squeeze()}, is_match_mix=False, device="cpu")
144
  # transform = torchaudio.transforms.Resample(44100, 16000)
145
  # output_audio = transform(output_audio)
146
  return output_audio.squeeze().T.cpu().numpy(), 44100
 
427
  class WhisperxModel:
428
  def __init__(self, model_name):
429
  # Lazily construct the WhisperX pipeline so that on Spaces we only
430
+ # touch CUDA inside spaces.GPU workers.
431
  self.model_name = model_name
432
+ self.model = None
433
+ if IS_SPACES and torch.cuda.is_available():
434
+ self.device = "cuda"
435
+ else:
436
+ self.device = "cpu"
437
 
438
  def _ensure_model(self):
439
  if self.model is not None:
 
442
 
443
  prompt = None # "This might be a blend of Simplified Chinese and English speech, do not translate, only transcription be allowed."
444
 
 
 
 
 
 
 
 
 
445
  # Use the lighter Silero VAD backend to avoid pyannote checkpoints
446
  # and their PyTorch 2.6 `weights_only` pickling issues.
447
  self.model = load_model(
448
  self.model_name,
449
+ self.device,
450
  compute_type="float32",
451
  asr_options={
452
  "suppress_numerals": False,
 
673
  "word_bounds": [f"{word['start']} {word['word']} {word['end']}" for word in segments["words"]]
674
  }
675
 
676
+ @spaces.GPU
677
  @torch.no_grad()
678
  @torch.inference_mode()
679
  def transcribe(seed, audio_info):
 
692
  state
693
  ]
694
 
 
 
 
695
  def align(transcript, audio_info, state):
696
  lang = state["segments"]["lang"]
697
  # print("realign: ", transcript, state)
 
717
  ]
718
 
719
 
 
 
 
720
  def denoise(audio_info):
721
  # Denoiser can be relatively heavy (especially UVR5), so schedule it on
722
  # GPU workers when running on HF Spaces.
 
736
  print("save result:", result.shape)
737
  # wavfile.write(os.path.join(TMP_PATH, "output.wav"), sr, result)
738
  return (int(sr), result)
739
+
740
 
741
  def get_edit_audio_part(audio_info, edit_start, edit_end):
742
  sr, raw_wav = audio_info
 
763
  return re.sub(r'\b\d+\b', replace_with_words, sentence) # Regular expression that matches numbers
764
 
765
 
766
+ @spaces.GPU
767
  @torch.no_grad()
768
  @torch.inference_mode()
769
  def run(seed, nfe_step, speed, cfg_strength, sway_sampling_coef, ref_ratio,
 
1036
  )
1037
  denoise_model_choice = gr.Radio(label="Denoise Model", scale=2, value="UVR5", choices=["UVR5", "DeepFilterNet"]) # "830M", "330M_TTSEnhanced", "830M_TTSEnhanced"])
1038
  # whisper_backend_choice = gr.Radio(label="Whisper backend", value="", choices=["whisperX", "whisper"])
1039
+ whisper_model_choice = gr.Radio(label="Whisper model", scale=3, value="medium", choices=["base", "small", "medium", "large"])
1040
  align_model_choice = gr.Radio(label="Forced alignment model", scale=2, value="MMS", choices=["whisperX", "MMS"], visible=False)
1041
 
1042
  with gr.Row():
 
1141
  with gr.Row():
1142
  nfe_step = gr.Number(
1143
  label="NFE Step",
1144
+ value=64,
1145
  precision=0,
1146
  info="Number of function evaluations (sampling steps).",
1147
  )