Spaces:
Running
Running
add understand_audio (LM reverse), demucs-infer fix, commit refs, dtype fixes
Browse files- train_engine.py +411 -0
train_engine.py
CHANGED
|
@@ -4,6 +4,11 @@ Standalone ACE-Step LoRA Training Engine (CPU + GPU).
|
|
| 4 |
Ported from Side-Step (koda-dernet/Side-Step) into a single self-contained
|
| 5 |
module. No external Side-Step dependency required.
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
Auto-detects GPU (CUDA > MPS > CPU) and uses it when available,
|
| 8 |
falling back to CPU. bfloat16 is used on GPU; float32 is forced
|
| 9 |
on CPU (bfloat16 deadlocks on CPU -- known PyTorch bug).
|
|
@@ -17,6 +22,7 @@ Exports:
|
|
| 17 |
get_trained_loras() - List saved adapters
|
| 18 |
generate_audio() - Standalone inference (text -> WAV, optional LoRA)
|
| 19 |
tiled_vae_decode() - Tiled VAE latent-to-waveform decode
|
|
|
|
| 20 |
"""
|
| 21 |
|
| 22 |
from __future__ import annotations
|
|
@@ -3086,3 +3092,408 @@ def generate_audio(
|
|
| 3086 |
|
| 3087 |
logger.info("Audio saved to %s (%.1fs @ %d Hz)", output_path, duration, TARGET_SR)
|
| 3088 |
return output_path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
Ported from Side-Step (koda-dernet/Side-Step) into a single self-contained
|
| 5 |
module. No external Side-Step dependency required.
|
| 6 |
|
| 7 |
+
Source commits:
|
| 8 |
+
Side-Step: koda-dernet/Side-Step @ ecd13bd (2026-04-19)
|
| 9 |
+
acestep.cpp: ServeurpersoCom/acestep.cpp @ 36e4db1 (prompt/understand format)
|
| 10 |
+
ACE-Step: ace-step/ACE-Step-1.5 (model architecture, training_v2)
|
| 11 |
+
|
| 12 |
Auto-detects GPU (CUDA > MPS > CPU) and uses it when available,
|
| 13 |
falling back to CPU. bfloat16 is used on GPU; float32 is forced
|
| 14 |
on CPU (bfloat16 deadlocks on CPU -- known PyTorch bug).
|
|
|
|
| 22 |
get_trained_loras() - List saved adapters
|
| 23 |
generate_audio() - Standalone inference (text -> WAV, optional LoRA)
|
| 24 |
tiled_vae_decode() - Tiled VAE latent-to-waveform decode
|
| 25 |
+
understand_audio() - Reverse pipeline (audio -> caption + lyrics)
|
| 26 |
"""
|
| 27 |
|
| 28 |
from __future__ import annotations
|
|
|
|
| 3092 |
|
| 3093 |
logger.info("Audio saved to %s (%.1fs @ %d Hz)", output_path, duration, TARGET_SR)
|
| 3094 |
return output_path
|
| 3095 |
+
|
| 3096 |
+
|
| 3097 |
+
# ============================================================================
|
| 3098 |
+
# UNDERSTAND MODE (reverse pipeline: audio -> caption + lyrics)
|
| 3099 |
+
# ============================================================================
|
| 3100 |
+
|
| 3101 |
+
# Qwen3 special token IDs (ACE-Step LM vocabulary)
|
| 3102 |
+
_TOKEN_IM_START = 151644
|
| 3103 |
+
_TOKEN_IM_END = 151645
|
| 3104 |
+
_TOKEN_THINK = 151667
|
| 3105 |
+
_TOKEN_THINK_END = 151668
|
| 3106 |
+
_AUDIO_CODE_BASE = 151669
|
| 3107 |
+
|
| 3108 |
+
# Understand system instruction (matches acestep.cpp task-types.h)
|
| 3109 |
+
_LM_UNDERSTAND_INSTRUCTION = (
|
| 3110 |
+
"Understand the given musical conditions and describe the audio semantics accordingly:"
|
| 3111 |
+
)
|
| 3112 |
+
|
| 3113 |
+
|
| 3114 |
+
def _build_understand_prompt(
|
| 3115 |
+
bpe_tokenizer, codes: List[int],
|
| 3116 |
+
) -> List[int]:
|
| 3117 |
+
"""Build the Qwen3 chat prompt for understand mode.
|
| 3118 |
+
|
| 3119 |
+
Format (matching C++ build_understand_prompt in prompt.h):
|
| 3120 |
+
<|im_start|>system
|
| 3121 |
+
# Instruction
|
| 3122 |
+
{LM_UNDERSTAND_INSTRUCTION}
|
| 3123 |
+
|
| 3124 |
+
<|im_end|>
|
| 3125 |
+
<|im_start|>user
|
| 3126 |
+
{audio_code_tokens}
|
| 3127 |
+
<|im_end|>
|
| 3128 |
+
<|im_start|>assistant
|
| 3129 |
+
"""
|
| 3130 |
+
ids: List[int] = []
|
| 3131 |
+
|
| 3132 |
+
def append_text(text: str):
|
| 3133 |
+
encoded = bpe_tokenizer.encode(text, add_special_tokens=False)
|
| 3134 |
+
ids.extend(encoded)
|
| 3135 |
+
|
| 3136 |
+
ids.append(_TOKEN_IM_START)
|
| 3137 |
+
append_text(
|
| 3138 |
+
"system\n# Instruction\n"
|
| 3139 |
+
+ _LM_UNDERSTAND_INSTRUCTION
|
| 3140 |
+
+ "\n\n"
|
| 3141 |
+
)
|
| 3142 |
+
ids.append(_TOKEN_IM_END)
|
| 3143 |
+
append_text("\n")
|
| 3144 |
+
ids.append(_TOKEN_IM_START)
|
| 3145 |
+
append_text("user\n")
|
| 3146 |
+
# Audio codes as raw token IDs (not BPE text)
|
| 3147 |
+
for code in codes:
|
| 3148 |
+
ids.append(_AUDIO_CODE_BASE + code)
|
| 3149 |
+
append_text("\n")
|
| 3150 |
+
ids.append(_TOKEN_IM_END)
|
| 3151 |
+
append_text("\n")
|
| 3152 |
+
ids.append(_TOKEN_IM_START)
|
| 3153 |
+
append_text("assistant\n")
|
| 3154 |
+
return ids
|
| 3155 |
+
|
| 3156 |
+
|
| 3157 |
+
def _parse_understand_output(text: str) -> Dict[str, str]:
|
| 3158 |
+
"""Parse CoT metadata + lyrics from understand LM output.
|
| 3159 |
+
|
| 3160 |
+
The LM generates:
|
| 3161 |
+
<think>
|
| 3162 |
+
bpm: 120
|
| 3163 |
+
caption: ...
|
| 3164 |
+
duration: 180
|
| 3165 |
+
keyscale: C major
|
| 3166 |
+
language: en
|
| 3167 |
+
timesignature: 4
|
| 3168 |
+
</think>
|
| 3169 |
+
[Verse 1]
|
| 3170 |
+
...lyrics...
|
| 3171 |
+
|
| 3172 |
+
Returns dict with: caption, lyrics, bpm, key, signature, duration,
|
| 3173 |
+
language.
|
| 3174 |
+
"""
|
| 3175 |
+
result: Dict[str, str] = {}
|
| 3176 |
+
|
| 3177 |
+
# Split at <think> / </think> boundaries
|
| 3178 |
+
cot = ""
|
| 3179 |
+
lyrics_after = ""
|
| 3180 |
+
ts = text.find("<think>")
|
| 3181 |
+
te = text.find("</think>")
|
| 3182 |
+
|
| 3183 |
+
if ts != -1 and te != -1:
|
| 3184 |
+
cot = text[ts + 7:te]
|
| 3185 |
+
lyrics_after = text[te + 8:]
|
| 3186 |
+
elif te != -1:
|
| 3187 |
+
cot = text[:te]
|
| 3188 |
+
lyrics_after = text[te + 8:]
|
| 3189 |
+
else:
|
| 3190 |
+
cot = text
|
| 3191 |
+
|
| 3192 |
+
# Parse YAML-like fields from CoT
|
| 3193 |
+
def get_field(key: str) -> str:
|
| 3194 |
+
needle = key + ":"
|
| 3195 |
+
p = cot.find(needle)
|
| 3196 |
+
if p == -1:
|
| 3197 |
+
return ""
|
| 3198 |
+
p += len(needle)
|
| 3199 |
+
# Skip leading whitespace and quotes
|
| 3200 |
+
while p < len(cot) and cot[p] in (" ", "'"):
|
| 3201 |
+
p += 1
|
| 3202 |
+
end = cot.find("\n", p)
|
| 3203 |
+
if end == -1:
|
| 3204 |
+
end = len(cot)
|
| 3205 |
+
val = cot[p:end].rstrip(" '\r")
|
| 3206 |
+
return val
|
| 3207 |
+
|
| 3208 |
+
bpm_s = get_field("bpm")
|
| 3209 |
+
if bpm_s:
|
| 3210 |
+
result["bpm"] = bpm_s
|
| 3211 |
+
|
| 3212 |
+
dur_s = get_field("duration")
|
| 3213 |
+
if dur_s:
|
| 3214 |
+
result["duration"] = dur_s
|
| 3215 |
+
|
| 3216 |
+
ks = get_field("keyscale")
|
| 3217 |
+
if ks:
|
| 3218 |
+
result["key"] = ks
|
| 3219 |
+
|
| 3220 |
+
ts_s = get_field("timesignature")
|
| 3221 |
+
if ts_s:
|
| 3222 |
+
result["signature"] = ts_s
|
| 3223 |
+
|
| 3224 |
+
lang = get_field("language")
|
| 3225 |
+
if lang:
|
| 3226 |
+
result["language"] = lang
|
| 3227 |
+
|
| 3228 |
+
# Caption may span multiple lines (YAML word-wrap)
|
| 3229 |
+
cap_needle = "caption:"
|
| 3230 |
+
cp = cot.find(cap_needle)
|
| 3231 |
+
if cp != -1:
|
| 3232 |
+
cp += len(cap_needle)
|
| 3233 |
+
# Read until next known field or end of CoT
|
| 3234 |
+
end = len(cot)
|
| 3235 |
+
for next_field in ("duration:", "keyscale:", "language:", "timesignature:", "bpm:"):
|
| 3236 |
+
nf = cot.find("\n" + next_field, cp)
|
| 3237 |
+
if nf != -1 and nf < end:
|
| 3238 |
+
end = nf
|
| 3239 |
+
full_cap = cot[cp:end]
|
| 3240 |
+
# Collapse whitespace
|
| 3241 |
+
cleaned = " ".join(full_cap.split()).strip()
|
| 3242 |
+
if cleaned:
|
| 3243 |
+
result["caption"] = cleaned
|
| 3244 |
+
|
| 3245 |
+
# Lyrics after </think>
|
| 3246 |
+
if lyrics_after:
|
| 3247 |
+
lyrics = lyrics_after.strip()
|
| 3248 |
+
# Strip "# Lyric\n" header the LM may echo back
|
| 3249 |
+
lp = lyrics.find("# Lyric\n")
|
| 3250 |
+
if lp != -1 and lp < 64:
|
| 3251 |
+
lyrics = lyrics[lp + 8:]
|
| 3252 |
+
lyrics = lyrics.strip()
|
| 3253 |
+
if lyrics:
|
| 3254 |
+
result["lyrics"] = lyrics
|
| 3255 |
+
|
| 3256 |
+
return result
|
| 3257 |
+
|
| 3258 |
+
|
| 3259 |
+
def understand_audio(
|
| 3260 |
+
audio_path: str,
|
| 3261 |
+
checkpoint_dir: str,
|
| 3262 |
+
device: str = "auto",
|
| 3263 |
+
variant: str = "turbo",
|
| 3264 |
+
temperature: float = 0.3,
|
| 3265 |
+
top_p: float = 0.0,
|
| 3266 |
+
top_k: int = 0,
|
| 3267 |
+
max_new_tokens: int = 4096,
|
| 3268 |
+
) -> Dict[str, str]:
|
| 3269 |
+
"""Extract caption, lyrics, BPM, key, signature from audio using the LM.
|
| 3270 |
+
|
| 3271 |
+
Pipeline: audio -> VAE encode -> FSQ tokenize -> LM understand -> text
|
| 3272 |
+
Returns dict with: caption, lyrics, bpm, key, signature, duration,
|
| 3273 |
+
language.
|
| 3274 |
+
|
| 3275 |
+
Args:
|
| 3276 |
+
audio_path: Path to input audio file (WAV, MP3, FLAC, etc.)
|
| 3277 |
+
checkpoint_dir: Path to ACE-Step checkpoints root directory
|
| 3278 |
+
(must contain vae/, acestep-v15-turbo/ or variant subdir,
|
| 3279 |
+
and acestep-5Hz-lm-1.7B/).
|
| 3280 |
+
device: Device string ("auto", "cuda:0", "cpu", etc.)
|
| 3281 |
+
variant: DiT variant to load for FSQ tokenizer ("turbo", "sft",
|
| 3282 |
+
"base", etc.)
|
| 3283 |
+
temperature: LM sampling temperature (default 0.3, lower = more
|
| 3284 |
+
deterministic).
|
| 3285 |
+
top_p: Nucleus sampling cutoff (0.0 = disabled).
|
| 3286 |
+
top_k: Top-K sampling (0 = disabled).
|
| 3287 |
+
max_new_tokens: Maximum tokens to generate.
|
| 3288 |
+
|
| 3289 |
+
Returns:
|
| 3290 |
+
Dict with extracted metadata. Keys may include:
|
| 3291 |
+
caption, lyrics, bpm, key, signature, duration, language.
|
| 3292 |
+
"""
|
| 3293 |
+
device = detect_device(device)
|
| 3294 |
+
dtype = select_dtype(device)
|
| 3295 |
+
ckpt = Path(checkpoint_dir).resolve()
|
| 3296 |
+
|
| 3297 |
+
# ------------------------------------------------------------------
|
| 3298 |
+
# Step 1: Load audio -> VAE encode -> latents [1, T_25Hz, 64]
|
| 3299 |
+
# ------------------------------------------------------------------
|
| 3300 |
+
logger.info("[Understand] Step 1: VAE encode")
|
| 3301 |
+
audio, sr = load_audio_stereo(audio_path, TARGET_SR, MAX_AUDIO_DURATION)
|
| 3302 |
+
audio = audio.unsqueeze(0) # [1, 2, samples]
|
| 3303 |
+
logger.info(
|
| 3304 |
+
"[Understand] Audio loaded: %.1fs, %d samples @ %d Hz",
|
| 3305 |
+
audio.shape[-1] / TARGET_SR, audio.shape[-1], TARGET_SR,
|
| 3306 |
+
)
|
| 3307 |
+
|
| 3308 |
+
vae = load_vae(checkpoint_dir, device)
|
| 3309 |
+
latents = tiled_vae_encode(vae, audio, dtype) # [1, T_25Hz, 64]
|
| 3310 |
+
T_25Hz = latents.shape[1]
|
| 3311 |
+
logger.info("[Understand] VAE encoded: %d latent frames (%.2fs)", T_25Hz, T_25Hz * 1920.0 / TARGET_SR)
|
| 3312 |
+
|
| 3313 |
+
unload_models(vae)
|
| 3314 |
+
del vae, audio
|
| 3315 |
+
gc.collect()
|
| 3316 |
+
_clear_gpu_cache(device)
|
| 3317 |
+
logger.info("[Understand] VAE unloaded")
|
| 3318 |
+
|
| 3319 |
+
# ------------------------------------------------------------------
|
| 3320 |
+
# Step 2: Load DiT (for FSQ tokenizer) -> tokenize latents -> codes
|
| 3321 |
+
# ------------------------------------------------------------------
|
| 3322 |
+
logger.info("[Understand] Step 2: FSQ tokenize")
|
| 3323 |
+
|
| 3324 |
+
# Load silence_latent for padding
|
| 3325 |
+
silence_latent = load_silence_latent(checkpoint_dir, device="cpu", variant=variant)
|
| 3326 |
+
|
| 3327 |
+
# Load DiT model (only need its tokenizer submodule)
|
| 3328 |
+
model = load_model_for_training(checkpoint_dir, variant=variant, device=device)
|
| 3329 |
+
model = model.to(dtype=dtype)
|
| 3330 |
+
pool_window = model.config.pool_window_size # 5 (25Hz -> 5Hz)
|
| 3331 |
+
|
| 3332 |
+
# Pad latents to multiple of pool_window_size
|
| 3333 |
+
lat = latents.to(device=device, dtype=dtype)
|
| 3334 |
+
pad_len = 0
|
| 3335 |
+
if T_25Hz % pool_window != 0:
|
| 3336 |
+
pad_len = pool_window - (T_25Hz % pool_window)
|
| 3337 |
+
# Use silence_latent for padding
|
| 3338 |
+
sl = silence_latent[:1, :pad_len, :].to(device=device, dtype=dtype)
|
| 3339 |
+
lat = torch.cat([lat, sl.expand(lat.shape[0], -1, -1)], dim=1)
|
| 3340 |
+
|
| 3341 |
+
# Tokenize: lat [1, T_padded, 64] -> indices [1, T_5Hz, 1]
|
| 3342 |
+
with torch.inference_mode():
|
| 3343 |
+
_quantized, indices = model.tokenizer.tokenize(lat)
|
| 3344 |
+
|
| 3345 |
+
# indices shape: [1, T_5Hz, num_quantizers=1] -> flatten to [T_5Hz]
|
| 3346 |
+
codes = indices.squeeze(0).squeeze(-1).cpu().tolist() # List[int]
|
| 3347 |
+
T_5Hz = len(codes)
|
| 3348 |
+
logger.info(
|
| 3349 |
+
"[Understand] FSQ tokenized: %d codes (%.2fs @ 5Hz)",
|
| 3350 |
+
T_5Hz, T_5Hz / 5.0,
|
| 3351 |
+
)
|
| 3352 |
+
|
| 3353 |
+
unload_models(model)
|
| 3354 |
+
del model, lat, latents, _quantized, indices, silence_latent
|
| 3355 |
+
gc.collect()
|
| 3356 |
+
_clear_gpu_cache(device)
|
| 3357 |
+
logger.info("[Understand] DiT unloaded")
|
| 3358 |
+
|
| 3359 |
+
# ------------------------------------------------------------------
|
| 3360 |
+
# Step 3: Load LM -> build understand prompt -> generate text
|
| 3361 |
+
# ------------------------------------------------------------------
|
| 3362 |
+
logger.info("[Understand] Step 3: LM generation")
|
| 3363 |
+
|
| 3364 |
+
lm_path = ckpt / "acestep-5Hz-lm-1.7B"
|
| 3365 |
+
if not lm_path.is_dir():
|
| 3366 |
+
raise FileNotFoundError(f"LM checkpoint not found: {lm_path}")
|
| 3367 |
+
|
| 3368 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModel
|
| 3369 |
+
|
| 3370 |
+
# Load BPE tokenizer
|
| 3371 |
+
bpe_tokenizer = AutoTokenizer.from_pretrained(str(lm_path))
|
| 3372 |
+
|
| 3373 |
+
# Build understand prompt
|
| 3374 |
+
prompt_ids = _build_understand_prompt(bpe_tokenizer, codes)
|
| 3375 |
+
logger.info(
|
| 3376 |
+
"[Understand] Prompt: %d tokens (%d codes + framing)",
|
| 3377 |
+
len(prompt_ids), len(codes),
|
| 3378 |
+
)
|
| 3379 |
+
|
| 3380 |
+
# Load the LM (Qwen3Model with tied embeddings).
|
| 3381 |
+
# Config says "Qwen3Model" but we need generation (lm_head). Since
|
| 3382 |
+
# tie_word_embeddings=true, Qwen3ForCausalLM will tie the lm_head
|
| 3383 |
+
# to embed_tokens automatically. We override the architecture to load
|
| 3384 |
+
# as CausalLM.
|
| 3385 |
+
from transformers import Qwen3Config
|
| 3386 |
+
lm_config = Qwen3Config.from_pretrained(str(lm_path))
|
| 3387 |
+
lm_config.architectures = ["Qwen3ForCausalLM"]
|
| 3388 |
+
|
| 3389 |
+
lm_dtype = select_dtype(device)
|
| 3390 |
+
lm_model = AutoModelForCausalLM.from_pretrained(
|
| 3391 |
+
str(lm_path),
|
| 3392 |
+
config=lm_config,
|
| 3393 |
+
torch_dtype=lm_dtype,
|
| 3394 |
+
low_cpu_mem_usage=True,
|
| 3395 |
+
)
|
| 3396 |
+
lm_model = lm_model.to(device=device)
|
| 3397 |
+
lm_model.eval()
|
| 3398 |
+
logger.info("[Understand] LM loaded on %s (dtype=%s)", device, lm_dtype)
|
| 3399 |
+
|
| 3400 |
+
vocab_size = lm_config.vocab_size # 217204
|
| 3401 |
+
|
| 3402 |
+
# Autoregressive decode: no CFG, no batch, single sequence.
|
| 3403 |
+
# FSM is not implemented in Python (would require the prefix tree);
|
| 3404 |
+
# the LM generates structured CoT well enough without it at low temp.
|
| 3405 |
+
prompt_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
|
| 3406 |
+
|
| 3407 |
+
with torch.inference_mode():
|
| 3408 |
+
# Prefill
|
| 3409 |
+
outputs = lm_model(input_ids=prompt_tensor, use_cache=True)
|
| 3410 |
+
logits = outputs.logits[:, -1, :] # [1, vocab_size]
|
| 3411 |
+
past_kv = outputs.past_key_values
|
| 3412 |
+
|
| 3413 |
+
gen_tokens: List[int] = []
|
| 3414 |
+
past_think = False
|
| 3415 |
+
|
| 3416 |
+
for step in range(max_new_tokens):
|
| 3417 |
+
# After </think>: block audio codes so the LM only generates text
|
| 3418 |
+
if past_think:
|
| 3419 |
+
logits[0, _AUDIO_CODE_BASE:] = float("-inf")
|
| 3420 |
+
|
| 3421 |
+
# Sample
|
| 3422 |
+
if temperature <= 0:
|
| 3423 |
+
next_id = int(logits[0].argmax().item())
|
| 3424 |
+
else:
|
| 3425 |
+
scaled = logits[0] / temperature
|
| 3426 |
+
if top_k > 0:
|
| 3427 |
+
# Zero out everything below top_k
|
| 3428 |
+
topk_vals, _ = torch.topk(scaled, min(top_k, scaled.shape[0]))
|
| 3429 |
+
scaled[scaled < topk_vals[-1]] = float("-inf")
|
| 3430 |
+
if top_p > 0 and top_p < 1.0:
|
| 3431 |
+
sorted_logits, sorted_idx = torch.sort(scaled, descending=True)
|
| 3432 |
+
probs = torch.softmax(sorted_logits, dim=-1)
|
| 3433 |
+
cumsum = torch.cumsum(probs, dim=-1)
|
| 3434 |
+
mask = cumsum - probs > top_p
|
| 3435 |
+
sorted_logits[mask] = float("-inf")
|
| 3436 |
+
# Scatter masked values back to original positions
|
| 3437 |
+
scaled = torch.zeros_like(scaled).scatter(0, sorted_idx, sorted_logits)
|
| 3438 |
+
probs = torch.softmax(scaled, dim=-1)
|
| 3439 |
+
next_id = int(torch.multinomial(probs, 1).item())
|
| 3440 |
+
|
| 3441 |
+
if next_id == _TOKEN_IM_END:
|
| 3442 |
+
break
|
| 3443 |
+
|
| 3444 |
+
if next_id == _TOKEN_THINK_END:
|
| 3445 |
+
past_think = True
|
| 3446 |
+
|
| 3447 |
+
gen_tokens.append(next_id)
|
| 3448 |
+
|
| 3449 |
+
# Next step
|
| 3450 |
+
next_input = torch.tensor([[next_id]], dtype=torch.long, device=device)
|
| 3451 |
+
with torch.inference_mode():
|
| 3452 |
+
outputs = lm_model(
|
| 3453 |
+
input_ids=next_input,
|
| 3454 |
+
past_key_values=past_kv,
|
| 3455 |
+
use_cache=True,
|
| 3456 |
+
)
|
| 3457 |
+
logits = outputs.logits[:, -1, :]
|
| 3458 |
+
past_kv = outputs.past_key_values
|
| 3459 |
+
|
| 3460 |
+
logger.info("[Understand] Generated %d tokens", len(gen_tokens))
|
| 3461 |
+
|
| 3462 |
+
# Decode tokens to text (skip audio code tokens and special tokens)
|
| 3463 |
+
text_tokens = [
|
| 3464 |
+
t for t in gen_tokens
|
| 3465 |
+
if t < _AUDIO_CODE_BASE and t not in (
|
| 3466 |
+
_TOKEN_IM_START, _TOKEN_IM_END, _TOKEN_THINK, _TOKEN_THINK_END,
|
| 3467 |
+
)
|
| 3468 |
+
]
|
| 3469 |
+
generated_text = bpe_tokenizer.decode(text_tokens, skip_special_tokens=False)
|
| 3470 |
+
|
| 3471 |
+
# Re-insert <think> / </think> markers for the parser
|
| 3472 |
+
think_text = ""
|
| 3473 |
+
in_think = False
|
| 3474 |
+
for t in gen_tokens:
|
| 3475 |
+
if t == _TOKEN_THINK:
|
| 3476 |
+
think_text += "<think>"
|
| 3477 |
+
in_think = True
|
| 3478 |
+
elif t == _TOKEN_THINK_END:
|
| 3479 |
+
think_text += "</think>"
|
| 3480 |
+
in_think = False
|
| 3481 |
+
elif t < _AUDIO_CODE_BASE and t not in (_TOKEN_IM_START, _TOKEN_IM_END):
|
| 3482 |
+
think_text += bpe_tokenizer.decode([t], skip_special_tokens=False)
|
| 3483 |
+
|
| 3484 |
+
logger.info("[Understand] Raw output:\n%s", think_text[:500])
|
| 3485 |
+
|
| 3486 |
+
# Unload LM
|
| 3487 |
+
del outputs, logits, past_kv, prompt_tensor
|
| 3488 |
+
unload_models(lm_model)
|
| 3489 |
+
del lm_model, bpe_tokenizer
|
| 3490 |
+
gc.collect()
|
| 3491 |
+
_clear_gpu_cache(device)
|
| 3492 |
+
logger.info("[Understand] LM unloaded")
|
| 3493 |
+
|
| 3494 |
+
# ------------------------------------------------------------------
|
| 3495 |
+
# Step 4: Parse generated text into structured fields
|
| 3496 |
+
# ------------------------------------------------------------------
|
| 3497 |
+
result = _parse_understand_output(think_text)
|
| 3498 |
+
logger.info("[Understand] Parsed result: %s", {k: v[:80] + "..." if len(v) > 80 else v for k, v in result.items()})
|
| 3499 |
+
return result
|