| |
| |
|
|
| import argparse |
| import json |
| import traceback |
| from pathlib import Path |
|
|
| import datasets |
| import torch |
|
|
| from inference_full import ( |
| TokenLayout, |
| batch_generate_segmentwise, |
| build_mucodec_decoder, |
| generate_segmentwise, |
| load_hf_template_sample_from_music_dataset, |
| save_outputs, |
| ) |
| from runtime_utils import ( |
| load_magel_checkpoint, |
| load_music_dataset, |
| maybe_compile_model, |
| resolve_device, |
| seed_everything, |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Run audio inference on validation samples for multiple checkpoints." |
| ) |
| parser.add_argument( |
| "--checkpoint_list", |
| type=str, |
| default=None, |
| help="Text file with one checkpoint path per line.", |
| ) |
| parser.add_argument( |
| "--checkpoint_dir", |
| type=str, |
| default=None, |
| help="Directory to scan for checkpoint-* subdirectories and optional final.", |
| ) |
| parser.add_argument( |
| "--dataset_path", |
| type=str, |
| default="muse_mucodec_chord.ds", |
| ) |
| parser.add_argument( |
| "--split", |
| type=str, |
| default="validation", |
| ) |
| parser.add_argument( |
| "--tokenizer_path", |
| type=str, |
| default="checkpoints/Qwen3-0.6B", |
| ) |
| parser.add_argument( |
| "--sample_indices", |
| type=int, |
| nargs="*", |
| default=None, |
| help="Specific sample indices to infer. Leave unset to run the full split.", |
| ) |
| parser.add_argument( |
| "--max_samples", |
| type=int, |
| default=0, |
| help="Run only the first N samples from the split. Ignored if --sample_indices is set.", |
| ) |
| parser.add_argument( |
| "--infer_batch_size", |
| type=int, |
| default=1, |
| help="Number of samples to decode together per step for the same checkpoint.", |
| ) |
| parser.add_argument("--temperature", type=float, default=1.0) |
| parser.add_argument("--top_k", type=int, default=50) |
| parser.add_argument("--top_p", type=float, default=0.90) |
| parser.add_argument("--greedy", action="store_true", default=False) |
| parser.add_argument("--max_audio_tokens", type=int, default=0) |
| parser.add_argument("--fps", type=int, default=25) |
| parser.add_argument("--seed", type=int, default=1234) |
| parser.add_argument("--device", type=str, default="auto") |
| parser.add_argument( |
| "--dtype", |
| type=str, |
| default="bfloat16", |
| choices=["float32", "float16", "bfloat16"], |
| ) |
| parser.add_argument( |
| "--attn_implementation", |
| type=str, |
| default="sdpa", |
| choices=["eager", "sdpa", "flash_attention_2"], |
| ) |
| parser.add_argument("--use_cache", action="store_true", default=True) |
| parser.add_argument("--no_cache", action="store_true", default=False) |
| parser.add_argument("--compile", action="store_true", default=False) |
| parser.add_argument( |
| "--compile_mode", |
| type=str, |
| default="reduce-overhead", |
| choices=["default", "reduce-overhead", "max-autotune"], |
| ) |
| parser.add_argument("--mucodec_device", type=str, default="auto") |
| parser.add_argument("--mucodec_layer_num", type=int, default=7) |
| parser.add_argument("--mucodec_duration", type=float, default=40.96) |
| parser.add_argument("--mucodec_guidance_scale", type=float, default=1.5) |
| parser.add_argument("--mucodec_num_steps", type=int, default=20) |
| parser.add_argument("--mucodec_sample_rate", type=int, default=48000) |
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| default="/root/new_batch_predictions", |
| help="Root output dir. Each checkpoint gets its own subdirectory.", |
| ) |
| parser.add_argument( |
| "--summary_json", |
| type=str, |
| default="/root/new_batch_predictions/summary.json", |
| ) |
| args = parser.parse_args() |
| if not args.checkpoint_list and not args.checkpoint_dir: |
| parser.error("one of --checkpoint_list or --checkpoint_dir is required") |
| return args |
|
|
|
|
| def parse_checkpoint_list(path: str) -> list[str]: |
| checkpoints: list[str] = [] |
| with open(path, "r", encoding="utf-8") as f: |
| for raw_line in f: |
| line = raw_line.strip() |
| if not line or line.startswith("#"): |
| continue |
| checkpoints.append(line) |
| if not checkpoints: |
| raise ValueError(f"No checkpoints found in list: {path}") |
| return checkpoints |
|
|
|
|
| def scan_checkpoint_dir(path: str) -> list[str]: |
| root = Path(path) |
| if not root.is_dir(): |
| raise NotADirectoryError(f"Checkpoint directory not found: {path}") |
|
|
| checkpoint_dirs = [ |
| item |
| for item in root.iterdir() |
| if item.is_dir() and item.name.startswith("checkpoint-") |
| ] |
| checkpoint_dirs = sorted( |
| checkpoint_dirs, |
| key=lambda p: int(p.name.split("-", 1)[1]) |
| if p.name.split("-", 1)[1].isdigit() |
| else p.name, |
| ) |
|
|
| final_dir = root / "final" |
| if final_dir.is_dir(): |
| checkpoint_dirs.append(final_dir) |
|
|
| checkpoints = [str(path_obj) for path_obj in checkpoint_dirs] |
| if not checkpoints: |
| raise ValueError(f"No checkpoint-* directories found under: {path}") |
| return checkpoints |
|
|
|
|
| def get_dtype(name: str) -> torch.dtype: |
| return { |
| "float32": torch.float32, |
| "float16": torch.float16, |
| "bfloat16": torch.bfloat16, |
| }[name] |
|
|
|
|
| def get_split_size(dataset_path: str, split: str) -> int: |
| dataset_obj = datasets.load_from_disk(dataset_path) |
| if isinstance(dataset_obj, datasets.DatasetDict): |
| if split not in dataset_obj: |
| raise KeyError(f"Split not found: {split}") |
| return len(dataset_obj[split]) |
| return len(dataset_obj) |
|
|
|
|
| def resolve_sample_indices( |
| dataset_path: str, |
| split: str, |
| sample_indices: list[int] | None, |
| max_samples: int, |
| ) -> list[int]: |
| if sample_indices: |
| return list(sample_indices) |
| split_size = get_split_size(dataset_path, split) |
| if max_samples and max_samples > 0: |
| split_size = min(split_size, max_samples) |
| return list(range(split_size)) |
|
|
|
|
| def sanitize_checkpoint_name(checkpoint_path: str) -> str: |
| path = Path(checkpoint_path.rstrip("/")) |
| if path.parent.name: |
| return f"{path.parent.name}__{path.name}" |
| return path.name |
|
|
|
|
| def chunk_list(items: list[int], chunk_size: int) -> list[list[int]]: |
| return [items[i : i + chunk_size] for i in range(0, len(items), chunk_size)] |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| seed_everything(args.seed) |
|
|
| if args.checkpoint_list: |
| checkpoints = parse_checkpoint_list(args.checkpoint_list) |
| else: |
| checkpoints = scan_checkpoint_dir(args.checkpoint_dir) |
| sample_indices = resolve_sample_indices( |
| dataset_path=args.dataset_path, |
| split=args.split, |
| sample_indices=args.sample_indices, |
| max_samples=args.max_samples, |
| ) |
|
|
| use_cache = args.use_cache and not args.no_cache |
| device = resolve_device(args.device) |
| dtype = get_dtype(args.dtype) |
| if device.type == "cpu" and dtype != torch.float32: |
| print(f"[WARN] dtype {dtype} on CPU may be unsupported; fallback to float32.") |
| dtype = torch.float32 |
|
|
| output_root = Path(args.output_dir) |
| output_root.mkdir(parents=True, exist_ok=True) |
|
|
| print(f"[INFO] checkpoints={len(checkpoints)}") |
| print(f"[INFO] samples_per_checkpoint={len(sample_indices)}") |
| print(f"[INFO] device={device}, dtype={dtype}, use_cache={use_cache}") |
|
|
| mucodec_decoder = build_mucodec_decoder(args) |
| summary: list[dict] = [] |
|
|
| for checkpoint_path in checkpoints: |
| ckpt_name = sanitize_checkpoint_name(checkpoint_path) |
| ckpt_output_dir = output_root / ckpt_name |
| json_dir = ckpt_output_dir / "json" |
| wav_dir = ckpt_output_dir / "wav" |
|
|
| print(f"\n[INFO] loading model from {checkpoint_path}") |
| model = load_magel_checkpoint( |
| checkpoint_path=checkpoint_path, |
| device=device, |
| dtype=dtype, |
| attn_implementation=args.attn_implementation, |
| ) |
| model = maybe_compile_model( |
| model, |
| enabled=bool(args.compile), |
| mode=str(args.compile_mode), |
| ) |
| num_audio_codebook = int(getattr(model.config, "magel_num_audio_token", 16384)) |
| music_ds = load_music_dataset( |
| dataset_path=args.dataset_path, |
| split=args.split, |
| tokenizer_path=args.tokenizer_path, |
| num_audio_token=num_audio_codebook, |
| use_fast=True, |
| ) |
|
|
| checkpoint_record = { |
| "checkpoint_path": checkpoint_path, |
| "checkpoint_name": ckpt_name, |
| "status": "ok", |
| "num_samples_requested": len(sample_indices), |
| "results": [], |
| } |
|
|
| try: |
| for batch_indices in chunk_list(sample_indices, max(1, int(args.infer_batch_size))): |
| samples = [] |
| for sample_idx in batch_indices: |
| print( |
| f"[INFO] checkpoint={ckpt_name} sample_idx={sample_idx} split={args.split}" |
| ) |
| samples.append( |
| load_hf_template_sample_from_music_dataset( |
| music_ds=music_ds, |
| sample_idx=sample_idx, |
| num_audio_codebook=num_audio_codebook, |
| ) |
| ) |
|
|
| layout = TokenLayout( |
| num_text_token=samples[0].num_text_token, |
| num_audio_codebook=num_audio_codebook, |
| ) |
|
|
| if len(samples) == 1: |
| batch_outputs = [ |
| generate_segmentwise( |
| model=model, |
| sample=samples[0], |
| layout=layout, |
| device=device, |
| use_cache=use_cache, |
| temperature=float(args.temperature), |
| top_k=int(args.top_k), |
| top_p=float(args.top_p), |
| greedy=bool(args.greedy), |
| max_audio_tokens=max(0, int(args.max_audio_tokens)), |
| ) |
| ] |
| else: |
| try: |
| batch_outputs = batch_generate_segmentwise( |
| model=model, |
| samples=samples, |
| layout=layout, |
| device=device, |
| use_cache=use_cache, |
| temperature=float(args.temperature), |
| top_k=int(args.top_k), |
| top_p=float(args.top_p), |
| greedy=bool(args.greedy), |
| max_audio_tokens=max(0, int(args.max_audio_tokens)), |
| ) |
| except Exception as exc: |
| print( |
| "[WARN] batch_generate_segmentwise failed; " |
| f"falling back to single-sample decode. error={exc!r}" |
| ) |
| traceback.print_exc() |
| batch_outputs = [ |
| generate_segmentwise( |
| model=model, |
| sample=sample, |
| layout=layout, |
| device=device, |
| use_cache=use_cache, |
| temperature=float(args.temperature), |
| top_k=int(args.top_k), |
| top_p=float(args.top_p), |
| greedy=bool(args.greedy), |
| max_audio_tokens=max(0, int(args.max_audio_tokens)), |
| ) |
| for sample in samples |
| ] |
|
|
| for sample_idx, sample, batch_output in zip(batch_indices, samples, batch_outputs): |
| generated_ids, sampled_count, sampled_chord_ids, sampled_segment_ids = batch_output |
| prefix = f"{sample_idx:05d}_{sample.song_id}" |
|
|
| |
| args.sample_idx = sample_idx |
| args.json_output_dir = str(json_dir) |
| args.wav_output_dir = str(wav_dir) |
|
|
| save_outputs( |
| output_dir=str(ckpt_output_dir), |
| output_prefix=prefix, |
| sample=sample, |
| layout=layout, |
| generated_ids=generated_ids, |
| sampled_chord_ids=sampled_chord_ids, |
| sampled_segment_ids=sampled_segment_ids, |
| args=args, |
| mucodec_decoder=mucodec_decoder, |
| ) |
|
|
| checkpoint_record["results"].append( |
| { |
| "sample_idx": sample_idx, |
| "song_id": sample.song_id, |
| "generated_audio_tokens": sampled_count, |
| "wav_path": str(wav_dir / f"{prefix}.wav"), |
| "json_path": str(json_dir / f"{prefix}.chord_segment.json"), |
| } |
| ) |
| except Exception as exc: |
| checkpoint_record["status"] = "error" |
| checkpoint_record["error"] = str(exc) |
| print(f"[ERROR] checkpoint {checkpoint_path}: {exc!r}") |
| traceback.print_exc() |
|
|
| summary.append(checkpoint_record) |
|
|
| del model |
| if device.type == "cuda": |
| torch.cuda.empty_cache() |
|
|
| summary_path = Path(args.summary_json) |
| summary_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(summary_path, "w", encoding="utf-8") as f: |
| json.dump(summary, f, ensure_ascii=False, indent=2) |
|
|
| print(f"\nSaved summary to: {summary_path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|