Spaces:
Running
Running
fix inference: add LM generation step, detokenize codes before DiT, full pipeline working
Browse files- train_engine.py +302 -21
train_engine.py
CHANGED
|
@@ -2899,18 +2899,25 @@ def generate_audio(
|
|
| 2899 |
device: str = "auto",
|
| 2900 |
adapter_path: Optional[str] = None,
|
| 2901 |
adapter_scale: float = 1.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2902 |
) -> str:
|
| 2903 |
-
"""Generate audio using the ACE-Step
|
| 2904 |
|
| 2905 |
Pipeline:
|
| 2906 |
-
1.
|
| 2907 |
-
|
| 2908 |
-
|
| 2909 |
-
|
| 2910 |
-
|
| 2911 |
-
5.
|
| 2912 |
-
|
| 2913 |
-
|
|
|
|
|
|
|
|
|
|
| 2914 |
|
| 2915 |
Args:
|
| 2916 |
caption: Text description of the desired music.
|
|
@@ -2926,6 +2933,11 @@ def generate_audio(
|
|
| 2926 |
device: ``"auto"``, ``"cpu"``, ``"cuda:0"``, etc.
|
| 2927 |
adapter_path: Path to a PEFT LoRA adapter directory (optional).
|
| 2928 |
adapter_scale: Scaling factor applied to the adapter.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2929 |
|
| 2930 |
Returns:
|
| 2931 |
The *output_path* string (for convenience).
|
|
@@ -2938,8 +2950,8 @@ def generate_audio(
|
|
| 2938 |
device = detect_device(device)
|
| 2939 |
dtype = select_dtype(device)
|
| 2940 |
logger.info(
|
| 2941 |
-
"generate_audio: device=%s, dtype=%s, variant=%s, steps=%d, duration=%.1fs",
|
| 2942 |
-
device, dtype, variant, steps, duration,
|
| 2943 |
)
|
| 2944 |
|
| 2945 |
# Resolve seed
|
|
@@ -2948,7 +2960,34 @@ def generate_audio(
|
|
| 2948 |
logger.info("Using seed=%d", seed)
|
| 2949 |
|
| 2950 |
# ------------------------------------------------------------------
|
| 2951 |
-
# 1.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2952 |
# ------------------------------------------------------------------
|
| 2953 |
logger.info("Loading text encoder...")
|
| 2954 |
tokenizer, text_encoder = load_text_encoder(checkpoint_dir, device)
|
|
@@ -2964,7 +3003,7 @@ def generate_audio(
|
|
| 2964 |
logger.info("Text encoder unloaded.")
|
| 2965 |
|
| 2966 |
# ------------------------------------------------------------------
|
| 2967 |
-
#
|
| 2968 |
# ------------------------------------------------------------------
|
| 2969 |
logger.info("Loading ACE-Step model (%s)...", variant)
|
| 2970 |
model = load_model_for_training(checkpoint_dir, variant=variant, device=device)
|
|
@@ -2972,7 +3011,7 @@ def generate_audio(
|
|
| 2972 |
model.eval()
|
| 2973 |
|
| 2974 |
# ------------------------------------------------------------------
|
| 2975 |
-
#
|
| 2976 |
# ------------------------------------------------------------------
|
| 2977 |
if adapter_path:
|
| 2978 |
logger.info("Loading LoRA adapter from %s (scale=%.2f)...", adapter_path, adapter_scale)
|
|
@@ -3001,7 +3040,7 @@ def generate_audio(
|
|
| 3001 |
logger.info("LoRA adapter applied.")
|
| 3002 |
|
| 3003 |
# ------------------------------------------------------------------
|
| 3004 |
-
#
|
| 3005 |
# ------------------------------------------------------------------
|
| 3006 |
# Latent frame rate is 25 Hz
|
| 3007 |
LATENT_HZ = 25
|
|
@@ -3015,10 +3054,31 @@ def generate_audio(
|
|
| 3015 |
silence_latent = silence_latent.repeat(1, repeats, 1)
|
| 3016 |
silence_latent = silence_latent[:, :latent_length, :].to(device=device, dtype=dtype)
|
| 3017 |
|
| 3018 |
-
# Build source latents and masks
|
| 3019 |
src_latents = silence_latent[:1, :latent_length, :]
|
| 3020 |
chunk_masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype)
|
| 3021 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3022 |
|
| 3023 |
# Dummy timbre reference (single silence frame -> no timbre conditioning)
|
| 3024 |
refer_audio = torch.zeros(1, 1, 64, device=device, dtype=dtype)
|
|
@@ -3028,7 +3088,7 @@ def generate_audio(
|
|
| 3028 |
shift = 3.0 if "turbo" in variant else 1.0
|
| 3029 |
|
| 3030 |
# ------------------------------------------------------------------
|
| 3031 |
-
#
|
| 3032 |
# ------------------------------------------------------------------
|
| 3033 |
logger.info("Running diffusion (%d steps, shift=%.1f)...", steps, shift)
|
| 3034 |
with torch.no_grad():
|
|
@@ -3062,7 +3122,7 @@ def generate_audio(
|
|
| 3062 |
logger.info("DiT model unloaded.")
|
| 3063 |
|
| 3064 |
# ------------------------------------------------------------------
|
| 3065 |
-
#
|
| 3066 |
# ------------------------------------------------------------------
|
| 3067 |
logger.info("Loading VAE decoder...")
|
| 3068 |
vae = load_vae(checkpoint_dir, device)
|
|
@@ -3077,7 +3137,7 @@ def generate_audio(
|
|
| 3077 |
logger.info("VAE unloaded.")
|
| 3078 |
|
| 3079 |
# ------------------------------------------------------------------
|
| 3080 |
-
#
|
| 3081 |
# ------------------------------------------------------------------
|
| 3082 |
audio_np = waveform[0].float().clamp(-1.0, 1.0).cpu().numpy() # [2, samples]
|
| 3083 |
|
|
@@ -3105,7 +3165,10 @@ _TOKEN_THINK = 151667
|
|
| 3105 |
_TOKEN_THINK_END = 151668
|
| 3106 |
_AUDIO_CODE_BASE = 151669
|
| 3107 |
|
| 3108 |
-
#
|
|
|
|
|
|
|
|
|
|
| 3109 |
_LM_UNDERSTAND_INSTRUCTION = (
|
| 3110 |
"Understand the given musical conditions and describe the audio semantics accordingly:"
|
| 3111 |
)
|
|
@@ -3154,6 +3217,224 @@ def _build_understand_prompt(
|
|
| 3154 |
return ids
|
| 3155 |
|
| 3156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3157 |
def _parse_understand_output(text: str) -> Dict[str, str]:
|
| 3158 |
"""Parse CoT metadata + lyrics from understand LM output.
|
| 3159 |
|
|
|
|
| 2899 |
device: str = "auto",
|
| 2900 |
adapter_path: Optional[str] = None,
|
| 2901 |
adapter_scale: float = 1.0,
|
| 2902 |
+
use_lm: bool = True,
|
| 2903 |
+
lm_temperature: float = 0.85,
|
| 2904 |
+
lm_top_p: float = 0.9,
|
| 2905 |
+
lm_top_k: int = 0,
|
| 2906 |
) -> str:
|
| 2907 |
+
"""Generate audio using the full ACE-Step pipeline (LM + DiT).
|
| 2908 |
|
| 2909 |
Pipeline:
|
| 2910 |
+
1. LM (Qwen3 1.7B) generates CoT metadata + audio codes from
|
| 2911 |
+
caption and lyrics
|
| 2912 |
+
2. Text encoder -> text_hidden_states, lyric embeddings
|
| 2913 |
+
3. Load full model (DiT + condition encoder + FSQ)
|
| 2914 |
+
4. Optional: inject LoRA adapter via PEFT
|
| 2915 |
+
5. model.generate_audio() -- uses LM audio codes as context
|
| 2916 |
+
conditioning via the FSQ detokenizer, then runs flow-matching
|
| 2917 |
+
diffusion
|
| 2918 |
+
6. VAE decode latents -> waveform
|
| 2919 |
+
7. Save waveform as 48 kHz stereo WAV
|
| 2920 |
+
8. Unload all models, free memory
|
| 2921 |
|
| 2922 |
Args:
|
| 2923 |
caption: Text description of the desired music.
|
|
|
|
| 2933 |
device: ``"auto"``, ``"cpu"``, ``"cuda:0"``, etc.
|
| 2934 |
adapter_path: Path to a PEFT LoRA adapter directory (optional).
|
| 2935 |
adapter_scale: Scaling factor applied to the adapter.
|
| 2936 |
+
use_lm: Run the LM to generate audio codes (True) or skip
|
| 2937 |
+
and use silence context like before (False).
|
| 2938 |
+
lm_temperature: LM sampling temperature.
|
| 2939 |
+
lm_top_p: LM nucleus sampling cutoff.
|
| 2940 |
+
lm_top_k: LM top-K sampling (0 = disabled).
|
| 2941 |
|
| 2942 |
Returns:
|
| 2943 |
The *output_path* string (for convenience).
|
|
|
|
| 2950 |
device = detect_device(device)
|
| 2951 |
dtype = select_dtype(device)
|
| 2952 |
logger.info(
|
| 2953 |
+
"generate_audio: device=%s, dtype=%s, variant=%s, steps=%d, duration=%.1fs, use_lm=%s",
|
| 2954 |
+
device, dtype, variant, steps, duration, use_lm,
|
| 2955 |
)
|
| 2956 |
|
| 2957 |
# Resolve seed
|
|
|
|
| 2960 |
logger.info("Using seed=%d", seed)
|
| 2961 |
|
| 2962 |
# ------------------------------------------------------------------
|
| 2963 |
+
# 1. LM generation -- produce audio codes from caption + lyrics
|
| 2964 |
+
# ------------------------------------------------------------------
|
| 2965 |
+
audio_codes_list: Optional[List[int]] = None
|
| 2966 |
+
if use_lm:
|
| 2967 |
+
logger.info("Running LM to generate audio codes...")
|
| 2968 |
+
audio_codes_list = _generate_codes_with_lm(
|
| 2969 |
+
checkpoint_dir=checkpoint_dir,
|
| 2970 |
+
caption=caption,
|
| 2971 |
+
lyrics=lyrics,
|
| 2972 |
+
duration=duration,
|
| 2973 |
+
device=device,
|
| 2974 |
+
temperature=lm_temperature,
|
| 2975 |
+
top_p=lm_top_p,
|
| 2976 |
+
top_k=lm_top_k,
|
| 2977 |
+
)
|
| 2978 |
+
if audio_codes_list:
|
| 2979 |
+
# The LM determines the actual duration via its code count
|
| 2980 |
+
lm_duration = len(audio_codes_list) / 5.0
|
| 2981 |
+
logger.info(
|
| 2982 |
+
"LM generated %d codes (%.1fs). Overriding duration %.1f -> %.1f",
|
| 2983 |
+
len(audio_codes_list), lm_duration, duration, lm_duration,
|
| 2984 |
+
)
|
| 2985 |
+
duration = lm_duration
|
| 2986 |
+
else:
|
| 2987 |
+
logger.warning("LM produced no codes, falling back to silence context.")
|
| 2988 |
+
|
| 2989 |
+
# ------------------------------------------------------------------
|
| 2990 |
+
# 2. Text encoder -- encode caption and lyrics
|
| 2991 |
# ------------------------------------------------------------------
|
| 2992 |
logger.info("Loading text encoder...")
|
| 2993 |
tokenizer, text_encoder = load_text_encoder(checkpoint_dir, device)
|
|
|
|
| 3003 |
logger.info("Text encoder unloaded.")
|
| 3004 |
|
| 3005 |
# ------------------------------------------------------------------
|
| 3006 |
+
# 3. Load full model (DiT + CondEncoder + FSQ tokenizer/detokenizer)
|
| 3007 |
# ------------------------------------------------------------------
|
| 3008 |
logger.info("Loading ACE-Step model (%s)...", variant)
|
| 3009 |
model = load_model_for_training(checkpoint_dir, variant=variant, device=device)
|
|
|
|
| 3011 |
model.eval()
|
| 3012 |
|
| 3013 |
# ------------------------------------------------------------------
|
| 3014 |
+
# 4. Optional: inject LoRA adapter
|
| 3015 |
# ------------------------------------------------------------------
|
| 3016 |
if adapter_path:
|
| 3017 |
logger.info("Loading LoRA adapter from %s (scale=%.2f)...", adapter_path, adapter_scale)
|
|
|
|
| 3040 |
logger.info("LoRA adapter applied.")
|
| 3041 |
|
| 3042 |
# ------------------------------------------------------------------
|
| 3043 |
+
# 5. Prepare inputs for model.generate_audio()
|
| 3044 |
# ------------------------------------------------------------------
|
| 3045 |
# Latent frame rate is 25 Hz
|
| 3046 |
LATENT_HZ = 25
|
|
|
|
| 3054 |
silence_latent = silence_latent.repeat(1, repeats, 1)
|
| 3055 |
silence_latent = silence_latent[:, :latent_length, :].to(device=device, dtype=dtype)
|
| 3056 |
|
| 3057 |
+
# Build source latents and masks
|
| 3058 |
src_latents = silence_latent[:1, :latent_length, :]
|
| 3059 |
chunk_masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype)
|
| 3060 |
+
|
| 3061 |
+
# Detokenize LM audio codes into context latents for the DiT
|
| 3062 |
+
if audio_codes_list:
|
| 3063 |
+
indices_tensor = torch.tensor(
|
| 3064 |
+
audio_codes_list, dtype=torch.long, device=device,
|
| 3065 |
+
).unsqueeze(0).unsqueeze(-1) # [1, T_5Hz, 1]
|
| 3066 |
+
with torch.no_grad():
|
| 3067 |
+
lm_latents = model.tokenizer.quantizer.get_output_from_indices(indices_tensor)
|
| 3068 |
+
# lm_latents: [1, T_5Hz, codebook_dim] -> detokenize to [1, T_25Hz, 64]
|
| 3069 |
+
lm_latents = model.detokenize(lm_latents)
|
| 3070 |
+
T_lm = lm_latents.shape[1]
|
| 3071 |
+
# Use LM latents as src_latents context
|
| 3072 |
+
if T_lm < latent_length:
|
| 3073 |
+
pad = silence_latent[:, :latent_length - T_lm, :]
|
| 3074 |
+
src_latents = torch.cat([lm_latents, pad], dim=1)
|
| 3075 |
+
else:
|
| 3076 |
+
src_latents = lm_latents[:, :latent_length, :]
|
| 3077 |
+
chunk_masks = torch.ones(1, latent_length, 64, device=device, dtype=dtype)
|
| 3078 |
+
is_covers = torch.ones(1, device=device, dtype=torch.long)
|
| 3079 |
+
logger.info("LM codes detokenized: %d codes -> %d latent frames, used as DiT context", len(audio_codes_list), T_lm)
|
| 3080 |
+
else:
|
| 3081 |
+
is_covers = torch.zeros(1, device=device, dtype=torch.long)
|
| 3082 |
|
| 3083 |
# Dummy timbre reference (single silence frame -> no timbre conditioning)
|
| 3084 |
refer_audio = torch.zeros(1, 1, 64, device=device, dtype=dtype)
|
|
|
|
| 3088 |
shift = 3.0 if "turbo" in variant else 1.0
|
| 3089 |
|
| 3090 |
# ------------------------------------------------------------------
|
| 3091 |
+
# 6. Run diffusion (model.generate_audio handles everything internally)
|
| 3092 |
# ------------------------------------------------------------------
|
| 3093 |
logger.info("Running diffusion (%d steps, shift=%.1f)...", steps, shift)
|
| 3094 |
with torch.no_grad():
|
|
|
|
| 3122 |
logger.info("DiT model unloaded.")
|
| 3123 |
|
| 3124 |
# ------------------------------------------------------------------
|
| 3125 |
+
# 7. VAE decode latents -> waveform
|
| 3126 |
# ------------------------------------------------------------------
|
| 3127 |
logger.info("Loading VAE decoder...")
|
| 3128 |
vae = load_vae(checkpoint_dir, device)
|
|
|
|
| 3137 |
logger.info("VAE unloaded.")
|
| 3138 |
|
| 3139 |
# ------------------------------------------------------------------
|
| 3140 |
+
# 8. Save as WAV (48 kHz stereo)
|
| 3141 |
# ------------------------------------------------------------------
|
| 3142 |
audio_np = waveform[0].float().clamp(-1.0, 1.0).cpu().numpy() # [2, samples]
|
| 3143 |
|
|
|
|
| 3165 |
_TOKEN_THINK_END = 151668
|
| 3166 |
_AUDIO_CODE_BASE = 151669
|
| 3167 |
|
| 3168 |
+
# LM system instructions (matches acestep.cpp task-types.h)
|
| 3169 |
+
_LM_GENERATE_INSTRUCTION = (
|
| 3170 |
+
"Generate audio semantic tokens based on the given conditions:"
|
| 3171 |
+
)
|
| 3172 |
_LM_UNDERSTAND_INSTRUCTION = (
|
| 3173 |
"Understand the given musical conditions and describe the audio semantics accordingly:"
|
| 3174 |
)
|
|
|
|
| 3217 |
return ids
|
| 3218 |
|
| 3219 |
|
| 3220 |
+
def _build_generate_prompt(
|
| 3221 |
+
bpe_tokenizer, caption: str, lyrics: str,
|
| 3222 |
+
) -> List[int]:
|
| 3223 |
+
"""Build the Qwen3 chat prompt for audio code generation.
|
| 3224 |
+
|
| 3225 |
+
Format (matching C++ build_lm_prompt in prompt.h):
|
| 3226 |
+
<|im_start|>system
|
| 3227 |
+
# Instruction
|
| 3228 |
+
{LM_GENERATE_INSTRUCTION}
|
| 3229 |
+
|
| 3230 |
+
<|im_end|>
|
| 3231 |
+
<|im_start|>user
|
| 3232 |
+
# Caption
|
| 3233 |
+
{caption}
|
| 3234 |
+
|
| 3235 |
+
# Lyric
|
| 3236 |
+
{lyrics}
|
| 3237 |
+
<|im_end|>
|
| 3238 |
+
<|im_start|>assistant
|
| 3239 |
+
"""
|
| 3240 |
+
ids: List[int] = []
|
| 3241 |
+
|
| 3242 |
+
def append_text(text: str):
|
| 3243 |
+
encoded = bpe_tokenizer.encode(text, add_special_tokens=False)
|
| 3244 |
+
ids.extend(encoded)
|
| 3245 |
+
|
| 3246 |
+
ids.append(_TOKEN_IM_START)
|
| 3247 |
+
append_text(
|
| 3248 |
+
"system\n# Instruction\n"
|
| 3249 |
+
+ _LM_GENERATE_INSTRUCTION
|
| 3250 |
+
+ "\n\n"
|
| 3251 |
+
)
|
| 3252 |
+
ids.append(_TOKEN_IM_END)
|
| 3253 |
+
append_text("\n")
|
| 3254 |
+
ids.append(_TOKEN_IM_START)
|
| 3255 |
+
append_text(
|
| 3256 |
+
"user\n# Caption\n" + caption + "\n\n"
|
| 3257 |
+
"# Lyric\n" + lyrics + "\n"
|
| 3258 |
+
)
|
| 3259 |
+
ids.append(_TOKEN_IM_END)
|
| 3260 |
+
append_text("\n")
|
| 3261 |
+
ids.append(_TOKEN_IM_START)
|
| 3262 |
+
append_text("assistant\n")
|
| 3263 |
+
return ids
|
| 3264 |
+
|
| 3265 |
+
|
| 3266 |
+
def _generate_codes_with_lm(
|
| 3267 |
+
checkpoint_dir: str,
|
| 3268 |
+
caption: str,
|
| 3269 |
+
lyrics: str,
|
| 3270 |
+
duration: float,
|
| 3271 |
+
device: str,
|
| 3272 |
+
temperature: float = 0.85,
|
| 3273 |
+
top_p: float = 0.9,
|
| 3274 |
+
top_k: int = 0,
|
| 3275 |
+
max_new_tokens: int = 8192,
|
| 3276 |
+
) -> List[int]:
|
| 3277 |
+
"""Run the ACE-Step LM (Qwen3 1.7B) to generate audio codes from text.
|
| 3278 |
+
|
| 3279 |
+
The LM generates in two phases within a single autoregressive pass:
|
| 3280 |
+
Phase 1 (CoT): <think> metadata YAML (bpm, duration, key, etc.) </think>
|
| 3281 |
+
Phase 2 (codes): audio code tokens (token_id >= AUDIO_CODE_BASE)
|
| 3282 |
+
|
| 3283 |
+
Args:
|
| 3284 |
+
checkpoint_dir: Root directory containing acestep-5Hz-lm-1.7B/.
|
| 3285 |
+
caption: Text description of the music.
|
| 3286 |
+
lyrics: Lyrics text or "[Instrumental]".
|
| 3287 |
+
duration: Target duration in seconds (the LM may override via CoT).
|
| 3288 |
+
device: Torch device string.
|
| 3289 |
+
temperature: Sampling temperature.
|
| 3290 |
+
top_p: Nucleus sampling cutoff (0.0 = disabled).
|
| 3291 |
+
top_k: Top-K sampling (0 = disabled).
|
| 3292 |
+
max_new_tokens: Maximum tokens to generate.
|
| 3293 |
+
|
| 3294 |
+
Returns:
|
| 3295 |
+
List of FSQ code indices (0-63999 range, NOT offset by AUDIO_CODE_BASE).
|
| 3296 |
+
Length is approximately duration * 5 (5 Hz token rate).
|
| 3297 |
+
"""
|
| 3298 |
+
ckpt = Path(checkpoint_dir).resolve()
|
| 3299 |
+
lm_path = ckpt / "acestep-5Hz-lm-1.7B"
|
| 3300 |
+
if not lm_path.is_dir():
|
| 3301 |
+
raise FileNotFoundError(f"LM checkpoint not found: {lm_path}")
|
| 3302 |
+
|
| 3303 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 3304 |
+
|
| 3305 |
+
# Load BPE tokenizer
|
| 3306 |
+
bpe_tokenizer = AutoTokenizer.from_pretrained(str(lm_path))
|
| 3307 |
+
|
| 3308 |
+
# Build generation prompt
|
| 3309 |
+
prompt_ids = _build_generate_prompt(bpe_tokenizer, caption, lyrics)
|
| 3310 |
+
logger.info(
|
| 3311 |
+
"[LM Generate] Prompt: %d tokens, caption=%r, lyrics=%r",
|
| 3312 |
+
len(prompt_ids), caption[:80], lyrics[:80],
|
| 3313 |
+
)
|
| 3314 |
+
|
| 3315 |
+
# Load the LM (Qwen3Model with tied embeddings -> CausalLM)
|
| 3316 |
+
from transformers import Qwen3Config
|
| 3317 |
+
lm_config = Qwen3Config.from_pretrained(str(lm_path))
|
| 3318 |
+
lm_config.architectures = ["Qwen3ForCausalLM"]
|
| 3319 |
+
|
| 3320 |
+
lm_dtype = select_dtype(device)
|
| 3321 |
+
lm_model = AutoModelForCausalLM.from_pretrained(
|
| 3322 |
+
str(lm_path),
|
| 3323 |
+
config=lm_config,
|
| 3324 |
+
torch_dtype=lm_dtype,
|
| 3325 |
+
low_cpu_mem_usage=True,
|
| 3326 |
+
)
|
| 3327 |
+
lm_model = lm_model.to(device=device)
|
| 3328 |
+
lm_model.eval()
|
| 3329 |
+
logger.info("[LM Generate] LM loaded on %s (dtype=%s)", device, lm_dtype)
|
| 3330 |
+
|
| 3331 |
+
# Autoregressive decode: single sequence, no CFG.
|
| 3332 |
+
prompt_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
|
| 3333 |
+
|
| 3334 |
+
with torch.inference_mode():
|
| 3335 |
+
outputs = lm_model(input_ids=prompt_tensor, use_cache=True)
|
| 3336 |
+
logits = outputs.logits[:, -1, :] # [1, vocab_size]
|
| 3337 |
+
past_kv = outputs.past_key_values
|
| 3338 |
+
|
| 3339 |
+
gen_tokens: List[int] = []
|
| 3340 |
+
audio_codes: List[int] = []
|
| 3341 |
+
past_think = False
|
| 3342 |
+
in_think = False
|
| 3343 |
+
|
| 3344 |
+
for step in range(max_new_tokens):
|
| 3345 |
+
logits = logits.clone()
|
| 3346 |
+
|
| 3347 |
+
# Phase 1 (inside <think>): block audio codes so only text is generated
|
| 3348 |
+
if in_think:
|
| 3349 |
+
logits[0, _AUDIO_CODE_BASE:] = float("-inf")
|
| 3350 |
+
|
| 3351 |
+
# Phase 2 (after </think>): only allow audio codes + im_end
|
| 3352 |
+
if past_think:
|
| 3353 |
+
# Zero out all non-audio-code logits except im_end (stop token)
|
| 3354 |
+
mask = torch.full_like(logits[0], float("-inf"))
|
| 3355 |
+
mask[_AUDIO_CODE_BASE:] = 0.0
|
| 3356 |
+
mask[_TOKEN_IM_END] = 0.0
|
| 3357 |
+
logits[0] = logits[0] + mask
|
| 3358 |
+
|
| 3359 |
+
# Sample
|
| 3360 |
+
if temperature <= 0:
|
| 3361 |
+
next_id = int(logits[0].argmax().item())
|
| 3362 |
+
else:
|
| 3363 |
+
scaled = logits[0].clone() / temperature
|
| 3364 |
+
if top_k > 0:
|
| 3365 |
+
topk_vals, _ = torch.topk(scaled, min(top_k, scaled.shape[0]))
|
| 3366 |
+
scaled[scaled < topk_vals[-1]] = float("-inf")
|
| 3367 |
+
if top_p > 0 and top_p < 1.0:
|
| 3368 |
+
sorted_logits, sorted_idx = torch.sort(scaled, descending=True)
|
| 3369 |
+
probs = torch.softmax(sorted_logits, dim=-1)
|
| 3370 |
+
cumsum = torch.cumsum(probs, dim=-1)
|
| 3371 |
+
nucleus_mask = cumsum - probs > top_p
|
| 3372 |
+
sorted_logits[nucleus_mask] = float("-inf")
|
| 3373 |
+
scaled = torch.zeros_like(scaled).scatter(0, sorted_idx, sorted_logits)
|
| 3374 |
+
probs = torch.softmax(scaled, dim=-1)
|
| 3375 |
+
next_id = int(torch.multinomial(probs, 1).item())
|
| 3376 |
+
|
| 3377 |
+
# Stop on im_end
|
| 3378 |
+
if next_id == _TOKEN_IM_END:
|
| 3379 |
+
break
|
| 3380 |
+
|
| 3381 |
+
# Track think state transitions
|
| 3382 |
+
if next_id == _TOKEN_THINK:
|
| 3383 |
+
in_think = True
|
| 3384 |
+
elif next_id == _TOKEN_THINK_END:
|
| 3385 |
+
in_think = False
|
| 3386 |
+
past_think = True
|
| 3387 |
+
|
| 3388 |
+
gen_tokens.append(next_id)
|
| 3389 |
+
|
| 3390 |
+
# Collect audio codes (Phase 2 tokens)
|
| 3391 |
+
if next_id >= _AUDIO_CODE_BASE:
|
| 3392 |
+
audio_codes.append(next_id - _AUDIO_CODE_BASE)
|
| 3393 |
+
|
| 3394 |
+
# Next step
|
| 3395 |
+
next_input = torch.tensor([[next_id]], dtype=torch.long, device=device)
|
| 3396 |
+
with torch.inference_mode():
|
| 3397 |
+
outputs = lm_model(
|
| 3398 |
+
input_ids=next_input,
|
| 3399 |
+
past_key_values=past_kv,
|
| 3400 |
+
use_cache=True,
|
| 3401 |
+
)
|
| 3402 |
+
logits = outputs.logits[:, -1, :]
|
| 3403 |
+
past_kv = outputs.past_key_values
|
| 3404 |
+
|
| 3405 |
+
# Log what the LM generated
|
| 3406 |
+
cot_tokens = [
|
| 3407 |
+
t for t in gen_tokens
|
| 3408 |
+
if t < _AUDIO_CODE_BASE and t not in (
|
| 3409 |
+
_TOKEN_IM_START, _TOKEN_IM_END, _TOKEN_THINK, _TOKEN_THINK_END,
|
| 3410 |
+
)
|
| 3411 |
+
]
|
| 3412 |
+
if cot_tokens:
|
| 3413 |
+
cot_text = bpe_tokenizer.decode(cot_tokens, skip_special_tokens=False)
|
| 3414 |
+
logger.info("[LM Generate] CoT output:\n%s", cot_text[:500])
|
| 3415 |
+
|
| 3416 |
+
logger.info(
|
| 3417 |
+
"[LM Generate] Generated %d total tokens, %d audio codes (%.1fs @ 5Hz)",
|
| 3418 |
+
len(gen_tokens), len(audio_codes), len(audio_codes) / 5.0,
|
| 3419 |
+
)
|
| 3420 |
+
|
| 3421 |
+
# Unload LM
|
| 3422 |
+
del outputs, logits, past_kv, prompt_tensor
|
| 3423 |
+
unload_models(lm_model)
|
| 3424 |
+
del lm_model, bpe_tokenizer
|
| 3425 |
+
gc.collect()
|
| 3426 |
+
_clear_gpu_cache(device)
|
| 3427 |
+
logger.info("[LM Generate] LM unloaded")
|
| 3428 |
+
|
| 3429 |
+
if not audio_codes:
|
| 3430 |
+
logger.warning(
|
| 3431 |
+
"[LM Generate] No audio codes generated! The DiT will fall back to "
|
| 3432 |
+
"silence context. Check that the LM checkpoint is correct."
|
| 3433 |
+
)
|
| 3434 |
+
|
| 3435 |
+
return audio_codes
|
| 3436 |
+
|
| 3437 |
+
|
| 3438 |
def _parse_understand_output(text: str) -> Dict[str, str]:
|
| 3439 |
"""Parse CoT metadata + lyrics from understand LM output.
|
| 3440 |
|