Approximetal commited on
Commit
db5f9bf
·
verified ·
1 Parent(s): 484e4a0

Update inference_gradio.py

Browse files
Files changed (1) hide show
  1. inference_gradio.py +51 -95
inference_gradio.py CHANGED
@@ -16,29 +16,15 @@ from cached_path import cached_path
16
 
17
  from lemas_tts.api import TTS, PRETRAINED_ROOT, CKPTS_ROOT
18
 
19
- # Global variables
20
- tts_api = None
21
- last_checkpoint = ""
22
- last_device = ""
23
- last_ema = None
24
-
25
- # Detect whether we are running inside a HF Space with stateless GPU.
26
- IS_SPACES = os.getenv("SYSTEM") == "spaces"
27
-
28
- # Device detection
29
- if IS_SPACES:
30
- # On Spaces main process we must not initialize CUDA; keep TTS on CPU.
31
- device = "cpu"
32
- else:
33
- device = (
34
- "cuda"
35
- if torch.cuda.is_available()
36
- else "xpu"
37
- if torch.xpu.is_available()
38
- else "mps"
39
- if torch.backends.mps.is_available()
40
- else "cpu"
41
- )
42
 
43
  REPO_ROOT = Path(__file__).resolve().parent
44
 
@@ -72,7 +58,7 @@ class UVR5:
72
  sys.path.append(self.code_dir)
73
 
74
  # Reuse an already-loaded model if it matches the requested device.
75
- if self.model is not None and self.device == device:
76
  return self.model
77
 
78
  from multiprocess_cuda_infer import ModelData, Inference
@@ -85,42 +71,25 @@ class UVR5:
85
  model_path=model_path,
86
  audio_path=self.model_dir,
87
  result_path=self.model_dir,
88
- device=device,
89
  process_method="MDX-Net",
90
  # keep base_dir and model_dir the same (paths under `pretrained_models`)
91
  base_dir=self.model_dir,
92
  **configs,
93
  )
94
 
95
- uvr5_model = Inference(model_data, device)
96
- # On HF Spaces with stateless GPU, we must not initialize CUDA in the
97
- # main process. When running there and staying on CPU, temporarily
98
- # spoof torch.cuda.is_available() so UVR5 never touches CUDA APIs.
99
- if IS_SPACES and device == "cpu":
100
- orig_is_available = _torch.cuda.is_available
101
- _torch.cuda.is_available = lambda: False
102
- try:
103
- uvr5_model.load_model(model_path, 1)
104
- finally:
105
- _torch.cuda.is_available = orig_is_available
106
- else:
107
- uvr5_model.load_model(model_path, 1)
108
-
109
- self.model = uvr5_model
110
- self.device = device
111
  return self.model
112
 
113
  def denoise(self, audio_info):
114
  print("denoise UVR5: ", audio_info)
115
- # On Spaces, force CPU; locally prefer CUDA if available.
116
- if IS_SPACES:
117
- dev = "cpu"
118
- else:
119
- dev = "cuda" if torch.cuda.is_available() else "cpu"
120
- model = self.load_model(device=dev)
121
-
122
  input_audio = load_wav(audio_info, sr=44100, channel=2)
123
- output_audio = model.demix_base({0: input_audio.squeeze()}, is_match_mix=False, device=dev)
124
  return output_audio.squeeze().T.cpu().numpy(), 44100
125
 
126
 
@@ -193,7 +162,6 @@ def get_checkpoints_project(project_name=None, is_gradio=True):
193
  for f in files_checkpoints
194
  if "pretrained_" not in os.path.basename(f) and "model_last.pt" not in os.path.basename(f)
195
  ]
196
- last_checkpoint = [f for f in files_checkpoints if "model_last.pt" in os.path.basename(f)]
197
 
198
  # Sort regular checkpoints by number
199
  try:
@@ -204,7 +172,7 @@ def get_checkpoints_project(project_name=None, is_gradio=True):
204
  regular_checkpoints = sorted(regular_checkpoints)
205
 
206
  # Combine in order: pretrained, regular, last
207
- files_checkpoints = pretrained_checkpoints + regular_checkpoints + last_checkpoint
208
 
209
  select_checkpoint = None if not files_checkpoints else files_checkpoints[-1]
210
 
@@ -235,13 +203,13 @@ def get_available_projects():
235
  print("project_list:", project_list)
236
  return project_list
237
 
238
- @spaces.GPU(duration=240)
239
  @torch.no_grad()
240
  @torch.inference_mode()
241
  def infer(
242
  project, file_checkpoint, exp_name, ref_text, ref_audio, denoise_audio, gen_text, nfe_step, use_ema, separate_langs, frontend, speed, cfg_strength, use_acc_grl, ref_ratio, no_ref_audio, sway_sampling_coef, use_prosody_encoder, seed
243
  ):
244
- global last_checkpoint, last_device, tts_api, last_ema
245
 
246
  # Resolve checkpoint path (local or HF URL)
247
  ckpt_path = file_checkpoint
@@ -260,52 +228,40 @@ def infer(
260
  if denoise_audio:
261
  ref_audio = denoise_audio
262
 
263
- device_test = device # Use the global device
264
-
265
- if last_checkpoint != ckpt_resolved or last_device != device_test or last_ema != use_ema or tts_api is None:
266
- if last_checkpoint != ckpt_resolved:
267
- last_checkpoint = ckpt_resolved
268
-
269
- if last_device != device_test:
270
- last_device = device_test
271
-
272
- if last_ema != use_ema:
273
- last_ema = use_ema
274
-
275
- # Automatically enable prosody encoder when using the prosody checkpoint
276
- use_prosody_encoder = True if "prosody" in str(ckpt_resolved) else False
277
 
278
- # Resolve vocab file (local)
279
- local_vocab = Path(PRETRAINED_ROOT) / "data" / project / "vocab.txt"
280
- if not local_vocab.is_file():
281
- return None, "Vocab file not found!", ""
282
- vocab_file = str(local_vocab)
283
 
284
- # Resolve prosody encoder config & weights (local)
285
- local_prosody_cfg = Path(CKPTS_ROOT) / "prosody_encoder" / "pretssel_cfg.json"
286
- local_prosody_ckpt = Path(CKPTS_ROOT) / "prosody_encoder" / "prosody_encoder_UnitY2.pt"
287
- if not local_prosody_cfg.is_file() or not local_prosody_ckpt.is_file():
288
- return None, "Prosody encoder files not found!", ""
289
- prosody_cfg_path = str(local_prosody_cfg)
290
- prosody_ckpt_path = str(local_prosody_ckpt)
291
 
292
- try:
293
- tts_api = TTS(
294
- model=exp_name,
295
- ckpt_file=ckpt_resolved,
296
- vocab_file=vocab_file,
297
- device=device_test,
298
- use_ema=use_ema,
299
- frontend=frontend,
300
- use_prosody_encoder=use_prosody_encoder,
301
- prosody_cfg_path=prosody_cfg_path,
302
- prosody_ckpt_path=prosody_ckpt_path,
303
- )
304
- except Exception as e:
305
- traceback.print_exc()
306
- return None, f"Error loading model: {str(e)}", ""
307
 
308
- print("Model loaded >>", device_test, file_checkpoint, use_ema)
309
 
310
  if seed == -1: # -1 used for random
311
  seed = None
 
16
 
17
  from lemas_tts.api import TTS, PRETRAINED_ROOT, CKPTS_ROOT
18
 
19
+ device = (
20
+ "cuda"
21
+ if torch.cuda.is_available()
22
+ else "xpu"
23
+ if torch.xpu.is_available()
24
+ else "mps"
25
+ if torch.backends.mps.is_available()
26
+ else "cpu"
27
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  REPO_ROOT = Path(__file__).resolve().parent
30
 
 
58
  sys.path.append(self.code_dir)
59
 
60
  # Reuse an already-loaded model if it matches the requested device.
61
+ if self.model is not None:
62
  return self.model
63
 
64
  from multiprocess_cuda_infer import ModelData, Inference
 
71
  model_path=model_path,
72
  audio_path=self.model_dir,
73
  result_path=self.model_dir,
74
+ device="cpu",
75
  process_method="MDX-Net",
76
  # keep base_dir and model_dir the same (paths under `pretrained_models`)
77
  base_dir=self.model_dir,
78
  **configs,
79
  )
80
 
81
+ uvr5_model = Inference(model_data, "cpu")
82
+ uvr5_model.load_model(model_path, 1)
83
+
84
+ self.model = uvr5_model.load_model(device="cpu")
85
+ self.device = "cpu"
 
 
 
 
 
 
 
 
 
 
 
86
  return self.model
87
 
88
  def denoise(self, audio_info):
89
  print("denoise UVR5: ", audio_info)
90
+ # # On Spaces, force CPU; locally prefer CUDA if available.
 
 
 
 
 
 
91
  input_audio = load_wav(audio_info, sr=44100, channel=2)
92
+ output_audio = self.model.demix_base({0: input_audio.squeeze()}, is_match_mix=False, device="cpu")
93
  return output_audio.squeeze().T.cpu().numpy(), 44100
94
 
95
 
 
162
  for f in files_checkpoints
163
  if "pretrained_" not in os.path.basename(f) and "model_last.pt" not in os.path.basename(f)
164
  ]
 
165
 
166
  # Sort regular checkpoints by number
167
  try:
 
172
  regular_checkpoints = sorted(regular_checkpoints)
173
 
174
  # Combine in order: pretrained, regular, last
175
+ files_checkpoints = pretrained_checkpoints + regular_checkpoints
176
 
177
  select_checkpoint = None if not files_checkpoints else files_checkpoints[-1]
178
 
 
203
  print("project_list:", project_list)
204
  return project_list
205
 
206
+ @spaces.GPU
207
  @torch.no_grad()
208
  @torch.inference_mode()
209
  def infer(
210
  project, file_checkpoint, exp_name, ref_text, ref_audio, denoise_audio, gen_text, nfe_step, use_ema, separate_langs, frontend, speed, cfg_strength, use_acc_grl, ref_ratio, no_ref_audio, sway_sampling_coef, use_prosody_encoder, seed
211
  ):
212
+ global tts_api, last_ema
213
 
214
  # Resolve checkpoint path (local or HF URL)
215
  ckpt_path = file_checkpoint
 
228
  if denoise_audio:
229
  ref_audio = denoise_audio
230
 
231
+ # Automatically enable prosody encoder when using the prosody checkpoint
232
+ use_prosody_encoder = True if "prosody" in str(ckpt_resolved) else False
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
+ # Resolve vocab file (local)
235
+ local_vocab = Path(PRETRAINED_ROOT) / "data" / project / "vocab.txt"
236
+ if not local_vocab.is_file():
237
+ return None, "Vocab file not found!", ""
238
+ vocab_file = str(local_vocab)
239
 
240
+ # Resolve prosody encoder config & weights (local)
241
+ local_prosody_cfg = Path(CKPTS_ROOT) / "prosody_encoder" / "pretssel_cfg.json"
242
+ local_prosody_ckpt = Path(CKPTS_ROOT) / "prosody_encoder" / "prosody_encoder_UnitY2.pt"
243
+ if not local_prosody_cfg.is_file() or not local_prosody_ckpt.is_file():
244
+ return None, "Prosody encoder files not found!", ""
245
+ prosody_cfg_path = str(local_prosody_cfg)
246
+ prosody_ckpt_path = str(local_prosody_ckpt)
247
 
248
+ try:
249
+ tts_api = TTS(
250
+ model=exp_name,
251
+ ckpt_file=ckpt_resolved,
252
+ vocab_file=vocab_file,
253
+ device="cuda",
254
+ use_ema=use_ema,
255
+ frontend=frontend,
256
+ use_prosody_encoder=use_prosody_encoder,
257
+ prosody_cfg_path=prosody_cfg_path,
258
+ prosody_ckpt_path=prosody_ckpt_path,
259
+ )
260
+ except Exception as e:
261
+ traceback.print_exc()
262
+ return None, f"Error loading model: {str(e)}", ""
263
 
264
+ print("Model loaded >>", file_checkpoint, use_ema)
265
 
266
  if seed == -1: # -1 used for random
267
  seed = None