Spaces:
Running
Running
| """ | |
| Standalone ACE-Step LoRA Training Engine (CPU + GPU). | |
| Ported from Side-Step (koda-dernet/Side-Step) into a single self-contained | |
| module. No external Side-Step dependency required. | |
| Source commits: | |
| Side-Step: koda-dernet/Side-Step @ ecd13bd (2026-04-19) | |
| acestep.cpp: ServeurpersoCom/acestep.cpp @ 36e4db1 (prompt/understand format) | |
| ACE-Step: ace-step/ACE-Step-1.5 (model architecture, training_v2) | |
| Auto-detects GPU (CUDA > MPS > CPU) and uses it when available, | |
| falling back to CPU. bfloat16 is used on GPU; float32 is forced | |
| on CPU (bfloat16 deadlocks on CPU -- known PyTorch bug). | |
| Exports: | |
| detect_device() - Auto-detect best available device | |
| select_dtype() - Pick dtype for a device | |
| preprocess_audio() - 2-pass sequential preprocessing | |
| train_lora_generator() - Generator-based LoRA training loop | |
| cancel_training() - Set the cancel flag | |
| get_trained_loras() - List saved adapters | |
| generate_audio() - Standalone inference (text -> WAV, optional LoRA) | |
| tiled_vae_decode() - Tiled VAE latent-to-waveform decode | |
| understand_audio() - Reverse pipeline (audio -> caption + lyrics) | |
| """ | |
| from __future__ import annotations | |
| import gc | |
| import json | |
| import logging | |
| import math | |
| import os | |
| import random | |
| import re | |
| import shutil | |
| import sys | |
| import tempfile | |
| import time | |
| import types | |
| import unicodedata | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Callable, Dict, Generator, List, Optional, Tuple | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.optim import AdamW | |
| from torch.optim.lr_scheduler import ( | |
| CosineAnnealingLR, | |
| LinearLR, | |
| SequentialLR, | |
| ) | |
| from torch.utils.data import DataLoader, Dataset | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # Configurable caps (edit these at the top of the file) | |
| # --------------------------------------------------------------------------- | |
| MAX_AUDIO_DURATION = 240.0 # seconds, cap per audio file | |
| MAX_TRAINING_TIME = 28800 # 8 hours hard timeout | |
| TARGET_SR = 48000 | |
| LATENT_HZ = 25 # latent frames per second (48000 / 1920) | |
| CHUNK_LATENT_MIN = 20 * LATENT_HZ # 500 frames (20s) | |
| CHUNK_LATENT_TARGET = 30 * LATENT_HZ # 750 frames (30s) | |
| CHUNK_LATENT_MAX = 40 * LATENT_HZ # 1000 frames (40s) | |
| AUDIO_EXTENSIONS = frozenset({".wav", ".mp3", ".flac", ".ogg", ".opus", ".m4a", ".aac"}) | |
| # bfloat16 deadlocks on CPU (known PyTorch bug) -- force float32 | |
| CPU_DTYPE = torch.float32 | |
| import threading | |
| _training_cancel = threading.Event() | |
| def cancel_training() -> None: | |
| _training_cancel.set() | |
| # ============================================================================ | |
| # DEVICE DETECTION & DTYPE SELECTION | |
| # ============================================================================ | |
| def detect_device(requested: str = "auto") -> str: | |
| """Return the best available device string. | |
| Priority: CUDA (best GPU by VRAM) > MPS (Apple Silicon) > CPU. | |
| Pass an explicit device string (e.g. "cuda:0", "cpu") to skip | |
| auto-detection. | |
| """ | |
| if requested != "auto": | |
| return requested | |
| if torch.cuda.is_available(): | |
| # Pick the GPU with the most VRAM when multiple are present | |
| count = torch.cuda.device_count() | |
| if count <= 1: | |
| best_idx = 0 | |
| else: | |
| best_idx, best_mem = 0, 0 | |
| for i in range(count): | |
| mem = torch.cuda.get_device_properties(i).total_memory | |
| if mem > best_mem: | |
| best_idx, best_mem = i, mem | |
| if best_idx != 0: | |
| logger.info( | |
| "Multiple CUDA devices (%d). Selected cuda:%d (%s, %.0f MiB).", | |
| count, best_idx, | |
| torch.cuda.get_device_name(best_idx), | |
| best_mem / (1024 ** 2), | |
| ) | |
| device = f"cuda:{best_idx}" | |
| logger.info("Auto-detected device: %s (%s)", device, torch.cuda.get_device_name(best_idx)) | |
| return device | |
| if hasattr(torch, "mps") and hasattr(torch.mps, "is_available") and torch.mps.is_available(): | |
| logger.info("Auto-detected device: mps (Apple Silicon)") | |
| return "mps" | |
| logger.info("Auto-detected device: cpu") | |
| return "cpu" | |
| def select_dtype(device: str) -> torch.dtype: | |
| """Select the appropriate training dtype for *device*. | |
| GPU: bfloat16 if supported, else float16. | |
| CPU: MUST stay float32 (bfloat16 deadlocks on CPU). | |
| """ | |
| dev_type = device.split(":")[0] | |
| if dev_type == "cpu": | |
| return CPU_DTYPE # always float32 | |
| if dev_type == "cuda": | |
| # Prefer bfloat16 on Ampere+ (compute capability >= 8.0) | |
| try: | |
| idx = int(device.split(":")[1]) if ":" in device else 0 | |
| props = torch.cuda.get_device_properties(idx) | |
| if props.major >= 8: | |
| return torch.bfloat16 | |
| except Exception: | |
| pass | |
| return torch.float16 | |
| # MPS / other accelerators -- float32 is safest | |
| if dev_type == "mps": | |
| return torch.float32 | |
| return CPU_DTYPE | |
| def _cuda_sync(device: str) -> None: | |
| """Synchronize CUDA if the device is a CUDA device (no-op otherwise).""" | |
| if device.startswith("cuda") and torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| def _clear_gpu_cache(device: str) -> None: | |
| """Free cached GPU memory for the given device type.""" | |
| dev_type = device.split(":")[0] | |
| if dev_type == "cuda" and torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| elif dev_type == "mps" and hasattr(torch, "mps") and torch.mps.is_available(): | |
| torch.mps.empty_cache() | |
| # ============================================================================ | |
| # CONFIGS | |
| # ============================================================================ | |
| class LoRAConfig: | |
| r: int = 64 | |
| alpha: int = 128 | |
| dropout: float = 0.1 | |
| target_modules: List[str] = field(default_factory=lambda: [ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| ]) | |
| bias: str = "none" | |
| attention_type: str = "both" | |
| target_mlp: bool = True | |
| # ============================================================================ | |
| # TIMESTEP SAMPLING & CFG DROPOUT | |
| # ============================================================================ | |
| def sample_timesteps( | |
| batch_size: int, | |
| device: torch.device, | |
| dtype: torch.dtype, | |
| timestep_mu: float = -0.4, | |
| timestep_sigma: float = 1.0, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| t = torch.sigmoid( | |
| torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu | |
| ) | |
| r = torch.sigmoid( | |
| torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu | |
| ) | |
| t, r = torch.maximum(t, r), torch.minimum(t, r) | |
| # use_meanflow=False forces r=t (ACE-Step convention) | |
| return t, t | |
| def apply_cfg_dropout( | |
| encoder_hidden_states: torch.Tensor, | |
| null_condition_emb: torch.Tensor, | |
| cfg_ratio: float = 0.15, | |
| ) -> torch.Tensor: | |
| bsz = encoder_hidden_states.shape[0] | |
| device = encoder_hidden_states.device | |
| dtype = encoder_hidden_states.dtype | |
| mask = torch.where( | |
| torch.rand(size=(bsz,), device=device, dtype=dtype) < cfg_ratio, | |
| torch.zeros(size=(bsz,), device=device, dtype=dtype), | |
| torch.ones(size=(bsz,), device=device, dtype=dtype), | |
| ).view(-1, 1, 1) | |
| return torch.where( | |
| mask > 0, | |
| encoder_hidden_states, | |
| null_condition_emb.expand_as(encoder_hidden_states), | |
| ) | |
| # ============================================================================ | |
| # OPTIMIZER (Adafactor preferred for CPU -- 1.5 bytes/param) | |
| # ============================================================================ | |
| def build_optimizer( | |
| params, lr: float = 1e-4, weight_decay: float = 0.01, | |
| ) -> torch.optim.Optimizer: | |
| try: | |
| from transformers.optimization import Adafactor | |
| logger.info("Using Adafactor optimizer (minimal state memory)") | |
| return Adafactor( | |
| params, lr=lr, weight_decay=weight_decay, | |
| scale_parameter=False, relative_step=False, | |
| ) | |
| except ImportError: | |
| logger.warning("transformers not installed, falling back to AdamW") | |
| return AdamW(params, lr=lr, weight_decay=weight_decay) | |
| def build_scheduler( | |
| optimizer, total_steps: int, warmup_steps: int, lr: float, | |
| ): | |
| _max_warmup = max(1, total_steps // 10) | |
| if warmup_steps > _max_warmup: | |
| warmup_steps = _max_warmup | |
| warmup = LinearLR(optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps) | |
| remaining = max(1, total_steps - warmup_steps) | |
| main = CosineAnnealingLR(optimizer, T_max=remaining, eta_min=lr * 0.01) | |
| return SequentialLR(optimizer, [warmup, main], milestones=[warmup_steps]) | |
| # ============================================================================ | |
| # DATASET | |
| # ============================================================================ | |
| def _collate_batch(batch: List[Dict]) -> Dict[str, torch.Tensor]: | |
| max_t = max(s["target_latents"].shape[0] for s in batch) | |
| max_e = max(s["encoder_hidden_states"].shape[0] for s in batch) | |
| def pad(t, max_len, dim=0): | |
| diff = max_len - t.shape[dim] | |
| if diff <= 0: | |
| return t | |
| shape = list(t.shape) | |
| shape[dim] = diff | |
| return torch.cat([t, t.new_zeros(*shape)], dim=dim) | |
| return { | |
| "target_latents": torch.stack([pad(s["target_latents"], max_t) for s in batch]), | |
| "attention_mask": torch.stack([pad(s["attention_mask"], max_t) for s in batch]), | |
| "encoder_hidden_states": torch.stack([pad(s["encoder_hidden_states"], max_e) for s in batch]), | |
| "encoder_attention_mask": torch.stack([pad(s["encoder_attention_mask"], max_e) for s in batch]), | |
| "context_latents": torch.stack([pad(s["context_latents"], max_t) for s in batch]), | |
| } | |
| class TensorDataset(Dataset): | |
| _REQUIRED = frozenset([ | |
| "target_latents", "attention_mask", "encoder_hidden_states", | |
| "encoder_attention_mask", "context_latents", | |
| ]) | |
| def __init__(self, tensor_dir: str): | |
| self.paths: List[str] = [] | |
| for f in sorted(os.listdir(tensor_dir)): | |
| if f.endswith(".pt") and not f.endswith(".tmp.pt") and f != "manifest.json": | |
| self.paths.append(str(Path(tensor_dir) / f)) | |
| def __len__(self) -> int: | |
| return len(self.paths) | |
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
| data = torch.load(self.paths[idx], map_location="cpu", weights_only=True) | |
| missing = self._REQUIRED - data.keys() | |
| if missing: | |
| raise KeyError(f"Missing keys {sorted(missing)} in {self.paths[idx]}") | |
| for k in ("target_latents", "encoder_hidden_states", "context_latents"): | |
| t = data[k] | |
| if torch.isnan(t).any() or torch.isinf(t).any(): | |
| t.nan_to_num_(nan=0.0, posinf=0.0, neginf=0.0) | |
| return {k: data[k] for k in self._REQUIRED} | |
| # ============================================================================ | |
| # GRADIENT CHECKPOINTING | |
| # ============================================================================ | |
| def _find_decoder_layers(decoder: nn.Module) -> Optional[nn.ModuleList]: | |
| for attr in ("layers", "blocks", "transformer_blocks"): | |
| c = getattr(decoder, attr, None) | |
| if isinstance(c, nn.ModuleList) and len(c) > 0: | |
| return c | |
| for child in decoder.children(): | |
| for attr in ("layers", "blocks", "transformer_blocks"): | |
| c = getattr(child, attr, None) | |
| if isinstance(c, nn.ModuleList) and len(c) > 0: | |
| return c | |
| return None | |
| def enable_gradient_checkpointing(decoder: nn.Module) -> bool: | |
| """Enable gradient checkpointing on the decoder to save memory.""" | |
| enabled = False | |
| # Walk wrapper chain | |
| stack = [decoder] | |
| visited = set() | |
| while stack: | |
| mod = stack.pop() | |
| if not isinstance(mod, nn.Module): | |
| continue | |
| mid = id(mod) | |
| if mid in visited: | |
| continue | |
| visited.add(mid) | |
| if hasattr(mod, "gradient_checkpointing_enable"): | |
| try: | |
| mod.gradient_checkpointing_enable() | |
| enabled = True | |
| except Exception: | |
| pass | |
| elif hasattr(mod, "gradient_checkpointing"): | |
| try: | |
| mod.gradient_checkpointing = True | |
| enabled = True | |
| except Exception: | |
| pass | |
| if hasattr(mod, "enable_input_require_grads"): | |
| try: | |
| mod.enable_input_require_grads() | |
| except Exception: | |
| pass | |
| cfg = getattr(mod, "config", None) | |
| if cfg is not None and hasattr(cfg, "use_cache"): | |
| try: | |
| cfg.use_cache = False | |
| except Exception: | |
| pass | |
| for a in ("_forward_module", "_orig_mod", "base_model", "model", "module"): | |
| child = getattr(mod, a, None) | |
| if isinstance(child, nn.Module): | |
| stack.append(child) | |
| return enabled | |
| # ============================================================================ | |
| # LORA INJECTION (PEFT only -- no DoRA/LoKR/LoHA/OFT) | |
| # ============================================================================ | |
| def _unwrap_decoder(model): | |
| decoder = model.decoder if hasattr(model, "decoder") else model | |
| while hasattr(decoder, "_forward_module"): | |
| decoder = decoder._forward_module | |
| if hasattr(decoder, "base_model"): | |
| bm = decoder.base_model | |
| decoder = bm.model if hasattr(bm, "model") else bm | |
| if hasattr(decoder, "model") and isinstance(decoder.model, nn.Module): | |
| decoder = decoder.model | |
| return decoder | |
| def inject_lora(model, lora_cfg: LoRAConfig) -> Tuple[Any, Dict[str, Any]]: | |
| from peft import get_peft_model, LoraConfig as PeftLoraConfig, TaskType | |
| decoder = _unwrap_decoder(model) | |
| model.decoder = decoder | |
| # Guard enable_input_require_grads for DiT (no get_input_embeddings) | |
| if hasattr(decoder, "enable_input_require_grads"): | |
| orig = decoder.enable_input_require_grads | |
| def _safe(self): | |
| try: | |
| return orig() | |
| except NotImplementedError: | |
| return None | |
| decoder.enable_input_require_grads = types.MethodType(_safe, decoder) | |
| if hasattr(decoder, "is_gradient_checkpointing"): | |
| try: | |
| decoder.is_gradient_checkpointing = False | |
| except Exception: | |
| pass | |
| peft_cfg = PeftLoraConfig( | |
| r=lora_cfg.r, | |
| lora_alpha=lora_cfg.alpha, | |
| lora_dropout=lora_cfg.dropout, | |
| target_modules=lora_cfg.target_modules, | |
| bias=lora_cfg.bias, | |
| task_type=TaskType.FEATURE_EXTRACTION, | |
| ) | |
| model.decoder = get_peft_model(decoder, peft_cfg) | |
| for name, param in model.named_parameters(): | |
| if "lora_" not in name: | |
| param.requires_grad = False | |
| total = sum(p.numel() for p in model.parameters()) | |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| return model, { | |
| "total_params": total, | |
| "trainable_params": trainable, | |
| "trainable_ratio": trainable / total if total > 0 else 0, | |
| } | |
| def save_lora_adapter(model, output_dir: str) -> None: | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Use the PEFT-wrapped decoder (model.decoder), NOT the unwrapped base model. | |
| # _unwrap_decoder strips the PEFT wrapper, causing save_pretrained to save | |
| # the full model instead of just the LoRA adapter weights. | |
| decoder = model.decoder if hasattr(model, "decoder") else model | |
| if hasattr(decoder, "save_pretrained"): | |
| decoder.save_pretrained(output_dir) | |
| # Scrub base_model path for portability | |
| cfg_path = os.path.join(output_dir, "adapter_config.json") | |
| if os.path.isfile(cfg_path): | |
| try: | |
| with open(cfg_path, "r") as f: | |
| cfg = json.load(f) | |
| if cfg.get("base_model_name_or_path"): | |
| cfg["base_model_name_or_path"] = "" | |
| with open(cfg_path, "w") as f: | |
| json.dump(cfg, f, indent=2) | |
| except Exception: | |
| pass | |
| logger.info("LoRA adapter saved to %s", output_dir) | |
| else: | |
| # Fallback: manual extraction | |
| state = {} | |
| for name, param in decoder.named_parameters(): | |
| if "lora_" in name: | |
| state[name] = param.data.clone() | |
| if state: | |
| try: | |
| from safetensors.torch import save_file | |
| save_file(state, str(Path(output_dir) / "adapter_model.safetensors")) | |
| except ImportError: | |
| torch.save(state, str(Path(output_dir) / "lora_weights.pt")) | |
| logger.info("LoRA adapter saved (fallback) to %s", output_dir) | |
| # ============================================================================ | |
| # MODEL LOADING (FA2 -> SDPA -> eager fallback) | |
| # ============================================================================ | |
| _VARIANT_DIR = { | |
| "turbo": "acestep-v15-turbo", | |
| "xl-turbo": "acestep-v15-xl-turbo", | |
| "base": "acestep-v15-base", | |
| "xl-base": "acestep-v15-xl-base", | |
| "sft": "acestep-v15-sft", | |
| "xl-sft": "acestep-v15-xl-sft", | |
| } | |
| def _resolve_model_dir(checkpoint_dir: str, variant: str) -> Path: | |
| base = Path(checkpoint_dir).resolve() | |
| subdir = _VARIANT_DIR.get(variant) | |
| if subdir: | |
| p = (Path(checkpoint_dir) / subdir).resolve() | |
| if p.is_dir(): | |
| return p | |
| p = (Path(checkpoint_dir) / variant).resolve() | |
| if p.is_dir(): | |
| return p | |
| raise FileNotFoundError( | |
| f"Model directory not found: tried {_VARIANT_DIR.get(variant, variant)!r} " | |
| f"and {variant!r} under {checkpoint_dir}" | |
| ) | |
| def _ensure_acestep_imports(): | |
| """Register stub modules so AutoModel can load ACE-Step checkpoints.""" | |
| for name in ( | |
| "acestep", "acestep.models", "acestep.models.common", | |
| "acestep.models.xl_base", "acestep.models.xl_turbo", "acestep.models.xl_sft", | |
| ): | |
| if name not in sys.modules: | |
| stub = types.ModuleType(name) | |
| stub.__path__ = [] | |
| sys.modules[name] = stub | |
| # Try to load real modules from adjacent ACE-Step checkout | |
| for name in ( | |
| "acestep.models.common.configuration_acestep_v15", | |
| "acestep.models.common.apg_guidance", | |
| ): | |
| if name not in sys.modules: | |
| sys.modules[name] = types.ModuleType(name) | |
| def _attn_candidates(device: str) -> List[str]: | |
| """SDPA -> FA2 -> eager, filtered by availability. | |
| SDPA is preferred (faster on Blackwell SM12.0, native cuDNN). | |
| FA2 is fallback for older GPUs where SDPA is slower. | |
| On CPU, only SDPA and eager are tried. | |
| """ | |
| candidates = [] | |
| if device.startswith("cuda"): | |
| candidates.append("sdpa") | |
| try: | |
| import flash_attn # noqa: F401 | |
| dev_idx = int(device.split(":")[1]) if ":" in device else 0 | |
| props = torch.cuda.get_device_properties(dev_idx) | |
| if props.major >= 8 and props.major < 12: | |
| # FA2 is faster on Ampere/Hopper (SM 8.x-9.x), slower on Blackwell (SM 12.x) | |
| candidates.insert(0, "flash_attention_2") | |
| logger.info("FA2 prioritized (compute %d.%d, Ampere/Hopper)", props.major, props.minor) | |
| else: | |
| logger.info("FA2 available but SDPA preferred (compute %d.%d)", props.major, props.minor) | |
| except ImportError: | |
| logger.info("flash_attention_2 skipped: flash_attn package not installed") | |
| except Exception as exc: | |
| logger.info("flash_attention_2 skipped: %s", exc) | |
| else: | |
| candidates.append("sdpa") | |
| logger.info("flash_attention_2 skipped: device is %s (not CUDA)", device) | |
| if "eager" not in candidates: | |
| candidates.append("eager") | |
| return list(dict.fromkeys(candidates)) | |
| def load_model_for_training( | |
| checkpoint_dir: str, variant: str = "base", device: str = "cpu", | |
| ) -> Any: | |
| from transformers import AutoModel | |
| model_dir = _resolve_model_dir(checkpoint_dir, variant) | |
| dtype = select_dtype(device) | |
| logger.info( | |
| "Loading model from %s (variant=%s, device=%s, dtype=%s)", | |
| model_dir, variant, device, dtype, | |
| ) | |
| _ensure_acestep_imports() | |
| candidates = _attn_candidates(device) | |
| model = None | |
| last_err = None | |
| for idx, attn in enumerate(candidates): | |
| try: | |
| load_kwargs = dict( | |
| # SECURITY: trust_remote_code=True is required because the | |
| # ACE-Step model config references custom Python code in its | |
| # checkpoint (config.json -> auto_map). Only load checkpoints | |
| # from trusted sources (the official ACE-Step HF repo). | |
| trust_remote_code=True, | |
| attn_implementation=attn, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=False, | |
| ) | |
| if device != "cpu": | |
| load_kwargs["device_map"] = {"": device} | |
| model = AutoModel.from_pretrained(str(model_dir), **load_kwargs) | |
| logger.info("Model loaded with attn_implementation=%s on %s", attn, device) | |
| break | |
| except Exception as exc: | |
| err_text = str(exc) | |
| if "packages that were not found" in err_text or "No module named" in err_text: | |
| raise RuntimeError( | |
| f"Model files in {model_dir} require a missing Python package.\n" | |
| f" Original error: {err_text}" | |
| ) from exc | |
| last_err = exc | |
| next_attn = candidates[idx + 1] if idx + 1 < len(candidates) else None | |
| if next_attn: | |
| logger.warning("attn backend '%s' failed: %s; trying '%s'", attn, exc, next_attn) | |
| else: | |
| logger.warning("attn backend '%s' failed: %s", attn, exc) | |
| if model is None: | |
| raise RuntimeError(f"Failed to load model from {model_dir}: {last_err}") from last_err | |
| # If device_map was not used (CPU), move model explicitly | |
| if device != "cpu": | |
| # device_map already placed weights; just verify dtype | |
| if any(p.dtype != dtype for p in model.parameters()): | |
| model = model.to(dtype=dtype) | |
| else: | |
| model = model.to(device=device, dtype=dtype) | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| model.eval() | |
| return model | |
| def load_vae(checkpoint_dir: str, device: str = "cpu"): | |
| from diffusers.models import AutoencoderOobleck | |
| vae_path = Path(checkpoint_dir) / "vae" | |
| if not vae_path.is_dir(): | |
| raise FileNotFoundError(f"VAE directory not found: {vae_path}") | |
| dtype = select_dtype(device) | |
| vae = AutoencoderOobleck.from_pretrained(str(vae_path), torch_dtype=dtype) | |
| vae = vae.to(device=device) | |
| vae.eval() | |
| logger.info("VAE loaded on %s (dtype=%s)", device, dtype) | |
| return vae | |
| def load_text_encoder(checkpoint_dir: str, device: str = "cpu"): | |
| from transformers import AutoModel, AutoTokenizer | |
| text_path = Path(checkpoint_dir) / "Qwen3-Embedding-0.6B" | |
| if not text_path.is_dir(): | |
| raise FileNotFoundError(f"Text encoder not found: {text_path}") | |
| dtype = select_dtype(device) | |
| tokenizer = AutoTokenizer.from_pretrained(str(text_path)) | |
| encoder = AutoModel.from_pretrained(str(text_path), torch_dtype=dtype) | |
| encoder = encoder.to(device=device) | |
| encoder.eval() | |
| logger.info("Text encoder loaded on %s (dtype=%s)", device, dtype) | |
| return tokenizer, encoder | |
| def load_silence_latent( | |
| checkpoint_dir: str, device: str = "cpu", variant: str = "base", | |
| ) -> torch.Tensor: | |
| ckpt = Path(checkpoint_dir) | |
| dtype = select_dtype(device) | |
| candidates = [ckpt / "silence_latent.pt"] | |
| subdir = _VARIANT_DIR.get(variant) | |
| if subdir: | |
| candidates.append(ckpt / subdir / "silence_latent.pt") | |
| for sd in _VARIANT_DIR.values(): | |
| candidates.append(ckpt / sd / "silence_latent.pt") | |
| for c in candidates: | |
| if c.is_file(): | |
| sl = torch.load(str(c), weights_only=True).transpose(1, 2) | |
| return sl.to(device=device, dtype=dtype) | |
| raise FileNotFoundError(f"silence_latent.pt not found under {ckpt}") | |
| def unload_models(*models) -> None: | |
| for obj in models: | |
| if obj is None: | |
| continue | |
| if hasattr(obj, "to"): | |
| try: | |
| obj.to("cpu") | |
| except Exception: | |
| pass | |
| del obj | |
| gc.collect() | |
| # Free GPU memory after unloading | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| if hasattr(torch, "mps") and hasattr(torch.mps, "is_available") and torch.mps.is_available(): | |
| try: | |
| torch.mps.empty_cache() | |
| except Exception: | |
| pass | |
| # ============================================================================ | |
| # AUDIO LOADING | |
| # ============================================================================ | |
| def load_audio_stereo( | |
| audio_path: str, target_sr: int, max_duration: float, | |
| ) -> Tuple[torch.Tensor, int]: | |
| import numpy as np | |
| try: | |
| import soundfile as sf | |
| data, sr = sf.read(audio_path, dtype="float32", always_2d=True) | |
| audio_np = np.ascontiguousarray(data.T) | |
| sr = int(sr) | |
| if sr != target_sr: | |
| import librosa | |
| audio_np = librosa.resample(audio_np, orig_sr=sr, target_sr=target_sr, axis=1) | |
| sr = target_sr | |
| audio = torch.from_numpy(np.ascontiguousarray(audio_np)) | |
| except Exception: | |
| import torchaudio | |
| audio, sr = torchaudio.load(audio_path) | |
| sr = int(sr) | |
| if sr != target_sr: | |
| audio = torchaudio.transforms.Resample(sr, target_sr)(audio) | |
| sr = target_sr | |
| if audio.shape[0] == 1: | |
| audio = audio.repeat(2, 1) | |
| elif audio.shape[0] > 2: | |
| audio = audio[:2, :] | |
| max_samples = int(max_duration * target_sr) | |
| if audio.shape[1] > max_samples: | |
| audio = audio[:, :max_samples] | |
| return audio, sr | |
| # ============================================================================ | |
| # TEXT / LYRICS ENCODING | |
| # ============================================================================ | |
| def encode_text(text_encoder, tokenizer, text_prompt: str, device, dtype): | |
| inputs = tokenizer( | |
| text_prompt, padding="max_length", max_length=256, | |
| truncation=True, return_tensors="pt", | |
| ) | |
| ids = inputs.input_ids.to(device) | |
| mask = inputs.attention_mask.to(device).to(dtype) | |
| enc_dev = next(text_encoder.parameters()).device | |
| if ids.device != enc_dev: | |
| ids = ids.to(enc_dev) | |
| mask = mask.to(enc_dev) | |
| with torch.no_grad(): | |
| hs = text_encoder(ids).last_hidden_state.to(dtype) | |
| return hs, mask | |
| def encode_lyrics(text_encoder, tokenizer, lyrics: str, device, dtype): | |
| inputs = tokenizer( | |
| lyrics, padding="max_length", max_length=512, | |
| truncation=True, return_tensors="pt", | |
| ) | |
| ids = inputs.input_ids.to(device) | |
| mask = inputs.attention_mask.to(device).to(dtype) | |
| enc_dev = next(text_encoder.parameters()).device | |
| if ids.device != enc_dev: | |
| ids = ids.to(enc_dev) | |
| mask = mask.to(enc_dev) | |
| with torch.no_grad(): | |
| hs = text_encoder.embed_tokens(ids).to(dtype) | |
| return hs, mask | |
| # ============================================================================ | |
| # AUDIO CHUNKING (split long audio into ~30s training samples) | |
| # ============================================================================ | |
| CHUNK_MIN_SAMPLES = 20 * TARGET_SR # 20s | |
| CHUNK_MAX_SAMPLES = 40 * TARGET_SR # 40s | |
| def _chunk_audio(audio: torch.Tensor) -> List[torch.Tensor]: | |
| """Split a [C, S] audio tensor into ~30s chunks for faster training. | |
| Uses RMS energy to find the quietest point within the 20-40s window | |
| around each cut, avoiding cuts through loud notes. | |
| Short files (<=40s) are returned as-is. | |
| """ | |
| S = audio.shape[-1] | |
| if S <= CHUNK_MAX_SAMPLES: | |
| return [audio] | |
| mono = audio.mean(dim=0) # [S] | |
| hop = TARGET_SR // 10 # 0.1s resolution | |
| frame_count = S // hop | |
| rms = torch.zeros(frame_count) | |
| for fi in range(frame_count): | |
| seg = mono[fi * hop:(fi + 1) * hop] | |
| rms[fi] = seg.pow(2).mean().sqrt() | |
| min_frames = 20 * 10 # 20s in 0.1s frames | |
| max_frames = 40 * 10 # 40s | |
| chunks = [] | |
| pos = 0 | |
| while pos < frame_count: | |
| remaining = frame_count - pos | |
| if remaining <= max_frames: | |
| chunks.append(audio[:, pos * hop:]) | |
| break | |
| search_start = pos + min_frames | |
| search_end = min(pos + max_frames, frame_count) | |
| window = rms[search_start:search_end] | |
| cut = search_start + window.argmin().item() | |
| # If cutting here leaves a tail shorter than 20s, take it all | |
| tail = frame_count - cut | |
| if tail < min_frames: | |
| chunks.append(audio[:, pos * hop:]) | |
| break | |
| chunks.append(audio[:, pos * hop:cut * hop]) | |
| pos = cut | |
| return chunks | |
| # ============================================================================ | |
| # VAE TILED ENCODING | |
| # ============================================================================ | |
| def tiled_vae_encode( | |
| vae, audio: torch.Tensor, dtype: torch.dtype, | |
| chunk_size: Optional[int] = None, overlap: int = 96000, | |
| ) -> torch.Tensor: | |
| vae_device = next(vae.parameters()).device | |
| vae_dtype = vae.dtype | |
| if chunk_size is None: | |
| chunk_size = TARGET_SR * 30 | |
| B, C, S = audio.shape | |
| if S <= chunk_size: | |
| vae_input = audio.to(vae_device, dtype=vae_dtype) | |
| with torch.inference_mode(): | |
| latents = vae.encode(vae_input).latent_dist.sample() | |
| return latents.transpose(1, 2).to(dtype) | |
| stride = chunk_size - 2 * overlap | |
| if stride <= 0: | |
| raise ValueError(f"chunk_size ({chunk_size}) must be > 2 * overlap ({overlap})") | |
| num_steps = math.ceil(S / stride) | |
| ds_factor = None | |
| write_pos = 0 | |
| final = None | |
| for i in range(num_steps): | |
| core_start = i * stride | |
| core_end = min(core_start + stride, S) | |
| win_start = max(0, core_start - overlap) | |
| win_end = min(S, core_end + overlap) | |
| chunk = audio[:, :, win_start:win_end].to(vae_device, dtype=vae_dtype) | |
| with torch.inference_mode(): | |
| lat = vae.encode(chunk).latent_dist.sample() | |
| if ds_factor is None: | |
| ds_factor = chunk.shape[-1] / lat.shape[-1] | |
| total_len = int(round(S / ds_factor)) | |
| final = torch.zeros(B, lat.shape[1], total_len, dtype=lat.dtype, device="cpu") | |
| trim_start = int(round((core_start - win_start) / ds_factor)) | |
| trim_end = int(round((win_end - core_end) / ds_factor)) | |
| end_idx = lat.shape[-1] - trim_end if trim_end > 0 else lat.shape[-1] | |
| core = lat[:, :, trim_start:end_idx] | |
| core_len = core.shape[-1] | |
| final[:, :, write_pos:write_pos + core_len] = core.cpu() | |
| write_pos += core_len | |
| del chunk, lat, core | |
| final = final[:, :, :write_pos] | |
| return final.transpose(1, 2).to(dtype) | |
| # ============================================================================ | |
| # ENCODER / CONTEXT HELPERS | |
| # ============================================================================ | |
| def run_encoder( | |
| model, text_hs, text_mask, lyric_hs, lyric_mask, device, dtype, | |
| ): | |
| refer = torch.zeros(1, 1, 64, device=device, dtype=dtype) | |
| order_mask = torch.zeros(1, device=device, dtype=torch.long) | |
| with torch.no_grad(): | |
| enc_hs, enc_mask = model.encoder( | |
| text_hidden_states=text_hs, | |
| text_attention_mask=text_mask, | |
| lyric_hidden_states=lyric_hs, | |
| lyric_attention_mask=lyric_mask, | |
| refer_audio_acoustic_hidden_states_packed=refer, | |
| refer_audio_order_mask=order_mask, | |
| ) | |
| return enc_hs, enc_mask | |
| def build_context_latents(silence_latent, latent_length: int, device, dtype): | |
| src = silence_latent[:, :latent_length, :].to(dtype) | |
| if src.shape[0] < 1: | |
| src = src.expand(1, -1, -1) | |
| if src.shape[1] < latent_length: | |
| pad_len = latent_length - src.shape[1] | |
| src = torch.cat([src, silence_latent[:, :pad_len, :].expand(1, -1, -1).to(dtype)], dim=1) | |
| elif src.shape[1] > latent_length: | |
| src = src[:, :latent_length, :] | |
| masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype) | |
| return torch.cat([src, masks], dim=-1) | |
| # ============================================================================ | |
| # AUDIO DISCOVERY | |
| # ============================================================================ | |
| def _discover_audio_files(audio_dir: str) -> List[Path]: | |
| files = [] | |
| for root, _, names in os.walk(audio_dir): | |
| for name in sorted(names): | |
| if Path(name).suffix.lower() in AUDIO_EXTENSIONS: | |
| files.append(Path(root) / name) | |
| return files | |
| def _detect_max_duration(files: List[Path]) -> float: | |
| """Return the longest audio file duration (capped at MAX_AUDIO_DURATION).""" | |
| max_dur = 0.0 | |
| try: | |
| import soundfile as sf | |
| for f in files[:50]: | |
| try: | |
| info = sf.info(str(f)) | |
| max_dur = max(max_dur, info.duration) | |
| except Exception: | |
| pass | |
| except ImportError: | |
| pass | |
| return min(max_dur if max_dur > 0 else MAX_AUDIO_DURATION, MAX_AUDIO_DURATION) | |
| # ============================================================================ | |
| # AUDIO ANALYSIS (ported from Side-Step -- faf / mid / sas modes) | |
| # ============================================================================ | |
| # | |
| # faf ("Fast As F*ck") ~2-3 s/file - single-method, no Demucs | |
| # mid ~12 s/file - 3-method ensemble, Demucs stems | |
| # sas ("Smart/Slow As Sh*t") ~30 s/file - deep multi-technique + chunked | |
| # | |
| # Demucs on CPU is SLOW (~2-5 min/file). mid/sas are designed for GPU | |
| # stem separation but will still work on CPU -- just much slower. | |
| # ============================================================================ | |
| _ANALYSIS_MODES = ("faf", "mid", "sas") | |
| _SAS_NUM_CHUNKS = 5 | |
| _SAS_CHUNK_SECONDS = 15 # seconds per analysis window | |
| # Key profile families for multi-profile voting (mid / sas) | |
| _KEY_PROFILES = { | |
| "krumhansl": { | |
| "major": [6.35, 2.23, 3.48, 2.33, 4.38, 4.09, | |
| 2.52, 5.19, 2.39, 3.66, 2.29, 2.88], | |
| "minor": [6.33, 2.68, 3.52, 5.38, 2.60, 3.53, | |
| 2.54, 4.75, 3.98, 2.69, 3.34, 3.17], | |
| }, | |
| "temperley": { | |
| "major": [5.0, 2.0, 3.5, 2.0, 4.5, 4.0, | |
| 2.0, 4.5, 2.0, 3.5, 1.5, 4.0], | |
| "minor": [5.0, 2.0, 3.5, 4.5, 2.0, 3.5, | |
| 2.0, 4.5, 3.5, 2.0, 1.5, 4.0], | |
| }, | |
| "albrecht": { | |
| "major": [0.238, 0.006, 0.111, 0.006, 0.137, 0.094, | |
| 0.016, 0.214, 0.009, 0.080, 0.008, 0.081], | |
| "minor": [0.220, 0.006, 0.104, 0.123, 0.019, 0.103, | |
| 0.012, 0.214, 0.062, 0.022, 0.061, 0.052], | |
| }, | |
| } | |
| _PITCH_CLASSES = ["C", "C#", "D", "D#", "E", "F", | |
| "F#", "G", "G#", "A", "A#", "B"] | |
| # Filename pattern: "Artist - Title" | |
| _FILENAME_RE = re.compile(r"^(.+?)\s*[-–—]\s*(.+)$") | |
| # ---- Demucs stem separation (mid / sas) -------------------------------- | |
| def separate_stems( | |
| audio_path: Path, | |
| tmp_dir: Path, | |
| device: str = "cpu", | |
| ) -> Tuple[Path, Path]: | |
| """Run Demucs HTDemucs and return (drums_path, harmonics_path). | |
| Harmonics = bass + other stems summed. Vocals are discarded. | |
| WARNING: On CPU this takes ~2-5 minutes per file. | |
| """ | |
| import torchaudio | |
| from demucs.pretrained import get_model | |
| from demucs.apply import apply_model | |
| torch_device = torch.device(device) | |
| logger.info("Loading Demucs HTDemucs model on %s", device) | |
| if device == "cpu": | |
| logger.warning( | |
| "Demucs on CPU is slow (~2-5 min per file). " | |
| "Consider using 'faf' mode or running on a GPU machine." | |
| ) | |
| model = get_model("htdemucs") | |
| model.to(torch_device) | |
| model.eval() | |
| wav, sr = torchaudio.load(str(audio_path)) | |
| # Resample to model's expected rate (44100 Hz) if needed | |
| if sr != model.samplerate: | |
| wav = torchaudio.functional.resample(wav, sr, model.samplerate) | |
| sr = model.samplerate | |
| # HTDemucs requires stereo input | |
| if wav.shape[0] == 1: | |
| wav = wav.repeat(2, 1) | |
| wav = wav.unsqueeze(0).to(torch_device) | |
| logger.info("Separating stems for %s", audio_path.name) | |
| with torch.no_grad(): | |
| sources = apply_model(model, wav, device=torch_device) | |
| source_map = {name: i for i, name in enumerate(model.sources)} | |
| drums = sources[0, source_map["drums"]].cpu() | |
| bass = sources[0, source_map["bass"]].cpu() | |
| other = sources[0, source_map["other"]].cpu() | |
| harmonics = bass + other | |
| drums_path = tmp_dir / "drums.wav" | |
| harmonics_path = tmp_dir / "harmonics.wav" | |
| torchaudio.save(str(drums_path), drums, sr) | |
| torchaudio.save(str(harmonics_path), harmonics, sr) | |
| del model, sources, wav, drums, bass, other, harmonics | |
| gc.collect() | |
| logger.info("Stems written: %s, %s", drums_path, harmonics_path) | |
| return drums_path, harmonics_path | |
| # ---- Chunk selection (sas mode) ---------------------------------------- | |
| def _select_chunks( | |
| y, # np.ndarray | |
| sr: int, | |
| n_chunks: int = _SAS_NUM_CHUNKS, | |
| chunk_sec: float = _SAS_CHUNK_SECONDS, | |
| min_gap_sec: float = 10.0, | |
| use_onset: bool = True, | |
| ) -> list: | |
| """Select the most informative audio chunks for sas analysis. | |
| Energy-gated + spread: rank windows by onset density (or RMS), | |
| discard below-median, then greedily pick chunks maximally spread apart. | |
| """ | |
| import librosa | |
| import numpy as np | |
| chunk_samples = int(chunk_sec * sr) | |
| hop_samples = chunk_samples // 2 | |
| if len(y) < chunk_samples: | |
| return [y] | |
| candidates = [] | |
| for start in range(0, len(y) - chunk_samples + 1, hop_samples): | |
| window = y[start : start + chunk_samples] | |
| if use_onset: | |
| onset_env = librosa.onset.onset_strength(y=window, sr=sr) | |
| score = float(np.mean(onset_env)) | |
| else: | |
| score = float(np.sqrt(np.mean(window ** 2))) | |
| candidates.append((start, score)) | |
| if not candidates: | |
| return [y] | |
| scores = np.array([s for _, s in candidates]) | |
| median_score = float(np.median(scores)) | |
| gated = [(start, score) for start, score in candidates if score >= median_score] | |
| if not gated: | |
| gated = candidates | |
| gated.sort(key=lambda x: x[1], reverse=True) | |
| min_gap_samples = int(min_gap_sec * sr) | |
| selected_starts = [] | |
| for start, score in gated: | |
| centre = start + chunk_samples // 2 | |
| too_close = any( | |
| abs(centre - (s + chunk_samples // 2)) < min_gap_samples | |
| for s in selected_starts | |
| ) | |
| if not too_close: | |
| selected_starts.append(start) | |
| if len(selected_starts) >= n_chunks: | |
| break | |
| if len(selected_starts) < n_chunks: | |
| for start, score in gated: | |
| if start not in selected_starts: | |
| selected_starts.append(start) | |
| if len(selected_starts) >= n_chunks: | |
| break | |
| selected_starts.sort() | |
| return [y[s : s + chunk_samples] for s in selected_starts] | |
| # ---- BPM helpers -------------------------------------------------------- | |
| def _octave_correct_bpm(bpm: float, lo: float = 70.0, hi: float = 180.0) -> float: | |
| """Fold BPM into the musical sweet-spot range [lo, hi].""" | |
| if bpm <= 0: | |
| return bpm | |
| candidate = bpm | |
| while candidate > hi: | |
| candidate /= 2.0 | |
| while candidate < lo: | |
| candidate *= 2.0 | |
| if candidate < lo or candidate > hi: | |
| return bpm | |
| return candidate | |
| def _bpm_core_ensemble(y, sr) -> list: | |
| """Run the 3-method BPM ensemble on a single audio buffer (mid/sas). | |
| Returns a list of octave-corrected BPM estimates. | |
| """ | |
| import librosa | |
| import numpy as np | |
| estimates = [] | |
| # Method A: beat_track | |
| try: | |
| tempo_a, _ = librosa.beat.beat_track(y=y, sr=sr) | |
| val_a = float(np.atleast_1d(tempo_a)[0]) | |
| if val_a > 0: | |
| estimates.append(_octave_correct_bpm(val_a)) | |
| except Exception: | |
| pass | |
| # Method B: tempogram peak | |
| try: | |
| onset_env = librosa.onset.onset_strength(y=y, sr=sr) | |
| tempogram = librosa.feature.tempogram(onset_envelope=onset_env, sr=sr) | |
| avg_tempogram = np.mean(tempogram, axis=1) | |
| bpm_axis = librosa.tempo_frequencies(tempogram.shape[0], sr=sr) | |
| valid = (bpm_axis >= 30) & (bpm_axis <= 300) | |
| if np.any(valid): | |
| masked = avg_tempogram.copy() | |
| masked[~valid] = 0 | |
| peak_idx = np.argmax(masked) | |
| val_b = float(bpm_axis[peak_idx]) | |
| if val_b > 0: | |
| estimates.append(_octave_correct_bpm(val_b)) | |
| except Exception: | |
| pass | |
| # Method C: onset autocorrelation | |
| try: | |
| onset_env = librosa.onset.onset_strength(y=y, sr=sr) | |
| ac = librosa.autocorrelate(onset_env, max_size=len(onset_env)) | |
| hop = 512 | |
| min_lag = int(60.0 * sr / (300.0 * hop)) | |
| max_lag = int(60.0 * sr / (30.0 * hop)) | |
| max_lag = min(max_lag, len(ac) - 1) | |
| if min_lag < max_lag and max_lag > 0: | |
| segment = ac[min_lag:max_lag + 1] | |
| peak_offset = np.argmax(segment) | |
| peak_lag = min_lag + peak_offset | |
| if peak_lag > 0: | |
| val_c = 60.0 * sr / (peak_lag * hop) | |
| if val_c > 0: | |
| estimates.append(_octave_correct_bpm(val_c)) | |
| except Exception: | |
| pass | |
| return estimates | |
| def _bpm_consensus(estimates: list) -> Tuple[Optional[int], str]: | |
| """Find consensus BPM from a list of estimates + assign confidence.""" | |
| import numpy as np | |
| if not estimates: | |
| return None, "low" | |
| estimates_arr = np.array(estimates) | |
| best_cluster = [] | |
| for ref in estimates_arr: | |
| cluster = [e for e in estimates_arr | |
| if abs(e - ref) / max(ref, 1) < 0.08] | |
| if len(cluster) > len(best_cluster): | |
| best_cluster = cluster | |
| consensus = float(np.median(best_cluster)) if best_cluster else estimates[0] | |
| bpm = int(round(consensus)) | |
| if bpm <= 0: | |
| return None, "low" | |
| n_agree = len(best_cluster) | |
| n_total = len(estimates) | |
| if n_total >= 6: | |
| # sas thresholds (many data points) | |
| if n_agree / n_total >= 0.7: | |
| confidence = "high" | |
| elif n_agree / n_total >= 0.4: | |
| confidence = "medium" | |
| else: | |
| confidence = "low" | |
| else: | |
| # mid thresholds | |
| if n_agree >= 3: | |
| confidence = "high" | |
| elif n_agree >= 2: | |
| confidence = "medium" | |
| else: | |
| confidence = "low" | |
| return bpm, confidence | |
| # ---- Unified BPM detection --------------------------------------------- | |
| def _detect_bpm(y, sr, mode: str = "faf") -> Tuple[Optional[int], str]: | |
| """Detect BPM with quality controlled by mode. | |
| faf: Single beat_track + octave correction. | |
| mid: 3-method ensemble (beat_track + tempogram + onset-AC). | |
| sas: mid ensemble + PLP + multi-hop + chunked analysis. | |
| Returns (bpm, confidence). | |
| """ | |
| import librosa | |
| import numpy as np | |
| try: | |
| # faf: single method | |
| if mode == "faf": | |
| try: | |
| tempo, _ = librosa.beat.beat_track(y=y, sr=sr) | |
| val = float(np.atleast_1d(tempo)[0]) | |
| if val > 0: | |
| bpm = int(round(_octave_correct_bpm(val))) | |
| logger.info("BPM faf: %d (raw: %.1f)", bpm, val) | |
| return bpm, "low" | |
| except Exception: | |
| pass | |
| return None, "low" | |
| # mid: 3-method ensemble | |
| estimates = _bpm_core_ensemble(y, sr) | |
| # sas: additional techniques | |
| ibi_cv = 0.5 | |
| if mode == "sas": | |
| # PLP (Predominant Local Pulse) | |
| try: | |
| onset_env = librosa.onset.onset_strength(y=y, sr=sr) | |
| pulse = librosa.beat.plp(onset_envelope=onset_env, sr=sr) | |
| plp_ac = librosa.autocorrelate(pulse, max_size=len(pulse)) | |
| hop = 512 | |
| min_lag = int(60.0 * sr / (300.0 * hop)) | |
| max_lag = int(60.0 * sr / (30.0 * hop)) | |
| max_lag = min(max_lag, len(plp_ac) - 1) | |
| if min_lag < max_lag and max_lag > 0: | |
| seg = plp_ac[min_lag:max_lag + 1] | |
| peak_lag = min_lag + np.argmax(seg) | |
| if peak_lag > 0: | |
| plp_bpm = 60.0 * sr / (peak_lag * hop) | |
| if plp_bpm > 0: | |
| estimates.append(_octave_correct_bpm(plp_bpm)) | |
| except Exception: | |
| pass | |
| # Multi-hop beat_track (256, 1024) | |
| for extra_hop in (256, 1024): | |
| try: | |
| tempo_h, _ = librosa.beat.beat_track(y=y, sr=sr, hop_length=extra_hop) | |
| val_h = float(np.atleast_1d(tempo_h)[0]) | |
| if val_h > 0: | |
| estimates.append(_octave_correct_bpm(val_h)) | |
| except Exception: | |
| pass | |
| # Chunked ensemble | |
| chunks = _select_chunks(y, sr, n_chunks=_SAS_NUM_CHUNKS, use_onset=True) | |
| for chunk in chunks: | |
| chunk_estimates = _bpm_core_ensemble(chunk, sr) | |
| estimates.extend(chunk_estimates) | |
| # IBI stability | |
| try: | |
| _, beat_frames = librosa.beat.beat_track(y=y, sr=sr) | |
| if beat_frames is not None and len(beat_frames) > 4: | |
| beat_times = librosa.frames_to_time(beat_frames, sr=sr) | |
| ibis = np.diff(beat_times) | |
| ibi_cv = float(np.std(ibis) / (np.mean(ibis) + 1e-10)) | |
| else: | |
| ibi_cv = 0.5 | |
| except Exception: | |
| ibi_cv = 0.5 | |
| bpm, confidence = _bpm_consensus(estimates) | |
| # sas: IBI stability can upgrade medium->high or downgrade | |
| if mode == "sas" and bpm is not None: | |
| if ibi_cv < 0.10 and confidence == "medium": | |
| confidence = "high" | |
| elif ibi_cv > 0.30 and confidence == "high": | |
| confidence = "medium" | |
| logger.info( | |
| "BPM [%s]: %s (estimates=%s, conf=%s)", | |
| mode, bpm, | |
| [round(e, 1) for e in estimates[:10]], | |
| confidence, | |
| ) | |
| return bpm, confidence | |
| except Exception as exc: | |
| logger.warning("BPM detection failed: %s", exc) | |
| return None, "low" | |
| # ---- Key detection helpers ---------------------------------------------- | |
| def _best_key_for_profile(chroma_avg, major_profile, minor_profile): | |
| """Find the best key match for a single profile family. | |
| Returns (key_label, correlation). | |
| """ | |
| import numpy as np | |
| major_norm = np.array(major_profile, dtype=float) | |
| major_norm = major_norm / major_norm.sum() | |
| minor_norm = np.array(minor_profile, dtype=float) | |
| minor_norm = minor_norm / minor_norm.sum() | |
| best_corr = -2.0 | |
| best_key = "C major" | |
| for shift in range(12): | |
| rotated = np.roll(chroma_avg, -shift) | |
| corr_maj = float(np.corrcoef(rotated, major_norm)[0, 1]) | |
| if corr_maj > best_corr: | |
| best_corr = corr_maj | |
| best_key = f"{_PITCH_CLASSES[shift]} major" | |
| corr_min = float(np.corrcoef(rotated, minor_norm)[0, 1]) | |
| if corr_min > best_corr: | |
| best_corr = corr_min | |
| best_key = f"{_PITCH_CLASSES[shift]} minor" | |
| return best_key, best_corr | |
| def _key_votes_from_chroma(chroma_avg, profiles=None) -> list: | |
| """Vote on key from a single chroma vector using specified profiles. | |
| Returns list of (key_label, correlation) -- one per profile family. | |
| """ | |
| if profiles is None: | |
| profiles = _KEY_PROFILES | |
| results = [] | |
| for name, pf in profiles.items(): | |
| key_label, corr = _best_key_for_profile( | |
| chroma_avg, pf["major"], pf["minor"], | |
| ) | |
| results.append((key_label, corr)) | |
| return results | |
| def _energy_weighted_chroma(chroma, y_harmonic): | |
| """Compute an energy-weighted average chroma vector. | |
| Returns normalized chroma_avg or None if zero energy. | |
| """ | |
| import librosa | |
| import numpy as np | |
| rms = librosa.feature.rms(y=y_harmonic, frame_length=2048, hop_length=512) | |
| rms_vec = rms[0] | |
| min_len = min(chroma.shape[1], len(rms_vec)) | |
| chroma = chroma[:, :min_len] | |
| rms_vec = rms_vec[:min_len] | |
| weights = rms_vec / (rms_vec.sum() + 1e-10) | |
| chroma_avg = (chroma * weights[None, :]).sum(axis=1) | |
| s = chroma_avg.sum() | |
| if s == 0: | |
| return None | |
| return chroma_avg / s | |
| # ---- Unified key detection ---------------------------------------------- | |
| def _detect_key(y, sr, mode: str = "faf") -> Tuple[Optional[str], str]: | |
| """Detect musical key with quality controlled by mode. | |
| faf: Single Krumhansl profile on chroma_cens. | |
| mid: 3-profile x energy-weighted chroma_cens x 8s segment voting. | |
| sas: mid + multi-chroma fusion + tonnetz + tuning correction + | |
| ending resolution + chunked voting. | |
| Returns (key, confidence). | |
| """ | |
| import librosa | |
| import numpy as np | |
| from collections import Counter | |
| try: | |
| # Harmonic enhancement | |
| margin = 4.0 if mode != "faf" else 2.0 | |
| y_harmonic = librosa.effects.harmonic(y, margin=margin) | |
| # sas: tuning correction | |
| tuning = 0.0 | |
| if mode == "sas": | |
| try: | |
| tuning = float(librosa.estimate_tuning(y=y_harmonic, sr=sr)) | |
| except Exception: | |
| tuning = 0.0 | |
| # faf: single chroma, single profile | |
| if mode == "faf": | |
| chroma = librosa.feature.chroma_cens(y=y_harmonic, sr=sr) | |
| chroma_avg = _energy_weighted_chroma(chroma, y_harmonic) | |
| if chroma_avg is None: | |
| return None, "low" | |
| kr = _KEY_PROFILES["krumhansl"] | |
| key_label, corr = _best_key_for_profile( | |
| chroma_avg, kr["major"], kr["minor"], | |
| ) | |
| logger.info("Key faf: %s (corr=%.3f)", key_label, corr) | |
| return key_label, "low" | |
| # mid / sas: multi-profile voting | |
| all_votes = [] | |
| all_weights = [] | |
| if mode == "sas": | |
| chroma_types = { | |
| "cens": lambda: librosa.feature.chroma_cens( | |
| y=y_harmonic, sr=sr, tuning=tuning, | |
| ), | |
| "cqt": lambda: librosa.feature.chroma_cqt( | |
| y=y_harmonic, sr=sr, tuning=tuning, | |
| ), | |
| "stft": lambda: librosa.feature.chroma_stft( | |
| y=y_harmonic, sr=sr, tuning=tuning, | |
| ), | |
| } | |
| else: | |
| chroma_types = { | |
| "cens": lambda: librosa.feature.chroma_cens( | |
| y=y_harmonic, sr=sr, | |
| ), | |
| } | |
| for chroma_name, chroma_fn in chroma_types.items(): | |
| try: | |
| chroma = chroma_fn() | |
| except Exception: | |
| continue | |
| chroma_avg = _energy_weighted_chroma(chroma, y_harmonic) | |
| if chroma_avg is None: | |
| continue | |
| # Global multi-profile vote | |
| for key_label, corr in _key_votes_from_chroma(chroma_avg): | |
| all_votes.append(key_label) | |
| all_weights.append(1.0) | |
| # Segment-based voting | |
| rms = librosa.feature.rms( | |
| y=y_harmonic, frame_length=2048, hop_length=512, | |
| ) | |
| rms_vec = rms[0] | |
| min_len = min(chroma.shape[1], len(rms_vec)) | |
| chroma_s = chroma[:, :min_len] | |
| rms_s = rms_vec[:min_len] | |
| seg_frames = int(8.0 * sr / 512) | |
| n_segments = max(1, chroma_s.shape[1] // seg_frames) | |
| for seg_i in range(n_segments): | |
| start = seg_i * seg_frames | |
| end = min(start + seg_frames, chroma_s.shape[1]) | |
| seg_chroma = chroma_s[:, start:end] | |
| seg_w = rms_s[start:end] | |
| w_sum = seg_w.sum() | |
| if w_sum < 1e-10: | |
| continue | |
| seg_w_norm = seg_w / w_sum | |
| seg_avg = (seg_chroma * seg_w_norm[None, :]).sum(axis=1) | |
| s = seg_avg.sum() | |
| if s < 1e-10: | |
| continue | |
| seg_avg = seg_avg / s | |
| for key_label, _ in _key_votes_from_chroma(seg_avg): | |
| all_votes.append(key_label) | |
| all_weights.append(1.0) | |
| # sas-only extras | |
| if mode == "sas": | |
| # Tonnetz -- weighted vote for major/minor disambiguation | |
| try: | |
| tonnetz = librosa.feature.tonnetz(y=y_harmonic, sr=sr) | |
| tonnetz_avg = np.mean(tonnetz, axis=1) | |
| major_energy = float(np.sum(tonnetz_avg[4:6] ** 2)) | |
| minor_energy = float(np.sum(tonnetz_avg[2:4] ** 2)) | |
| tonnetz_ratio = major_energy / (minor_energy + 1e-10) | |
| if all_votes: | |
| temp_counts = Counter(all_votes) | |
| leader = temp_counts.most_common(1)[0][0] | |
| leader_is_major = "major" in leader | |
| tonnetz_says_major = tonnetz_ratio > 1.0 | |
| if leader_is_major == tonnetz_says_major: | |
| all_votes.extend([leader] * 3) | |
| all_weights.extend([1.5] * 3) | |
| else: | |
| alt_mode = "minor" if leader_is_major else "major" | |
| chroma_cens = librosa.feature.chroma_cens( | |
| y=y_harmonic, sr=sr, tuning=tuning, | |
| ) | |
| ca = _energy_weighted_chroma(chroma_cens, y_harmonic) | |
| if ca is not None: | |
| for name, pf in _KEY_PROFILES.items(): | |
| prof = np.array(pf[alt_mode], dtype=float) | |
| prof_norm = prof / prof.sum() | |
| best_corr = -2.0 | |
| best_k = "" | |
| for shift in range(12): | |
| rotated = np.roll(ca, -shift) | |
| c = float(np.corrcoef(rotated, prof_norm)[0, 1]) | |
| if c > best_corr: | |
| best_corr = c | |
| best_k = f"{_PITCH_CLASSES[shift]} {alt_mode}" | |
| if best_k: | |
| all_votes.append(best_k) | |
| all_weights.append(1.0) | |
| except Exception: | |
| pass | |
| # Ending resolution -- last ~5 s weighted extra | |
| try: | |
| end_samples = min(int(5.0 * sr), len(y_harmonic)) | |
| y_end = y_harmonic[-end_samples:] | |
| chroma_end = librosa.feature.chroma_cens( | |
| y=y_end, sr=sr, tuning=tuning, | |
| ) | |
| end_avg = np.mean(chroma_end, axis=1) | |
| s = end_avg.sum() | |
| if s > 1e-10: | |
| end_avg = end_avg / s | |
| for key_label, _ in _key_votes_from_chroma(end_avg): | |
| all_votes.append(key_label) | |
| all_weights.append(2.0) | |
| except Exception: | |
| pass | |
| # Chunked voting | |
| chunks = _select_chunks( | |
| y_harmonic, sr, n_chunks=_SAS_NUM_CHUNKS, use_onset=False, | |
| ) | |
| for chunk in chunks: | |
| try: | |
| ch_chroma = librosa.feature.chroma_cens( | |
| y=chunk, sr=sr, tuning=tuning, | |
| ) | |
| ch_avg = _energy_weighted_chroma(ch_chroma, chunk) | |
| if ch_avg is not None: | |
| for key_label, _ in _key_votes_from_chroma(ch_avg): | |
| all_votes.append(key_label) | |
| all_weights.append(1.0) | |
| except Exception: | |
| pass | |
| # Final weighted majority vote | |
| if not all_votes: | |
| return None, "low" | |
| weighted_counts = {} | |
| for vote, w in zip(all_votes, all_weights): | |
| weighted_counts[vote] = weighted_counts.get(vote, 0.0) + w | |
| best_key = max(weighted_counts, key=weighted_counts.get) | |
| total_weight = sum(all_weights) | |
| best_weight = weighted_counts[best_key] | |
| share = best_weight / total_weight | |
| if share >= 0.55: | |
| confidence = "high" | |
| elif share >= 0.35: | |
| confidence = "medium" | |
| else: | |
| confidence = "low" | |
| logger.info( | |
| "Key [%s]: %s (share=%.0f%%, votes=%d, conf=%s)", | |
| mode, best_key, share * 100, len(all_votes), confidence, | |
| ) | |
| return best_key, confidence | |
| except Exception as exc: | |
| logger.warning("Key detection failed: %s", exc) | |
| return None, "low" | |
| # ---- Time-signature helpers --------------------------------------------- | |
| def _timesig_core_scores(y, sr) -> dict: | |
| """Compute 3-signal time-signature scores on a single buffer (mid/sas). | |
| Returns dict mapping signature labels to raw scores. | |
| """ | |
| import librosa | |
| import numpy as np | |
| scores = {} | |
| tempo, beat_frames = librosa.beat.beat_track(y=y, sr=sr) | |
| if beat_frames is None or len(beat_frames) < 8: | |
| return scores | |
| onset_env = librosa.onset.onset_strength(y=y, sr=sr) | |
| beat_strengths = onset_env[beat_frames[beat_frames < len(onset_env)]] | |
| if len(beat_strengths) < 8: | |
| return scores | |
| # Signal 1: Accent pattern analysis | |
| for label, grouping in [("3/4", 3), ("4/4", 4), ("6/8", 6)]: | |
| if len(beat_strengths) < grouping * 2: | |
| scores[label] = 0.0 | |
| continue | |
| usable = len(beat_strengths) - (len(beat_strengths) % grouping) | |
| grouped = beat_strengths[:usable].reshape(-1, grouping) | |
| downbeat_mean = float(np.mean(grouped[:, 0])) | |
| offbeat_mean = float(np.mean(grouped[:, 1:])) | |
| contrast = downbeat_mean / offbeat_mean if offbeat_mean > 0 else 1.0 | |
| scores[label] = contrast | |
| # Signal 2: Autocorrelation at meter periods | |
| hop = 512 | |
| beat_times = librosa.frames_to_time(beat_frames, sr=sr) | |
| intervals = np.diff(beat_times) | |
| if len(intervals) > 0: | |
| median_interval = float(np.median(intervals)) | |
| beat_period = int(round(median_interval * sr / hop)) | |
| if beat_period > 0: | |
| ac = librosa.autocorrelate(onset_env, max_size=len(onset_env)) | |
| for label, mult in [("3/4", 3), ("4/4", 4), ("6/8", 6)]: | |
| period = beat_period * mult | |
| if period < len(ac): | |
| lo = max(0, period - 2) | |
| hi = min(len(ac), period + 3) | |
| ac_score = float(np.mean(ac[lo:hi])) | |
| if ac[0] > 0: | |
| ac_score /= float(ac[0]) | |
| scores[label] = scores.get(label, 0.0) + ac_score | |
| # Signal 3: Beat-strength variance ratio | |
| for label, grouping in [("3/4", 3), ("4/4", 4)]: | |
| usable = len(beat_strengths) - (len(beat_strengths) % grouping) | |
| if usable >= grouping * 2: | |
| grouped = beat_strengths[:usable].reshape(-1, grouping) | |
| row_vars = np.var(grouped, axis=1) | |
| scores[label] = scores.get(label, 0.0) + float(np.mean(row_vars)) | |
| return scores | |
| # ---- Unified time-signature detection ----------------------------------- | |
| def _detect_time_sig(y, sr, mode: str = "faf") -> Tuple[Optional[str], str]: | |
| """Estimate time signature with quality controlled by mode. | |
| faf: Hardcoded "4/4" (correct ~80%+ of the time). | |
| mid: Beat-sync accent + AC + variance + 4/4 prior. | |
| sas: mid signals + PLP periodicity + multi-band onset + | |
| tempogram harmonic ratios + chunked voting. | |
| Returns (signature, confidence). | |
| """ | |
| if mode == "faf": | |
| return "4/4", "low" | |
| import librosa | |
| import numpy as np | |
| try: | |
| # mid: core 3-signal scoring | |
| scores = _timesig_core_scores(y, sr) | |
| # sas: additional techniques | |
| if mode == "sas": | |
| onset_env = librosa.onset.onset_strength(y=y, sr=sr) | |
| # PLP periodicity | |
| try: | |
| pulse = librosa.beat.plp(onset_envelope=onset_env, sr=sr) | |
| plp_ac = librosa.autocorrelate(pulse, max_size=len(pulse)) | |
| tempo_est, _ = librosa.beat.beat_track(y=y, sr=sr) | |
| tempo_val = float(np.atleast_1d(tempo_est)[0]) | |
| if tempo_val > 0: | |
| hop = 512 | |
| bp = int(round(60.0 / tempo_val * sr / hop)) | |
| if bp > 0: | |
| for label, mult in [("3/4", 3), ("4/4", 4), ("6/8", 6)]: | |
| lag = bp * mult | |
| if lag < len(plp_ac): | |
| lo = max(0, lag - 2) | |
| hi = min(len(plp_ac), lag + 3) | |
| s = float(np.mean(plp_ac[lo:hi])) | |
| if plp_ac[0] > 0: | |
| s /= float(plp_ac[0]) | |
| scores[label] = scores.get(label, 0.0) + s | |
| except Exception: | |
| pass | |
| # Multi-band onset analysis (low/mid/high) | |
| try: | |
| S = np.abs(librosa.stft(y)) | |
| n_bins = S.shape[0] | |
| third = n_bins // 3 | |
| bands = { | |
| "low": S[:third, :], | |
| "mid_band": S[third:2*third, :], | |
| "high": S[2*third:, :], | |
| } | |
| for band_name, band_S in bands.items(): | |
| band_onset = librosa.onset.onset_strength(S=band_S, sr=sr) | |
| band_ac = librosa.autocorrelate( | |
| band_onset, max_size=len(band_onset), | |
| ) | |
| tempo_val2 = float(np.atleast_1d(tempo_est)[0]) | |
| if tempo_val2 > 0: | |
| hop = 512 | |
| bp2 = int(round(60.0 / tempo_val2 * sr / hop)) | |
| if bp2 > 0 and band_ac[0] > 0: | |
| for label, mult in [("3/4", 3), ("4/4", 4)]: | |
| lag = bp2 * mult | |
| if lag < len(band_ac): | |
| lo = max(0, lag - 2) | |
| hi = min(len(band_ac), lag + 3) | |
| s = float(np.mean(band_ac[lo:hi])) | |
| s /= float(band_ac[0]) | |
| w = 1.5 if band_name == "low" else 1.0 | |
| scores[label] = scores.get(label, 0.0) + s * w | |
| except Exception: | |
| pass | |
| # Tempogram harmonic ratios | |
| try: | |
| tempogram = librosa.feature.tempogram( | |
| onset_envelope=onset_env, sr=sr, | |
| ) | |
| avg_tg = np.mean(tempogram, axis=1) | |
| bpm_axis = librosa.tempo_frequencies(tempogram.shape[0], sr=sr) | |
| if tempo_val > 0: | |
| for mult_label, t_mult in [("duple", 2.0), ("triple", 3.0)]: | |
| target_bpm = tempo_val * t_mult | |
| if target_bpm < 300: | |
| idx = np.argmin(np.abs(bpm_axis - target_bpm)) | |
| energy = float(avg_tg[idx]) | |
| base_idx = np.argmin(np.abs(bpm_axis - tempo_val)) | |
| base_energy = float(avg_tg[base_idx]) + 1e-10 | |
| ratio = energy / base_energy | |
| if t_mult == 2.0: | |
| scores["4/4"] = scores.get("4/4", 0.0) + ratio | |
| else: | |
| scores["3/4"] = scores.get("3/4", 0.0) + ratio | |
| except Exception: | |
| pass | |
| # Chunked voting | |
| chunks = _select_chunks(y, sr, n_chunks=_SAS_NUM_CHUNKS, use_onset=True) | |
| chunk_votes = [] | |
| for chunk in chunks: | |
| cs = _timesig_core_scores(chunk, sr) | |
| if cs: | |
| cs["4/4"] = cs.get("4/4", 0.0) * 1.15 | |
| best_c = max(cs, key=cs.get) | |
| chunk_votes.append(best_c) | |
| for vote in chunk_votes: | |
| scores[vote] = scores.get(vote, 0.0) + 1.0 | |
| # Bayesian prior: bias toward 4/4 | |
| scores["4/4"] = scores.get("4/4", 0.0) * 1.15 | |
| if not scores: | |
| return "4/4", "low" | |
| best = max(scores, key=scores.get) | |
| # Confidence: margin between top 2 | |
| sorted_scores = sorted(scores.values(), reverse=True) | |
| if len(sorted_scores) >= 2 and sorted_scores[1] > 0: | |
| margin = sorted_scores[0] / sorted_scores[1] | |
| else: | |
| margin = 1.0 | |
| if margin > 1.4: | |
| confidence = "high" | |
| elif margin > 1.15: | |
| confidence = "medium" | |
| else: | |
| confidence = "low" | |
| logger.info( | |
| "TimeSig [%s]: %s (scores=%s, margin=%.2f, conf=%s)", | |
| mode, best, | |
| {k: round(v, 3) for k, v in scores.items()}, | |
| margin, confidence, | |
| ) | |
| return best, confidence | |
| except Exception as exc: | |
| logger.warning("Time signature detection failed: %s", exc) | |
| return "4/4", "low" | |
| def _sanitize_tag(value: str) -> str: | |
| """Normalize a tag value: NFKC normalize, strip invisible chars.""" | |
| value = unicodedata.normalize("NFKC", value) | |
| value = ( | |
| value | |
| .replace("", "").replace("", "") | |
| .replace("", "").replace("", "") | |
| .replace("", "").replace("", "") | |
| .replace("", "").replace("", "") | |
| .replace("", "") | |
| ) | |
| value = "".join( | |
| c for c in value | |
| if c in ("\n", "\r", "\t", " ") or unicodedata.category(c)[0] != "C" | |
| ) | |
| return value.strip() | |
| def _extract_metadata_from_tags(audio_path: Path) -> tuple: | |
| """Extract (title, artist) from audio tags via mutagen, fallback to filename.""" | |
| title, artist = None, None | |
| try: | |
| import mutagen | |
| mf = mutagen.File(str(audio_path)) | |
| if mf is not None and mf.tags is not None: | |
| # ID3 (MP3, AIFF) | |
| for key in ("TIT2",): | |
| val = mf.tags.get(key) | |
| if val: | |
| title = _sanitize_tag(str(val)) | |
| break | |
| for key in ("TPE1", "TPE2"): | |
| val = mf.tags.get(key) | |
| if val: | |
| artist = _sanitize_tag(str(val)) | |
| break | |
| # Vorbis (FLAC, OGG) and MP4 atoms | |
| if title is None: | |
| for key in ("title", "\xa9nam"): | |
| vals = mf.tags.get(key) | |
| if vals: | |
| raw = str(vals[0]) if isinstance(vals, list) else str(vals) | |
| title = _sanitize_tag(raw) | |
| break | |
| if artist is None: | |
| for key in ("artist", "\xa9ART", "albumartist", "aART"): | |
| vals = mf.tags.get(key) | |
| if vals: | |
| raw = str(vals[0]) if isinstance(vals, list) else str(vals) | |
| artist = _sanitize_tag(raw) | |
| break | |
| except Exception: | |
| pass | |
| # Fallback to filename parsing | |
| if not title: | |
| stem = audio_path.stem | |
| match = _FILENAME_RE.match(stem) | |
| if match: | |
| artist = artist or match.group(1).strip() | |
| title = match.group(2).strip() | |
| else: | |
| title = stem.strip() | |
| return title or audio_path.stem, artist or "" | |
| def analyze_and_caption( | |
| audio_path: str, | |
| mode: str = "faf", | |
| device: str = "cpu", | |
| ) -> Dict[str, Any]: | |
| """Analyze an audio file and build a training caption. | |
| Supports three quality modes: | |
| faf - CPU, ~2-3s/file. Single-method detection on raw mix. | |
| mid - ~12s/file. Demucs stems + 3-method ensemble. | |
| sas - ~30s/file. Deep multi-technique + chunked analysis. | |
| For mid/sas, Demucs separates drums and harmonics stems first. | |
| On CPU, Demucs adds ~2-5 minutes per file. | |
| Args: | |
| audio_path: Path to the audio file. | |
| mode: Analysis mode ("faf", "mid", or "sas"). | |
| device: Torch device for Demucs ("cpu"). | |
| Returns: | |
| Dict with keys: caption, bpm, key, signature, lyrics, title, artist, | |
| confidence (dict of per-field confidence levels). | |
| """ | |
| import librosa | |
| import numpy as np | |
| audio_path = Path(audio_path) | |
| if mode not in _ANALYSIS_MODES: | |
| logger.warning("Unknown analysis mode '%s', falling back to 'faf'", mode) | |
| mode = "faf" | |
| # Load audio once, reuse for all detectors | |
| try: | |
| y, sr = librosa.load(str(audio_path), sr=None, mono=True) | |
| # Trim silence + peak normalize | |
| y_trimmed, _ = librosa.effects.trim(y, top_db=30) | |
| if len(y_trimmed) >= sr: | |
| y = y_trimmed | |
| peak = np.max(np.abs(y)) | |
| if peak > 0: | |
| y = y / peak | |
| except Exception as exc: | |
| logger.warning("Could not load audio for analysis: %s: %s", audio_path.name, exc) | |
| title, artist = _extract_metadata_from_tags(audio_path) | |
| return { | |
| "caption": f"A track by {artist}" if artist else f"A track titled {title}", | |
| "bpm": None, "key": None, "signature": "4/4", | |
| "lyrics": "[Instrumental]", "title": title, "artist": artist, | |
| "confidence": {}, | |
| } | |
| confidence = {} | |
| tmp_dir = None | |
| try: | |
| if mode in ("mid", "sas"): | |
| # Demucs stem separation -- run BPM/timesig on drums, | |
| # key detection on harmonics | |
| tmp_dir = Path(tempfile.mkdtemp(prefix="ace_analysis_")) | |
| try: | |
| drums_path, harmonics_path = separate_stems( | |
| audio_path, tmp_dir, device=device, | |
| ) | |
| # Load separated stems for analysis | |
| y_drums, sr_drums = librosa.load( | |
| str(drums_path), sr=None, mono=True, | |
| ) | |
| y_harmonics, sr_harmonics = librosa.load( | |
| str(harmonics_path), sr=None, mono=True, | |
| ) | |
| # Preprocess stems | |
| y_drums_trimmed, _ = librosa.effects.trim(y_drums, top_db=30) | |
| if len(y_drums_trimmed) >= sr_drums: | |
| y_drums = y_drums_trimmed | |
| peak_d = np.max(np.abs(y_drums)) | |
| if peak_d > 0: | |
| y_drums = y_drums / peak_d | |
| y_harm_trimmed, _ = librosa.effects.trim(y_harmonics, top_db=30) | |
| if len(y_harm_trimmed) >= sr_harmonics: | |
| y_harmonics = y_harm_trimmed | |
| peak_h = np.max(np.abs(y_harmonics)) | |
| if peak_h > 0: | |
| y_harmonics = y_harmonics / peak_h | |
| # BPM + time sig on drums stem | |
| bpm, bpm_conf = _detect_bpm(y_drums, sr_drums, mode) | |
| signature, sig_conf = _detect_time_sig(y_drums, sr_drums, mode) | |
| # Key on harmonics stem | |
| key, key_conf = _detect_key(y_harmonics, sr_harmonics, mode) | |
| confidence = {"bpm": bpm_conf, "key": key_conf, "signature": sig_conf} | |
| except Exception as exc: | |
| logger.warning( | |
| "Demucs separation failed for %s: %s -- " | |
| "falling back to analysis on raw mix", | |
| audio_path.name, exc, | |
| ) | |
| # Fallback: run detectors on raw mix | |
| bpm, bpm_conf = _detect_bpm(y, sr, mode) | |
| key, key_conf = _detect_key(y, sr, mode) | |
| signature, sig_conf = _detect_time_sig(y, sr, mode) | |
| confidence = {"bpm": bpm_conf, "key": key_conf, "signature": sig_conf} | |
| else: | |
| # faf: all detectors on raw mix | |
| bpm, bpm_conf = _detect_bpm(y, sr, mode) | |
| key, key_conf = _detect_key(y, sr, mode) | |
| signature, sig_conf = _detect_time_sig(y, sr, mode) | |
| confidence = {"bpm": bpm_conf, "key": key_conf, "signature": sig_conf} | |
| finally: | |
| if tmp_dir is not None: | |
| try: | |
| shutil.rmtree(tmp_dir) | |
| except OSError as exc: | |
| logger.debug("Could not clean temp dir %s: %s", tmp_dir, exc) | |
| title, artist = _extract_metadata_from_tags(audio_path) | |
| # Build caption string for ACE-Step training | |
| parts = ["A"] | |
| if artist: | |
| parts.append(f"track by {artist}") | |
| else: | |
| parts.append("track") | |
| if bpm: | |
| parts.append(f"at {bpm} BPM") | |
| if key: | |
| parts.append(f"in {key}") | |
| parts.append(f"{signature} time") | |
| caption = " ".join(parts) | |
| lyrics = "[Instrumental]" | |
| result = { | |
| "caption": caption, | |
| "bpm": bpm, | |
| "key": key, | |
| "signature": signature, | |
| "lyrics": lyrics, | |
| "title": title, | |
| "artist": artist, | |
| "confidence": confidence, | |
| } | |
| logger.info("Auto-caption [%s] for %s: %s", mode, audio_path.name, caption) | |
| return result | |
| def _write_caption_sidecar(audio_path: Path, analysis: Dict[str, Any]) -> Path: | |
| """Write analysis results as a .json sidecar next to the audio file.""" | |
| sidecar_path = audio_path.with_suffix(".json") | |
| with open(sidecar_path, "w", encoding="utf-8") as f: | |
| json.dump(analysis, f, indent=2, ensure_ascii=False) | |
| logger.info("Wrote caption sidecar: %s", sidecar_path) | |
| return sidecar_path | |
| def _parse_txt_caption(text: str) -> Dict[str, Any]: | |
| """Parse user's .txt caption format into structured fields.""" | |
| result: Dict[str, Any] = {} | |
| lyrics_match = re.search(r'lyrics say "(.*?)" at tempo', text, re.DOTALL) | |
| if lyrics_match: | |
| result["lyrics"] = lyrics_match.group(1).strip() | |
| caption_part = text[:lyrics_match.start()].strip().rstrip(",").strip() | |
| else: | |
| result["lyrics"] = "[Instrumental]" | |
| caption_part = text.strip() | |
| bpm_match = re.search(r'at tempo (\d+) BPM', text) | |
| if bpm_match: | |
| result["bpm"] = bpm_match.group(1) | |
| caption_part = re.sub(r'\s*at tempo \d+ BPM.*', '', caption_part).strip() | |
| key_match = re.search(r'in the key of ([A-G][#b]?[-\d]*)', text) | |
| if key_match: | |
| result["key"] = key_match.group(1) | |
| result["caption"] = caption_part if caption_part else text[:200] | |
| return result | |
| def _read_caption_sidecar(audio_path: Path) -> Optional[Dict[str, Any]]: | |
| """Read .json or .txt caption sidecar.""" | |
| json_path = audio_path.with_suffix(".json") | |
| if json_path.is_file(): | |
| try: | |
| with open(json_path, "r", encoding="utf-8") as f: | |
| return json.load(f) | |
| except Exception: | |
| pass | |
| txt_path = audio_path.with_suffix(".txt") | |
| if txt_path.is_file(): | |
| try: | |
| with open(txt_path, "r", encoding="utf-8") as f: | |
| return _parse_txt_caption(f.read()) | |
| except Exception: | |
| pass | |
| return None | |
| # ============================================================================ | |
| # PREPROCESSING (2-pass sequential) | |
| # ============================================================================ | |
| def preprocess_audio( | |
| audio_dir: str, | |
| output_dir: str, | |
| checkpoint_dir: str, | |
| device: str = "auto", | |
| variant: str = "base", | |
| max_duration: float = 0, | |
| progress_callback: Optional[Callable] = None, | |
| cancel_check: Optional[Callable] = None, | |
| ) -> Dict[str, Any]: | |
| """2-pass sequential preprocessing. | |
| Pass 1: Load VAE + text encoder, encode audio + text, save intermediates. | |
| Pass 2: Load DIT model, run encoder, build context, save final .pt files. | |
| Args: | |
| device: "auto" to auto-detect GPU/CPU, or explicit device string. | |
| """ | |
| device = detect_device(device) | |
| logger.info("Preprocessing on device: %s", device) | |
| out = Path(output_dir) | |
| out.mkdir(parents=True, exist_ok=True) | |
| # Clean orphaned staging files | |
| for orphan in out.glob("*.__writing__"): | |
| try: | |
| orphan.unlink() | |
| except OSError: | |
| pass | |
| audio_files = _discover_audio_files(audio_dir) | |
| if not audio_files: | |
| return {"processed": 0, "failed": 0, "total": 0, "output_dir": str(out)} | |
| total = len(audio_files) | |
| if max_duration <= 0: | |
| max_duration = _detect_max_duration(audio_files) | |
| dtype = select_dtype(device) | |
| # ---- Pass 1: VAE + Text Encoder ---- | |
| logger.info("Pass 1/2: Loading VAE + Text Encoder...") | |
| vae = load_vae(checkpoint_dir, device) | |
| tokenizer, text_enc = load_text_encoder(checkpoint_dir, device) | |
| silence_lat = load_silence_latent(checkpoint_dir, device, variant=variant) | |
| intermediates: List[Path] = [] | |
| p1_failed = 0 | |
| try: | |
| for i, af in enumerate(audio_files): | |
| if cancel_check and cancel_check(): | |
| break | |
| stem = af.stem | |
| final_pt = out / f"{stem}.pt" | |
| if final_pt.exists(): | |
| continue | |
| try: | |
| audio, _ = load_audio_stereo(str(af), TARGET_SR, max_duration) | |
| # Auto-caption (once per file, shared across chunks) | |
| sidecar = _read_caption_sidecar(af) | |
| if sidecar is not None: | |
| caption = sidecar.get("caption", "") or af.stem | |
| lyrics = sidecar.get("lyrics", "[Instrumental]") | |
| logger.info("[Caption] %s: using existing sidecar", af.name) | |
| else: | |
| if device == "cpu": | |
| analysis_mode = "faf" | |
| elif total <= 20: | |
| analysis_mode = "sas" | |
| elif total <= 100: | |
| analysis_mode = "mid" | |
| else: | |
| analysis_mode = "faf" | |
| if i == 0: | |
| _MODE_DESC = { | |
| "faf": "fast, ~3s/file", | |
| "mid": "balanced, ~12s/file", | |
| "sas": "best quality, ~30s/file on GPU, slower on CPU", | |
| } | |
| logger.info( | |
| "[Analysis] Mode '%s' (%s) for %d files", | |
| analysis_mode, _MODE_DESC[analysis_mode], total, | |
| ) | |
| try: | |
| logger.info("[Caption] %s: analyzing (mode=%s)...", af.name, analysis_mode) | |
| analysis = analyze_and_caption( | |
| str(af), mode=analysis_mode, device=device, | |
| ) | |
| caption = analysis["caption"] | |
| lyrics = analysis.get("lyrics", "[Instrumental]") | |
| _write_caption_sidecar(af, analysis) | |
| logger.info("[Caption] %s: %s", af.name, caption) | |
| except Exception as exc: | |
| logger.warning("[Caption] %s: analysis failed (%s), using filename", af.name, exc) | |
| caption = af.stem | |
| lyrics = "[Instrumental]" | |
| with torch.no_grad(): | |
| text_hs, text_mask = encode_text(text_enc, tokenizer, caption, device, dtype) | |
| lyric_hs, lyric_mask = encode_lyrics(text_enc, tokenizer, lyrics, device, dtype) | |
| has_bad = any( | |
| torch.isnan(t).any() or torch.isinf(t).any() | |
| for t in [text_hs, lyric_hs] | |
| ) | |
| if has_bad: | |
| p1_failed += 1 | |
| del text_hs, text_mask, lyric_hs, lyric_mask | |
| continue | |
| # VAE encode full audio (tiled for memory, output is full-length) | |
| audio_in = audio.unsqueeze(0).to(device=device, dtype=vae.dtype) | |
| with torch.no_grad(): | |
| target_latents = tiled_vae_encode(vae, audio_in, dtype) | |
| del audio_in, audio | |
| if torch.isnan(target_latents).any() or torch.isinf(target_latents).any(): | |
| p1_failed += 1 | |
| del target_latents, text_hs, text_mask, lyric_hs, lyric_mask | |
| continue | |
| lat = target_latents.squeeze(0).cpu() | |
| lat_len = lat.shape[0] | |
| att_mask = torch.ones(lat_len, dtype=dtype) | |
| tmp_path = out / f"{stem}.tmp.pt" | |
| torch.save({ | |
| "target_latents": lat, | |
| "attention_mask": att_mask, | |
| "text_hidden_states": text_hs.cpu(), | |
| "text_attention_mask": text_mask.cpu(), | |
| "lyric_hidden_states": lyric_hs.cpu(), | |
| "lyric_attention_mask": lyric_mask.cpu(), | |
| "silence_latent": silence_lat.cpu(), | |
| "latent_length": lat_len, | |
| "metadata": { | |
| "audio_path": str(af), | |
| "filename": af.name, | |
| "caption": caption, | |
| "lyrics": lyrics, | |
| }, | |
| }, tmp_path) | |
| intermediates.append(tmp_path) | |
| del target_latents, lat, text_hs, text_mask, lyric_hs, lyric_mask | |
| logger.info("[OK] %s: %d latent frames (%.1fs)", af.name, lat_len, lat_len / LATENT_HZ) | |
| if progress_callback: | |
| progress_callback(i + 1, total, f"[Pass 1] {af.name}") | |
| except Exception as exc: | |
| p1_failed += 1 | |
| logger.error("Pass 1 FAIL %s: %s", af.name, exc) | |
| finally: | |
| logger.info("Unloading VAE + Text Encoder...") | |
| unload_models(vae, text_enc, tokenizer, silence_lat) | |
| _clear_gpu_cache(device) | |
| # ---- Pass 2: DIT Encoder ---- | |
| if not intermediates: | |
| return {"processed": 0, "failed": p1_failed, "total": total, "output_dir": str(out)} | |
| logger.info("Pass 2/2: Loading DIT model (variant=%s)...", variant) | |
| model = load_model_for_training(checkpoint_dir, variant, device) | |
| processed = 0 | |
| p2_failed = 0 | |
| p2_total = len(intermediates) | |
| try: | |
| for i, tmp_path in enumerate(intermediates): | |
| if cancel_check and cancel_check(): | |
| break | |
| try: | |
| data = torch.load(str(tmp_path), weights_only=True) | |
| m_device = next(model.parameters()).device | |
| m_dtype = next(model.parameters()).dtype | |
| text_hs = data["text_hidden_states"].to(m_device, dtype=m_dtype) | |
| text_mask = data["text_attention_mask"].to(m_device, dtype=m_dtype) | |
| lyric_hs = data["lyric_hidden_states"].to(m_device, dtype=m_dtype) | |
| lyric_mask = data["lyric_attention_mask"].to(m_device, dtype=m_dtype) | |
| silence_lat = data["silence_latent"].to(m_device, dtype=m_dtype) | |
| lat_len = data["latent_length"] | |
| enc_hs, enc_mask = run_encoder( | |
| model, text_hs, text_mask, lyric_hs, lyric_mask, | |
| str(m_device), m_dtype, | |
| ) | |
| del text_hs, text_mask, lyric_hs, lyric_mask | |
| if silence_lat.dim() == 2: | |
| silence_lat = silence_lat.unsqueeze(0) | |
| ctx = build_context_latents(silence_lat, lat_len, str(m_device), m_dtype) | |
| del silence_lat | |
| has_bad = any( | |
| torch.isnan(t).any() or torch.isinf(t).any() | |
| for t in [enc_hs, ctx] | |
| ) | |
| if has_bad: | |
| p2_failed += 1 | |
| del enc_hs, enc_mask, ctx, data | |
| continue | |
| base_name = tmp_path.name.replace(".tmp.pt", ".pt") | |
| final_path = out / base_name | |
| staging_path = out / (base_name + ".__writing__") | |
| torch.save({ | |
| "target_latents": data["target_latents"], | |
| "attention_mask": data["attention_mask"], | |
| "encoder_hidden_states": enc_hs.squeeze(0).cpu(), | |
| "encoder_attention_mask": enc_mask.squeeze(0).cpu(), | |
| "context_latents": ctx.squeeze(0).cpu(), | |
| "metadata": data.get("metadata", {}), | |
| }, staging_path) | |
| os.replace(staging_path, final_path) | |
| del enc_hs, enc_mask, ctx, data | |
| tmp_path.unlink(missing_ok=True) | |
| processed += 1 | |
| if progress_callback: | |
| progress_callback(i + 1, p2_total, f"[Pass 2] {tmp_path.stem}") | |
| except Exception as exc: | |
| p2_failed += 1 | |
| logger.error("Pass 2 FAIL %s: %s", tmp_path.stem, exc) | |
| finally: | |
| logger.info("Unloading DIT model...") | |
| unload_models(model) | |
| _clear_gpu_cache(device) | |
| failed = p1_failed + p2_failed | |
| return {"processed": processed, "failed": failed, "total": total, | |
| "chunks": len(intermediates), "output_dir": str(out)} | |
| # ============================================================================ | |
| # TRAINING LOOP (generator for Gradio compatibility) | |
| # ============================================================================ | |
| def train_lora_generator( | |
| dataset_dir: str, | |
| output_dir: str, | |
| checkpoint_dir: str, | |
| epochs: int = 1000, | |
| lr: float = 3e-4, | |
| rank: int = 64, | |
| alpha: int = 128, | |
| dropout: float = 0.1, | |
| batch_size: int = 1, | |
| gradient_accumulation_steps: int = 4, | |
| warmup_steps: int = 100, | |
| weight_decay: float = 0.01, | |
| max_grad_norm: float = 1.0, | |
| save_every_n_epochs: int = 0, | |
| seed: int = 42, | |
| variant: str = "base", | |
| device: str = "auto", | |
| cfg_ratio: float = 0.15, | |
| timestep_mu: float = -0.4, | |
| timestep_sigma: float = 1.0, | |
| target_modules: Optional[List[str]] = None, | |
| log_every: int = 10, | |
| resume_from: Optional[str] = None, | |
| chunk_duration: float = 0, | |
| ) -> Generator[str, None, None]: | |
| """Run LoRA training, yielding progress strings each epoch. | |
| This is a generator for Gradio live-update compatibility. | |
| Call cancel_training() to stop after the current epoch. | |
| Args: | |
| device: "auto" to auto-detect GPU/CPU, or explicit device string. | |
| GPU uses mixed-precision (bfloat16/float16); CPU stays float32. | |
| """ | |
| _training_cancel.clear() | |
| train_start = time.time() | |
| # Auto-detect device | |
| device = detect_device(device) | |
| dtype = select_dtype(device) | |
| dev_type = device.split(":")[0] | |
| use_amp = dev_type == "cuda" | |
| if target_modules is None: | |
| target_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] | |
| ds_path = Path(dataset_dir) | |
| if not ds_path.is_dir(): | |
| yield f"[FAIL] Dataset directory not found: {ds_path}" | |
| return | |
| out_path = Path(output_dir) | |
| out_path.mkdir(parents=True, exist_ok=True) | |
| yield f"[INFO] Device: {device}, dtype: {dtype}, AMP: {use_amp}" | |
| if dev_type == "cuda": | |
| gpu_name = torch.cuda.get_device_name(device) | |
| gpu_mem = torch.cuda.get_device_properties(device).total_memory / (1024 ** 3) | |
| yield f"[INFO] GPU: {gpu_name} ({gpu_mem:.1f} GiB VRAM)" | |
| yield "[INFO] Loading model..." | |
| try: | |
| model = load_model_for_training(checkpoint_dir, variant, device) | |
| except Exception as exc: | |
| yield f"[FAIL] Model load failed: {exc}" | |
| return | |
| # Ensure model is in the correct dtype (load_model_for_training handles this, | |
| # but be explicit for safety) | |
| model = model.to(dtype=dtype) | |
| # Move model to device if not already there (CPU path) | |
| if dev_type == "cpu": | |
| model = model.to(device=device) | |
| yield "[INFO] Injecting LoRA..." | |
| lora_cfg = LoRAConfig( | |
| r=rank, alpha=alpha, dropout=dropout, | |
| target_modules=target_modules, bias="none", | |
| ) | |
| try: | |
| model, info = inject_lora(model, lora_cfg) | |
| except Exception as exc: | |
| yield f"[FAIL] LoRA injection failed: {exc}" | |
| unload_models(model) | |
| return | |
| yield f"[OK] LoRA injected: {info['trainable_params']:,} trainable params" | |
| # Gradient checkpointing + cache disable (enable_gradient_checkpointing | |
| # also walks the module tree and sets use_cache=False on any config it finds) | |
| ckpt_ok = enable_gradient_checkpointing(model.decoder) | |
| force_input_grads = ckpt_ok | |
| if ckpt_ok: | |
| yield "[INFO] Gradient checkpointing enabled" | |
| # Dataset | |
| dataset = TensorDataset(dataset_dir) | |
| if len(dataset) == 0: | |
| yield "[FAIL] No valid .pt files found in dataset directory" | |
| unload_models(model) | |
| return | |
| yield f"[OK] Loaded {len(dataset)} preprocessed samples" | |
| loader = DataLoader( | |
| dataset, batch_size=batch_size, shuffle=True, | |
| num_workers=0, collate_fn=_collate_batch, drop_last=False, | |
| pin_memory=(dev_type == "cuda"), | |
| ) | |
| # Optimizer & scheduler | |
| torch.manual_seed(seed) | |
| random.seed(seed) | |
| if dev_type == "cuda": | |
| torch.cuda.manual_seed_all(seed) | |
| trainable_params = [p for p in model.parameters() if p.requires_grad] | |
| if not trainable_params: | |
| yield "[FAIL] No trainable parameters found" | |
| unload_models(model) | |
| return | |
| optimizer = build_optimizer(trainable_params, lr=lr, weight_decay=weight_decay) | |
| steps_per_epoch = max(1, math.ceil(len(loader) / gradient_accumulation_steps)) | |
| total_steps = steps_per_epoch * epochs | |
| scheduler = build_scheduler(optimizer, total_steps, warmup_steps, lr) | |
| yield f"[INFO] Training {sum(p.numel() for p in trainable_params):,} params for {epochs} epochs" | |
| yield f"[INFO] Steps/epoch: {steps_per_epoch}, total: {total_steps}" | |
| # GradScaler for mixed precision on GPU (only for float16, not bfloat16) | |
| use_grad_scaler = use_amp and dtype == torch.float16 | |
| grad_scaler = None | |
| if use_grad_scaler: | |
| grad_scaler = torch.cuda.amp.GradScaler() | |
| yield "[INFO] GradScaler enabled (float16 mixed precision)" | |
| # Null condition embedding for CFG dropout | |
| null_cond = getattr(model, "null_condition_emb", None) | |
| # Resume checkpoint | |
| start_epoch = 0 | |
| global_step = 0 | |
| if resume_from and Path(resume_from).exists(): | |
| try: | |
| yield f"[INFO] Resuming from {resume_from}" | |
| ckpt_dir = Path(resume_from) | |
| if ckpt_dir.is_file(): | |
| ckpt_dir = ckpt_dir.parent | |
| # Load adapter weights | |
| aw = ckpt_dir / "adapter_model.safetensors" | |
| if aw.exists(): | |
| from safetensors.torch import load_file | |
| state = load_file(str(aw)) | |
| decoder = _unwrap_decoder(model) | |
| decoder.load_state_dict(state, strict=False) | |
| # Load training state | |
| ts = ckpt_dir / "training_state.pt" | |
| if ts.exists(): | |
| tstate = torch.load(str(ts), map_location=device, weights_only=True) | |
| start_epoch = tstate.get("epoch", 0) | |
| global_step = tstate.get("global_step", 0) | |
| if "optimizer_state_dict" in tstate: | |
| try: | |
| optimizer.load_state_dict(tstate["optimizer_state_dict"]) | |
| except Exception: | |
| pass | |
| if "scheduler_state_dict" in tstate: | |
| try: | |
| scheduler.load_state_dict(tstate["scheduler_state_dict"]) | |
| except Exception: | |
| pass | |
| yield f"[OK] Resumed from epoch {start_epoch}, step {global_step}" | |
| except Exception as exc: | |
| yield f"[WARN] Checkpoint load failed: {exc}, starting fresh" | |
| start_epoch = 0 | |
| global_step = 0 | |
| # Training loop | |
| model.decoder.train() | |
| acc_step = 0 | |
| acc_loss = 0.0 | |
| optimizer.zero_grad(set_to_none=True) | |
| best_loss = float("inf") | |
| best_epoch = 0 | |
| consecutive_nan = 0 | |
| MAX_NAN = 10 | |
| for epoch in range(start_epoch, epochs): | |
| # Cancel check | |
| if _training_cancel.is_set(): | |
| _training_cancel.clear() | |
| if epoch > start_epoch: | |
| model.decoder.eval() | |
| save_lora_adapter(model, str(out_path)) | |
| yield f"[OK] Cancelled at epoch {epoch + 1}, adapter saved" | |
| else: | |
| yield f"[CANCELLED] Stopped before any epoch completed" | |
| yield "[DONE]" | |
| _cuda_sync(device) | |
| unload_models(model) | |
| _clear_gpu_cache(device) | |
| return | |
| # Timeout check | |
| elapsed = time.time() - train_start | |
| if elapsed > MAX_TRAINING_TIME: | |
| model.decoder.eval() | |
| save_lora_adapter(model, str(out_path)) | |
| yield f"[WARN] Training timed out after {int(elapsed)}s, adapter saved" | |
| yield "[DONE]" | |
| _cuda_sync(device) | |
| unload_models(model) | |
| _clear_gpu_cache(device) | |
| return | |
| epoch_loss = 0.0 | |
| num_updates = 0 | |
| epoch_start = time.time() | |
| for batch in loader: | |
| # Move batch tensors to device | |
| nb = dev_type != "cpu" | |
| tgt = batch["target_latents"].to(device, dtype=dtype, non_blocking=nb) | |
| att = batch["attention_mask"].to(device, dtype=dtype, non_blocking=nb) | |
| enc_hs = batch["encoder_hidden_states"].to(device, dtype=dtype, non_blocking=nb) | |
| enc_mask = batch["encoder_attention_mask"].to(device, dtype=dtype, non_blocking=nb) | |
| ctx = batch["context_latents"].to(device, dtype=dtype, non_blocking=nb) | |
| # Random crop to chunk_duration (data augmentation + speed) | |
| if chunk_duration > 0: | |
| max_len = int(chunk_duration * LATENT_HZ) | |
| T = tgt.shape[1] | |
| if T > max_len: | |
| start = random.randint(0, T - max_len) | |
| tgt = tgt[:, start:start + max_len, :] | |
| att = att[:, start:start + max_len] | |
| ctx = ctx[:, start:start + max_len, :] | |
| bsz = tgt.shape[0] | |
| # CFG dropout | |
| if null_cond is not None and cfg_ratio > 0: | |
| enc_hs = apply_cfg_dropout(enc_hs, null_cond, cfg_ratio) | |
| # Timestep sampling | |
| t, _r = sample_timesteps(bsz, torch.device(device), dtype, timestep_mu, timestep_sigma) | |
| # Flow matching noise | |
| x1 = torch.randn_like(tgt) | |
| x0 = tgt | |
| t_ = t.unsqueeze(-1).unsqueeze(-1) | |
| xt = t_ * x1 + (1.0 - t_) * x0 | |
| if force_input_grads: | |
| xt = xt.requires_grad_(True) | |
| # Decoder forward -- use AMP autocast on GPU for mixed precision | |
| if use_amp: | |
| with torch.cuda.amp.autocast(dtype=dtype): | |
| dec_out = model.decoder( | |
| hidden_states=xt, | |
| timestep=t, | |
| timestep_r=t, | |
| attention_mask=att, | |
| encoder_hidden_states=enc_hs, | |
| encoder_attention_mask=enc_mask, | |
| context_latents=ctx, | |
| ) | |
| flow = x1 - x0 | |
| loss = F.mse_loss(dec_out[0], flow) | |
| else: | |
| # CPU path -- no autocast | |
| dec_out = model.decoder( | |
| hidden_states=xt, | |
| timestep=t, | |
| timestep_r=t, | |
| attention_mask=att, | |
| encoder_hidden_states=enc_hs, | |
| encoder_attention_mask=enc_mask, | |
| context_latents=ctx, | |
| ) | |
| flow = x1 - x0 | |
| loss = F.mse_loss(dec_out[0], flow) | |
| loss = loss.float() # fp32 for stable backward | |
| # NaN guard | |
| if torch.isnan(loss) or torch.isinf(loss): | |
| consecutive_nan += 1 | |
| del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow | |
| if consecutive_nan >= MAX_NAN: | |
| yield f"[FAIL] {consecutive_nan} consecutive NaN losses, halting" | |
| _cuda_sync(device) | |
| unload_models(model) | |
| _clear_gpu_cache(device) | |
| return | |
| if acc_step > 0: | |
| optimizer.zero_grad(set_to_none=True) | |
| acc_loss = 0.0 | |
| acc_step = 0 | |
| continue | |
| consecutive_nan = 0 | |
| loss = loss / gradient_accumulation_steps | |
| # Backward -- use GradScaler on float16 GPU | |
| if grad_scaler is not None: | |
| grad_scaler.scale(loss).backward() | |
| else: | |
| loss.backward() | |
| acc_loss += loss.item() | |
| del loss, tgt, att, enc_hs, enc_mask, ctx, xt, dec_out, flow | |
| acc_step += 1 | |
| if acc_step >= gradient_accumulation_steps: | |
| if grad_scaler is not None: | |
| grad_scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm) | |
| grad_scaler.step(optimizer) | |
| grad_scaler.update() | |
| else: | |
| torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm) | |
| optimizer.step() | |
| scheduler.step() | |
| global_step += 1 | |
| avg_loss = acc_loss * gradient_accumulation_steps / acc_step | |
| if global_step % log_every == 0: | |
| current_lr = scheduler.get_last_lr()[0] | |
| yield ( | |
| f"Epoch {epoch + 1}/{epochs}, " | |
| f"Step {global_step}, " | |
| f"Loss: {avg_loss:.4f}, " | |
| f"LR: {current_lr:.2e}" | |
| ) | |
| optimizer.zero_grad(set_to_none=True) | |
| epoch_loss += avg_loss | |
| num_updates += 1 | |
| acc_loss = 0.0 | |
| acc_step = 0 | |
| # Periodic GPU cache cleanup | |
| if dev_type == "cuda" and global_step % log_every == 0: | |
| torch.cuda.empty_cache() | |
| # Flush remainder | |
| if acc_step > 0: | |
| if grad_scaler is not None: | |
| grad_scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm) | |
| grad_scaler.step(optimizer) | |
| grad_scaler.update() | |
| else: | |
| torch.nn.utils.clip_grad_norm_(trainable_params, max_grad_norm) | |
| optimizer.step() | |
| scheduler.step() | |
| global_step += 1 | |
| avg_loss = acc_loss * gradient_accumulation_steps / acc_step | |
| optimizer.zero_grad(set_to_none=True) | |
| epoch_loss += avg_loss | |
| num_updates += 1 | |
| acc_loss = 0.0 | |
| acc_step = 0 | |
| epoch_time = time.time() - epoch_start | |
| avg_epoch_loss = epoch_loss / max(num_updates, 1) | |
| is_best = avg_epoch_loss < best_loss - 0.001 | |
| if is_best: | |
| best_loss = avg_epoch_loss | |
| best_epoch = epoch + 1 | |
| best_str = f" (best: {best_loss:.4f} @ ep{best_epoch})" if best_epoch > 0 else "" | |
| yield ( | |
| f"[OK] Epoch {epoch + 1}/{epochs} in {epoch_time:.1f}s, " | |
| f"Loss: {avg_epoch_loss:.4f}{best_str}" | |
| ) | |
| # Save best (directly to output dir so ace-server finds it) | |
| if is_best and epoch + 1 >= 10: | |
| model.decoder.eval() | |
| save_lora_adapter(model, str(out_path)) | |
| model.decoder.train() | |
| yield f"[OK] Best model saved (epoch {epoch + 1}, loss: {best_loss:.4f})" | |
| # Periodic checkpoint (0 = disabled, only save on cancel/finish) | |
| if save_every_n_epochs > 0 and (epoch + 1) % save_every_n_epochs == 0: | |
| ckpt_path = str(out_path / "checkpoints" / f"epoch_{epoch + 1}") | |
| model.decoder.eval() | |
| save_lora_adapter(model, ckpt_path) | |
| tstate = { | |
| "epoch": epoch + 1, | |
| "global_step": global_step, | |
| "optimizer_state_dict": optimizer.state_dict(), | |
| "scheduler_state_dict": scheduler.state_dict(), | |
| } | |
| os.makedirs(ckpt_path, exist_ok=True) | |
| torch.save(tstate, str(Path(ckpt_path) / "training_state.pt")) | |
| model.decoder.train() | |
| yield f"[OK] Checkpoint saved at epoch {epoch + 1}" | |
| # Clear GPU cache after epoch + checkpoint save | |
| _clear_gpu_cache(device) | |
| # Sanity check | |
| if global_step == 0: | |
| yield "[FAIL] Training completed 0 steps -- no batches processed" | |
| _cuda_sync(device) | |
| unload_models(model) | |
| _clear_gpu_cache(device) | |
| return | |
| # Final save (directly to output_dir, not a subdirectory) | |
| model.decoder.eval() | |
| save_lora_adapter(model, str(out_path)) | |
| final_loss = avg_epoch_loss if num_updates > 0 else 0.0 | |
| best_note = "" | |
| if best_epoch > 0 and Path(out_path / "best").exists(): | |
| best_note = f"\n Best: {out_path / 'best'} (epoch {best_epoch}, loss: {best_loss:.4f})" | |
| yield ( | |
| f"[OK] Training complete! LoRA saved to {out_path}{best_note}\n" | |
| f" Adapter ready for inference." | |
| ) | |
| yield "[DONE]" | |
| _cuda_sync(device) | |
| unload_models(model) | |
| _clear_gpu_cache(device) | |
| # ============================================================================ | |
| # ADAPTER LISTING | |
| # ============================================================================ | |
| def get_trained_loras(adapter_dir: str) -> List[str]: | |
| """List all saved LoRA adapter directories under adapter_dir.""" | |
| result = [] | |
| base = Path(adapter_dir) | |
| if not base.is_dir(): | |
| return result | |
| for root, dirs, files in os.walk(str(base)): | |
| for f in files: | |
| if f in ("adapter_config.json", "adapter_model.safetensors", "lora_weights.pt"): | |
| result.append(root) | |
| break | |
| return sorted(set(result)) | |
| # ============================================================================ | |
| # TILED VAE DECODE (mirror of tiled_vae_encode) | |
| # ============================================================================ | |
| def tiled_vae_decode( | |
| vae, latents: torch.Tensor, dtype: torch.dtype, | |
| chunk_frames: int = 1024, overlap_frames: int = 64, | |
| ) -> torch.Tensor: | |
| """Decode latents [B, T, C] -> waveform [B, 2, samples] using tiled VAE. | |
| Mirrors tiled_vae_encode but in the reverse direction. Tiles along | |
| the time axis of the latent to keep peak memory bounded. | |
| Args: | |
| vae: AutoencoderOobleck decoder. | |
| latents: Latent tensor in [B, T, C] layout (C=64). | |
| dtype: Target dtype for the output waveform. | |
| chunk_frames: Number of latent frames per tile. | |
| overlap_frames: Overlap frames per side for crossfade. | |
| Returns: | |
| Waveform tensor [B, 2, total_samples] in *dtype*. | |
| """ | |
| vae_device = next(vae.parameters()).device | |
| vae_dtype = vae.dtype | |
| # Transpose to VAE convention [B, C, T] | |
| lat = latents.transpose(1, 2).contiguous() | |
| B, C, T = lat.shape | |
| if T <= chunk_frames: | |
| with torch.inference_mode(): | |
| audio = vae.decode(lat.to(vae_device, dtype=vae_dtype)).sample | |
| return audio.to(dtype=dtype, device="cpu") | |
| # Upsample factor: unknown until first decode, so we discover it. | |
| stride = chunk_frames - 2 * overlap_frames | |
| if stride <= 0: | |
| raise ValueError(f"chunk_frames ({chunk_frames}) must be > 2*overlap ({overlap_frames})") | |
| num_tiles = math.ceil(T / stride) | |
| us_factor: Optional[float] = None | |
| write_pos = 0 | |
| final: Optional[torch.Tensor] = None | |
| for i in range(num_tiles): | |
| core_start = i * stride | |
| core_end = min(core_start + stride, T) | |
| win_start = max(0, core_start - overlap_frames) | |
| win_end = min(T, core_end + overlap_frames) | |
| chunk = lat[:, :, win_start:win_end].to(vae_device, dtype=vae_dtype) | |
| with torch.inference_mode(): | |
| decoded = vae.decode(chunk).sample # [B, 2, samples_chunk] | |
| if us_factor is None: | |
| us_factor = decoded.shape[-1] / chunk.shape[-1] | |
| total_samples = int(round(T * us_factor)) | |
| final = torch.zeros(B, decoded.shape[1], total_samples, dtype=decoded.dtype, device="cpu") | |
| trim_start = int(round((core_start - win_start) * us_factor)) | |
| trim_end = int(round((win_end - core_end) * us_factor)) | |
| end_idx = decoded.shape[-1] - trim_end if trim_end > 0 else decoded.shape[-1] | |
| core = decoded[:, :, trim_start:end_idx] | |
| core_len = core.shape[-1] | |
| final[:, :, write_pos:write_pos + core_len] = core.cpu() | |
| write_pos += core_len | |
| del chunk, decoded, core | |
| final = final[:, :, :write_pos] | |
| return final.to(dtype=dtype) | |
| # ============================================================================ | |
| # INFERENCE -- generate_audio() | |
| # ============================================================================ | |
| def generate_audio( | |
| caption: str, | |
| checkpoint_dir: str, | |
| output_path: str, | |
| lyrics: str = "[Instrumental]", | |
| duration: float = 10.0, | |
| bpm: int = 120, | |
| steps: int = 8, | |
| seed: int = -1, | |
| variant: str = "turbo", | |
| device: str = "auto", | |
| adapter_path: Optional[str] = None, | |
| adapter_scale: float = 1.0, | |
| use_lm: bool = True, | |
| lm_temperature: float = 0.85, | |
| lm_top_p: float = 0.9, | |
| lm_top_k: int = 0, | |
| ) -> str: | |
| """Generate audio using the full ACE-Step pipeline (LM + DiT). | |
| Pipeline: | |
| 1. LM (Qwen3 1.7B) generates CoT metadata + audio codes from | |
| caption and lyrics | |
| 2. Text encoder -> text_hidden_states, lyric embeddings | |
| 3. Load full model (DiT + condition encoder + FSQ) | |
| 4. Optional: inject LoRA adapter via PEFT | |
| 5. model.generate_audio() -- uses LM audio codes as context | |
| conditioning via the FSQ detokenizer, then runs flow-matching | |
| diffusion | |
| 6. VAE decode latents -> waveform | |
| 7. Save waveform as 48 kHz stereo WAV | |
| 8. Unload all models, free memory | |
| Args: | |
| caption: Text description of the desired music. | |
| checkpoint_dir: Root directory that contains model sub-dirs | |
| (e.g. ``acestep-v15-turbo/``, ``vae/``, ``Qwen3-Embedding-0.6B/``). | |
| output_path: Where to write the output WAV file. | |
| lyrics: Lyrics text or ``"[Instrumental]"`` for no vocals. | |
| duration: Desired audio length in seconds. | |
| bpm: Beats per minute (metadata hint for the model). | |
| steps: Number of diffusion steps (8 for turbo, 50 for base/SFT). | |
| seed: RNG seed (-1 = random). | |
| variant: Model variant name (e.g. ``"turbo"``, ``"base"``). | |
| device: ``"auto"``, ``"cpu"``, ``"cuda:0"``, etc. | |
| adapter_path: Path to a PEFT LoRA adapter directory (optional). | |
| adapter_scale: Scaling factor applied to the adapter. | |
| use_lm: Run the LM to generate audio codes (True) or skip | |
| and use silence context like before (False). | |
| lm_temperature: LM sampling temperature. | |
| lm_top_p: LM nucleus sampling cutoff. | |
| lm_top_k: LM top-K sampling (0 = disabled). | |
| Returns: | |
| The *output_path* string (for convenience). | |
| """ | |
| import numpy as np | |
| # ------------------------------------------------------------------ | |
| # 0. Device / dtype | |
| # ------------------------------------------------------------------ | |
| device = detect_device(device) | |
| dtype = select_dtype(device) | |
| logger.info( | |
| "generate_audio: device=%s, dtype=%s, variant=%s, steps=%d, duration=%.1fs, use_lm=%s", | |
| device, dtype, variant, steps, duration, use_lm, | |
| ) | |
| # Resolve seed | |
| if seed < 0: | |
| seed = random.randint(0, 2**31 - 1) | |
| logger.info("Using seed=%d", seed) | |
| # ------------------------------------------------------------------ | |
| # 1. LM generation -- produce audio codes from caption + lyrics | |
| # ------------------------------------------------------------------ | |
| audio_codes_list: Optional[List[int]] = None | |
| if use_lm: | |
| logger.info("Running LM to generate audio codes...") | |
| audio_codes_list = _generate_codes_with_lm( | |
| checkpoint_dir=checkpoint_dir, | |
| caption=caption, | |
| lyrics=lyrics, | |
| duration=duration, | |
| device=device, | |
| temperature=lm_temperature, | |
| top_p=lm_top_p, | |
| top_k=lm_top_k, | |
| ) | |
| if audio_codes_list: | |
| # The LM determines the actual duration via its code count | |
| lm_duration = len(audio_codes_list) / 5.0 | |
| logger.info( | |
| "LM generated %d codes (%.1fs). Overriding duration %.1f -> %.1f", | |
| len(audio_codes_list), lm_duration, duration, lm_duration, | |
| ) | |
| duration = lm_duration | |
| else: | |
| logger.warning("LM produced no codes, falling back to silence context.") | |
| # ------------------------------------------------------------------ | |
| # 2. Text encoder -- encode caption and lyrics | |
| # ------------------------------------------------------------------ | |
| logger.info("Loading text encoder...") | |
| tokenizer, text_encoder = load_text_encoder(checkpoint_dir, device) | |
| text_hs, text_mask = encode_text(text_encoder, tokenizer, caption, device, dtype) | |
| lyric_hs, lyric_mask = encode_lyrics(text_encoder, tokenizer, lyrics, device, dtype) | |
| # Free text encoder -- no longer needed | |
| unload_models(text_encoder) | |
| del text_encoder, tokenizer | |
| gc.collect() | |
| _clear_gpu_cache(device) | |
| logger.info("Text encoder unloaded.") | |
| # ------------------------------------------------------------------ | |
| # 3. Load full model (DiT + CondEncoder + FSQ tokenizer/detokenizer) | |
| # ------------------------------------------------------------------ | |
| logger.info("Loading ACE-Step model (%s)...", variant) | |
| model = load_model_for_training(checkpoint_dir, variant=variant, device=device) | |
| model = model.to(dtype=dtype) | |
| model.eval() | |
| # ------------------------------------------------------------------ | |
| # 4. Optional: inject LoRA adapter | |
| # ------------------------------------------------------------------ | |
| if adapter_path: | |
| logger.info("Loading LoRA adapter from %s (scale=%.2f)...", adapter_path, adapter_scale) | |
| from peft import PeftModel | |
| decoder = _unwrap_decoder(model) | |
| model.decoder = PeftModel.from_pretrained( | |
| decoder, adapter_path, is_trainable=False, | |
| ) | |
| # Apply adapter scale if not 1.0 | |
| if abs(adapter_scale - 1.0) > 1e-6: | |
| for name, module in model.decoder.named_modules(): | |
| if hasattr(module, "scaling"): | |
| for key in module.scaling: | |
| module.scaling[key] = adapter_scale | |
| model.decoder.eval() | |
| logger.info("LoRA adapter applied.") | |
| # ------------------------------------------------------------------ | |
| # 5. Prepare inputs for model.generate_audio() | |
| # ------------------------------------------------------------------ | |
| # Latent frame rate is 25 Hz | |
| LATENT_HZ = 25 | |
| latent_length = int(duration * LATENT_HZ) | |
| # Load silence latent for context building | |
| silence_latent = load_silence_latent(checkpoint_dir, device, variant) | |
| # Ensure silence latent covers the required length | |
| if silence_latent.shape[1] < latent_length: | |
| repeats = math.ceil(latent_length / silence_latent.shape[1]) | |
| silence_latent = silence_latent.repeat(1, repeats, 1) | |
| silence_latent = silence_latent[:, :latent_length, :].to(device=device, dtype=dtype) | |
| # Build source latents and masks | |
| src_latents = silence_latent[:1, :latent_length, :] | |
| chunk_masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype) | |
| # Detokenize LM audio codes into context latents for the DiT | |
| if audio_codes_list: | |
| indices_tensor = torch.tensor( | |
| audio_codes_list, dtype=torch.long, device=device, | |
| ).unsqueeze(0).unsqueeze(-1) # [1, T_5Hz, 1] | |
| with torch.no_grad(): | |
| lm_latents = model.tokenizer.quantizer.get_output_from_indices(indices_tensor) | |
| # lm_latents: [1, T_5Hz, codebook_dim] -> detokenize to [1, T_25Hz, 64] | |
| lm_latents = model.detokenize(lm_latents) | |
| T_lm = lm_latents.shape[1] | |
| # Use LM latents as src_latents context | |
| if T_lm < latent_length: | |
| pad = silence_latent[:, :latent_length - T_lm, :] | |
| src_latents = torch.cat([lm_latents, pad], dim=1) | |
| else: | |
| src_latents = lm_latents[:, :latent_length, :] | |
| chunk_masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype) | |
| is_covers = torch.ones(1, device=device, dtype=torch.long) | |
| logger.info("LM codes detokenized: %d codes -> %d latent frames, used as DiT context", len(audio_codes_list), T_lm) | |
| else: | |
| is_covers = torch.zeros(1, device=device, dtype=torch.long) | |
| # Dummy timbre reference (single silence frame -> no timbre conditioning) | |
| refer_audio = torch.zeros(1, 1, 64, device=device, dtype=dtype) | |
| refer_order = torch.zeros(1, device=device, dtype=torch.long) | |
| # Shift schedule: turbo uses 3.0, base/sft uses 1.0 | |
| shift = 3.0 if "turbo" in variant else 1.0 | |
| # ------------------------------------------------------------------ | |
| # 6. Run diffusion (model.generate_audio handles everything internally) | |
| # ------------------------------------------------------------------ | |
| logger.info("Running diffusion (%d steps, shift=%.1f)...", steps, shift) | |
| with torch.no_grad(): | |
| result = model.generate_audio( | |
| text_hidden_states=text_hs.to(device=device, dtype=dtype), | |
| text_attention_mask=text_mask.to(device=device, dtype=dtype), | |
| lyric_hidden_states=lyric_hs.to(device=device, dtype=dtype), | |
| lyric_attention_mask=lyric_mask.to(device=device, dtype=dtype), | |
| refer_audio_acoustic_hidden_states_packed=refer_audio, | |
| refer_audio_order_mask=refer_order, | |
| src_latents=src_latents, | |
| chunk_masks=chunk_masks, | |
| is_covers=is_covers, | |
| silence_latent=silence_latent, | |
| seed=seed, | |
| fix_nfe=steps, | |
| shift=shift, | |
| ) | |
| target_latents = result["target_latents"] # [1, T, 64] | |
| time_costs = result.get("time_costs", {}) | |
| logger.info("Diffusion done. Time costs: %s", time_costs) | |
| # Free model weights -- keep latents on CPU | |
| target_latents = target_latents.cpu().to(dtype) | |
| unload_models(model) | |
| del model, silence_latent, src_latents, chunk_masks | |
| del text_hs, text_mask, lyric_hs, lyric_mask | |
| gc.collect() | |
| _clear_gpu_cache(device) | |
| logger.info("DiT model unloaded.") | |
| # ------------------------------------------------------------------ | |
| # 7. VAE decode latents -> waveform | |
| # ------------------------------------------------------------------ | |
| logger.info("Loading VAE decoder...") | |
| vae = load_vae(checkpoint_dir, device) | |
| logger.info("Decoding latents -> waveform (tiled)...") | |
| waveform = tiled_vae_decode(vae, target_latents.to(device), dtype) # [1, 2, samples] | |
| unload_models(vae) | |
| del vae, target_latents | |
| gc.collect() | |
| _clear_gpu_cache(device) | |
| logger.info("VAE unloaded.") | |
| # ------------------------------------------------------------------ | |
| # 8. Save as WAV (48 kHz stereo) | |
| # ------------------------------------------------------------------ | |
| audio_np = waveform[0].float().clamp(-1.0, 1.0).cpu().numpy() # [2, samples] | |
| os.makedirs(os.path.dirname(os.path.abspath(output_path)), exist_ok=True) | |
| try: | |
| import soundfile as sf | |
| # soundfile expects [samples, channels] | |
| sf.write(output_path, audio_np.T, TARGET_SR, subtype="PCM_16") | |
| except ImportError: | |
| import torchaudio | |
| torchaudio.save(output_path, torch.from_numpy(audio_np), TARGET_SR) | |
| logger.info("Audio saved to %s (%.1fs @ %d Hz)", output_path, duration, TARGET_SR) | |
| return output_path | |
| # ============================================================================ | |
| # UNDERSTAND MODE (reverse pipeline: audio -> caption + lyrics) | |
| # ============================================================================ | |
| # Qwen3 special token IDs (ACE-Step LM vocabulary) | |
| _TOKEN_IM_START = 151644 | |
| _TOKEN_IM_END = 151645 | |
| _TOKEN_THINK = 151667 | |
| _TOKEN_THINK_END = 151668 | |
| _AUDIO_CODE_BASE = 151669 | |
| # LM system instructions (matches acestep.cpp task-types.h) | |
| _LM_GENERATE_INSTRUCTION = ( | |
| "Generate audio semantic tokens based on the given conditions:" | |
| ) | |
| _LM_UNDERSTAND_INSTRUCTION = ( | |
| "Understand the given musical conditions and describe the audio semantics accordingly:" | |
| ) | |
| def _sample_next_token( | |
| logits: torch.Tensor, temperature: float, top_k: int, top_p: float, | |
| ) -> int: | |
| """Sample a single token from logits with temperature, top-k, and top-p. | |
| Args: | |
| logits: 1-D logits tensor (vocab_size,). | |
| temperature: Sampling temperature (<=0 for argmax). | |
| top_k: Top-K filtering (0 = disabled). | |
| top_p: Nucleus sampling cutoff (0 or >=1 = disabled). | |
| Returns: | |
| Selected token ID as int. | |
| """ | |
| if temperature <= 0: | |
| return int(logits.argmax().item()) | |
| scaled = logits.clone() / temperature | |
| if top_k > 0: | |
| topk_vals, _ = torch.topk(scaled, min(top_k, scaled.shape[0])) | |
| scaled[scaled < topk_vals[-1]] = float("-inf") | |
| if top_p > 0 and top_p < 1.0: | |
| sorted_logits, sorted_idx = torch.sort(scaled, descending=True) | |
| probs = torch.softmax(sorted_logits, dim=-1) | |
| cumsum = torch.cumsum(probs, dim=-1) | |
| nucleus_mask = cumsum - probs > top_p | |
| sorted_logits[nucleus_mask] = float("-inf") | |
| scaled = torch.zeros_like(scaled).scatter(0, sorted_idx, sorted_logits) | |
| probs = torch.softmax(scaled, dim=-1) | |
| return int(torch.multinomial(probs, 1).item()) | |
| def _build_understand_prompt( | |
| bpe_tokenizer, codes: List[int], | |
| ) -> List[int]: | |
| """Build the Qwen3 chat prompt for understand mode. | |
| Format (matching C++ build_understand_prompt in prompt.h): | |
| <|im_start|>system | |
| # Instruction | |
| {LM_UNDERSTAND_INSTRUCTION} | |
| <|im_end|> | |
| <|im_start|>user | |
| {audio_code_tokens} | |
| <|im_end|> | |
| <|im_start|>assistant | |
| """ | |
| ids: List[int] = [] | |
| def append_text(text: str): | |
| encoded = bpe_tokenizer.encode(text, add_special_tokens=False) | |
| ids.extend(encoded) | |
| ids.append(_TOKEN_IM_START) | |
| append_text( | |
| "system\n# Instruction\n" | |
| + _LM_UNDERSTAND_INSTRUCTION | |
| + "\n\n" | |
| ) | |
| ids.append(_TOKEN_IM_END) | |
| append_text("\n") | |
| ids.append(_TOKEN_IM_START) | |
| append_text("user\n") | |
| # Audio codes as raw token IDs (not BPE text) | |
| for code in codes: | |
| ids.append(_AUDIO_CODE_BASE + code) | |
| append_text("\n") | |
| ids.append(_TOKEN_IM_END) | |
| append_text("\n") | |
| ids.append(_TOKEN_IM_START) | |
| append_text("assistant\n") | |
| return ids | |
| def _build_generate_prompt( | |
| bpe_tokenizer, caption: str, lyrics: str, | |
| ) -> List[int]: | |
| """Build the Qwen3 chat prompt for audio code generation. | |
| Format (matching C++ build_lm_prompt in prompt.h): | |
| <|im_start|>system | |
| # Instruction | |
| {LM_GENERATE_INSTRUCTION} | |
| <|im_end|> | |
| <|im_start|>user | |
| # Caption | |
| {caption} | |
| # Lyric | |
| {lyrics} | |
| <|im_end|> | |
| <|im_start|>assistant | |
| """ | |
| ids: List[int] = [] | |
| def append_text(text: str): | |
| encoded = bpe_tokenizer.encode(text, add_special_tokens=False) | |
| ids.extend(encoded) | |
| ids.append(_TOKEN_IM_START) | |
| append_text( | |
| "system\n# Instruction\n" | |
| + _LM_GENERATE_INSTRUCTION | |
| + "\n\n" | |
| ) | |
| ids.append(_TOKEN_IM_END) | |
| append_text("\n") | |
| ids.append(_TOKEN_IM_START) | |
| append_text( | |
| "user\n# Caption\n" + caption + "\n\n" | |
| "# Lyric\n" + lyrics + "\n" | |
| ) | |
| ids.append(_TOKEN_IM_END) | |
| append_text("\n") | |
| ids.append(_TOKEN_IM_START) | |
| append_text("assistant\n") | |
| return ids | |
| def _generate_codes_with_lm( | |
| checkpoint_dir: str, | |
| caption: str, | |
| lyrics: str, | |
| duration: float, | |
| device: str, | |
| temperature: float = 0.85, | |
| top_p: float = 0.9, | |
| top_k: int = 0, | |
| max_new_tokens: int = 8192, | |
| ) -> List[int]: | |
| """Run the ACE-Step LM (Qwen3 1.7B) to generate audio codes from text. | |
| The LM generates in two phases within a single autoregressive pass: | |
| Phase 1 (CoT): <think> metadata YAML (bpm, duration, key, etc.) </think> | |
| Phase 2 (codes): audio code tokens (token_id >= AUDIO_CODE_BASE) | |
| Args: | |
| checkpoint_dir: Root directory containing acestep-5Hz-lm-1.7B/. | |
| caption: Text description of the music. | |
| lyrics: Lyrics text or "[Instrumental]". | |
| duration: Target duration in seconds (the LM may override via CoT). | |
| device: Torch device string. | |
| temperature: Sampling temperature. | |
| top_p: Nucleus sampling cutoff (0.0 = disabled). | |
| top_k: Top-K sampling (0 = disabled). | |
| max_new_tokens: Maximum tokens to generate. | |
| Returns: | |
| List of FSQ code indices (0-63999 range, NOT offset by AUDIO_CODE_BASE). | |
| Length is approximately duration * 5 (5 Hz token rate). | |
| """ | |
| ckpt = Path(checkpoint_dir).resolve() | |
| lm_path = ckpt / "acestep-5Hz-lm-1.7B" | |
| if not lm_path.is_dir(): | |
| raise FileNotFoundError(f"LM checkpoint not found: {lm_path}") | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| # Load BPE tokenizer | |
| bpe_tokenizer = AutoTokenizer.from_pretrained(str(lm_path)) | |
| # Build generation prompt | |
| prompt_ids = _build_generate_prompt(bpe_tokenizer, caption, lyrics) | |
| logger.info( | |
| "[LM Generate] Prompt: %d tokens, caption=%r, lyrics=%r", | |
| len(prompt_ids), caption[:80], lyrics[:80], | |
| ) | |
| # Load the LM (Qwen3Model with tied embeddings -> CausalLM) | |
| from transformers import Qwen3Config | |
| lm_config = Qwen3Config.from_pretrained(str(lm_path)) | |
| lm_config.architectures = ["Qwen3ForCausalLM"] | |
| lm_dtype = select_dtype(device) | |
| lm_model = AutoModelForCausalLM.from_pretrained( | |
| str(lm_path), | |
| config=lm_config, | |
| torch_dtype=lm_dtype, | |
| low_cpu_mem_usage=True, | |
| ) | |
| lm_model = lm_model.to(device=device) | |
| lm_model.eval() | |
| logger.info("[LM Generate] LM loaded on %s (dtype=%s)", device, lm_dtype) | |
| # Autoregressive decode: single sequence, no CFG. | |
| prompt_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device) | |
| with torch.inference_mode(): | |
| outputs = lm_model(input_ids=prompt_tensor, use_cache=True) | |
| logits = outputs.logits[:, -1, :] # [1, vocab_size] | |
| past_kv = outputs.past_key_values | |
| gen_tokens: List[int] = [] | |
| audio_codes: List[int] = [] | |
| past_think = False | |
| in_think = False | |
| for step in range(max_new_tokens): | |
| logits = logits.clone() | |
| # Phase 1 (inside <think>): block audio codes so only text is generated | |
| if in_think: | |
| logits[0, _AUDIO_CODE_BASE:] = float("-inf") | |
| # Phase 2 (after </think>): only allow audio codes + im_end | |
| if past_think: | |
| # Zero out all non-audio-code logits except im_end (stop token) | |
| mask = torch.full_like(logits[0], float("-inf")) | |
| mask[_AUDIO_CODE_BASE:] = 0.0 | |
| mask[_TOKEN_IM_END] = 0.0 | |
| logits[0] = logits[0] + mask | |
| # Sample | |
| next_id = _sample_next_token(logits[0], temperature, top_k, top_p) | |
| # Stop on im_end | |
| if next_id == _TOKEN_IM_END: | |
| break | |
| # Track think state transitions | |
| if next_id == _TOKEN_THINK: | |
| in_think = True | |
| elif next_id == _TOKEN_THINK_END: | |
| in_think = False | |
| past_think = True | |
| gen_tokens.append(next_id) | |
| # Collect audio codes (Phase 2 tokens) | |
| if next_id >= _AUDIO_CODE_BASE: | |
| audio_codes.append(next_id - _AUDIO_CODE_BASE) | |
| # Next step | |
| next_input = torch.tensor([[next_id]], dtype=torch.long, device=device) | |
| with torch.inference_mode(): | |
| outputs = lm_model( | |
| input_ids=next_input, | |
| past_key_values=past_kv, | |
| use_cache=True, | |
| ) | |
| logits = outputs.logits[:, -1, :] | |
| past_kv = outputs.past_key_values | |
| # Log what the LM generated | |
| cot_tokens = [ | |
| t for t in gen_tokens | |
| if t < _AUDIO_CODE_BASE and t not in ( | |
| _TOKEN_IM_START, _TOKEN_IM_END, _TOKEN_THINK, _TOKEN_THINK_END, | |
| ) | |
| ] | |
| if cot_tokens: | |
| cot_text = bpe_tokenizer.decode(cot_tokens, skip_special_tokens=False) | |
| logger.info("[LM Generate] CoT output:\n%s", cot_text[:500]) | |
| logger.info( | |
| "[LM Generate] Generated %d total tokens, %d audio codes (%.1fs @ 5Hz)", | |
| len(gen_tokens), len(audio_codes), len(audio_codes) / 5.0, | |
| ) | |
| # Unload LM | |
| del outputs, logits, past_kv, prompt_tensor | |
| unload_models(lm_model) | |
| del lm_model, bpe_tokenizer | |
| gc.collect() | |
| _clear_gpu_cache(device) | |
| logger.info("[LM Generate] LM unloaded") | |
| if not audio_codes: | |
| logger.warning( | |
| "[LM Generate] No audio codes generated! The DiT will fall back to " | |
| "silence context. Check that the LM checkpoint is correct." | |
| ) | |
| return audio_codes | |
| def _parse_understand_output(text: str) -> Dict[str, str]: | |
| """Parse CoT metadata + lyrics from understand LM output. | |
| The LM generates: | |
| <think> | |
| bpm: 120 | |
| caption: ... | |
| duration: 180 | |
| keyscale: C major | |
| language: en | |
| timesignature: 4 | |
| </think> | |
| [Verse 1] | |
| ...lyrics... | |
| Returns dict with: caption, lyrics, bpm, key, signature, duration, | |
| language. | |
| """ | |
| result: Dict[str, str] = {} | |
| # Split at <think> / </think> boundaries | |
| cot = "" | |
| lyrics_after = "" | |
| ts = text.find("<think>") | |
| te = text.find("</think>") | |
| if ts != -1 and te != -1: | |
| cot = text[ts + 7:te] | |
| lyrics_after = text[te + 8:] | |
| elif te != -1: | |
| cot = text[:te] | |
| lyrics_after = text[te + 8:] | |
| else: | |
| cot = text | |
| # Parse YAML-like fields from CoT | |
| def get_field(key: str) -> str: | |
| needle = key + ":" | |
| p = cot.find(needle) | |
| if p == -1: | |
| return "" | |
| p += len(needle) | |
| # Skip leading whitespace and quotes | |
| while p < len(cot) and cot[p] in (" ", "'"): | |
| p += 1 | |
| end = cot.find("\n", p) | |
| if end == -1: | |
| end = len(cot) | |
| val = cot[p:end].rstrip(" '\r") | |
| return val | |
| bpm_s = get_field("bpm") | |
| if bpm_s: | |
| result["bpm"] = bpm_s | |
| dur_s = get_field("duration") | |
| if dur_s: | |
| result["duration"] = dur_s | |
| ks = get_field("keyscale") | |
| if ks: | |
| result["key"] = ks | |
| ts_s = get_field("timesignature") | |
| if ts_s: | |
| result["signature"] = ts_s | |
| lang = get_field("language") | |
| if lang: | |
| result["language"] = lang | |
| # Caption may span multiple lines (YAML word-wrap) | |
| cap_needle = "caption:" | |
| cp = cot.find(cap_needle) | |
| if cp != -1: | |
| cp += len(cap_needle) | |
| # Read until next known field or end of CoT | |
| end = len(cot) | |
| for next_field in ("duration:", "keyscale:", "language:", "timesignature:", "bpm:"): | |
| nf = cot.find("\n" + next_field, cp) | |
| if nf != -1 and nf < end: | |
| end = nf | |
| full_cap = cot[cp:end] | |
| # Collapse whitespace | |
| cleaned = " ".join(full_cap.split()).strip() | |
| if cleaned: | |
| result["caption"] = cleaned | |
| # Lyrics after </think> | |
| if lyrics_after: | |
| lyrics = lyrics_after.strip() | |
| # Strip "# Lyric\n" header the LM may echo back | |
| lp = lyrics.find("# Lyric\n") | |
| if lp != -1 and lp < 64: | |
| lyrics = lyrics[lp + 8:] | |
| lyrics = lyrics.strip() | |
| if lyrics: | |
| result["lyrics"] = lyrics | |
| return result | |
| def understand_audio( | |
| audio_path: str, | |
| checkpoint_dir: str, | |
| device: str = "auto", | |
| variant: str = "turbo", | |
| temperature: float = 0.3, | |
| top_p: float = 0.0, | |
| top_k: int = 0, | |
| max_new_tokens: int = 4096, | |
| ) -> Dict[str, str]: | |
| """Extract caption, lyrics, BPM, key, signature from audio using the LM. | |
| Pipeline: audio -> VAE encode -> FSQ tokenize -> LM understand -> text | |
| Returns dict with: caption, lyrics, bpm, key, signature, duration, | |
| language. | |
| Args: | |
| audio_path: Path to input audio file (WAV, MP3, FLAC, etc.) | |
| checkpoint_dir: Path to ACE-Step checkpoints root directory | |
| (must contain vae/, acestep-v15-turbo/ or variant subdir, | |
| and acestep-5Hz-lm-1.7B/). | |
| device: Device string ("auto", "cuda:0", "cpu", etc.) | |
| variant: DiT variant to load for FSQ tokenizer ("turbo", "sft", | |
| "base", etc.) | |
| temperature: LM sampling temperature (default 0.3, lower = more | |
| deterministic). | |
| top_p: Nucleus sampling cutoff (0.0 = disabled). | |
| top_k: Top-K sampling (0 = disabled). | |
| max_new_tokens: Maximum tokens to generate. | |
| Returns: | |
| Dict with extracted metadata. Keys may include: | |
| caption, lyrics, bpm, key, signature, duration, language. | |
| """ | |
| device = detect_device(device) | |
| dtype = select_dtype(device) | |
| ckpt = Path(checkpoint_dir).resolve() | |
| # ------------------------------------------------------------------ | |
| # Step 1: Load audio -> VAE encode -> latents [1, T_25Hz, 64] | |
| # ------------------------------------------------------------------ | |
| logger.info("[Understand] Step 1: VAE encode") | |
| audio, sr = load_audio_stereo(audio_path, TARGET_SR, MAX_AUDIO_DURATION) | |
| audio = audio.unsqueeze(0) # [1, 2, samples] | |
| logger.info( | |
| "[Understand] Audio loaded: %.1fs, %d samples @ %d Hz", | |
| audio.shape[-1] / TARGET_SR, audio.shape[-1], TARGET_SR, | |
| ) | |
| vae = load_vae(checkpoint_dir, device) | |
| latents = tiled_vae_encode(vae, audio, dtype) # [1, T_25Hz, 64] | |
| T_25Hz = latents.shape[1] | |
| logger.info("[Understand] VAE encoded: %d latent frames (%.2fs)", T_25Hz, T_25Hz * 1920.0 / TARGET_SR) | |
| unload_models(vae) | |
| del vae, audio | |
| gc.collect() | |
| _clear_gpu_cache(device) | |
| logger.info("[Understand] VAE unloaded") | |
| # ------------------------------------------------------------------ | |
| # Step 2: Load DiT (for FSQ tokenizer) -> tokenize latents -> codes | |
| # ------------------------------------------------------------------ | |
| logger.info("[Understand] Step 2: FSQ tokenize") | |
| # Load silence_latent for padding | |
| silence_latent = load_silence_latent(checkpoint_dir, device="cpu", variant=variant) | |
| # Load DiT model (only need its tokenizer submodule) | |
| model = load_model_for_training(checkpoint_dir, variant=variant, device=device) | |
| model = model.to(dtype=dtype) | |
| pool_window = model.config.pool_window_size # 5 (25Hz -> 5Hz) | |
| # Pad latents to multiple of pool_window_size | |
| lat = latents.to(device=device, dtype=dtype) | |
| pad_len = 0 | |
| if T_25Hz % pool_window != 0: | |
| pad_len = pool_window - (T_25Hz % pool_window) | |
| # Use silence_latent for padding | |
| sl = silence_latent[:1, :pad_len, :].to(device=device, dtype=dtype) | |
| lat = torch.cat([lat, sl.expand(lat.shape[0], -1, -1)], dim=1) | |
| # Tokenize: lat [1, T_padded, 64] -> indices [1, T_5Hz, 1] | |
| with torch.inference_mode(): | |
| _quantized, indices = model.tokenizer.tokenize(lat) | |
| # indices shape: [1, T_5Hz, num_quantizers=1] -> flatten to [T_5Hz] | |
| codes = indices.squeeze(0).squeeze(-1).cpu().tolist() # List[int] | |
| T_5Hz = len(codes) | |
| logger.info( | |
| "[Understand] FSQ tokenized: %d codes (%.2fs @ 5Hz)", | |
| T_5Hz, T_5Hz / 5.0, | |
| ) | |
| unload_models(model) | |
| del model, lat, latents, _quantized, indices, silence_latent | |
| gc.collect() | |
| _clear_gpu_cache(device) | |
| logger.info("[Understand] DiT unloaded") | |
| # ------------------------------------------------------------------ | |
| # Step 3: Load LM -> build understand prompt -> generate text | |
| # ------------------------------------------------------------------ | |
| logger.info("[Understand] Step 3: LM generation") | |
| lm_path = ckpt / "acestep-5Hz-lm-1.7B" | |
| if not lm_path.is_dir(): | |
| raise FileNotFoundError(f"LM checkpoint not found: {lm_path}") | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel | |
| # Load BPE tokenizer | |
| bpe_tokenizer = AutoTokenizer.from_pretrained(str(lm_path)) | |
| # Build understand prompt | |
| prompt_ids = _build_understand_prompt(bpe_tokenizer, codes) | |
| logger.info( | |
| "[Understand] Prompt: %d tokens (%d codes + framing)", | |
| len(prompt_ids), len(codes), | |
| ) | |
| # Load the LM (Qwen3Model with tied embeddings). | |
| # Config says "Qwen3Model" but we need generation (lm_head). Since | |
| # tie_word_embeddings=true, Qwen3ForCausalLM will tie the lm_head | |
| # to embed_tokens automatically. We override the architecture to load | |
| # as CausalLM. | |
| from transformers import Qwen3Config | |
| lm_config = Qwen3Config.from_pretrained(str(lm_path)) | |
| lm_config.architectures = ["Qwen3ForCausalLM"] | |
| lm_dtype = select_dtype(device) | |
| lm_model = AutoModelForCausalLM.from_pretrained( | |
| str(lm_path), | |
| config=lm_config, | |
| torch_dtype=lm_dtype, | |
| low_cpu_mem_usage=True, | |
| ) | |
| lm_model = lm_model.to(device=device) | |
| lm_model.eval() | |
| logger.info("[Understand] LM loaded on %s (dtype=%s)", device, lm_dtype) | |
| vocab_size = lm_config.vocab_size # 217204 | |
| # Autoregressive decode: no CFG, no batch, single sequence. | |
| # FSM is not implemented in Python (would require the prefix tree); | |
| # the LM generates structured CoT well enough without it at low temp. | |
| prompt_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device) | |
| with torch.inference_mode(): | |
| # Prefill | |
| outputs = lm_model(input_ids=prompt_tensor, use_cache=True) | |
| logits = outputs.logits[:, -1, :] # [1, vocab_size] | |
| past_kv = outputs.past_key_values | |
| gen_tokens: List[int] = [] | |
| past_think = False | |
| for step in range(max_new_tokens): | |
| logits = logits.clone() | |
| # After </think>: block audio codes so the LM only generates text | |
| if past_think: | |
| logits[0, _AUDIO_CODE_BASE:] = float("-inf") | |
| # Sample | |
| next_id = _sample_next_token(logits[0], temperature, top_k, top_p) | |
| if next_id == _TOKEN_IM_END: | |
| break | |
| if next_id == _TOKEN_THINK_END: | |
| past_think = True | |
| gen_tokens.append(next_id) | |
| # Next step | |
| next_input = torch.tensor([[next_id]], dtype=torch.long, device=device) | |
| with torch.inference_mode(): | |
| outputs = lm_model( | |
| input_ids=next_input, | |
| past_key_values=past_kv, | |
| use_cache=True, | |
| ) | |
| logits = outputs.logits[:, -1, :] | |
| past_kv = outputs.past_key_values | |
| logger.info("[Understand] Generated %d tokens", len(gen_tokens)) | |
| # Decode tokens to text (skip audio code tokens and special tokens) | |
| text_tokens = [ | |
| t for t in gen_tokens | |
| if t < _AUDIO_CODE_BASE and t not in ( | |
| _TOKEN_IM_START, _TOKEN_IM_END, _TOKEN_THINK, _TOKEN_THINK_END, | |
| ) | |
| ] | |
| generated_text = bpe_tokenizer.decode(text_tokens, skip_special_tokens=False) | |
| # Re-insert <think> / </think> markers for the parser | |
| think_text = "" | |
| in_think = False | |
| for t in gen_tokens: | |
| if t == _TOKEN_THINK: | |
| think_text += "<think>" | |
| in_think = True | |
| elif t == _TOKEN_THINK_END: | |
| think_text += "</think>" | |
| in_think = False | |
| elif t < _AUDIO_CODE_BASE and t not in (_TOKEN_IM_START, _TOKEN_IM_END): | |
| think_text += bpe_tokenizer.decode([t], skip_special_tokens=False) | |
| logger.info("[Understand] Raw output:\n%s", think_text[:500]) | |
| # Unload LM | |
| del outputs, logits, past_kv, prompt_tensor | |
| unload_models(lm_model) | |
| del lm_model, bpe_tokenizer | |
| gc.collect() | |
| _clear_gpu_cache(device) | |
| logger.info("[Understand] LM unloaded") | |
| # ------------------------------------------------------------------ | |
| # Step 4: Parse generated text into structured fields | |
| # ------------------------------------------------------------------ | |
| result = _parse_understand_output(think_text) | |
| logger.info("[Understand] Parsed result: %s", {k: v[:80] + "..." if len(v) > 80 else v for k, v in result.items()}) | |
| return result | |