Spaces:
Running
Running
major update: PyTorch inference, Gradio 6, session isolation, /understand captioning
Browse files- generate_audio(): full PyTorch inference pipeline (no ace-server needed)
- tiled_vae_decode(): memory-bounded VAE decoding
- Gradio 6 migration (gr.update -> component constructors)
- Session isolation (random suffix on LoRA names)
- /understand captioning before training (falls back to librosa)
- Active tab detection pattern from rvc-beatrice
- Gradio API fix (tempfile for adapter download)
- dtype fix for mixed precision inference
- Dockerfile +1 -1
- app.py +133 -20
- train_engine.py +287 -0
Dockerfile
CHANGED
|
@@ -75,7 +75,7 @@ RUN curl -fL --retry 3 --retry-delay 5 -o /app/models/vae-BF16.gguf \
|
|
| 75 |
|
| 76 |
# Install Python deps for Gradio UI + training
|
| 77 |
RUN pip3 install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu \
|
| 78 |
-
"gradio[mcp]=
|
| 79 |
"transformers>=4.51.0,<4.58.0" peft>=0.18.0 \
|
| 80 |
loguru "torchaudio==2.4.0" "diffusers==0.30.3" lightning numpy tensorboard soundfile \
|
| 81 |
einops vector_quantize_pytorch librosa mutagen
|
|
|
|
| 75 |
|
| 76 |
# Install Python deps for Gradio UI + training
|
| 77 |
RUN pip3 install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu \
|
| 78 |
+
"gradio[mcp]>=6.0.0,<7.0.0" requests torch safetensors \
|
| 79 |
"transformers>=4.51.0,<4.58.0" peft>=0.18.0 \
|
| 80 |
loguru "torchaudio==2.4.0" "diffusers==0.30.3" lightning numpy tensorboard soundfile \
|
| 81 |
einops vector_quantize_pytorch librosa mutagen
|
app.py
CHANGED
|
@@ -5,9 +5,12 @@ import sys
|
|
| 5 |
import time
|
| 6 |
import json
|
| 7 |
import argparse
|
|
|
|
| 8 |
import tempfile
|
| 9 |
import subprocess
|
| 10 |
import shutil
|
|
|
|
|
|
|
| 11 |
import requests
|
| 12 |
import logging
|
| 13 |
|
|
@@ -97,6 +100,61 @@ def _fetch_result(job_id, timeout=60):
|
|
| 97 |
return r
|
| 98 |
|
| 99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format,
|
| 101 |
adapter=None, lm_model=None, progress_cb=None):
|
| 102 |
"""Run full LM -> synth pipeline. Returns (audio_path, status_msg) or raises."""
|
|
@@ -460,17 +518,20 @@ def gradio_main():
|
|
| 460 |
# -- Validation --
|
| 461 |
if not audio_files:
|
| 462 |
_log("[FAIL] No audio files uploaded.")
|
| 463 |
-
yield _log_text(), gr.
|
| 464 |
return
|
| 465 |
|
| 466 |
if len(audio_files) > MAX_AUDIO_FILES:
|
| 467 |
_log(f"[FAIL] Too many files ({len(audio_files)}). Max: {MAX_AUDIO_FILES}")
|
| 468 |
-
yield _log_text(), gr.
|
| 469 |
return
|
| 470 |
|
| 471 |
lora_name = (lora_name or "").strip() or "my-lora"
|
| 472 |
# Sanitize: alphanumeric, dash, underscore only
|
| 473 |
lora_name = "".join(c if c.isalnum() or c in "-_" else "-" for c in lora_name)
|
|
|
|
|
|
|
|
|
|
| 474 |
|
| 475 |
epochs = max(1, min(int(epochs), 10))
|
| 476 |
lr = float(lr)
|
|
@@ -485,7 +546,7 @@ def gradio_main():
|
|
| 485 |
|
| 486 |
# Copy uploaded audio files + check total duration
|
| 487 |
_log(f"[INFO] Preparing {len(audio_files)} audio files...")
|
| 488 |
-
yield _log_text(), gr.
|
| 489 |
|
| 490 |
import librosa as _lr
|
| 491 |
total_dur = 0.0
|
|
@@ -530,18 +591,49 @@ def gradio_main():
|
|
| 530 |
|
| 531 |
_log(f"[INFO] LoRA: '{lora_name}' | Files: {len(audio_files)} | "
|
| 532 |
f"Epochs: {epochs} | LR: {lr} | Rank: {rank}")
|
| 533 |
-
yield _log_text(), gr.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 534 |
|
| 535 |
# Stop ace-server before training (frees memory)
|
| 536 |
_log("[INFO] Stopping ace-server for training...")
|
| 537 |
-
yield _log_text(), gr.
|
| 538 |
_stop_ace_server()
|
| 539 |
_gc.collect()
|
| 540 |
|
| 541 |
try:
|
| 542 |
# -- Phase 1: Preprocessing --
|
| 543 |
_log("[Step 1/2] Preprocessing audio...")
|
| 544 |
-
yield _log_text(), gr.
|
| 545 |
|
| 546 |
preprocessed_dir = os.path.join(work_dir, "preprocessed_tensors")
|
| 547 |
|
|
@@ -558,24 +650,24 @@ def gradio_main():
|
|
| 558 |
progress_callback=preprocess_progress,
|
| 559 |
cancel_check=lambda: False,
|
| 560 |
)
|
| 561 |
-
yield _log_text(), gr.
|
| 562 |
|
| 563 |
processed = result.get("processed", 0)
|
| 564 |
failed = result.get("failed", 0)
|
| 565 |
total = result.get("total", 0)
|
| 566 |
_log(f"[OK] Preprocessed: {processed}/{total} (failed: {failed})")
|
| 567 |
-
yield _log_text(), gr.
|
| 568 |
|
| 569 |
if processed == 0:
|
| 570 |
_log("[FAIL] No files preprocessed successfully. Cannot train.")
|
| 571 |
-
yield _log_text(), gr.
|
| 572 |
return
|
| 573 |
|
| 574 |
_gc.collect()
|
| 575 |
|
| 576 |
# -- Phase 2: Training --
|
| 577 |
_log("[Step 2/2] Training LoRA...")
|
| 578 |
-
yield _log_text(), gr.
|
| 579 |
|
| 580 |
for msg in train_lora_generator(
|
| 581 |
dataset_dir=preprocessed_dir,
|
|
@@ -605,24 +697,24 @@ def gradio_main():
|
|
| 605 |
break
|
| 606 |
|
| 607 |
_log(msg)
|
| 608 |
-
yield _log_text(), gr.
|
| 609 |
|
| 610 |
if msg.strip() == "[DONE]":
|
| 611 |
break
|
| 612 |
|
| 613 |
_log(f"[INFO] Total time: {time.time() - train_start:.0f}s")
|
| 614 |
-
yield _log_text(), gr.
|
| 615 |
|
| 616 |
except Exception as exc:
|
| 617 |
_log(f"[FAIL] Training error: {exc}")
|
| 618 |
import traceback
|
| 619 |
_log(traceback.format_exc())
|
| 620 |
-
yield _log_text(), gr.
|
| 621 |
|
| 622 |
finally:
|
| 623 |
# Always restart ace-server
|
| 624 |
_log("[INFO] Restarting ace-server...")
|
| 625 |
-
yield _log_text(), gr.
|
| 626 |
_gc.collect()
|
| 627 |
ok = _start_ace_server()
|
| 628 |
if ok:
|
|
@@ -631,9 +723,20 @@ def gradio_main():
|
|
| 631 |
_log("[WARN] ace-server may not have restarted -- check logs")
|
| 632 |
adapter_safetensors = os.path.join(adapter_out, "adapter_model.safetensors")
|
| 633 |
if os.path.isfile(adapter_safetensors):
|
| 634 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 635 |
else:
|
| 636 |
-
yield _log_text(), gr.
|
| 637 |
|
| 638 |
# -- Cancel handler --
|
| 639 |
def _on_cancel():
|
|
@@ -657,7 +760,7 @@ def gradio_main():
|
|
| 657 |
.status-box textarea { font-family: monospace; font-size: 13px; }
|
| 658 |
"""
|
| 659 |
|
| 660 |
-
with gr.Blocks(title="ACE-Step 1.5 XL (CPU)"
|
| 661 |
|
| 662 |
with gr.Tabs():
|
| 663 |
# ============================================================
|
|
@@ -777,6 +880,14 @@ def gradio_main():
|
|
| 777 |
elem_classes="status-box",
|
| 778 |
)
|
| 779 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 780 |
# Training generator -- yields (log, train_btn, cancel_btn, output_file)
|
| 781 |
train_event = train_btn.click(
|
| 782 |
train_lora_ui,
|
|
@@ -787,11 +898,12 @@ def gradio_main():
|
|
| 787 |
)
|
| 788 |
|
| 789 |
# After training completes, restore buttons and refresh LoRA dropdown
|
|
|
|
| 790 |
def _post_training():
|
| 791 |
return (
|
| 792 |
-
gr.
|
| 793 |
-
gr.
|
| 794 |
-
gr.
|
| 795 |
)
|
| 796 |
|
| 797 |
train_event.then(
|
|
@@ -816,6 +928,7 @@ def gradio_main():
|
|
| 816 |
server_name="0.0.0.0",
|
| 817 |
server_port=7860,
|
| 818 |
mcp_server=True,
|
|
|
|
| 819 |
)
|
| 820 |
|
| 821 |
|
|
|
|
| 5 |
import time
|
| 6 |
import json
|
| 7 |
import argparse
|
| 8 |
+
import base64
|
| 9 |
import tempfile
|
| 10 |
import subprocess
|
| 11 |
import shutil
|
| 12 |
+
import string
|
| 13 |
+
import random
|
| 14 |
import requests
|
| 15 |
import logging
|
| 16 |
|
|
|
|
| 100 |
return r
|
| 101 |
|
| 102 |
|
| 103 |
+
def _caption_via_understand(audio_path, timeout=120):
|
| 104 |
+
"""Call ace-server /understand to get a rich caption for an audio file.
|
| 105 |
+
|
| 106 |
+
Returns a dict with caption, bpm, key, signature, lyrics on success,
|
| 107 |
+
or None on failure (caller should fall back to librosa).
|
| 108 |
+
"""
|
| 109 |
+
fname = os.path.basename(audio_path)
|
| 110 |
+
try:
|
| 111 |
+
with open(audio_path, "rb") as f:
|
| 112 |
+
audio_b64 = base64.b64encode(f.read()).decode("ascii")
|
| 113 |
+
except Exception as exc:
|
| 114 |
+
logger.warning("[Caption] %s: failed to read file: %s", fname, exc)
|
| 115 |
+
return None
|
| 116 |
+
|
| 117 |
+
# Submit
|
| 118 |
+
try:
|
| 119 |
+
r = requests.post(
|
| 120 |
+
f"{ACE_SERVER}/understand",
|
| 121 |
+
json={"audio": audio_b64},
|
| 122 |
+
timeout=30,
|
| 123 |
+
)
|
| 124 |
+
if r.status_code != 200:
|
| 125 |
+
logger.warning("[Caption] %s: /understand returned %d", fname, r.status_code)
|
| 126 |
+
return None
|
| 127 |
+
job_id = r.json().get("id")
|
| 128 |
+
if not job_id:
|
| 129 |
+
logger.warning("[Caption] %s: /understand returned no job id", fname)
|
| 130 |
+
return None
|
| 131 |
+
except Exception as exc:
|
| 132 |
+
logger.warning("[Caption] %s: /understand submit failed: %s", fname, exc)
|
| 133 |
+
return None
|
| 134 |
+
|
| 135 |
+
# Poll until done
|
| 136 |
+
status, _ = _poll_job(job_id, timeout=timeout)
|
| 137 |
+
if status != "done":
|
| 138 |
+
logger.warning("[Caption] %s: /understand job %s -> %s", fname, job_id, status)
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
# Fetch result
|
| 142 |
+
try:
|
| 143 |
+
r = _fetch_result(job_id, timeout=30)
|
| 144 |
+
if r.status_code != 200:
|
| 145 |
+
logger.warning("[Caption] %s: /understand result fetch failed: %d", fname, r.status_code)
|
| 146 |
+
return None
|
| 147 |
+
data = r.json()
|
| 148 |
+
# The result should contain caption, bpm, key, signature, lyrics
|
| 149 |
+
if isinstance(data, dict) and data.get("caption"):
|
| 150 |
+
return data
|
| 151 |
+
logger.warning("[Caption] %s: /understand returned no caption field", fname)
|
| 152 |
+
return None
|
| 153 |
+
except Exception as exc:
|
| 154 |
+
logger.warning("[Caption] %s: /understand result parse failed: %s", fname, exc)
|
| 155 |
+
return None
|
| 156 |
+
|
| 157 |
+
|
| 158 |
def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format,
|
| 159 |
adapter=None, lm_model=None, progress_cb=None):
|
| 160 |
"""Run full LM -> synth pipeline. Returns (audio_path, status_msg) or raises."""
|
|
|
|
| 518 |
# -- Validation --
|
| 519 |
if not audio_files:
|
| 520 |
_log("[FAIL] No audio files uploaded.")
|
| 521 |
+
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 522 |
return
|
| 523 |
|
| 524 |
if len(audio_files) > MAX_AUDIO_FILES:
|
| 525 |
_log(f"[FAIL] Too many files ({len(audio_files)}). Max: {MAX_AUDIO_FILES}")
|
| 526 |
+
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 527 |
return
|
| 528 |
|
| 529 |
lora_name = (lora_name or "").strip() or "my-lora"
|
| 530 |
# Sanitize: alphanumeric, dash, underscore only
|
| 531 |
lora_name = "".join(c if c.isalnum() or c in "-_" else "-" for c in lora_name)
|
| 532 |
+
# Append random suffix to prevent naming collisions between users
|
| 533 |
+
suffix = "".join(random.choices(string.ascii_lowercase + string.digits, k=4))
|
| 534 |
+
lora_name = f"{lora_name}-{suffix}"
|
| 535 |
|
| 536 |
epochs = max(1, min(int(epochs), 10))
|
| 537 |
lr = float(lr)
|
|
|
|
| 546 |
|
| 547 |
# Copy uploaded audio files + check total duration
|
| 548 |
_log(f"[INFO] Preparing {len(audio_files)} audio files...")
|
| 549 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 550 |
|
| 551 |
import librosa as _lr
|
| 552 |
total_dur = 0.0
|
|
|
|
| 591 |
|
| 592 |
_log(f"[INFO] LoRA: '{lora_name}' | Files: {len(audio_files)} | "
|
| 593 |
f"Epochs: {epochs} | LR: {lr} | Rank: {rank}")
|
| 594 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 595 |
+
|
| 596 |
+
# Caption each audio file via ace-server /understand BEFORE stopping it
|
| 597 |
+
if _server_ok():
|
| 598 |
+
_log("[INFO] Captioning audio via ace-server /understand...")
|
| 599 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 600 |
+
for audio_fname in sorted(os.listdir(audio_dir)):
|
| 601 |
+
full_path = os.path.join(audio_dir, audio_fname)
|
| 602 |
+
if not os.path.isfile(full_path) or audio_fname.endswith(".json"):
|
| 603 |
+
continue
|
| 604 |
+
caption_json_path = full_path + ".json"
|
| 605 |
+
caption_data = _caption_via_understand(full_path, timeout=120)
|
| 606 |
+
if caption_data:
|
| 607 |
+
_log(f"[Caption] {audio_fname}: using ace-server /understand")
|
| 608 |
+
with open(caption_json_path, "w") as cj:
|
| 609 |
+
json.dump(caption_data, cj)
|
| 610 |
+
else:
|
| 611 |
+
# Fallback to librosa for basic metadata
|
| 612 |
+
_log(f"[Caption] {audio_fname}: fallback to librosa")
|
| 613 |
+
try:
|
| 614 |
+
y_cap, sr_cap = _lr.load(full_path, sr=None, mono=True)
|
| 615 |
+
tempo, _ = _lr.beat.beat_track(y=y_cap, sr=sr_cap)
|
| 616 |
+
bpm_val = float(tempo) if hasattr(tempo, '__float__') else float(tempo[0])
|
| 617 |
+
fallback = {"caption": "", "bpm": round(bpm_val), "key": "", "signature": "", "lyrics": ""}
|
| 618 |
+
with open(caption_json_path, "w") as cj:
|
| 619 |
+
json.dump(fallback, cj)
|
| 620 |
+
except Exception as cap_exc:
|
| 621 |
+
_log(f"[Caption] {audio_fname}: librosa fallback also failed: {cap_exc}")
|
| 622 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 623 |
+
else:
|
| 624 |
+
_log("[INFO] ace-server not running, skipping /understand captioning")
|
| 625 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 626 |
|
| 627 |
# Stop ace-server before training (frees memory)
|
| 628 |
_log("[INFO] Stopping ace-server for training...")
|
| 629 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 630 |
_stop_ace_server()
|
| 631 |
_gc.collect()
|
| 632 |
|
| 633 |
try:
|
| 634 |
# -- Phase 1: Preprocessing --
|
| 635 |
_log("[Step 1/2] Preprocessing audio...")
|
| 636 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 637 |
|
| 638 |
preprocessed_dir = os.path.join(work_dir, "preprocessed_tensors")
|
| 639 |
|
|
|
|
| 650 |
progress_callback=preprocess_progress,
|
| 651 |
cancel_check=lambda: False,
|
| 652 |
)
|
| 653 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 654 |
|
| 655 |
processed = result.get("processed", 0)
|
| 656 |
failed = result.get("failed", 0)
|
| 657 |
total = result.get("total", 0)
|
| 658 |
_log(f"[OK] Preprocessed: {processed}/{total} (failed: {failed})")
|
| 659 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 660 |
|
| 661 |
if processed == 0:
|
| 662 |
_log("[FAIL] No files preprocessed successfully. Cannot train.")
|
| 663 |
+
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 664 |
return
|
| 665 |
|
| 666 |
_gc.collect()
|
| 667 |
|
| 668 |
# -- Phase 2: Training --
|
| 669 |
_log("[Step 2/2] Training LoRA...")
|
| 670 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 671 |
|
| 672 |
for msg in train_lora_generator(
|
| 673 |
dataset_dir=preprocessed_dir,
|
|
|
|
| 697 |
break
|
| 698 |
|
| 699 |
_log(msg)
|
| 700 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 701 |
|
| 702 |
if msg.strip() == "[DONE]":
|
| 703 |
break
|
| 704 |
|
| 705 |
_log(f"[INFO] Total time: {time.time() - train_start:.0f}s")
|
| 706 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 707 |
|
| 708 |
except Exception as exc:
|
| 709 |
_log(f"[FAIL] Training error: {exc}")
|
| 710 |
import traceback
|
| 711 |
_log(traceback.format_exc())
|
| 712 |
+
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 713 |
|
| 714 |
finally:
|
| 715 |
# Always restart ace-server
|
| 716 |
_log("[INFO] Restarting ace-server...")
|
| 717 |
+
yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
|
| 718 |
_gc.collect()
|
| 719 |
ok = _start_ace_server()
|
| 720 |
if ok:
|
|
|
|
| 723 |
_log("[WARN] ace-server may not have restarted -- check logs")
|
| 724 |
adapter_safetensors = os.path.join(adapter_out, "adapter_model.safetensors")
|
| 725 |
if os.path.isfile(adapter_safetensors):
|
| 726 |
+
# Copy to a temp file so Gradio doesn't try to validate /app paths
|
| 727 |
+
# (avoids InvalidPathError: "Cannot move /app to the gradio cache dir
|
| 728 |
+
# because it was not uploaded by a user")
|
| 729 |
+
tmp_out = tempfile.NamedTemporaryFile(
|
| 730 |
+
suffix=".safetensors",
|
| 731 |
+
prefix=f"{lora_name}_",
|
| 732 |
+
delete=False,
|
| 733 |
+
)
|
| 734 |
+
tmp_out.close()
|
| 735 |
+
shutil.copy2(adapter_safetensors, tmp_out.name)
|
| 736 |
+
_log(f"[OK] LoRA saved: {lora_name}")
|
| 737 |
+
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File(value=tmp_out.name, visible=True)
|
| 738 |
else:
|
| 739 |
+
yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
|
| 740 |
|
| 741 |
# -- Cancel handler --
|
| 742 |
def _on_cancel():
|
|
|
|
| 760 |
.status-box textarea { font-family: monospace; font-size: 13px; }
|
| 761 |
"""
|
| 762 |
|
| 763 |
+
with gr.Blocks(title="ACE-Step 1.5 XL (CPU)") as demo:
|
| 764 |
|
| 765 |
with gr.Tabs():
|
| 766 |
# ============================================================
|
|
|
|
| 880 |
elem_classes="status-box",
|
| 881 |
)
|
| 882 |
|
| 883 |
+
# Button swap on click (separate handler, like rvc-beatrice)
|
| 884 |
+
# This fires immediately so user sees Cancel even if training
|
| 885 |
+
# queues behind concurrency_limit=1
|
| 886 |
+
train_btn.click(
|
| 887 |
+
lambda: (gr.Button(visible=False), gr.Button(visible=True)),
|
| 888 |
+
outputs=[train_btn, cancel_btn],
|
| 889 |
+
)
|
| 890 |
+
|
| 891 |
# Training generator -- yields (log, train_btn, cancel_btn, output_file)
|
| 892 |
train_event = train_btn.click(
|
| 893 |
train_lora_ui,
|
|
|
|
| 898 |
)
|
| 899 |
|
| 900 |
# After training completes, restore buttons and refresh LoRA dropdown
|
| 901 |
+
# This ensures cleanup even if the user navigated away
|
| 902 |
def _post_training():
|
| 903 |
return (
|
| 904 |
+
gr.Button(visible=True),
|
| 905 |
+
gr.Button(visible=False),
|
| 906 |
+
gr.Dropdown(choices=_list_lora_choices()),
|
| 907 |
)
|
| 908 |
|
| 909 |
train_event.then(
|
|
|
|
| 928 |
server_name="0.0.0.0",
|
| 929 |
server_port=7860,
|
| 930 |
mcp_server=True,
|
| 931 |
+
css=CSS,
|
| 932 |
)
|
| 933 |
|
| 934 |
|
train_engine.py
CHANGED
|
@@ -15,6 +15,8 @@ Exports:
|
|
| 15 |
train_lora_generator() - Generator-based LoRA training loop
|
| 16 |
cancel_training() - Set the cancel flag
|
| 17 |
get_trained_loras() - List saved adapters
|
|
|
|
|
|
|
| 18 |
"""
|
| 19 |
|
| 20 |
from __future__ import annotations
|
|
@@ -2799,3 +2801,288 @@ def get_trained_loras(adapter_dir: str) -> List[str]:
|
|
| 2799 |
break
|
| 2800 |
|
| 2801 |
return sorted(set(result))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
train_lora_generator() - Generator-based LoRA training loop
|
| 16 |
cancel_training() - Set the cancel flag
|
| 17 |
get_trained_loras() - List saved adapters
|
| 18 |
+
generate_audio() - Standalone inference (text -> WAV, optional LoRA)
|
| 19 |
+
tiled_vae_decode() - Tiled VAE latent-to-waveform decode
|
| 20 |
"""
|
| 21 |
|
| 22 |
from __future__ import annotations
|
|
|
|
| 2801 |
break
|
| 2802 |
|
| 2803 |
return sorted(set(result))
|
| 2804 |
+
|
| 2805 |
+
|
| 2806 |
+
# ============================================================================
|
| 2807 |
+
# TILED VAE DECODE (mirror of tiled_vae_encode)
|
| 2808 |
+
# ============================================================================
|
| 2809 |
+
|
| 2810 |
+
def tiled_vae_decode(
|
| 2811 |
+
vae, latents: torch.Tensor, dtype: torch.dtype,
|
| 2812 |
+
chunk_frames: int = 1024, overlap_frames: int = 64,
|
| 2813 |
+
) -> torch.Tensor:
|
| 2814 |
+
"""Decode latents [B, T, C] -> waveform [B, 2, samples] using tiled VAE.
|
| 2815 |
+
|
| 2816 |
+
Mirrors tiled_vae_encode but in the reverse direction. Tiles along
|
| 2817 |
+
the time axis of the latent to keep peak memory bounded.
|
| 2818 |
+
|
| 2819 |
+
Args:
|
| 2820 |
+
vae: AutoencoderOobleck decoder.
|
| 2821 |
+
latents: Latent tensor in [B, T, C] layout (C=64).
|
| 2822 |
+
dtype: Target dtype for the output waveform.
|
| 2823 |
+
chunk_frames: Number of latent frames per tile.
|
| 2824 |
+
overlap_frames: Overlap frames per side for crossfade.
|
| 2825 |
+
|
| 2826 |
+
Returns:
|
| 2827 |
+
Waveform tensor [B, 2, total_samples] in *dtype*.
|
| 2828 |
+
"""
|
| 2829 |
+
vae_device = next(vae.parameters()).device
|
| 2830 |
+
vae_dtype = vae.dtype
|
| 2831 |
+
|
| 2832 |
+
# Transpose to VAE convention [B, C, T]
|
| 2833 |
+
lat = latents.transpose(1, 2).contiguous()
|
| 2834 |
+
B, C, T = lat.shape
|
| 2835 |
+
|
| 2836 |
+
if T <= chunk_frames:
|
| 2837 |
+
with torch.inference_mode():
|
| 2838 |
+
audio = vae.decode(lat.to(vae_device, dtype=vae_dtype)).sample
|
| 2839 |
+
return audio.to(dtype=dtype, device="cpu")
|
| 2840 |
+
|
| 2841 |
+
# Upsample factor: unknown until first decode, so we discover it.
|
| 2842 |
+
stride = chunk_frames - 2 * overlap_frames
|
| 2843 |
+
if stride <= 0:
|
| 2844 |
+
raise ValueError(f"chunk_frames ({chunk_frames}) must be > 2*overlap ({overlap_frames})")
|
| 2845 |
+
|
| 2846 |
+
num_tiles = math.ceil(T / stride)
|
| 2847 |
+
us_factor: Optional[float] = None
|
| 2848 |
+
write_pos = 0
|
| 2849 |
+
final: Optional[torch.Tensor] = None
|
| 2850 |
+
|
| 2851 |
+
for i in range(num_tiles):
|
| 2852 |
+
core_start = i * stride
|
| 2853 |
+
core_end = min(core_start + stride, T)
|
| 2854 |
+
win_start = max(0, core_start - overlap_frames)
|
| 2855 |
+
win_end = min(T, core_end + overlap_frames)
|
| 2856 |
+
|
| 2857 |
+
chunk = lat[:, :, win_start:win_end].to(vae_device, dtype=vae_dtype)
|
| 2858 |
+
with torch.inference_mode():
|
| 2859 |
+
decoded = vae.decode(chunk).sample # [B, 2, samples_chunk]
|
| 2860 |
+
|
| 2861 |
+
if us_factor is None:
|
| 2862 |
+
us_factor = decoded.shape[-1] / chunk.shape[-1]
|
| 2863 |
+
total_samples = int(round(T * us_factor))
|
| 2864 |
+
final = torch.zeros(B, decoded.shape[1], total_samples, dtype=decoded.dtype, device="cpu")
|
| 2865 |
+
|
| 2866 |
+
trim_start = int(round((core_start - win_start) * us_factor))
|
| 2867 |
+
trim_end = int(round((win_end - core_end) * us_factor))
|
| 2868 |
+
end_idx = decoded.shape[-1] - trim_end if trim_end > 0 else decoded.shape[-1]
|
| 2869 |
+
core = decoded[:, :, trim_start:end_idx]
|
| 2870 |
+
core_len = core.shape[-1]
|
| 2871 |
+
final[:, :, write_pos:write_pos + core_len] = core.cpu()
|
| 2872 |
+
write_pos += core_len
|
| 2873 |
+
del chunk, decoded, core
|
| 2874 |
+
|
| 2875 |
+
final = final[:, :, :write_pos]
|
| 2876 |
+
return final.to(dtype=dtype)
|
| 2877 |
+
|
| 2878 |
+
|
| 2879 |
+
# ============================================================================
|
| 2880 |
+
# INFERENCE -- generate_audio()
|
| 2881 |
+
# ============================================================================
|
| 2882 |
+
|
| 2883 |
+
def generate_audio(
|
| 2884 |
+
caption: str,
|
| 2885 |
+
checkpoint_dir: str,
|
| 2886 |
+
output_path: str,
|
| 2887 |
+
lyrics: str = "[Instrumental]",
|
| 2888 |
+
duration: float = 10.0,
|
| 2889 |
+
bpm: int = 120,
|
| 2890 |
+
steps: int = 8,
|
| 2891 |
+
seed: int = -1,
|
| 2892 |
+
variant: str = "turbo",
|
| 2893 |
+
device: str = "auto",
|
| 2894 |
+
adapter_path: Optional[str] = None,
|
| 2895 |
+
adapter_scale: float = 1.0,
|
| 2896 |
+
) -> str:
|
| 2897 |
+
"""Generate audio using the ACE-Step DiT pipeline (pure PyTorch, no server).
|
| 2898 |
+
|
| 2899 |
+
Pipeline:
|
| 2900 |
+
1. Text encoder -> text_hidden_states, lyric embeddings
|
| 2901 |
+
2. Load full model (DiT + condition encoder + FSQ)
|
| 2902 |
+
3. Optional: inject LoRA adapter via PEFT
|
| 2903 |
+
4. model.generate_audio() -- runs condition encoder, FSQ detokenizer,
|
| 2904 |
+
and the flow-matching diffusion loop internally
|
| 2905 |
+
5. VAE decode latents -> waveform
|
| 2906 |
+
6. Save waveform as 48 kHz stereo WAV
|
| 2907 |
+
7. Unload all models, free memory
|
| 2908 |
+
|
| 2909 |
+
Args:
|
| 2910 |
+
caption: Text description of the desired music.
|
| 2911 |
+
checkpoint_dir: Root directory that contains model sub-dirs
|
| 2912 |
+
(e.g. ``acestep-v15-turbo/``, ``vae/``, ``Qwen3-Embedding-0.6B/``).
|
| 2913 |
+
output_path: Where to write the output WAV file.
|
| 2914 |
+
lyrics: Lyrics text or ``"[Instrumental]"`` for no vocals.
|
| 2915 |
+
duration: Desired audio length in seconds.
|
| 2916 |
+
bpm: Beats per minute (metadata hint for the model).
|
| 2917 |
+
steps: Number of diffusion steps (8 for turbo, 50 for base/SFT).
|
| 2918 |
+
seed: RNG seed (-1 = random).
|
| 2919 |
+
variant: Model variant name (e.g. ``"turbo"``, ``"base"``).
|
| 2920 |
+
device: ``"auto"``, ``"cpu"``, ``"cuda:0"``, etc.
|
| 2921 |
+
adapter_path: Path to a PEFT LoRA adapter directory (optional).
|
| 2922 |
+
adapter_scale: Scaling factor applied to the adapter.
|
| 2923 |
+
|
| 2924 |
+
Returns:
|
| 2925 |
+
The *output_path* string (for convenience).
|
| 2926 |
+
"""
|
| 2927 |
+
import numpy as np
|
| 2928 |
+
|
| 2929 |
+
# ------------------------------------------------------------------
|
| 2930 |
+
# 0. Device / dtype
|
| 2931 |
+
# ------------------------------------------------------------------
|
| 2932 |
+
device = detect_device(device)
|
| 2933 |
+
dtype = select_dtype(device)
|
| 2934 |
+
logger.info(
|
| 2935 |
+
"generate_audio: device=%s, dtype=%s, variant=%s, steps=%d, duration=%.1fs",
|
| 2936 |
+
device, dtype, variant, steps, duration,
|
| 2937 |
+
)
|
| 2938 |
+
|
| 2939 |
+
# Resolve seed
|
| 2940 |
+
if seed < 0:
|
| 2941 |
+
seed = random.randint(0, 2**31 - 1)
|
| 2942 |
+
logger.info("Using seed=%d", seed)
|
| 2943 |
+
|
| 2944 |
+
# ------------------------------------------------------------------
|
| 2945 |
+
# 1. Text encoder -- encode caption and lyrics
|
| 2946 |
+
# ------------------------------------------------------------------
|
| 2947 |
+
logger.info("Loading text encoder...")
|
| 2948 |
+
tokenizer, text_encoder = load_text_encoder(checkpoint_dir, device)
|
| 2949 |
+
|
| 2950 |
+
text_hs, text_mask = encode_text(text_encoder, tokenizer, caption, device, dtype)
|
| 2951 |
+
lyric_hs, lyric_mask = encode_lyrics(text_encoder, tokenizer, lyrics, device, dtype)
|
| 2952 |
+
|
| 2953 |
+
# Free text encoder -- no longer needed
|
| 2954 |
+
unload_models(text_encoder)
|
| 2955 |
+
del text_encoder, tokenizer
|
| 2956 |
+
gc.collect()
|
| 2957 |
+
_clear_gpu_cache(device)
|
| 2958 |
+
logger.info("Text encoder unloaded.")
|
| 2959 |
+
|
| 2960 |
+
# ------------------------------------------------------------------
|
| 2961 |
+
# 2. Load full model (DiT + CondEncoder + FSQ tokenizer/detokenizer)
|
| 2962 |
+
# ------------------------------------------------------------------
|
| 2963 |
+
logger.info("Loading ACE-Step model (%s)...", variant)
|
| 2964 |
+
model = load_model_for_training(checkpoint_dir, variant=variant, device=device)
|
| 2965 |
+
model = model.to(dtype=dtype)
|
| 2966 |
+
model.eval()
|
| 2967 |
+
|
| 2968 |
+
# ------------------------------------------------------------------
|
| 2969 |
+
# 3. Optional: inject LoRA adapter
|
| 2970 |
+
# ------------------------------------------------------------------
|
| 2971 |
+
if adapter_path:
|
| 2972 |
+
logger.info("Loading LoRA adapter from %s (scale=%.2f)...", adapter_path, adapter_scale)
|
| 2973 |
+
from peft import PeftModel
|
| 2974 |
+
|
| 2975 |
+
decoder = model.decoder if hasattr(model, "decoder") else model
|
| 2976 |
+
# Unwrap any wrappers
|
| 2977 |
+
while hasattr(decoder, "_forward_module"):
|
| 2978 |
+
decoder = decoder._forward_module
|
| 2979 |
+
if hasattr(decoder, "base_model"):
|
| 2980 |
+
bm = decoder.base_model
|
| 2981 |
+
decoder = bm.model if hasattr(bm, "model") else bm
|
| 2982 |
+
if hasattr(decoder, "model") and isinstance(decoder.model, nn.Module):
|
| 2983 |
+
decoder = decoder.model
|
| 2984 |
+
|
| 2985 |
+
model.decoder = PeftModel.from_pretrained(
|
| 2986 |
+
decoder, adapter_path, is_trainable=False,
|
| 2987 |
+
)
|
| 2988 |
+
# Apply adapter scale if not 1.0
|
| 2989 |
+
if abs(adapter_scale - 1.0) > 1e-6:
|
| 2990 |
+
for name, module in model.decoder.named_modules():
|
| 2991 |
+
if hasattr(module, "scaling"):
|
| 2992 |
+
for key in module.scaling:
|
| 2993 |
+
module.scaling[key] = adapter_scale
|
| 2994 |
+
model.decoder.eval()
|
| 2995 |
+
logger.info("LoRA adapter applied.")
|
| 2996 |
+
|
| 2997 |
+
# ------------------------------------------------------------------
|
| 2998 |
+
# 4. Prepare inputs for model.generate_audio()
|
| 2999 |
+
# ------------------------------------------------------------------
|
| 3000 |
+
# Latent frame rate is 25 Hz
|
| 3001 |
+
LATENT_HZ = 25
|
| 3002 |
+
latent_length = int(duration * LATENT_HZ)
|
| 3003 |
+
|
| 3004 |
+
# Load silence latent for context building
|
| 3005 |
+
silence_latent = load_silence_latent(checkpoint_dir, device, variant)
|
| 3006 |
+
# Ensure silence latent covers the required length
|
| 3007 |
+
if silence_latent.shape[1] < latent_length:
|
| 3008 |
+
repeats = math.ceil(latent_length / silence_latent.shape[1])
|
| 3009 |
+
silence_latent = silence_latent.repeat(1, repeats, 1)
|
| 3010 |
+
silence_latent = silence_latent[:, :latent_length, :].to(device=device, dtype=dtype)
|
| 3011 |
+
|
| 3012 |
+
# Build source latents and masks for text2music mode (all silence, all-ones mask)
|
| 3013 |
+
src_latents = silence_latent[:1, :latent_length, :]
|
| 3014 |
+
chunk_masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype)
|
| 3015 |
+
is_covers = torch.zeros(1, device=device, dtype=torch.long)
|
| 3016 |
+
|
| 3017 |
+
# Dummy timbre reference (single silence frame -> no timbre conditioning)
|
| 3018 |
+
refer_audio = torch.zeros(1, 1, 64, device=device, dtype=dtype)
|
| 3019 |
+
refer_order = torch.zeros(1, device=device, dtype=torch.long)
|
| 3020 |
+
|
| 3021 |
+
# Shift schedule: turbo uses 3.0, base/sft uses 1.0
|
| 3022 |
+
shift = 3.0 if "turbo" in variant else 1.0
|
| 3023 |
+
|
| 3024 |
+
# ------------------------------------------------------------------
|
| 3025 |
+
# 5. Run diffusion (model.generate_audio handles everything internally)
|
| 3026 |
+
# ------------------------------------------------------------------
|
| 3027 |
+
logger.info("Running diffusion (%d steps, shift=%.1f)...", steps, shift)
|
| 3028 |
+
with torch.no_grad():
|
| 3029 |
+
result = model.generate_audio(
|
| 3030 |
+
text_hidden_states=text_hs.to(device=device, dtype=dtype),
|
| 3031 |
+
text_attention_mask=text_mask.to(device=device, dtype=dtype),
|
| 3032 |
+
lyric_hidden_states=lyric_hs.to(device=device, dtype=dtype),
|
| 3033 |
+
lyric_attention_mask=lyric_mask.to(device=device, dtype=dtype),
|
| 3034 |
+
refer_audio_acoustic_hidden_states_packed=refer_audio,
|
| 3035 |
+
refer_audio_order_mask=refer_order,
|
| 3036 |
+
src_latents=src_latents,
|
| 3037 |
+
chunk_masks=chunk_masks,
|
| 3038 |
+
is_covers=is_covers,
|
| 3039 |
+
silence_latent=silence_latent,
|
| 3040 |
+
seed=seed,
|
| 3041 |
+
fix_nfe=steps,
|
| 3042 |
+
shift=shift,
|
| 3043 |
+
)
|
| 3044 |
+
|
| 3045 |
+
target_latents = result["target_latents"] # [1, T, 64]
|
| 3046 |
+
time_costs = result.get("time_costs", {})
|
| 3047 |
+
logger.info("Diffusion done. Time costs: %s", time_costs)
|
| 3048 |
+
|
| 3049 |
+
# Free model weights -- keep latents on CPU
|
| 3050 |
+
target_latents = target_latents.cpu().to(dtype)
|
| 3051 |
+
unload_models(model)
|
| 3052 |
+
del model, silence_latent, src_latents, chunk_masks
|
| 3053 |
+
del text_hs, text_mask, lyric_hs, lyric_mask
|
| 3054 |
+
gc.collect()
|
| 3055 |
+
_clear_gpu_cache(device)
|
| 3056 |
+
logger.info("DiT model unloaded.")
|
| 3057 |
+
|
| 3058 |
+
# ------------------------------------------------------------------
|
| 3059 |
+
# 6. VAE decode latents -> waveform
|
| 3060 |
+
# ------------------------------------------------------------------
|
| 3061 |
+
logger.info("Loading VAE decoder...")
|
| 3062 |
+
vae = load_vae(checkpoint_dir, device)
|
| 3063 |
+
|
| 3064 |
+
logger.info("Decoding latents -> waveform (tiled)...")
|
| 3065 |
+
waveform = tiled_vae_decode(vae, target_latents.to(device), dtype) # [1, 2, samples]
|
| 3066 |
+
|
| 3067 |
+
unload_models(vae)
|
| 3068 |
+
del vae, target_latents
|
| 3069 |
+
gc.collect()
|
| 3070 |
+
_clear_gpu_cache(device)
|
| 3071 |
+
logger.info("VAE unloaded.")
|
| 3072 |
+
|
| 3073 |
+
# ------------------------------------------------------------------
|
| 3074 |
+
# 7. Save as WAV (48 kHz stereo)
|
| 3075 |
+
# ------------------------------------------------------------------
|
| 3076 |
+
audio_np = waveform[0].float().clamp(-1.0, 1.0).cpu().numpy() # [2, samples]
|
| 3077 |
+
|
| 3078 |
+
os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
|
| 3079 |
+
try:
|
| 3080 |
+
import soundfile as sf
|
| 3081 |
+
# soundfile expects [samples, channels]
|
| 3082 |
+
sf.write(output_path, audio_np.T, TARGET_SR, subtype="PCM_16")
|
| 3083 |
+
except ImportError:
|
| 3084 |
+
import torchaudio
|
| 3085 |
+
torchaudio.save(output_path, torch.from_numpy(audio_np), TARGET_SR)
|
| 3086 |
+
|
| 3087 |
+
logger.info("Audio saved to %s (%.1fs @ %d Hz)", output_path, duration, TARGET_SR)
|
| 3088 |
+
return output_path
|