Nekochu commited on
Commit
ff9f4ad
·
1 Parent(s): 3c15b8b

fix inference: add LM generation step, detokenize codes before DiT, full pipeline working

Browse files
Files changed (1) hide show
  1. train_engine.py +302 -21
train_engine.py CHANGED
@@ -2899,18 +2899,25 @@ def generate_audio(
2899
  device: str = "auto",
2900
  adapter_path: Optional[str] = None,
2901
  adapter_scale: float = 1.0,
 
 
 
 
2902
  ) -> str:
2903
- """Generate audio using the ACE-Step DiT pipeline (pure PyTorch, no server).
2904
 
2905
  Pipeline:
2906
- 1. Text encoder -> text_hidden_states, lyric embeddings
2907
- 2. Load full model (DiT + condition encoder + FSQ)
2908
- 3. Optional: inject LoRA adapter via PEFT
2909
- 4. model.generate_audio() -- runs condition encoder, FSQ detokenizer,
2910
- and the flow-matching diffusion loop internally
2911
- 5. VAE decode latents -> waveform
2912
- 6. Save waveform as 48 kHz stereo WAV
2913
- 7. Unload all models, free memory
 
 
 
2914
 
2915
  Args:
2916
  caption: Text description of the desired music.
@@ -2926,6 +2933,11 @@ def generate_audio(
2926
  device: ``"auto"``, ``"cpu"``, ``"cuda:0"``, etc.
2927
  adapter_path: Path to a PEFT LoRA adapter directory (optional).
2928
  adapter_scale: Scaling factor applied to the adapter.
 
 
 
 
 
2929
 
2930
  Returns:
2931
  The *output_path* string (for convenience).
@@ -2938,8 +2950,8 @@ def generate_audio(
2938
  device = detect_device(device)
2939
  dtype = select_dtype(device)
2940
  logger.info(
2941
- "generate_audio: device=%s, dtype=%s, variant=%s, steps=%d, duration=%.1fs",
2942
- device, dtype, variant, steps, duration,
2943
  )
2944
 
2945
  # Resolve seed
@@ -2948,7 +2960,34 @@ def generate_audio(
2948
  logger.info("Using seed=%d", seed)
2949
 
2950
  # ------------------------------------------------------------------
2951
- # 1. Text encoder -- encode caption and lyrics
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2952
  # ------------------------------------------------------------------
2953
  logger.info("Loading text encoder...")
2954
  tokenizer, text_encoder = load_text_encoder(checkpoint_dir, device)
@@ -2964,7 +3003,7 @@ def generate_audio(
2964
  logger.info("Text encoder unloaded.")
2965
 
2966
  # ------------------------------------------------------------------
2967
- # 2. Load full model (DiT + CondEncoder + FSQ tokenizer/detokenizer)
2968
  # ------------------------------------------------------------------
2969
  logger.info("Loading ACE-Step model (%s)...", variant)
2970
  model = load_model_for_training(checkpoint_dir, variant=variant, device=device)
@@ -2972,7 +3011,7 @@ def generate_audio(
2972
  model.eval()
2973
 
2974
  # ------------------------------------------------------------------
2975
- # 3. Optional: inject LoRA adapter
2976
  # ------------------------------------------------------------------
2977
  if adapter_path:
2978
  logger.info("Loading LoRA adapter from %s (scale=%.2f)...", adapter_path, adapter_scale)
@@ -3001,7 +3040,7 @@ def generate_audio(
3001
  logger.info("LoRA adapter applied.")
3002
 
3003
  # ------------------------------------------------------------------
3004
- # 4. Prepare inputs for model.generate_audio()
3005
  # ------------------------------------------------------------------
3006
  # Latent frame rate is 25 Hz
3007
  LATENT_HZ = 25
@@ -3015,10 +3054,31 @@ def generate_audio(
3015
  silence_latent = silence_latent.repeat(1, repeats, 1)
3016
  silence_latent = silence_latent[:, :latent_length, :].to(device=device, dtype=dtype)
3017
 
3018
- # Build source latents and masks for text2music mode (all silence, all-ones mask)
3019
  src_latents = silence_latent[:1, :latent_length, :]
3020
  chunk_masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype)
3021
- is_covers = torch.zeros(1, device=device, dtype=torch.long)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3022
 
3023
  # Dummy timbre reference (single silence frame -> no timbre conditioning)
3024
  refer_audio = torch.zeros(1, 1, 64, device=device, dtype=dtype)
@@ -3028,7 +3088,7 @@ def generate_audio(
3028
  shift = 3.0 if "turbo" in variant else 1.0
3029
 
3030
  # ------------------------------------------------------------------
3031
- # 5. Run diffusion (model.generate_audio handles everything internally)
3032
  # ------------------------------------------------------------------
3033
  logger.info("Running diffusion (%d steps, shift=%.1f)...", steps, shift)
3034
  with torch.no_grad():
@@ -3062,7 +3122,7 @@ def generate_audio(
3062
  logger.info("DiT model unloaded.")
3063
 
3064
  # ------------------------------------------------------------------
3065
- # 6. VAE decode latents -> waveform
3066
  # ------------------------------------------------------------------
3067
  logger.info("Loading VAE decoder...")
3068
  vae = load_vae(checkpoint_dir, device)
@@ -3077,7 +3137,7 @@ def generate_audio(
3077
  logger.info("VAE unloaded.")
3078
 
3079
  # ------------------------------------------------------------------
3080
- # 7. Save as WAV (48 kHz stereo)
3081
  # ------------------------------------------------------------------
3082
  audio_np = waveform[0].float().clamp(-1.0, 1.0).cpu().numpy() # [2, samples]
3083
 
@@ -3105,7 +3165,10 @@ _TOKEN_THINK = 151667
3105
  _TOKEN_THINK_END = 151668
3106
  _AUDIO_CODE_BASE = 151669
3107
 
3108
- # Understand system instruction (matches acestep.cpp task-types.h)
 
 
 
3109
  _LM_UNDERSTAND_INSTRUCTION = (
3110
  "Understand the given musical conditions and describe the audio semantics accordingly:"
3111
  )
@@ -3154,6 +3217,224 @@ def _build_understand_prompt(
3154
  return ids
3155
 
3156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3157
  def _parse_understand_output(text: str) -> Dict[str, str]:
3158
  """Parse CoT metadata + lyrics from understand LM output.
3159
 
 
2899
  device: str = "auto",
2900
  adapter_path: Optional[str] = None,
2901
  adapter_scale: float = 1.0,
2902
+ use_lm: bool = True,
2903
+ lm_temperature: float = 0.85,
2904
+ lm_top_p: float = 0.9,
2905
+ lm_top_k: int = 0,
2906
  ) -> str:
2907
+ """Generate audio using the full ACE-Step pipeline (LM + DiT).
2908
 
2909
  Pipeline:
2910
+ 1. LM (Qwen3 1.7B) generates CoT metadata + audio codes from
2911
+ caption and lyrics
2912
+ 2. Text encoder -> text_hidden_states, lyric embeddings
2913
+ 3. Load full model (DiT + condition encoder + FSQ)
2914
+ 4. Optional: inject LoRA adapter via PEFT
2915
+ 5. model.generate_audio() -- uses LM audio codes as context
2916
+ conditioning via the FSQ detokenizer, then runs flow-matching
2917
+ diffusion
2918
+ 6. VAE decode latents -> waveform
2919
+ 7. Save waveform as 48 kHz stereo WAV
2920
+ 8. Unload all models, free memory
2921
 
2922
  Args:
2923
  caption: Text description of the desired music.
 
2933
  device: ``"auto"``, ``"cpu"``, ``"cuda:0"``, etc.
2934
  adapter_path: Path to a PEFT LoRA adapter directory (optional).
2935
  adapter_scale: Scaling factor applied to the adapter.
2936
+ use_lm: Run the LM to generate audio codes (True) or skip
2937
+ and use silence context like before (False).
2938
+ lm_temperature: LM sampling temperature.
2939
+ lm_top_p: LM nucleus sampling cutoff.
2940
+ lm_top_k: LM top-K sampling (0 = disabled).
2941
 
2942
  Returns:
2943
  The *output_path* string (for convenience).
 
2950
  device = detect_device(device)
2951
  dtype = select_dtype(device)
2952
  logger.info(
2953
+ "generate_audio: device=%s, dtype=%s, variant=%s, steps=%d, duration=%.1fs, use_lm=%s",
2954
+ device, dtype, variant, steps, duration, use_lm,
2955
  )
2956
 
2957
  # Resolve seed
 
2960
  logger.info("Using seed=%d", seed)
2961
 
2962
  # ------------------------------------------------------------------
2963
+ # 1. LM generation -- produce audio codes from caption + lyrics
2964
+ # ------------------------------------------------------------------
2965
+ audio_codes_list: Optional[List[int]] = None
2966
+ if use_lm:
2967
+ logger.info("Running LM to generate audio codes...")
2968
+ audio_codes_list = _generate_codes_with_lm(
2969
+ checkpoint_dir=checkpoint_dir,
2970
+ caption=caption,
2971
+ lyrics=lyrics,
2972
+ duration=duration,
2973
+ device=device,
2974
+ temperature=lm_temperature,
2975
+ top_p=lm_top_p,
2976
+ top_k=lm_top_k,
2977
+ )
2978
+ if audio_codes_list:
2979
+ # The LM determines the actual duration via its code count
2980
+ lm_duration = len(audio_codes_list) / 5.0
2981
+ logger.info(
2982
+ "LM generated %d codes (%.1fs). Overriding duration %.1f -> %.1f",
2983
+ len(audio_codes_list), lm_duration, duration, lm_duration,
2984
+ )
2985
+ duration = lm_duration
2986
+ else:
2987
+ logger.warning("LM produced no codes, falling back to silence context.")
2988
+
2989
+ # ------------------------------------------------------------------
2990
+ # 2. Text encoder -- encode caption and lyrics
2991
  # ------------------------------------------------------------------
2992
  logger.info("Loading text encoder...")
2993
  tokenizer, text_encoder = load_text_encoder(checkpoint_dir, device)
 
3003
  logger.info("Text encoder unloaded.")
3004
 
3005
  # ------------------------------------------------------------------
3006
+ # 3. Load full model (DiT + CondEncoder + FSQ tokenizer/detokenizer)
3007
  # ------------------------------------------------------------------
3008
  logger.info("Loading ACE-Step model (%s)...", variant)
3009
  model = load_model_for_training(checkpoint_dir, variant=variant, device=device)
 
3011
  model.eval()
3012
 
3013
  # ------------------------------------------------------------------
3014
+ # 4. Optional: inject LoRA adapter
3015
  # ------------------------------------------------------------------
3016
  if adapter_path:
3017
  logger.info("Loading LoRA adapter from %s (scale=%.2f)...", adapter_path, adapter_scale)
 
3040
  logger.info("LoRA adapter applied.")
3041
 
3042
  # ------------------------------------------------------------------
3043
+ # 5. Prepare inputs for model.generate_audio()
3044
  # ------------------------------------------------------------------
3045
  # Latent frame rate is 25 Hz
3046
  LATENT_HZ = 25
 
3054
  silence_latent = silence_latent.repeat(1, repeats, 1)
3055
  silence_latent = silence_latent[:, :latent_length, :].to(device=device, dtype=dtype)
3056
 
3057
+ # Build source latents and masks
3058
  src_latents = silence_latent[:1, :latent_length, :]
3059
  chunk_masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype)
3060
+
3061
+ # Detokenize LM audio codes into context latents for the DiT
3062
+ if audio_codes_list:
3063
+ indices_tensor = torch.tensor(
3064
+ audio_codes_list, dtype=torch.long, device=device,
3065
+ ).unsqueeze(0).unsqueeze(-1) # [1, T_5Hz, 1]
3066
+ with torch.no_grad():
3067
+ lm_latents = model.tokenizer.quantizer.get_output_from_indices(indices_tensor)
3068
+ # lm_latents: [1, T_5Hz, codebook_dim] -> detokenize to [1, T_25Hz, 64]
3069
+ lm_latents = model.detokenize(lm_latents)
3070
+ T_lm = lm_latents.shape[1]
3071
+ # Use LM latents as src_latents context
3072
+ if T_lm < latent_length:
3073
+ pad = silence_latent[:, :latent_length - T_lm, :]
3074
+ src_latents = torch.cat([lm_latents, pad], dim=1)
3075
+ else:
3076
+ src_latents = lm_latents[:, :latent_length, :]
3077
+ chunk_masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype)
3078
+ is_covers = torch.ones(1, device=device, dtype=torch.long)
3079
+ logger.info("LM codes detokenized: %d codes -> %d latent frames, used as DiT context", len(audio_codes_list), T_lm)
3080
+ else:
3081
+ is_covers = torch.zeros(1, device=device, dtype=torch.long)
3082
 
3083
  # Dummy timbre reference (single silence frame -> no timbre conditioning)
3084
  refer_audio = torch.zeros(1, 1, 64, device=device, dtype=dtype)
 
3088
  shift = 3.0 if "turbo" in variant else 1.0
3089
 
3090
  # ------------------------------------------------------------------
3091
+ # 6. Run diffusion (model.generate_audio handles everything internally)
3092
  # ------------------------------------------------------------------
3093
  logger.info("Running diffusion (%d steps, shift=%.1f)...", steps, shift)
3094
  with torch.no_grad():
 
3122
  logger.info("DiT model unloaded.")
3123
 
3124
  # ------------------------------------------------------------------
3125
+ # 7. VAE decode latents -> waveform
3126
  # ------------------------------------------------------------------
3127
  logger.info("Loading VAE decoder...")
3128
  vae = load_vae(checkpoint_dir, device)
 
3137
  logger.info("VAE unloaded.")
3138
 
3139
  # ------------------------------------------------------------------
3140
+ # 8. Save as WAV (48 kHz stereo)
3141
  # ------------------------------------------------------------------
3142
  audio_np = waveform[0].float().clamp(-1.0, 1.0).cpu().numpy() # [2, samples]
3143
 
 
3165
  _TOKEN_THINK_END = 151668
3166
  _AUDIO_CODE_BASE = 151669
3167
 
3168
+ # LM system instructions (matches acestep.cpp task-types.h)
3169
+ _LM_GENERATE_INSTRUCTION = (
3170
+ "Generate audio semantic tokens based on the given conditions:"
3171
+ )
3172
  _LM_UNDERSTAND_INSTRUCTION = (
3173
  "Understand the given musical conditions and describe the audio semantics accordingly:"
3174
  )
 
3217
  return ids
3218
 
3219
 
3220
+ def _build_generate_prompt(
3221
+ bpe_tokenizer, caption: str, lyrics: str,
3222
+ ) -> List[int]:
3223
+ """Build the Qwen3 chat prompt for audio code generation.
3224
+
3225
+ Format (matching C++ build_lm_prompt in prompt.h):
3226
+ <|im_start|>system
3227
+ # Instruction
3228
+ {LM_GENERATE_INSTRUCTION}
3229
+
3230
+ <|im_end|>
3231
+ <|im_start|>user
3232
+ # Caption
3233
+ {caption}
3234
+
3235
+ # Lyric
3236
+ {lyrics}
3237
+ <|im_end|>
3238
+ <|im_start|>assistant
3239
+ """
3240
+ ids: List[int] = []
3241
+
3242
+ def append_text(text: str):
3243
+ encoded = bpe_tokenizer.encode(text, add_special_tokens=False)
3244
+ ids.extend(encoded)
3245
+
3246
+ ids.append(_TOKEN_IM_START)
3247
+ append_text(
3248
+ "system\n# Instruction\n"
3249
+ + _LM_GENERATE_INSTRUCTION
3250
+ + "\n\n"
3251
+ )
3252
+ ids.append(_TOKEN_IM_END)
3253
+ append_text("\n")
3254
+ ids.append(_TOKEN_IM_START)
3255
+ append_text(
3256
+ "user\n# Caption\n" + caption + "\n\n"
3257
+ "# Lyric\n" + lyrics + "\n"
3258
+ )
3259
+ ids.append(_TOKEN_IM_END)
3260
+ append_text("\n")
3261
+ ids.append(_TOKEN_IM_START)
3262
+ append_text("assistant\n")
3263
+ return ids
3264
+
3265
+
3266
+ def _generate_codes_with_lm(
3267
+ checkpoint_dir: str,
3268
+ caption: str,
3269
+ lyrics: str,
3270
+ duration: float,
3271
+ device: str,
3272
+ temperature: float = 0.85,
3273
+ top_p: float = 0.9,
3274
+ top_k: int = 0,
3275
+ max_new_tokens: int = 8192,
3276
+ ) -> List[int]:
3277
+ """Run the ACE-Step LM (Qwen3 1.7B) to generate audio codes from text.
3278
+
3279
+ The LM generates in two phases within a single autoregressive pass:
3280
+ Phase 1 (CoT): <think> metadata YAML (bpm, duration, key, etc.) </think>
3281
+ Phase 2 (codes): audio code tokens (token_id >= AUDIO_CODE_BASE)
3282
+
3283
+ Args:
3284
+ checkpoint_dir: Root directory containing acestep-5Hz-lm-1.7B/.
3285
+ caption: Text description of the music.
3286
+ lyrics: Lyrics text or "[Instrumental]".
3287
+ duration: Target duration in seconds (the LM may override via CoT).
3288
+ device: Torch device string.
3289
+ temperature: Sampling temperature.
3290
+ top_p: Nucleus sampling cutoff (0.0 = disabled).
3291
+ top_k: Top-K sampling (0 = disabled).
3292
+ max_new_tokens: Maximum tokens to generate.
3293
+
3294
+ Returns:
3295
+ List of FSQ code indices (0-63999 range, NOT offset by AUDIO_CODE_BASE).
3296
+ Length is approximately duration * 5 (5 Hz token rate).
3297
+ """
3298
+ ckpt = Path(checkpoint_dir).resolve()
3299
+ lm_path = ckpt / "acestep-5Hz-lm-1.7B"
3300
+ if not lm_path.is_dir():
3301
+ raise FileNotFoundError(f"LM checkpoint not found: {lm_path}")
3302
+
3303
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3304
+
3305
+ # Load BPE tokenizer
3306
+ bpe_tokenizer = AutoTokenizer.from_pretrained(str(lm_path))
3307
+
3308
+ # Build generation prompt
3309
+ prompt_ids = _build_generate_prompt(bpe_tokenizer, caption, lyrics)
3310
+ logger.info(
3311
+ "[LM Generate] Prompt: %d tokens, caption=%r, lyrics=%r",
3312
+ len(prompt_ids), caption[:80], lyrics[:80],
3313
+ )
3314
+
3315
+ # Load the LM (Qwen3Model with tied embeddings -> CausalLM)
3316
+ from transformers import Qwen3Config
3317
+ lm_config = Qwen3Config.from_pretrained(str(lm_path))
3318
+ lm_config.architectures = ["Qwen3ForCausalLM"]
3319
+
3320
+ lm_dtype = select_dtype(device)
3321
+ lm_model = AutoModelForCausalLM.from_pretrained(
3322
+ str(lm_path),
3323
+ config=lm_config,
3324
+ torch_dtype=lm_dtype,
3325
+ low_cpu_mem_usage=True,
3326
+ )
3327
+ lm_model = lm_model.to(device=device)
3328
+ lm_model.eval()
3329
+ logger.info("[LM Generate] LM loaded on %s (dtype=%s)", device, lm_dtype)
3330
+
3331
+ # Autoregressive decode: single sequence, no CFG.
3332
+ prompt_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
3333
+
3334
+ with torch.inference_mode():
3335
+ outputs = lm_model(input_ids=prompt_tensor, use_cache=True)
3336
+ logits = outputs.logits[:, -1, :] # [1, vocab_size]
3337
+ past_kv = outputs.past_key_values
3338
+
3339
+ gen_tokens: List[int] = []
3340
+ audio_codes: List[int] = []
3341
+ past_think = False
3342
+ in_think = False
3343
+
3344
+ for step in range(max_new_tokens):
3345
+ logits = logits.clone()
3346
+
3347
+ # Phase 1 (inside <think>): block audio codes so only text is generated
3348
+ if in_think:
3349
+ logits[0, _AUDIO_CODE_BASE:] = float("-inf")
3350
+
3351
+ # Phase 2 (after </think>): only allow audio codes + im_end
3352
+ if past_think:
3353
+ # Zero out all non-audio-code logits except im_end (stop token)
3354
+ mask = torch.full_like(logits[0], float("-inf"))
3355
+ mask[_AUDIO_CODE_BASE:] = 0.0
3356
+ mask[_TOKEN_IM_END] = 0.0
3357
+ logits[0] = logits[0] + mask
3358
+
3359
+ # Sample
3360
+ if temperature <= 0:
3361
+ next_id = int(logits[0].argmax().item())
3362
+ else:
3363
+ scaled = logits[0].clone() / temperature
3364
+ if top_k > 0:
3365
+ topk_vals, _ = torch.topk(scaled, min(top_k, scaled.shape[0]))
3366
+ scaled[scaled < topk_vals[-1]] = float("-inf")
3367
+ if top_p > 0 and top_p < 1.0:
3368
+ sorted_logits, sorted_idx = torch.sort(scaled, descending=True)
3369
+ probs = torch.softmax(sorted_logits, dim=-1)
3370
+ cumsum = torch.cumsum(probs, dim=-1)
3371
+ nucleus_mask = cumsum - probs > top_p
3372
+ sorted_logits[nucleus_mask] = float("-inf")
3373
+ scaled = torch.zeros_like(scaled).scatter(0, sorted_idx, sorted_logits)
3374
+ probs = torch.softmax(scaled, dim=-1)
3375
+ next_id = int(torch.multinomial(probs, 1).item())
3376
+
3377
+ # Stop on im_end
3378
+ if next_id == _TOKEN_IM_END:
3379
+ break
3380
+
3381
+ # Track think state transitions
3382
+ if next_id == _TOKEN_THINK:
3383
+ in_think = True
3384
+ elif next_id == _TOKEN_THINK_END:
3385
+ in_think = False
3386
+ past_think = True
3387
+
3388
+ gen_tokens.append(next_id)
3389
+
3390
+ # Collect audio codes (Phase 2 tokens)
3391
+ if next_id >= _AUDIO_CODE_BASE:
3392
+ audio_codes.append(next_id - _AUDIO_CODE_BASE)
3393
+
3394
+ # Next step
3395
+ next_input = torch.tensor([[next_id]], dtype=torch.long, device=device)
3396
+ with torch.inference_mode():
3397
+ outputs = lm_model(
3398
+ input_ids=next_input,
3399
+ past_key_values=past_kv,
3400
+ use_cache=True,
3401
+ )
3402
+ logits = outputs.logits[:, -1, :]
3403
+ past_kv = outputs.past_key_values
3404
+
3405
+ # Log what the LM generated
3406
+ cot_tokens = [
3407
+ t for t in gen_tokens
3408
+ if t < _AUDIO_CODE_BASE and t not in (
3409
+ _TOKEN_IM_START, _TOKEN_IM_END, _TOKEN_THINK, _TOKEN_THINK_END,
3410
+ )
3411
+ ]
3412
+ if cot_tokens:
3413
+ cot_text = bpe_tokenizer.decode(cot_tokens, skip_special_tokens=False)
3414
+ logger.info("[LM Generate] CoT output:\n%s", cot_text[:500])
3415
+
3416
+ logger.info(
3417
+ "[LM Generate] Generated %d total tokens, %d audio codes (%.1fs @ 5Hz)",
3418
+ len(gen_tokens), len(audio_codes), len(audio_codes) / 5.0,
3419
+ )
3420
+
3421
+ # Unload LM
3422
+ del outputs, logits, past_kv, prompt_tensor
3423
+ unload_models(lm_model)
3424
+ del lm_model, bpe_tokenizer
3425
+ gc.collect()
3426
+ _clear_gpu_cache(device)
3427
+ logger.info("[LM Generate] LM unloaded")
3428
+
3429
+ if not audio_codes:
3430
+ logger.warning(
3431
+ "[LM Generate] No audio codes generated! The DiT will fall back to "
3432
+ "silence context. Check that the LM checkpoint is correct."
3433
+ )
3434
+
3435
+ return audio_codes
3436
+
3437
+
3438
  def _parse_understand_output(text: str) -> Dict[str, str]:
3439
  """Parse CoT metadata + lyrics from understand LM output.
3440