Nekochu commited on
Commit
829ed0c
·
1 Parent(s): a5741b1

fix all review issues: dedup sampling/unwrap, thread-safe lock, cleanup, retry, security docs

Browse files
Files changed (2) hide show
  1. app.py +63 -22
  2. train_engine.py +42 -72
app.py CHANGED
@@ -13,12 +13,14 @@ import string
13
  import random
14
  import requests
15
  import logging
 
16
 
17
  from train_engine import (
18
  preprocess_audio,
19
  train_lora_generator,
20
  cancel_training,
21
  get_trained_loras as _get_trained_loras_engine,
 
22
  )
23
 
24
  logger = logging.getLogger(__name__)
@@ -28,7 +30,7 @@ logger = logging.getLogger(__name__)
28
  # ---------------------------------------------------------------------------
29
 
30
  MAX_TOTAL_AUDIO = 1800 # seconds total across all uploaded files (30 min)
31
- MAX_TRAINING_TIME = 28800 # 8 hours hard training timeout (seconds)
32
  MAX_AUDIO_FILES = 50 # max number of training audio files per run
33
 
34
  # ---------------------------------------------------------------------------
@@ -39,6 +41,21 @@ ACE_SERVER = os.environ.get("ACE_SERVER", "http://127.0.0.1:8085")
39
  OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs")
40
  os.makedirs(OUTPUT_DIR, exist_ok=True)
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  ACE_CHECKPOINT_DIR = os.environ.get("ACE_CHECKPOINT_DIR", "/app/checkpoints")
43
  ACE_SOURCE_DIR = "/app/ace-step-source"
44
  ACE_HF_MODEL = "ACE-Step/Ace-Step1.5"
@@ -49,7 +66,7 @@ ACE_SERVER_BIN = "/app/ace-server"
49
 
50
  # Detect if running on HF Space (ace-server available) vs locally (PyTorch only)
51
  _is_space = os.path.isfile(ACE_SERVER_BIN) or os.environ.get("SPACE_ID") is not None
52
- _training_in_progress = False
53
 
54
  # HF repo for on-demand GGUF downloads
55
  GGUF_HF_REPO = "Serveurperso/ACE-Step-1.5-GGUF"
@@ -329,25 +346,47 @@ def _stop_ace_server():
329
  time.sleep(1)
330
 
331
 
332
- def _start_ace_server():
333
- """Start ace-server in background and wait for health."""
 
 
 
334
  global _ace_proc
335
- logger.info("[ace-server] Starting with --adapters %s", ADAPTER_DIR)
336
- try:
337
- _ace_proc = subprocess.Popen(
338
- [ACE_SERVER_BIN, "--host", "127.0.0.1", "--port", "8085",
339
- "--models", MODELS_DIR, "--adapters", ADAPTER_DIR, "--max-batch", "1"],
340
  )
341
- except Exception as exc:
342
- logger.error("[ace-server] Failed to start: %s", exc)
343
- return False
 
 
 
 
 
 
 
 
344
 
345
- for _ in range(30):
346
- if _server_ok():
347
- logger.info("[ace-server] Healthy")
348
- return True
349
- time.sleep(2)
350
- logger.error("[ace-server] Health check timeout")
 
 
 
 
 
 
 
 
 
 
 
 
351
  return False
352
 
353
 
@@ -449,8 +488,9 @@ def gradio_main():
449
  def generate_music(caption, lyrics, instrumental, bpm, duration, seed,
450
  steps, lora_select, lm_model_select,
451
  progress=gr.Progress(track_tqdm=True)):
452
- if _training_in_progress:
453
  return None, "Training in progress. Inference unavailable until training completes. Press Cancel to stop training."
 
454
  if not _server_ok():
455
  return None, "ace-server not running. Check logs."
456
 
@@ -631,8 +671,7 @@ def gradio_main():
631
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
632
 
633
  # Stop ace-server before training (frees memory)
634
- global _training_in_progress
635
- _training_in_progress = True
636
  _log("[INFO] Stopping ace-server for training...")
637
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
638
  _stop_ace_server()
@@ -720,7 +759,7 @@ def gradio_main():
720
  yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
721
 
722
  finally:
723
- _training_in_progress = False
724
  # Always restart ace-server
725
  _log("[INFO] Restarting ace-server...")
726
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
@@ -746,6 +785,8 @@ def gradio_main():
746
  yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File(value=tmp_out.name, visible=True)
747
  else:
748
  yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
 
 
749
 
750
  # -- Cancel handler --
751
  def _on_cancel():
 
13
  import random
14
  import requests
15
  import logging
16
+ import threading
17
 
18
  from train_engine import (
19
  preprocess_audio,
20
  train_lora_generator,
21
  cancel_training,
22
  get_trained_loras as _get_trained_loras_engine,
23
+ MAX_TRAINING_TIME,
24
  )
25
 
26
  logger = logging.getLogger(__name__)
 
30
  # ---------------------------------------------------------------------------
31
 
32
  MAX_TOTAL_AUDIO = 1800 # seconds total across all uploaded files (30 min)
33
+ # MAX_TRAINING_TIME is imported from train_engine (single source of truth)
34
  MAX_AUDIO_FILES = 50 # max number of training audio files per run
35
 
36
  # ---------------------------------------------------------------------------
 
41
  OUTPUT_DIR = os.environ.get("ACE_OUTPUT_DIR", "/app/outputs")
42
  os.makedirs(OUTPUT_DIR, exist_ok=True)
43
 
44
+ # Clean up old inference temp files (older than 1 hour) at startup
45
+ _CLEANUP_MAX_AGE = 3600 # seconds
46
+ try:
47
+ _now = time.time()
48
+ for _fname in os.listdir(OUTPUT_DIR):
49
+ if _fname.lower().endswith((".wav", ".mp3")):
50
+ _fpath = os.path.join(OUTPUT_DIR, _fname)
51
+ try:
52
+ if os.path.isfile(_fpath) and (_now - os.path.getmtime(_fpath)) > _CLEANUP_MAX_AGE:
53
+ os.remove(_fpath)
54
+ except OSError:
55
+ pass
56
+ except Exception:
57
+ pass
58
+
59
  ACE_CHECKPOINT_DIR = os.environ.get("ACE_CHECKPOINT_DIR", "/app/checkpoints")
60
  ACE_SOURCE_DIR = "/app/ace-step-source"
61
  ACE_HF_MODEL = "ACE-Step/Ace-Step1.5"
 
66
 
67
  # Detect if running on HF Space (ace-server available) vs locally (PyTorch only)
68
  _is_space = os.path.isfile(ACE_SERVER_BIN) or os.environ.get("SPACE_ID") is not None
69
+ _training_lock = threading.Lock()
70
 
71
  # HF repo for on-demand GGUF downloads
72
  GGUF_HF_REPO = "Serveurperso/ACE-Step-1.5-GGUF"
 
346
  time.sleep(1)
347
 
348
 
349
+ def _start_ace_server(max_retries: int = 3, retry_delay: float = 5.0):
350
+ """Start ace-server in background and wait for health.
351
+
352
+ Retries up to max_retries times with retry_delay seconds between attempts.
353
+ """
354
  global _ace_proc
355
+ for attempt in range(1, max_retries + 1):
356
+ logger.info(
357
+ "[ace-server] Starting (attempt %d/%d) with --adapters %s",
358
+ attempt, max_retries, ADAPTER_DIR,
 
359
  )
360
+ try:
361
+ _ace_proc = subprocess.Popen(
362
+ [ACE_SERVER_BIN, "--host", "127.0.0.1", "--port", "8085",
363
+ "--models", MODELS_DIR, "--adapters", ADAPTER_DIR, "--max-batch", "1"],
364
+ )
365
+ except Exception as exc:
366
+ logger.error("[ace-server] Failed to start: %s", exc)
367
+ if attempt < max_retries:
368
+ time.sleep(retry_delay)
369
+ continue
370
+ return False
371
 
372
+ for _ in range(30):
373
+ if _server_ok():
374
+ logger.info("[ace-server] Healthy")
375
+ return True
376
+ time.sleep(2)
377
+
378
+ logger.warning("[ace-server] Health check timeout on attempt %d/%d", attempt, max_retries)
379
+ # Kill the failed process before retrying
380
+ if _ace_proc and _ace_proc.poll() is None:
381
+ _ace_proc.kill()
382
+ try:
383
+ _ace_proc.wait(timeout=5)
384
+ except subprocess.TimeoutExpired:
385
+ pass
386
+ if attempt < max_retries:
387
+ time.sleep(retry_delay)
388
+
389
+ logger.error("[ace-server] Failed to start after %d attempts", max_retries)
390
  return False
391
 
392
 
 
488
  def generate_music(caption, lyrics, instrumental, bpm, duration, seed,
489
  steps, lora_select, lm_model_select,
490
  progress=gr.Progress(track_tqdm=True)):
491
+ if not _training_lock.acquire(blocking=False):
492
  return None, "Training in progress. Inference unavailable until training completes. Press Cancel to stop training."
493
+ _training_lock.release()
494
  if not _server_ok():
495
  return None, "ace-server not running. Check logs."
496
 
 
671
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
672
 
673
  # Stop ace-server before training (frees memory)
674
+ _training_lock.acquire()
 
675
  _log("[INFO] Stopping ace-server for training...")
676
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
677
  _stop_ace_server()
 
759
  yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
760
 
761
  finally:
762
+ _training_lock.release()
763
  # Always restart ace-server
764
  _log("[INFO] Restarting ace-server...")
765
  yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
 
785
  yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File(value=tmp_out.name, visible=True)
786
  else:
787
  yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
788
+ # Clean up training workspace (preprocessed tensors, temp audio, etc.)
789
+ shutil.rmtree(work_dir, ignore_errors=True)
790
 
791
  # -- Cancel handler --
792
  def _on_cancel():
train_engine.py CHANGED
@@ -374,28 +374,6 @@ def enable_gradient_checkpointing(decoder: nn.Module) -> bool:
374
  return enabled
375
 
376
 
377
- def force_disable_cache(decoder: nn.Module) -> None:
378
- stack = [decoder]
379
- visited = set()
380
- while stack:
381
- mod = stack.pop()
382
- if not isinstance(mod, nn.Module):
383
- continue
384
- mid = id(mod)
385
- if mid in visited:
386
- continue
387
- visited.add(mid)
388
- cfg = getattr(mod, "config", None)
389
- if cfg is not None and hasattr(cfg, "use_cache"):
390
- try:
391
- cfg.use_cache = False
392
- except Exception:
393
- pass
394
- for a in ("_forward_module", "_orig_mod", "base_model", "model", "module"):
395
- child = getattr(mod, a, None)
396
- if isinstance(child, nn.Module):
397
- stack.append(child)
398
-
399
 
400
  # ============================================================================
401
  # LORA INJECTION (PEFT only -- no DoRA/LoKR/LoHA/OFT)
@@ -464,9 +442,7 @@ def inject_lora(model, lora_cfg: LoRAConfig) -> Tuple[Any, Dict[str, Any]]:
464
 
465
  def save_lora_adapter(model, output_dir: str) -> None:
466
  os.makedirs(output_dir, exist_ok=True)
467
- decoder = model.decoder if hasattr(model, "decoder") else model
468
- while hasattr(decoder, "_forward_module"):
469
- decoder = decoder._forward_module
470
 
471
  if hasattr(decoder, "save_pretrained"):
472
  decoder.save_pretrained(output_dir)
@@ -602,6 +578,10 @@ def load_model_for_training(
602
  for idx, attn in enumerate(candidates):
603
  try:
604
  load_kwargs = dict(
 
 
 
 
605
  trust_remote_code=True,
606
  attn_implementation=attn,
607
  torch_dtype=dtype,
@@ -2443,8 +2423,8 @@ def train_lora_generator(
2443
 
2444
  yield f"[OK] LoRA injected: {info['trainable_params']:,} trainable params"
2445
 
2446
- # Gradient checkpointing + cache disable
2447
- force_disable_cache(model.decoder)
2448
  ckpt_ok = enable_gradient_checkpointing(model.decoder)
2449
  force_input_grads = ckpt_ok
2450
  if ckpt_ok:
@@ -2511,9 +2491,7 @@ def train_lora_generator(
2511
  if aw.exists():
2512
  from safetensors.torch import load_file
2513
  state = load_file(str(aw))
2514
- decoder = model.decoder
2515
- while hasattr(decoder, "_forward_module"):
2516
- decoder = decoder._forward_module
2517
  decoder.load_state_dict(state, strict=False)
2518
 
2519
  # Load training state
@@ -3017,15 +2995,7 @@ def generate_audio(
3017
  logger.info("Loading LoRA adapter from %s (scale=%.2f)...", adapter_path, adapter_scale)
3018
  from peft import PeftModel
3019
 
3020
- decoder = model.decoder if hasattr(model, "decoder") else model
3021
- # Unwrap any wrappers
3022
- while hasattr(decoder, "_forward_module"):
3023
- decoder = decoder._forward_module
3024
- if hasattr(decoder, "base_model"):
3025
- bm = decoder.base_model
3026
- decoder = bm.model if hasattr(bm, "model") else bm
3027
- if hasattr(decoder, "model") and isinstance(decoder.model, nn.Module):
3028
- decoder = decoder.model
3029
 
3030
  model.decoder = PeftModel.from_pretrained(
3031
  decoder, adapter_path, is_trainable=False,
@@ -3174,6 +3144,37 @@ _LM_UNDERSTAND_INSTRUCTION = (
3174
  )
3175
 
3176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3177
  def _build_understand_prompt(
3178
  bpe_tokenizer, codes: List[int],
3179
  ) -> List[int]:
@@ -3357,22 +3358,7 @@ def _generate_codes_with_lm(
3357
  logits[0] = logits[0] + mask
3358
 
3359
  # Sample
3360
- if temperature <= 0:
3361
- next_id = int(logits[0].argmax().item())
3362
- else:
3363
- scaled = logits[0].clone() / temperature
3364
- if top_k > 0:
3365
- topk_vals, _ = torch.topk(scaled, min(top_k, scaled.shape[0]))
3366
- scaled[scaled < topk_vals[-1]] = float("-inf")
3367
- if top_p > 0 and top_p < 1.0:
3368
- sorted_logits, sorted_idx = torch.sort(scaled, descending=True)
3369
- probs = torch.softmax(sorted_logits, dim=-1)
3370
- cumsum = torch.cumsum(probs, dim=-1)
3371
- nucleus_mask = cumsum - probs > top_p
3372
- sorted_logits[nucleus_mask] = float("-inf")
3373
- scaled = torch.zeros_like(scaled).scatter(0, sorted_idx, sorted_logits)
3374
- probs = torch.softmax(scaled, dim=-1)
3375
- next_id = int(torch.multinomial(probs, 1).item())
3376
 
3377
  # Stop on im_end
3378
  if next_id == _TOKEN_IM_END:
@@ -3701,23 +3687,7 @@ def understand_audio(
3701
  logits[0, _AUDIO_CODE_BASE:] = float("-inf")
3702
 
3703
  # Sample
3704
- if temperature <= 0:
3705
- next_id = int(logits[0].argmax().item())
3706
- else:
3707
- scaled = logits[0].clone() / temperature
3708
- if top_k > 0:
3709
- topk_vals, _ = torch.topk(scaled, min(top_k, scaled.shape[0]))
3710
- scaled[scaled < topk_vals[-1]] = float("-inf")
3711
- if top_p > 0 and top_p < 1.0:
3712
- sorted_logits, sorted_idx = torch.sort(scaled, descending=True)
3713
- probs = torch.softmax(sorted_logits, dim=-1)
3714
- cumsum = torch.cumsum(probs, dim=-1)
3715
- mask = cumsum - probs > top_p
3716
- sorted_logits[mask] = float("-inf")
3717
- # Scatter masked values back to original positions
3718
- scaled = torch.zeros_like(scaled).scatter(0, sorted_idx, sorted_logits)
3719
- probs = torch.softmax(scaled, dim=-1)
3720
- next_id = int(torch.multinomial(probs, 1).item())
3721
 
3722
  if next_id == _TOKEN_IM_END:
3723
  break
 
374
  return enabled
375
 
376
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
377
 
378
  # ============================================================================
379
  # LORA INJECTION (PEFT only -- no DoRA/LoKR/LoHA/OFT)
 
442
 
443
  def save_lora_adapter(model, output_dir: str) -> None:
444
  os.makedirs(output_dir, exist_ok=True)
445
+ decoder = _unwrap_decoder(model)
 
 
446
 
447
  if hasattr(decoder, "save_pretrained"):
448
  decoder.save_pretrained(output_dir)
 
578
  for idx, attn in enumerate(candidates):
579
  try:
580
  load_kwargs = dict(
581
+ # SECURITY: trust_remote_code=True is required because the
582
+ # ACE-Step model config references custom Python code in its
583
+ # checkpoint (config.json -> auto_map). Only load checkpoints
584
+ # from trusted sources (the official ACE-Step HF repo).
585
  trust_remote_code=True,
586
  attn_implementation=attn,
587
  torch_dtype=dtype,
 
2423
 
2424
  yield f"[OK] LoRA injected: {info['trainable_params']:,} trainable params"
2425
 
2426
+ # Gradient checkpointing + cache disable (enable_gradient_checkpointing
2427
+ # also walks the module tree and sets use_cache=False on any config it finds)
2428
  ckpt_ok = enable_gradient_checkpointing(model.decoder)
2429
  force_input_grads = ckpt_ok
2430
  if ckpt_ok:
 
2491
  if aw.exists():
2492
  from safetensors.torch import load_file
2493
  state = load_file(str(aw))
2494
+ decoder = _unwrap_decoder(model)
 
 
2495
  decoder.load_state_dict(state, strict=False)
2496
 
2497
  # Load training state
 
2995
  logger.info("Loading LoRA adapter from %s (scale=%.2f)...", adapter_path, adapter_scale)
2996
  from peft import PeftModel
2997
 
2998
+ decoder = _unwrap_decoder(model)
 
 
 
 
 
 
 
 
2999
 
3000
  model.decoder = PeftModel.from_pretrained(
3001
  decoder, adapter_path, is_trainable=False,
 
3144
  )
3145
 
3146
 
3147
+ def _sample_next_token(
3148
+ logits: torch.Tensor, temperature: float, top_k: int, top_p: float,
3149
+ ) -> int:
3150
+ """Sample a single token from logits with temperature, top-k, and top-p.
3151
+
3152
+ Args:
3153
+ logits: 1-D logits tensor (vocab_size,).
3154
+ temperature: Sampling temperature (<=0 for argmax).
3155
+ top_k: Top-K filtering (0 = disabled).
3156
+ top_p: Nucleus sampling cutoff (0 or >=1 = disabled).
3157
+
3158
+ Returns:
3159
+ Selected token ID as int.
3160
+ """
3161
+ if temperature <= 0:
3162
+ return int(logits.argmax().item())
3163
+ scaled = logits.clone() / temperature
3164
+ if top_k > 0:
3165
+ topk_vals, _ = torch.topk(scaled, min(top_k, scaled.shape[0]))
3166
+ scaled[scaled < topk_vals[-1]] = float("-inf")
3167
+ if top_p > 0 and top_p < 1.0:
3168
+ sorted_logits, sorted_idx = torch.sort(scaled, descending=True)
3169
+ probs = torch.softmax(sorted_logits, dim=-1)
3170
+ cumsum = torch.cumsum(probs, dim=-1)
3171
+ nucleus_mask = cumsum - probs > top_p
3172
+ sorted_logits[nucleus_mask] = float("-inf")
3173
+ scaled = torch.zeros_like(scaled).scatter(0, sorted_idx, sorted_logits)
3174
+ probs = torch.softmax(scaled, dim=-1)
3175
+ return int(torch.multinomial(probs, 1).item())
3176
+
3177
+
3178
  def _build_understand_prompt(
3179
  bpe_tokenizer, codes: List[int],
3180
  ) -> List[int]:
 
3358
  logits[0] = logits[0] + mask
3359
 
3360
  # Sample
3361
+ next_id = _sample_next_token(logits[0], temperature, top_k, top_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3362
 
3363
  # Stop on im_end
3364
  if next_id == _TOKEN_IM_END:
 
3687
  logits[0, _AUDIO_CODE_BASE:] = float("-inf")
3688
 
3689
  # Sample
3690
+ next_id = _sample_next_token(logits[0], temperature, top_k, top_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3691
 
3692
  if next_id == _TOKEN_IM_END:
3693
  break