Nekochu commited on
Commit
ff239f5
·
1 Parent(s): bc97006

major update: PyTorch inference, Gradio 6, session isolation, /understand captioning

Browse files

- generate_audio(): full PyTorch inference pipeline (no ace-server needed)
- tiled_vae_decode(): memory-bounded VAE decoding
- Gradio 6 migration (gr.update -> component constructors)
- Session isolation (random suffix on LoRA names)
- /understand captioning before training (falls back to librosa)
- Active tab detection pattern from rvc-beatrice
- Gradio API fix (tempfile for adapter download)
- dtype fix for mixed precision inference

Files changed (3) hide show
  1. Dockerfile +1 -1
  2. app.py +133 -20
  3. train_engine.py +287 -0
Dockerfile CHANGED
@@ -75,7 +75,7 @@ RUN curl -fL --retry 3 --retry-delay 5 -o /app/models/vae-BF16.gguf \
75
 
76
  # Install Python deps for Gradio UI + training
77
  RUN pip3 install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu \
78
- "gradio[mcp]==5.29.0" requests torch safetensors \
79
  "transformers>=4.51.0,<4.58.0" peft>=0.18.0 \
80
  loguru "torchaudio==2.4.0" "diffusers==0.30.3" lightning numpy tensorboard soundfile \
81
  einops vector_quantize_pytorch librosa mutagen
 
75
 
76
  # Install Python deps for Gradio UI + training
77
  RUN pip3 install --no-cache-dir --extra-index-url https://download.pytorch.org/whl/cpu \
78
+ "gradio[mcp]>=6.0.0,<7.0.0" requests torch safetensors \
79
  "transformers>=4.51.0,<4.58.0" peft>=0.18.0 \
80
  loguru "torchaudio==2.4.0" "diffusers==0.30.3" lightning numpy tensorboard soundfile \
81
  einops vector_quantize_pytorch librosa mutagen
app.py CHANGED
@@ -5,9 +5,12 @@ import sys
5
  import time
6
  import json
7
  import argparse
 
8
  import tempfile
9
  import subprocess
10
  import shutil
 
 
11
  import requests
12
  import logging
13
 
@@ -97,6 +100,61 @@ def _fetch_result(job_id, timeout=60):
97
  return r
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format,
101
  adapter=None, lm_model=None, progress_cb=None):
102
  """Run full LM -> synth pipeline. Returns (audio_path, status_msg) or raises."""
@@ -460,17 +518,20 @@ def gradio_main():
460
  # -- Validation --
461
  if not audio_files:
462
  _log("[FAIL] No audio files uploaded.")
463
- yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update()
464
  return
465
 
466
  if len(audio_files) > MAX_AUDIO_FILES:
467
  _log(f"[FAIL] Too many files ({len(audio_files)}). Max: {MAX_AUDIO_FILES}")
468
- yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update()
469
  return
470
 
471
  lora_name = (lora_name or "").strip() or "my-lora"
472
  # Sanitize: alphanumeric, dash, underscore only
473
  lora_name = "".join(c if c.isalnum() or c in "-_" else "-" for c in lora_name)
 
 
 
474
 
475
  epochs = max(1, min(int(epochs), 10))
476
  lr = float(lr)
@@ -485,7 +546,7 @@ def gradio_main():
485
 
486
  # Copy uploaded audio files + check total duration
487
  _log(f"[INFO] Preparing {len(audio_files)} audio files...")
488
- yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
489
 
490
  import librosa as _lr
491
  total_dur = 0.0
@@ -530,18 +591,49 @@ def gradio_main():
530
 
531
  _log(f"[INFO] LoRA: '{lora_name}' | Files: {len(audio_files)} | "
532
  f"Epochs: {epochs} | LR: {lr} | Rank: {rank}")
533
- yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
 
535
  # Stop ace-server before training (frees memory)
536
  _log("[INFO] Stopping ace-server for training...")
537
- yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
538
  _stop_ace_server()
539
  _gc.collect()
540
 
541
  try:
542
  # -- Phase 1: Preprocessing --
543
  _log("[Step 1/2] Preprocessing audio...")
544
- yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
545
 
546
  preprocessed_dir = os.path.join(work_dir, "preprocessed_tensors")
547
 
@@ -558,24 +650,24 @@ def gradio_main():
558
  progress_callback=preprocess_progress,
559
  cancel_check=lambda: False,
560
  )
561
- yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
562
 
563
  processed = result.get("processed", 0)
564
  failed = result.get("failed", 0)
565
  total = result.get("total", 0)
566
  _log(f"[OK] Preprocessed: {processed}/{total} (failed: {failed})")
567
- yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
568
 
569
  if processed == 0:
570
  _log("[FAIL] No files preprocessed successfully. Cannot train.")
571
- yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update()
572
  return
573
 
574
  _gc.collect()
575
 
576
  # -- Phase 2: Training --
577
  _log("[Step 2/2] Training LoRA...")
578
- yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
579
 
580
  for msg in train_lora_generator(
581
  dataset_dir=preprocessed_dir,
@@ -605,24 +697,24 @@ def gradio_main():
605
  break
606
 
607
  _log(msg)
608
- yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
609
 
610
  if msg.strip() == "[DONE]":
611
  break
612
 
613
  _log(f"[INFO] Total time: {time.time() - train_start:.0f}s")
614
- yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
615
 
616
  except Exception as exc:
617
  _log(f"[FAIL] Training error: {exc}")
618
  import traceback
619
  _log(traceback.format_exc())
620
- yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update()
621
 
622
  finally:
623
  # Always restart ace-server
624
  _log("[INFO] Restarting ace-server...")
625
- yield _log_text(), gr.update(visible=False), gr.update(visible=True), gr.update()
626
  _gc.collect()
627
  ok = _start_ace_server()
628
  if ok:
@@ -631,9 +723,20 @@ def gradio_main():
631
  _log("[WARN] ace-server may not have restarted -- check logs")
632
  adapter_safetensors = os.path.join(adapter_out, "adapter_model.safetensors")
633
  if os.path.isfile(adapter_safetensors):
634
- yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update(value=adapter_safetensors, visible=True)
 
 
 
 
 
 
 
 
 
 
 
635
  else:
636
- yield _log_text(), gr.update(visible=True), gr.update(visible=False), gr.update()
637
 
638
  # -- Cancel handler --
639
  def _on_cancel():
@@ -657,7 +760,7 @@ def gradio_main():
657
  .status-box textarea { font-family: monospace; font-size: 13px; }
658
  """
659
 
660
- with gr.Blocks(title="ACE-Step 1.5 XL (CPU)", css=CSS) as demo:
661
 
662
  with gr.Tabs():
663
  # ============================================================
@@ -777,6 +880,14 @@ def gradio_main():
777
  elem_classes="status-box",
778
  )
779
 
 
 
 
 
 
 
 
 
780
  # Training generator -- yields (log, train_btn, cancel_btn, output_file)
781
  train_event = train_btn.click(
782
  train_lora_ui,
@@ -787,11 +898,12 @@ def gradio_main():
787
  )
788
 
789
  # After training completes, restore buttons and refresh LoRA dropdown
 
790
  def _post_training():
791
  return (
792
- gr.update(visible=True),
793
- gr.update(visible=False),
794
- gr.update(choices=_list_lora_choices()),
795
  )
796
 
797
  train_event.then(
@@ -816,6 +928,7 @@ def gradio_main():
816
  server_name="0.0.0.0",
817
  server_port=7860,
818
  mcp_server=True,
 
819
  )
820
 
821
 
 
5
  import time
6
  import json
7
  import argparse
8
+ import base64
9
  import tempfile
10
  import subprocess
11
  import shutil
12
+ import string
13
+ import random
14
  import requests
15
  import logging
16
 
 
100
  return r
101
 
102
 
103
+ def _caption_via_understand(audio_path, timeout=120):
104
+ """Call ace-server /understand to get a rich caption for an audio file.
105
+
106
+ Returns a dict with caption, bpm, key, signature, lyrics on success,
107
+ or None on failure (caller should fall back to librosa).
108
+ """
109
+ fname = os.path.basename(audio_path)
110
+ try:
111
+ with open(audio_path, "rb") as f:
112
+ audio_b64 = base64.b64encode(f.read()).decode("ascii")
113
+ except Exception as exc:
114
+ logger.warning("[Caption] %s: failed to read file: %s", fname, exc)
115
+ return None
116
+
117
+ # Submit
118
+ try:
119
+ r = requests.post(
120
+ f"{ACE_SERVER}/understand",
121
+ json={"audio": audio_b64},
122
+ timeout=30,
123
+ )
124
+ if r.status_code != 200:
125
+ logger.warning("[Caption] %s: /understand returned %d", fname, r.status_code)
126
+ return None
127
+ job_id = r.json().get("id")
128
+ if not job_id:
129
+ logger.warning("[Caption] %s: /understand returned no job id", fname)
130
+ return None
131
+ except Exception as exc:
132
+ logger.warning("[Caption] %s: /understand submit failed: %s", fname, exc)
133
+ return None
134
+
135
+ # Poll until done
136
+ status, _ = _poll_job(job_id, timeout=timeout)
137
+ if status != "done":
138
+ logger.warning("[Caption] %s: /understand job %s -> %s", fname, job_id, status)
139
+ return None
140
+
141
+ # Fetch result
142
+ try:
143
+ r = _fetch_result(job_id, timeout=30)
144
+ if r.status_code != 200:
145
+ logger.warning("[Caption] %s: /understand result fetch failed: %d", fname, r.status_code)
146
+ return None
147
+ data = r.json()
148
+ # The result should contain caption, bpm, key, signature, lyrics
149
+ if isinstance(data, dict) and data.get("caption"):
150
+ return data
151
+ logger.warning("[Caption] %s: /understand returned no caption field", fname)
152
+ return None
153
+ except Exception as exc:
154
+ logger.warning("[Caption] %s: /understand result parse failed: %s", fname, exc)
155
+ return None
156
+
157
+
158
  def _run_pipeline(caption, lyrics, bpm, duration, seed, steps, output_format,
159
  adapter=None, lm_model=None, progress_cb=None):
160
  """Run full LM -> synth pipeline. Returns (audio_path, status_msg) or raises."""
 
518
  # -- Validation --
519
  if not audio_files:
520
  _log("[FAIL] No audio files uploaded.")
521
+ yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
522
  return
523
 
524
  if len(audio_files) > MAX_AUDIO_FILES:
525
  _log(f"[FAIL] Too many files ({len(audio_files)}). Max: {MAX_AUDIO_FILES}")
526
+ yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
527
  return
528
 
529
  lora_name = (lora_name or "").strip() or "my-lora"
530
  # Sanitize: alphanumeric, dash, underscore only
531
  lora_name = "".join(c if c.isalnum() or c in "-_" else "-" for c in lora_name)
532
+ # Append random suffix to prevent naming collisions between users
533
+ suffix = "".join(random.choices(string.ascii_lowercase + string.digits, k=4))
534
+ lora_name = f"{lora_name}-{suffix}"
535
 
536
  epochs = max(1, min(int(epochs), 10))
537
  lr = float(lr)
 
546
 
547
  # Copy uploaded audio files + check total duration
548
  _log(f"[INFO] Preparing {len(audio_files)} audio files...")
549
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
550
 
551
  import librosa as _lr
552
  total_dur = 0.0
 
591
 
592
  _log(f"[INFO] LoRA: '{lora_name}' | Files: {len(audio_files)} | "
593
  f"Epochs: {epochs} | LR: {lr} | Rank: {rank}")
594
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
595
+
596
+ # Caption each audio file via ace-server /understand BEFORE stopping it
597
+ if _server_ok():
598
+ _log("[INFO] Captioning audio via ace-server /understand...")
599
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
600
+ for audio_fname in sorted(os.listdir(audio_dir)):
601
+ full_path = os.path.join(audio_dir, audio_fname)
602
+ if not os.path.isfile(full_path) or audio_fname.endswith(".json"):
603
+ continue
604
+ caption_json_path = full_path + ".json"
605
+ caption_data = _caption_via_understand(full_path, timeout=120)
606
+ if caption_data:
607
+ _log(f"[Caption] {audio_fname}: using ace-server /understand")
608
+ with open(caption_json_path, "w") as cj:
609
+ json.dump(caption_data, cj)
610
+ else:
611
+ # Fallback to librosa for basic metadata
612
+ _log(f"[Caption] {audio_fname}: fallback to librosa")
613
+ try:
614
+ y_cap, sr_cap = _lr.load(full_path, sr=None, mono=True)
615
+ tempo, _ = _lr.beat.beat_track(y=y_cap, sr=sr_cap)
616
+ bpm_val = float(tempo) if hasattr(tempo, '__float__') else float(tempo[0])
617
+ fallback = {"caption": "", "bpm": round(bpm_val), "key": "", "signature": "", "lyrics": ""}
618
+ with open(caption_json_path, "w") as cj:
619
+ json.dump(fallback, cj)
620
+ except Exception as cap_exc:
621
+ _log(f"[Caption] {audio_fname}: librosa fallback also failed: {cap_exc}")
622
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
623
+ else:
624
+ _log("[INFO] ace-server not running, skipping /understand captioning")
625
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
626
 
627
  # Stop ace-server before training (frees memory)
628
  _log("[INFO] Stopping ace-server for training...")
629
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
630
  _stop_ace_server()
631
  _gc.collect()
632
 
633
  try:
634
  # -- Phase 1: Preprocessing --
635
  _log("[Step 1/2] Preprocessing audio...")
636
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
637
 
638
  preprocessed_dir = os.path.join(work_dir, "preprocessed_tensors")
639
 
 
650
  progress_callback=preprocess_progress,
651
  cancel_check=lambda: False,
652
  )
653
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
654
 
655
  processed = result.get("processed", 0)
656
  failed = result.get("failed", 0)
657
  total = result.get("total", 0)
658
  _log(f"[OK] Preprocessed: {processed}/{total} (failed: {failed})")
659
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
660
 
661
  if processed == 0:
662
  _log("[FAIL] No files preprocessed successfully. Cannot train.")
663
+ yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
664
  return
665
 
666
  _gc.collect()
667
 
668
  # -- Phase 2: Training --
669
  _log("[Step 2/2] Training LoRA...")
670
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
671
 
672
  for msg in train_lora_generator(
673
  dataset_dir=preprocessed_dir,
 
697
  break
698
 
699
  _log(msg)
700
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
701
 
702
  if msg.strip() == "[DONE]":
703
  break
704
 
705
  _log(f"[INFO] Total time: {time.time() - train_start:.0f}s")
706
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
707
 
708
  except Exception as exc:
709
  _log(f"[FAIL] Training error: {exc}")
710
  import traceback
711
  _log(traceback.format_exc())
712
+ yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
713
 
714
  finally:
715
  # Always restart ace-server
716
  _log("[INFO] Restarting ace-server...")
717
+ yield _log_text(), gr.Button(visible=False), gr.Button(visible=True), gr.File()
718
  _gc.collect()
719
  ok = _start_ace_server()
720
  if ok:
 
723
  _log("[WARN] ace-server may not have restarted -- check logs")
724
  adapter_safetensors = os.path.join(adapter_out, "adapter_model.safetensors")
725
  if os.path.isfile(adapter_safetensors):
726
+ # Copy to a temp file so Gradio doesn't try to validate /app paths
727
+ # (avoids InvalidPathError: "Cannot move /app to the gradio cache dir
728
+ # because it was not uploaded by a user")
729
+ tmp_out = tempfile.NamedTemporaryFile(
730
+ suffix=".safetensors",
731
+ prefix=f"{lora_name}_",
732
+ delete=False,
733
+ )
734
+ tmp_out.close()
735
+ shutil.copy2(adapter_safetensors, tmp_out.name)
736
+ _log(f"[OK] LoRA saved: {lora_name}")
737
+ yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File(value=tmp_out.name, visible=True)
738
  else:
739
+ yield _log_text(), gr.Button(visible=True), gr.Button(visible=False), gr.File()
740
 
741
  # -- Cancel handler --
742
  def _on_cancel():
 
760
  .status-box textarea { font-family: monospace; font-size: 13px; }
761
  """
762
 
763
+ with gr.Blocks(title="ACE-Step 1.5 XL (CPU)") as demo:
764
 
765
  with gr.Tabs():
766
  # ============================================================
 
880
  elem_classes="status-box",
881
  )
882
 
883
+ # Button swap on click (separate handler, like rvc-beatrice)
884
+ # This fires immediately so user sees Cancel even if training
885
+ # queues behind concurrency_limit=1
886
+ train_btn.click(
887
+ lambda: (gr.Button(visible=False), gr.Button(visible=True)),
888
+ outputs=[train_btn, cancel_btn],
889
+ )
890
+
891
  # Training generator -- yields (log, train_btn, cancel_btn, output_file)
892
  train_event = train_btn.click(
893
  train_lora_ui,
 
898
  )
899
 
900
  # After training completes, restore buttons and refresh LoRA dropdown
901
+ # This ensures cleanup even if the user navigated away
902
  def _post_training():
903
  return (
904
+ gr.Button(visible=True),
905
+ gr.Button(visible=False),
906
+ gr.Dropdown(choices=_list_lora_choices()),
907
  )
908
 
909
  train_event.then(
 
928
  server_name="0.0.0.0",
929
  server_port=7860,
930
  mcp_server=True,
931
+ css=CSS,
932
  )
933
 
934
 
train_engine.py CHANGED
@@ -15,6 +15,8 @@ Exports:
15
  train_lora_generator() - Generator-based LoRA training loop
16
  cancel_training() - Set the cancel flag
17
  get_trained_loras() - List saved adapters
 
 
18
  """
19
 
20
  from __future__ import annotations
@@ -2799,3 +2801,288 @@ def get_trained_loras(adapter_dir: str) -> List[str]:
2799
  break
2800
 
2801
  return sorted(set(result))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  train_lora_generator() - Generator-based LoRA training loop
16
  cancel_training() - Set the cancel flag
17
  get_trained_loras() - List saved adapters
18
+ generate_audio() - Standalone inference (text -> WAV, optional LoRA)
19
+ tiled_vae_decode() - Tiled VAE latent-to-waveform decode
20
  """
21
 
22
  from __future__ import annotations
 
2801
  break
2802
 
2803
  return sorted(set(result))
2804
+
2805
+
2806
+ # ============================================================================
2807
+ # TILED VAE DECODE (mirror of tiled_vae_encode)
2808
+ # ============================================================================
2809
+
2810
+ def tiled_vae_decode(
2811
+ vae, latents: torch.Tensor, dtype: torch.dtype,
2812
+ chunk_frames: int = 1024, overlap_frames: int = 64,
2813
+ ) -> torch.Tensor:
2814
+ """Decode latents [B, T, C] -> waveform [B, 2, samples] using tiled VAE.
2815
+
2816
+ Mirrors tiled_vae_encode but in the reverse direction. Tiles along
2817
+ the time axis of the latent to keep peak memory bounded.
2818
+
2819
+ Args:
2820
+ vae: AutoencoderOobleck decoder.
2821
+ latents: Latent tensor in [B, T, C] layout (C=64).
2822
+ dtype: Target dtype for the output waveform.
2823
+ chunk_frames: Number of latent frames per tile.
2824
+ overlap_frames: Overlap frames per side for crossfade.
2825
+
2826
+ Returns:
2827
+ Waveform tensor [B, 2, total_samples] in *dtype*.
2828
+ """
2829
+ vae_device = next(vae.parameters()).device
2830
+ vae_dtype = vae.dtype
2831
+
2832
+ # Transpose to VAE convention [B, C, T]
2833
+ lat = latents.transpose(1, 2).contiguous()
2834
+ B, C, T = lat.shape
2835
+
2836
+ if T <= chunk_frames:
2837
+ with torch.inference_mode():
2838
+ audio = vae.decode(lat.to(vae_device, dtype=vae_dtype)).sample
2839
+ return audio.to(dtype=dtype, device="cpu")
2840
+
2841
+ # Upsample factor: unknown until first decode, so we discover it.
2842
+ stride = chunk_frames - 2 * overlap_frames
2843
+ if stride <= 0:
2844
+ raise ValueError(f"chunk_frames ({chunk_frames}) must be > 2*overlap ({overlap_frames})")
2845
+
2846
+ num_tiles = math.ceil(T / stride)
2847
+ us_factor: Optional[float] = None
2848
+ write_pos = 0
2849
+ final: Optional[torch.Tensor] = None
2850
+
2851
+ for i in range(num_tiles):
2852
+ core_start = i * stride
2853
+ core_end = min(core_start + stride, T)
2854
+ win_start = max(0, core_start - overlap_frames)
2855
+ win_end = min(T, core_end + overlap_frames)
2856
+
2857
+ chunk = lat[:, :, win_start:win_end].to(vae_device, dtype=vae_dtype)
2858
+ with torch.inference_mode():
2859
+ decoded = vae.decode(chunk).sample # [B, 2, samples_chunk]
2860
+
2861
+ if us_factor is None:
2862
+ us_factor = decoded.shape[-1] / chunk.shape[-1]
2863
+ total_samples = int(round(T * us_factor))
2864
+ final = torch.zeros(B, decoded.shape[1], total_samples, dtype=decoded.dtype, device="cpu")
2865
+
2866
+ trim_start = int(round((core_start - win_start) * us_factor))
2867
+ trim_end = int(round((win_end - core_end) * us_factor))
2868
+ end_idx = decoded.shape[-1] - trim_end if trim_end > 0 else decoded.shape[-1]
2869
+ core = decoded[:, :, trim_start:end_idx]
2870
+ core_len = core.shape[-1]
2871
+ final[:, :, write_pos:write_pos + core_len] = core.cpu()
2872
+ write_pos += core_len
2873
+ del chunk, decoded, core
2874
+
2875
+ final = final[:, :, :write_pos]
2876
+ return final.to(dtype=dtype)
2877
+
2878
+
2879
+ # ============================================================================
2880
+ # INFERENCE -- generate_audio()
2881
+ # ============================================================================
2882
+
2883
+ def generate_audio(
2884
+ caption: str,
2885
+ checkpoint_dir: str,
2886
+ output_path: str,
2887
+ lyrics: str = "[Instrumental]",
2888
+ duration: float = 10.0,
2889
+ bpm: int = 120,
2890
+ steps: int = 8,
2891
+ seed: int = -1,
2892
+ variant: str = "turbo",
2893
+ device: str = "auto",
2894
+ adapter_path: Optional[str] = None,
2895
+ adapter_scale: float = 1.0,
2896
+ ) -> str:
2897
+ """Generate audio using the ACE-Step DiT pipeline (pure PyTorch, no server).
2898
+
2899
+ Pipeline:
2900
+ 1. Text encoder -> text_hidden_states, lyric embeddings
2901
+ 2. Load full model (DiT + condition encoder + FSQ)
2902
+ 3. Optional: inject LoRA adapter via PEFT
2903
+ 4. model.generate_audio() -- runs condition encoder, FSQ detokenizer,
2904
+ and the flow-matching diffusion loop internally
2905
+ 5. VAE decode latents -> waveform
2906
+ 6. Save waveform as 48 kHz stereo WAV
2907
+ 7. Unload all models, free memory
2908
+
2909
+ Args:
2910
+ caption: Text description of the desired music.
2911
+ checkpoint_dir: Root directory that contains model sub-dirs
2912
+ (e.g. ``acestep-v15-turbo/``, ``vae/``, ``Qwen3-Embedding-0.6B/``).
2913
+ output_path: Where to write the output WAV file.
2914
+ lyrics: Lyrics text or ``"[Instrumental]"`` for no vocals.
2915
+ duration: Desired audio length in seconds.
2916
+ bpm: Beats per minute (metadata hint for the model).
2917
+ steps: Number of diffusion steps (8 for turbo, 50 for base/SFT).
2918
+ seed: RNG seed (-1 = random).
2919
+ variant: Model variant name (e.g. ``"turbo"``, ``"base"``).
2920
+ device: ``"auto"``, ``"cpu"``, ``"cuda:0"``, etc.
2921
+ adapter_path: Path to a PEFT LoRA adapter directory (optional).
2922
+ adapter_scale: Scaling factor applied to the adapter.
2923
+
2924
+ Returns:
2925
+ The *output_path* string (for convenience).
2926
+ """
2927
+ import numpy as np
2928
+
2929
+ # ------------------------------------------------------------------
2930
+ # 0. Device / dtype
2931
+ # ------------------------------------------------------------------
2932
+ device = detect_device(device)
2933
+ dtype = select_dtype(device)
2934
+ logger.info(
2935
+ "generate_audio: device=%s, dtype=%s, variant=%s, steps=%d, duration=%.1fs",
2936
+ device, dtype, variant, steps, duration,
2937
+ )
2938
+
2939
+ # Resolve seed
2940
+ if seed < 0:
2941
+ seed = random.randint(0, 2**31 - 1)
2942
+ logger.info("Using seed=%d", seed)
2943
+
2944
+ # ------------------------------------------------------------------
2945
+ # 1. Text encoder -- encode caption and lyrics
2946
+ # ------------------------------------------------------------------
2947
+ logger.info("Loading text encoder...")
2948
+ tokenizer, text_encoder = load_text_encoder(checkpoint_dir, device)
2949
+
2950
+ text_hs, text_mask = encode_text(text_encoder, tokenizer, caption, device, dtype)
2951
+ lyric_hs, lyric_mask = encode_lyrics(text_encoder, tokenizer, lyrics, device, dtype)
2952
+
2953
+ # Free text encoder -- no longer needed
2954
+ unload_models(text_encoder)
2955
+ del text_encoder, tokenizer
2956
+ gc.collect()
2957
+ _clear_gpu_cache(device)
2958
+ logger.info("Text encoder unloaded.")
2959
+
2960
+ # ------------------------------------------------------------------
2961
+ # 2. Load full model (DiT + CondEncoder + FSQ tokenizer/detokenizer)
2962
+ # ------------------------------------------------------------------
2963
+ logger.info("Loading ACE-Step model (%s)...", variant)
2964
+ model = load_model_for_training(checkpoint_dir, variant=variant, device=device)
2965
+ model = model.to(dtype=dtype)
2966
+ model.eval()
2967
+
2968
+ # ------------------------------------------------------------------
2969
+ # 3. Optional: inject LoRA adapter
2970
+ # ------------------------------------------------------------------
2971
+ if adapter_path:
2972
+ logger.info("Loading LoRA adapter from %s (scale=%.2f)...", adapter_path, adapter_scale)
2973
+ from peft import PeftModel
2974
+
2975
+ decoder = model.decoder if hasattr(model, "decoder") else model
2976
+ # Unwrap any wrappers
2977
+ while hasattr(decoder, "_forward_module"):
2978
+ decoder = decoder._forward_module
2979
+ if hasattr(decoder, "base_model"):
2980
+ bm = decoder.base_model
2981
+ decoder = bm.model if hasattr(bm, "model") else bm
2982
+ if hasattr(decoder, "model") and isinstance(decoder.model, nn.Module):
2983
+ decoder = decoder.model
2984
+
2985
+ model.decoder = PeftModel.from_pretrained(
2986
+ decoder, adapter_path, is_trainable=False,
2987
+ )
2988
+ # Apply adapter scale if not 1.0
2989
+ if abs(adapter_scale - 1.0) > 1e-6:
2990
+ for name, module in model.decoder.named_modules():
2991
+ if hasattr(module, "scaling"):
2992
+ for key in module.scaling:
2993
+ module.scaling[key] = adapter_scale
2994
+ model.decoder.eval()
2995
+ logger.info("LoRA adapter applied.")
2996
+
2997
+ # ------------------------------------------------------------------
2998
+ # 4. Prepare inputs for model.generate_audio()
2999
+ # ------------------------------------------------------------------
3000
+ # Latent frame rate is 25 Hz
3001
+ LATENT_HZ = 25
3002
+ latent_length = int(duration * LATENT_HZ)
3003
+
3004
+ # Load silence latent for context building
3005
+ silence_latent = load_silence_latent(checkpoint_dir, device, variant)
3006
+ # Ensure silence latent covers the required length
3007
+ if silence_latent.shape[1] < latent_length:
3008
+ repeats = math.ceil(latent_length / silence_latent.shape[1])
3009
+ silence_latent = silence_latent.repeat(1, repeats, 1)
3010
+ silence_latent = silence_latent[:, :latent_length, :].to(device=device, dtype=dtype)
3011
+
3012
+ # Build source latents and masks for text2music mode (all silence, all-ones mask)
3013
+ src_latents = silence_latent[:1, :latent_length, :]
3014
+ chunk_masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype)
3015
+ is_covers = torch.zeros(1, device=device, dtype=torch.long)
3016
+
3017
+ # Dummy timbre reference (single silence frame -> no timbre conditioning)
3018
+ refer_audio = torch.zeros(1, 1, 64, device=device, dtype=dtype)
3019
+ refer_order = torch.zeros(1, device=device, dtype=torch.long)
3020
+
3021
+ # Shift schedule: turbo uses 3.0, base/sft uses 1.0
3022
+ shift = 3.0 if "turbo" in variant else 1.0
3023
+
3024
+ # ------------------------------------------------------------------
3025
+ # 5. Run diffusion (model.generate_audio handles everything internally)
3026
+ # ------------------------------------------------------------------
3027
+ logger.info("Running diffusion (%d steps, shift=%.1f)...", steps, shift)
3028
+ with torch.no_grad():
3029
+ result = model.generate_audio(
3030
+ text_hidden_states=text_hs.to(device=device, dtype=dtype),
3031
+ text_attention_mask=text_mask.to(device=device, dtype=dtype),
3032
+ lyric_hidden_states=lyric_hs.to(device=device, dtype=dtype),
3033
+ lyric_attention_mask=lyric_mask.to(device=device, dtype=dtype),
3034
+ refer_audio_acoustic_hidden_states_packed=refer_audio,
3035
+ refer_audio_order_mask=refer_order,
3036
+ src_latents=src_latents,
3037
+ chunk_masks=chunk_masks,
3038
+ is_covers=is_covers,
3039
+ silence_latent=silence_latent,
3040
+ seed=seed,
3041
+ fix_nfe=steps,
3042
+ shift=shift,
3043
+ )
3044
+
3045
+ target_latents = result["target_latents"] # [1, T, 64]
3046
+ time_costs = result.get("time_costs", {})
3047
+ logger.info("Diffusion done. Time costs: %s", time_costs)
3048
+
3049
+ # Free model weights -- keep latents on CPU
3050
+ target_latents = target_latents.cpu().to(dtype)
3051
+ unload_models(model)
3052
+ del model, silence_latent, src_latents, chunk_masks
3053
+ del text_hs, text_mask, lyric_hs, lyric_mask
3054
+ gc.collect()
3055
+ _clear_gpu_cache(device)
3056
+ logger.info("DiT model unloaded.")
3057
+
3058
+ # ------------------------------------------------------------------
3059
+ # 6. VAE decode latents -> waveform
3060
+ # ------------------------------------------------------------------
3061
+ logger.info("Loading VAE decoder...")
3062
+ vae = load_vae(checkpoint_dir, device)
3063
+
3064
+ logger.info("Decoding latents -> waveform (tiled)...")
3065
+ waveform = tiled_vae_decode(vae, target_latents.to(device), dtype) # [1, 2, samples]
3066
+
3067
+ unload_models(vae)
3068
+ del vae, target_latents
3069
+ gc.collect()
3070
+ _clear_gpu_cache(device)
3071
+ logger.info("VAE unloaded.")
3072
+
3073
+ # ------------------------------------------------------------------
3074
+ # 7. Save as WAV (48 kHz stereo)
3075
+ # ------------------------------------------------------------------
3076
+ audio_np = waveform[0].float().clamp(-1.0, 1.0).cpu().numpy() # [2, samples]
3077
+
3078
+ os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True)
3079
+ try:
3080
+ import soundfile as sf
3081
+ # soundfile expects [samples, channels]
3082
+ sf.write(output_path, audio_np.T, TARGET_SR, subtype="PCM_16")
3083
+ except ImportError:
3084
+ import torchaudio
3085
+ torchaudio.save(output_path, torch.from_numpy(audio_np), TARGET_SR)
3086
+
3087
+ logger.info("Audio saved to %s (%.1fs @ %d Hz)", output_path, duration, TARGET_SR)
3088
+ return output_path