Nekochu commited on
Commit
a4a86a8
·
1 Parent(s): a0e1f4c

run training as detached subprocess to survive Gradio session timeout

Browse files
Files changed (1) hide show
  1. app.py +131 -143
app.py CHANGED
@@ -285,15 +285,21 @@ def gradio_main():
285
  lines.append(json.dumps(props, indent=2))
286
  return "\n".join(lines)
287
 
288
- # -- Training --
 
 
289
  def train_lora(audio_files, lora_name, epochs, lr, rank,
290
  progress=gr.Progress(track_tqdm=True)):
291
- import shutil
292
- import gc
293
 
294
  if not audio_files:
295
  return "No audio files uploaded."
296
 
 
 
 
 
 
297
  lora_name = (lora_name or "").strip() or "my-lora"
298
  epochs = max(1, min(int(epochs), 10))
299
  lr = float(lr)
@@ -301,152 +307,126 @@ def gradio_main():
301
 
302
  output_dir = os.path.join(ADAPTER_DIR, lora_name)
303
  os.makedirs(output_dir, exist_ok=True)
304
-
305
  audio_dir = os.path.join(output_dir, "audio_input")
306
  os.makedirs(audio_dir, exist_ok=True)
307
  for f in audio_files:
308
  src = f.name if hasattr(f, "name") else str(f)
309
  shutil.copy2(src, os.path.join(audio_dir, os.path.basename(src)))
310
 
311
- log_lines = [
312
- f"LoRA Training: '{lora_name}'",
313
- f"Audio files: {len(audio_files)}",
314
- f"Epochs: {epochs}, LR: {lr}, Rank: {rank}",
315
- f"Output: {output_dir}",
316
- "",
317
- ]
318
-
319
- def _log(msg):
320
- log_lines.append(msg)
321
- print(f"[train] {msg}", flush=True)
322
-
323
- try:
324
- import subprocess, signal
325
- _log("Stopping ace-server to free RAM for training...")
326
- subprocess.run(["pkill", "-f", "ace-server"], stderr=subprocess.DEVNULL)
327
- time.sleep(2)
328
- gc.collect()
329
-
330
- ckpt_files = os.listdir(ACE_CHECKPOINT_DIR) if os.path.isdir(ACE_CHECKPOINT_DIR) else []
331
- if len(ckpt_files) < 3:
332
- _log("[Step 0] Downloading model checkpoints...")
333
- progress(0.02, desc="Downloading checkpoints...")
334
- from huggingface_hub import snapshot_download
335
- snapshot_download(
336
- ACE_HF_MODEL,
337
- local_dir=ACE_CHECKPOINT_DIR,
338
- ignore_patterns=["*.md", "*.txt", ".gitattributes"],
339
- )
340
- _log(" Checkpoints downloaded.")
341
-
342
- if ACE_SOURCE_DIR not in sys.path:
343
- sys.path.insert(0, ACE_SOURCE_DIR)
344
-
345
- import torchaudio
346
- _orig_load = torchaudio.load
347
- def _load_soundfile(filepath, *args, **kwargs):
348
- kwargs.setdefault('backend', 'soundfile')
349
- return _orig_load(filepath, *args, **kwargs)
350
- torchaudio.load = _load_soundfile
351
-
352
- _log("[Step 1/2] Preprocessing audio files...")
353
- progress(0.10, desc="Preprocessing audio...")
354
-
355
- tensor_dir = os.path.join(output_dir, "preprocessed_tensors")
356
- os.makedirs(tensor_dir, exist_ok=True)
357
-
358
- from acestep.training_v2.preprocess import preprocess_audio_files
359
- result = preprocess_audio_files(
360
- audio_dir=audio_dir,
361
- output_dir=tensor_dir,
362
- checkpoint_dir=ACE_CHECKPOINT_DIR,
363
- variant="turbo",
364
- max_duration=60.0,
365
- device="cpu",
366
- precision="bfloat16",
367
- )
368
-
369
- processed = result.get("processed", 0)
370
- total_files = result.get("total", 0)
371
- failed = result.get("failed", 0)
372
- _log(f" Preprocessed: {processed}/{total_files} (failed: {failed})")
373
-
374
- if processed == 0:
375
- _log("ERROR: No files preprocessed successfully.")
376
- return "\n".join(log_lines)
377
-
378
- _log("[Step 2/2] Training LoRA adapter (CPU, this will be slow)...")
379
- progress(0.30, desc="Loading model for training...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
380
 
381
- from acestep.training_v2.model_loader import load_decoder_for_training
382
- from acestep.training_v2.trainer_fixed import FixedLoRATrainer
383
- from acestep.training_v2.configs import TrainingConfigV2, LoRAConfigV2
 
384
 
385
- model = load_decoder_for_training(
386
- checkpoint_dir=ACE_CHECKPOINT_DIR,
387
- variant="turbo",
388
- device="cpu",
389
- precision="bfloat16",
390
- )
391
- model = model.bfloat16()
392
-
393
- adapter_cfg = LoRAConfigV2(r=rank, alpha=rank, dropout=0.0)
394
- train_cfg = TrainingConfigV2(
395
- checkpoint_dir=ACE_CHECKPOINT_DIR,
396
- model_variant="turbo",
397
- dataset_dir=tensor_dir,
398
- output_dir=output_dir,
399
- max_epochs=epochs,
400
- batch_size=1,
401
- learning_rate=lr,
402
- device="cpu",
403
- precision="bfloat16",
404
- seed=42,
405
- num_workers=0,
406
- pin_memory=False,
407
- )
408
-
409
- trainer = FixedLoRATrainer(model, adapter_cfg, train_cfg)
410
-
411
- step_count = 0
412
- last_loss = 0.0
413
- for update in trainer.train():
414
- if hasattr(update, "step"):
415
- step_count = update.step
416
- last_loss = update.loss
417
- elif isinstance(update, tuple) and len(update) >= 2:
418
- step_count = update[0]
419
- last_loss = update[1]
420
- if step_count % 5 == 0:
421
- log_lines.append(f" Step {step_count}: loss={last_loss:.4f}")
422
- pct = 0.30 + 0.65 * min(step_count / max(epochs * processed, 1), 1.0)
423
- progress(pct, desc=f"Step {step_count}, loss={last_loss:.4f}")
424
-
425
- _log(f"Training complete! Final: step {step_count}, loss={last_loss:.4f}")
426
- _log(f"LoRA saved to: {output_dir}")
427
-
428
- del model, trainer
429
- gc.collect()
430
-
431
- except ImportError as e:
432
- _log(f"Import error: {e}")
433
- _log(f"Check ACE-Step source at {ACE_SOURCE_DIR}")
434
- import traceback
435
- log_lines.append(traceback.format_exc())
436
- except Exception as e:
437
- import traceback
438
- _log(f"ERROR: {e}")
439
- log_lines.append(traceback.format_exc())
440
- finally:
441
- _log("Restarting ace-server...")
442
- import subprocess
443
- subprocess.Popen([
444
- "/app/ace-server", "--host", "127.0.0.1", "--port", "8085",
445
- "--models", "/app/models", "--adapters", "/app/adapters",
446
- "--max-batch", "1",
447
- ], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
448
-
449
- return "\n".join(log_lines)
450
 
451
  # -- Build UI --
452
  CSS = """
@@ -548,11 +528,13 @@ def gradio_main():
548
  lr = gr.Number(label="Learning Rate", value=1e-4)
549
  rank = gr.Number(label="Rank (r)", value=16, minimum=1, maximum=64)
550
 
551
- train_btn = gr.Button("Train", variant="primary")
 
 
552
  train_log = gr.Textbox(
553
  label="Training Log",
554
  interactive=False,
555
- lines=10,
556
  elem_classes="status-box",
557
  )
558
 
@@ -562,6 +544,12 @@ def gradio_main():
562
  outputs=[train_log],
563
  api_name="train_lora",
564
  )
 
 
 
 
 
 
565
 
566
  demo.launch(
567
  server_name="0.0.0.0",
 
285
  lines.append(json.dumps(props, indent=2))
286
  return "\n".join(lines)
287
 
288
+ # -- Training (runs as detached subprocess to survive Gradio session timeout) --
289
+ TRAIN_LOG = "/app/outputs/train.log"
290
+
291
  def train_lora(audio_files, lora_name, epochs, lr, rank,
292
  progress=gr.Progress(track_tqdm=True)):
293
+ import shutil, subprocess
 
294
 
295
  if not audio_files:
296
  return "No audio files uploaded."
297
 
298
+ if os.path.exists(TRAIN_LOG):
299
+ last_line = open(TRAIN_LOG).readlines()[-1] if os.path.getsize(TRAIN_LOG) > 0 else ""
300
+ if "DONE" not in last_line and "ERROR" not in last_line and last_line.strip():
301
+ return f"Training already in progress. Click 'Check Log' to monitor.\n\nLast: {last_line.strip()}"
302
+
303
  lora_name = (lora_name or "").strip() or "my-lora"
304
  epochs = max(1, min(int(epochs), 10))
305
  lr = float(lr)
 
307
 
308
  output_dir = os.path.join(ADAPTER_DIR, lora_name)
309
  os.makedirs(output_dir, exist_ok=True)
 
310
  audio_dir = os.path.join(output_dir, "audio_input")
311
  os.makedirs(audio_dir, exist_ok=True)
312
  for f in audio_files:
313
  src = f.name if hasattr(f, "name") else str(f)
314
  shutil.copy2(src, os.path.join(audio_dir, os.path.basename(src)))
315
 
316
+ train_script = f"""
317
+ import os, sys, time, gc
318
+ sys.path.insert(0, "{ACE_SOURCE_DIR}")
319
+ os.environ["TORCHAUDIO_USE_BACKEND_DISPATCHER"] = "1"
320
+
321
+ LOG = "{TRAIN_LOG}"
322
+ def log(msg):
323
+ print(f"[train] {{msg}}", flush=True)
324
+ with open(LOG, "a") as f:
325
+ f.write(msg + "\\n")
326
+ f.flush()
327
+
328
+ open(LOG, "w").close()
329
+ log("LoRA Training: '{lora_name}' | files={len(audio_files)} | epochs={epochs} lr={lr} rank={rank}")
330
+
331
+ import subprocess
332
+ log("Stopping ace-server...")
333
+ subprocess.run(["pkill", "-f", "ace-server"], stderr=subprocess.DEVNULL)
334
+ time.sleep(2)
335
+ gc.collect()
336
+
337
+ try:
338
+ import torchaudio
339
+ _orig = torchaudio.load
340
+ def _sf(p, *a, **kw):
341
+ kw.setdefault("backend", "soundfile")
342
+ return _orig(p, *a, **kw)
343
+ torchaudio.load = _sf
344
+
345
+ log("[Step 1/2] Preprocessing audio...")
346
+ from acestep.training_v2.preprocess import preprocess_audio_files
347
+ result = preprocess_audio_files(
348
+ audio_dir="{audio_dir}",
349
+ output_dir="{output_dir}/preprocessed_tensors",
350
+ checkpoint_dir="{ACE_CHECKPOINT_DIR}",
351
+ variant="turbo", max_duration=60.0,
352
+ device="cpu", precision="bfloat16",
353
+ )
354
+ processed = result.get("processed", 0)
355
+ failed = result.get("failed", 0)
356
+ log(f" Preprocessed: {{processed}}/{{result.get('total',0)}} (failed: {{failed}})")
357
+ if processed == 0:
358
+ log("ERROR: No files preprocessed. DONE")
359
+ raise SystemExit(1)
360
+
361
+ gc.collect()
362
+ log("[Step 2/2] Training LoRA...")
363
+ from acestep.training_v2.model_loader import load_decoder_for_training
364
+ from acestep.training_v2.trainer_fixed import FixedLoRATrainer
365
+ from acestep.training_v2.configs import TrainingConfigV2, LoRAConfigV2
366
+
367
+ model = load_decoder_for_training(
368
+ checkpoint_dir="{ACE_CHECKPOINT_DIR}", variant="turbo",
369
+ device="cpu", precision="bfloat16",
370
+ ).bfloat16()
371
+
372
+ trainer = FixedLoRATrainer(model,
373
+ LoRAConfigV2(r={rank}, alpha={rank}, dropout=0.0),
374
+ TrainingConfigV2(
375
+ checkpoint_dir="{ACE_CHECKPOINT_DIR}", model_variant="turbo",
376
+ dataset_dir="{output_dir}/preprocessed_tensors",
377
+ output_dir="{output_dir}",
378
+ max_epochs={epochs}, batch_size=1, learning_rate={lr},
379
+ device="cpu", precision="bfloat16", seed=42,
380
+ num_workers=0, pin_memory=False,
381
+ ))
382
+
383
+ step_count, last_loss = 0, 0.0
384
+ for update in trainer.train():
385
+ if hasattr(update, "step"):
386
+ step_count, last_loss = update.step, update.loss
387
+ elif isinstance(update, tuple) and len(update) >= 2:
388
+ step_count, last_loss = update[0], update[1]
389
+ if step_count % 5 == 0:
390
+ log(f" Step {{step_count}}: loss={{last_loss:.4f}}")
391
+
392
+ log(f"Training complete! step={{step_count}} loss={{last_loss:.4f}}")
393
+ log(f"LoRA saved to: {output_dir}")
394
+ del model, trainer
395
+ gc.collect()
396
+ log("DONE")
397
+
398
+ except Exception as e:
399
+ import traceback
400
+ log(f"ERROR: {{e}}")
401
+ log(traceback.format_exc())
402
+ log("DONE")
403
+ finally:
404
+ log("Restarting ace-server...")
405
+ subprocess.Popen(["/app/ace-server", "--host", "127.0.0.1", "--port", "8085",
406
+ "--models", "/app/models", "--adapters", "/app/adapters", "--max-batch", "1"],
407
+ stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
408
+ """
409
+ script_path = os.path.join(output_dir, "_train.py")
410
+ with open(script_path, "w") as f:
411
+ f.write(train_script)
412
+
413
+ subprocess.Popen(
414
+ ["python3", script_path],
415
+ stdout=open("/dev/null", "w"),
416
+ stderr=open("/dev/null", "w"),
417
+ start_new_session=True,
418
+ )
419
 
420
+ return (f"Training started in background for '{lora_name}'.\n"
421
+ f"Audio: {len(audio_files)} files, Epochs: {epochs}, Rank: {rank}\n\n"
422
+ f"Click 'Check Log' to monitor progress.\n"
423
+ f"Inference will be unavailable until training completes (ace-server stopped).")
424
 
425
+ def check_train_log():
426
+ if not os.path.exists(TRAIN_LOG):
427
+ return "No training log found."
428
+ with open(TRAIN_LOG) as f:
429
+ return f.read() or "Log is empty."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
  # -- Build UI --
432
  CSS = """
 
528
  lr = gr.Number(label="Learning Rate", value=1e-4)
529
  rank = gr.Number(label="Rank (r)", value=16, minimum=1, maximum=64)
530
 
531
+ with gr.Row(elem_classes="compact-row"):
532
+ train_btn = gr.Button("Train", variant="primary", scale=2)
533
+ log_btn = gr.Button("Check Log", scale=1)
534
  train_log = gr.Textbox(
535
  label="Training Log",
536
  interactive=False,
537
+ lines=12,
538
  elem_classes="status-box",
539
  )
540
 
 
544
  outputs=[train_log],
545
  api_name="train_lora",
546
  )
547
+ log_btn.click(
548
+ fn=check_train_log,
549
+ inputs=[],
550
+ outputs=[train_log],
551
+ api_name="check_train_log",
552
+ )
553
 
554
  demo.launch(
555
  server_name="0.0.0.0",