Nekochu commited on
Commit
6bfdc38
·
1 Parent(s): ff239f5

add understand_audio (LM reverse), demucs-infer fix, commit refs, dtype fixes

Browse files
Files changed (1) hide show
  1. train_engine.py +411 -0
train_engine.py CHANGED
@@ -4,6 +4,11 @@ Standalone ACE-Step LoRA Training Engine (CPU + GPU).
4
  Ported from Side-Step (koda-dernet/Side-Step) into a single self-contained
5
  module. No external Side-Step dependency required.
6
 
 
 
 
 
 
7
  Auto-detects GPU (CUDA > MPS > CPU) and uses it when available,
8
  falling back to CPU. bfloat16 is used on GPU; float32 is forced
9
  on CPU (bfloat16 deadlocks on CPU -- known PyTorch bug).
@@ -17,6 +22,7 @@ Exports:
17
  get_trained_loras() - List saved adapters
18
  generate_audio() - Standalone inference (text -> WAV, optional LoRA)
19
  tiled_vae_decode() - Tiled VAE latent-to-waveform decode
 
20
  """
21
 
22
  from __future__ import annotations
@@ -3086,3 +3092,408 @@ def generate_audio(
3086
 
3087
  logger.info("Audio saved to %s (%.1fs @ %d Hz)", output_path, duration, TARGET_SR)
3088
  return output_path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  Ported from Side-Step (koda-dernet/Side-Step) into a single self-contained
5
  module. No external Side-Step dependency required.
6
 
7
+ Source commits:
8
+ Side-Step: koda-dernet/Side-Step @ ecd13bd (2026-04-19)
9
+ acestep.cpp: ServeurpersoCom/acestep.cpp @ 36e4db1 (prompt/understand format)
10
+ ACE-Step: ace-step/ACE-Step-1.5 (model architecture, training_v2)
11
+
12
  Auto-detects GPU (CUDA > MPS > CPU) and uses it when available,
13
  falling back to CPU. bfloat16 is used on GPU; float32 is forced
14
  on CPU (bfloat16 deadlocks on CPU -- known PyTorch bug).
 
22
  get_trained_loras() - List saved adapters
23
  generate_audio() - Standalone inference (text -> WAV, optional LoRA)
24
  tiled_vae_decode() - Tiled VAE latent-to-waveform decode
25
+ understand_audio() - Reverse pipeline (audio -> caption + lyrics)
26
  """
27
 
28
  from __future__ import annotations
 
3092
 
3093
  logger.info("Audio saved to %s (%.1fs @ %d Hz)", output_path, duration, TARGET_SR)
3094
  return output_path
3095
+
3096
+
3097
+ # ============================================================================
3098
+ # UNDERSTAND MODE (reverse pipeline: audio -> caption + lyrics)
3099
+ # ============================================================================
3100
+
3101
+ # Qwen3 special token IDs (ACE-Step LM vocabulary)
3102
+ _TOKEN_IM_START = 151644
3103
+ _TOKEN_IM_END = 151645
3104
+ _TOKEN_THINK = 151667
3105
+ _TOKEN_THINK_END = 151668
3106
+ _AUDIO_CODE_BASE = 151669
3107
+
3108
+ # Understand system instruction (matches acestep.cpp task-types.h)
3109
+ _LM_UNDERSTAND_INSTRUCTION = (
3110
+ "Understand the given musical conditions and describe the audio semantics accordingly:"
3111
+ )
3112
+
3113
+
3114
+ def _build_understand_prompt(
3115
+ bpe_tokenizer, codes: List[int],
3116
+ ) -> List[int]:
3117
+ """Build the Qwen3 chat prompt for understand mode.
3118
+
3119
+ Format (matching C++ build_understand_prompt in prompt.h):
3120
+ <|im_start|>system
3121
+ # Instruction
3122
+ {LM_UNDERSTAND_INSTRUCTION}
3123
+
3124
+ <|im_end|>
3125
+ <|im_start|>user
3126
+ {audio_code_tokens}
3127
+ <|im_end|>
3128
+ <|im_start|>assistant
3129
+ """
3130
+ ids: List[int] = []
3131
+
3132
+ def append_text(text: str):
3133
+ encoded = bpe_tokenizer.encode(text, add_special_tokens=False)
3134
+ ids.extend(encoded)
3135
+
3136
+ ids.append(_TOKEN_IM_START)
3137
+ append_text(
3138
+ "system\n# Instruction\n"
3139
+ + _LM_UNDERSTAND_INSTRUCTION
3140
+ + "\n\n"
3141
+ )
3142
+ ids.append(_TOKEN_IM_END)
3143
+ append_text("\n")
3144
+ ids.append(_TOKEN_IM_START)
3145
+ append_text("user\n")
3146
+ # Audio codes as raw token IDs (not BPE text)
3147
+ for code in codes:
3148
+ ids.append(_AUDIO_CODE_BASE + code)
3149
+ append_text("\n")
3150
+ ids.append(_TOKEN_IM_END)
3151
+ append_text("\n")
3152
+ ids.append(_TOKEN_IM_START)
3153
+ append_text("assistant\n")
3154
+ return ids
3155
+
3156
+
3157
+ def _parse_understand_output(text: str) -> Dict[str, str]:
3158
+ """Parse CoT metadata + lyrics from understand LM output.
3159
+
3160
+ The LM generates:
3161
+ <think>
3162
+ bpm: 120
3163
+ caption: ...
3164
+ duration: 180
3165
+ keyscale: C major
3166
+ language: en
3167
+ timesignature: 4
3168
+ </think>
3169
+ [Verse 1]
3170
+ ...lyrics...
3171
+
3172
+ Returns dict with: caption, lyrics, bpm, key, signature, duration,
3173
+ language.
3174
+ """
3175
+ result: Dict[str, str] = {}
3176
+
3177
+ # Split at <think> / </think> boundaries
3178
+ cot = ""
3179
+ lyrics_after = ""
3180
+ ts = text.find("<think>")
3181
+ te = text.find("</think>")
3182
+
3183
+ if ts != -1 and te != -1:
3184
+ cot = text[ts + 7:te]
3185
+ lyrics_after = text[te + 8:]
3186
+ elif te != -1:
3187
+ cot = text[:te]
3188
+ lyrics_after = text[te + 8:]
3189
+ else:
3190
+ cot = text
3191
+
3192
+ # Parse YAML-like fields from CoT
3193
+ def get_field(key: str) -> str:
3194
+ needle = key + ":"
3195
+ p = cot.find(needle)
3196
+ if p == -1:
3197
+ return ""
3198
+ p += len(needle)
3199
+ # Skip leading whitespace and quotes
3200
+ while p < len(cot) and cot[p] in (" ", "'"):
3201
+ p += 1
3202
+ end = cot.find("\n", p)
3203
+ if end == -1:
3204
+ end = len(cot)
3205
+ val = cot[p:end].rstrip(" '\r")
3206
+ return val
3207
+
3208
+ bpm_s = get_field("bpm")
3209
+ if bpm_s:
3210
+ result["bpm"] = bpm_s
3211
+
3212
+ dur_s = get_field("duration")
3213
+ if dur_s:
3214
+ result["duration"] = dur_s
3215
+
3216
+ ks = get_field("keyscale")
3217
+ if ks:
3218
+ result["key"] = ks
3219
+
3220
+ ts_s = get_field("timesignature")
3221
+ if ts_s:
3222
+ result["signature"] = ts_s
3223
+
3224
+ lang = get_field("language")
3225
+ if lang:
3226
+ result["language"] = lang
3227
+
3228
+ # Caption may span multiple lines (YAML word-wrap)
3229
+ cap_needle = "caption:"
3230
+ cp = cot.find(cap_needle)
3231
+ if cp != -1:
3232
+ cp += len(cap_needle)
3233
+ # Read until next known field or end of CoT
3234
+ end = len(cot)
3235
+ for next_field in ("duration:", "keyscale:", "language:", "timesignature:", "bpm:"):
3236
+ nf = cot.find("\n" + next_field, cp)
3237
+ if nf != -1 and nf < end:
3238
+ end = nf
3239
+ full_cap = cot[cp:end]
3240
+ # Collapse whitespace
3241
+ cleaned = " ".join(full_cap.split()).strip()
3242
+ if cleaned:
3243
+ result["caption"] = cleaned
3244
+
3245
+ # Lyrics after </think>
3246
+ if lyrics_after:
3247
+ lyrics = lyrics_after.strip()
3248
+ # Strip "# Lyric\n" header the LM may echo back
3249
+ lp = lyrics.find("# Lyric\n")
3250
+ if lp != -1 and lp < 64:
3251
+ lyrics = lyrics[lp + 8:]
3252
+ lyrics = lyrics.strip()
3253
+ if lyrics:
3254
+ result["lyrics"] = lyrics
3255
+
3256
+ return result
3257
+
3258
+
3259
+ def understand_audio(
3260
+ audio_path: str,
3261
+ checkpoint_dir: str,
3262
+ device: str = "auto",
3263
+ variant: str = "turbo",
3264
+ temperature: float = 0.3,
3265
+ top_p: float = 0.0,
3266
+ top_k: int = 0,
3267
+ max_new_tokens: int = 4096,
3268
+ ) -> Dict[str, str]:
3269
+ """Extract caption, lyrics, BPM, key, signature from audio using the LM.
3270
+
3271
+ Pipeline: audio -> VAE encode -> FSQ tokenize -> LM understand -> text
3272
+ Returns dict with: caption, lyrics, bpm, key, signature, duration,
3273
+ language.
3274
+
3275
+ Args:
3276
+ audio_path: Path to input audio file (WAV, MP3, FLAC, etc.)
3277
+ checkpoint_dir: Path to ACE-Step checkpoints root directory
3278
+ (must contain vae/, acestep-v15-turbo/ or variant subdir,
3279
+ and acestep-5Hz-lm-1.7B/).
3280
+ device: Device string ("auto", "cuda:0", "cpu", etc.)
3281
+ variant: DiT variant to load for FSQ tokenizer ("turbo", "sft",
3282
+ "base", etc.)
3283
+ temperature: LM sampling temperature (default 0.3, lower = more
3284
+ deterministic).
3285
+ top_p: Nucleus sampling cutoff (0.0 = disabled).
3286
+ top_k: Top-K sampling (0 = disabled).
3287
+ max_new_tokens: Maximum tokens to generate.
3288
+
3289
+ Returns:
3290
+ Dict with extracted metadata. Keys may include:
3291
+ caption, lyrics, bpm, key, signature, duration, language.
3292
+ """
3293
+ device = detect_device(device)
3294
+ dtype = select_dtype(device)
3295
+ ckpt = Path(checkpoint_dir).resolve()
3296
+
3297
+ # ------------------------------------------------------------------
3298
+ # Step 1: Load audio -> VAE encode -> latents [1, T_25Hz, 64]
3299
+ # ------------------------------------------------------------------
3300
+ logger.info("[Understand] Step 1: VAE encode")
3301
+ audio, sr = load_audio_stereo(audio_path, TARGET_SR, MAX_AUDIO_DURATION)
3302
+ audio = audio.unsqueeze(0) # [1, 2, samples]
3303
+ logger.info(
3304
+ "[Understand] Audio loaded: %.1fs, %d samples @ %d Hz",
3305
+ audio.shape[-1] / TARGET_SR, audio.shape[-1], TARGET_SR,
3306
+ )
3307
+
3308
+ vae = load_vae(checkpoint_dir, device)
3309
+ latents = tiled_vae_encode(vae, audio, dtype) # [1, T_25Hz, 64]
3310
+ T_25Hz = latents.shape[1]
3311
+ logger.info("[Understand] VAE encoded: %d latent frames (%.2fs)", T_25Hz, T_25Hz * 1920.0 / TARGET_SR)
3312
+
3313
+ unload_models(vae)
3314
+ del vae, audio
3315
+ gc.collect()
3316
+ _clear_gpu_cache(device)
3317
+ logger.info("[Understand] VAE unloaded")
3318
+
3319
+ # ------------------------------------------------------------------
3320
+ # Step 2: Load DiT (for FSQ tokenizer) -> tokenize latents -> codes
3321
+ # ------------------------------------------------------------------
3322
+ logger.info("[Understand] Step 2: FSQ tokenize")
3323
+
3324
+ # Load silence_latent for padding
3325
+ silence_latent = load_silence_latent(checkpoint_dir, device="cpu", variant=variant)
3326
+
3327
+ # Load DiT model (only need its tokenizer submodule)
3328
+ model = load_model_for_training(checkpoint_dir, variant=variant, device=device)
3329
+ model = model.to(dtype=dtype)
3330
+ pool_window = model.config.pool_window_size # 5 (25Hz -> 5Hz)
3331
+
3332
+ # Pad latents to multiple of pool_window_size
3333
+ lat = latents.to(device=device, dtype=dtype)
3334
+ pad_len = 0
3335
+ if T_25Hz % pool_window != 0:
3336
+ pad_len = pool_window - (T_25Hz % pool_window)
3337
+ # Use silence_latent for padding
3338
+ sl = silence_latent[:1, :pad_len, :].to(device=device, dtype=dtype)
3339
+ lat = torch.cat([lat, sl.expand(lat.shape[0], -1, -1)], dim=1)
3340
+
3341
+ # Tokenize: lat [1, T_padded, 64] -> indices [1, T_5Hz, 1]
3342
+ with torch.inference_mode():
3343
+ _quantized, indices = model.tokenizer.tokenize(lat)
3344
+
3345
+ # indices shape: [1, T_5Hz, num_quantizers=1] -> flatten to [T_5Hz]
3346
+ codes = indices.squeeze(0).squeeze(-1).cpu().tolist() # List[int]
3347
+ T_5Hz = len(codes)
3348
+ logger.info(
3349
+ "[Understand] FSQ tokenized: %d codes (%.2fs @ 5Hz)",
3350
+ T_5Hz, T_5Hz / 5.0,
3351
+ )
3352
+
3353
+ unload_models(model)
3354
+ del model, lat, latents, _quantized, indices, silence_latent
3355
+ gc.collect()
3356
+ _clear_gpu_cache(device)
3357
+ logger.info("[Understand] DiT unloaded")
3358
+
3359
+ # ------------------------------------------------------------------
3360
+ # Step 3: Load LM -> build understand prompt -> generate text
3361
+ # ------------------------------------------------------------------
3362
+ logger.info("[Understand] Step 3: LM generation")
3363
+
3364
+ lm_path = ckpt / "acestep-5Hz-lm-1.7B"
3365
+ if not lm_path.is_dir():
3366
+ raise FileNotFoundError(f"LM checkpoint not found: {lm_path}")
3367
+
3368
+ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
3369
+
3370
+ # Load BPE tokenizer
3371
+ bpe_tokenizer = AutoTokenizer.from_pretrained(str(lm_path))
3372
+
3373
+ # Build understand prompt
3374
+ prompt_ids = _build_understand_prompt(bpe_tokenizer, codes)
3375
+ logger.info(
3376
+ "[Understand] Prompt: %d tokens (%d codes + framing)",
3377
+ len(prompt_ids), len(codes),
3378
+ )
3379
+
3380
+ # Load the LM (Qwen3Model with tied embeddings).
3381
+ # Config says "Qwen3Model" but we need generation (lm_head). Since
3382
+ # tie_word_embeddings=true, Qwen3ForCausalLM will tie the lm_head
3383
+ # to embed_tokens automatically. We override the architecture to load
3384
+ # as CausalLM.
3385
+ from transformers import Qwen3Config
3386
+ lm_config = Qwen3Config.from_pretrained(str(lm_path))
3387
+ lm_config.architectures = ["Qwen3ForCausalLM"]
3388
+
3389
+ lm_dtype = select_dtype(device)
3390
+ lm_model = AutoModelForCausalLM.from_pretrained(
3391
+ str(lm_path),
3392
+ config=lm_config,
3393
+ torch_dtype=lm_dtype,
3394
+ low_cpu_mem_usage=True,
3395
+ )
3396
+ lm_model = lm_model.to(device=device)
3397
+ lm_model.eval()
3398
+ logger.info("[Understand] LM loaded on %s (dtype=%s)", device, lm_dtype)
3399
+
3400
+ vocab_size = lm_config.vocab_size # 217204
3401
+
3402
+ # Autoregressive decode: no CFG, no batch, single sequence.
3403
+ # FSM is not implemented in Python (would require the prefix tree);
3404
+ # the LM generates structured CoT well enough without it at low temp.
3405
+ prompt_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
3406
+
3407
+ with torch.inference_mode():
3408
+ # Prefill
3409
+ outputs = lm_model(input_ids=prompt_tensor, use_cache=True)
3410
+ logits = outputs.logits[:, -1, :] # [1, vocab_size]
3411
+ past_kv = outputs.past_key_values
3412
+
3413
+ gen_tokens: List[int] = []
3414
+ past_think = False
3415
+
3416
+ for step in range(max_new_tokens):
3417
+ # After </think>: block audio codes so the LM only generates text
3418
+ if past_think:
3419
+ logits[0, _AUDIO_CODE_BASE:] = float("-inf")
3420
+
3421
+ # Sample
3422
+ if temperature <= 0:
3423
+ next_id = int(logits[0].argmax().item())
3424
+ else:
3425
+ scaled = logits[0] / temperature
3426
+ if top_k > 0:
3427
+ # Zero out everything below top_k
3428
+ topk_vals, _ = torch.topk(scaled, min(top_k, scaled.shape[0]))
3429
+ scaled[scaled < topk_vals[-1]] = float("-inf")
3430
+ if top_p > 0 and top_p < 1.0:
3431
+ sorted_logits, sorted_idx = torch.sort(scaled, descending=True)
3432
+ probs = torch.softmax(sorted_logits, dim=-1)
3433
+ cumsum = torch.cumsum(probs, dim=-1)
3434
+ mask = cumsum - probs > top_p
3435
+ sorted_logits[mask] = float("-inf")
3436
+ # Scatter masked values back to original positions
3437
+ scaled = torch.zeros_like(scaled).scatter(0, sorted_idx, sorted_logits)
3438
+ probs = torch.softmax(scaled, dim=-1)
3439
+ next_id = int(torch.multinomial(probs, 1).item())
3440
+
3441
+ if next_id == _TOKEN_IM_END:
3442
+ break
3443
+
3444
+ if next_id == _TOKEN_THINK_END:
3445
+ past_think = True
3446
+
3447
+ gen_tokens.append(next_id)
3448
+
3449
+ # Next step
3450
+ next_input = torch.tensor([[next_id]], dtype=torch.long, device=device)
3451
+ with torch.inference_mode():
3452
+ outputs = lm_model(
3453
+ input_ids=next_input,
3454
+ past_key_values=past_kv,
3455
+ use_cache=True,
3456
+ )
3457
+ logits = outputs.logits[:, -1, :]
3458
+ past_kv = outputs.past_key_values
3459
+
3460
+ logger.info("[Understand] Generated %d tokens", len(gen_tokens))
3461
+
3462
+ # Decode tokens to text (skip audio code tokens and special tokens)
3463
+ text_tokens = [
3464
+ t for t in gen_tokens
3465
+ if t < _AUDIO_CODE_BASE and t not in (
3466
+ _TOKEN_IM_START, _TOKEN_IM_END, _TOKEN_THINK, _TOKEN_THINK_END,
3467
+ )
3468
+ ]
3469
+ generated_text = bpe_tokenizer.decode(text_tokens, skip_special_tokens=False)
3470
+
3471
+ # Re-insert <think> / </think> markers for the parser
3472
+ think_text = ""
3473
+ in_think = False
3474
+ for t in gen_tokens:
3475
+ if t == _TOKEN_THINK:
3476
+ think_text += "<think>"
3477
+ in_think = True
3478
+ elif t == _TOKEN_THINK_END:
3479
+ think_text += "</think>"
3480
+ in_think = False
3481
+ elif t < _AUDIO_CODE_BASE and t not in (_TOKEN_IM_START, _TOKEN_IM_END):
3482
+ think_text += bpe_tokenizer.decode([t], skip_special_tokens=False)
3483
+
3484
+ logger.info("[Understand] Raw output:\n%s", think_text[:500])
3485
+
3486
+ # Unload LM
3487
+ del outputs, logits, past_kv, prompt_tensor
3488
+ unload_models(lm_model)
3489
+ del lm_model, bpe_tokenizer
3490
+ gc.collect()
3491
+ _clear_gpu_cache(device)
3492
+ logger.info("[Understand] LM unloaded")
3493
+
3494
+ # ------------------------------------------------------------------
3495
+ # Step 4: Parse generated text into structured fields
3496
+ # ------------------------------------------------------------------
3497
+ result = _parse_understand_output(think_text)
3498
+ logger.info("[Understand] Parsed result: %s", {k: v[:80] + "..." if len(v) > 80 else v for k, v in result.items()})
3499
+ return result