| import argparse |
| import json |
| import re |
| from pathlib import Path |
| from typing import Dict, Tuple |
|
|
| import torch |
| from datasets import load_dataset |
| from jsonschema import Draft7Validator |
| from peft import AutoPeftModelForCausalLM |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| SYSTEM_PREFIX = ( |
| "You are GravityLLM, a Spatial9 scene generation model. " |
| "Given music constraints and stem features, output ONLY valid Spatial9Scene JSON. " |
| "Do not return markdown. Do not explain your answer.\n\n" |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Evaluate GravityLLM outputs on a JSONL validation set.") |
| parser.add_argument("--model_dir", type=str, required=True) |
| parser.add_argument("--data_file", type=str, default="data/valid.jsonl") |
| parser.add_argument("--schema_path", type=Path, default=Path("schemas/scene.schema.json")) |
| parser.add_argument("--max_new_tokens", type=int, default=900) |
| parser.add_argument("--temperature", type=float, default=0.2) |
| parser.add_argument("--top_p", type=float, default=0.9) |
| parser.add_argument("--limit", type=int, default=0, help="0 means evaluate all rows.") |
| parser.add_argument("--report_path", type=Path, default=Path("reports/eval_report.json")) |
| return parser.parse_args() |
|
|
|
|
| def load_model_and_tokenizer(model_dir: str): |
| tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| try: |
| model = AutoPeftModelForCausalLM.from_pretrained( |
| model_dir, |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None, |
| device_map="auto" if torch.cuda.is_available() else None, |
| trust_remote_code=True, |
| ) |
| except Exception: |
| model = AutoModelForCausalLM.from_pretrained( |
| model_dir, |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None, |
| device_map="auto" if torch.cuda.is_available() else None, |
| trust_remote_code=True, |
| ) |
| model.eval() |
| return model, tokenizer |
|
|
|
|
| def format_prompt(raw_prompt: str) -> str: |
| raw_prompt = raw_prompt.strip() |
| if raw_prompt.lower().startswith("gravityllm:"): |
| raw_prompt = raw_prompt.split(":", 1)[1].strip() |
| return SYSTEM_PREFIX + raw_prompt + "\n\nOUTPUT:\n" |
|
|
|
|
| def extract_first_json(text: str) -> str: |
| match = re.search(r"\{.*\}", text, flags=re.DOTALL) |
| return match.group(0).strip() if match else text.strip() |
|
|
|
|
| def validate_schema(schema, output_text: str) -> Tuple[bool, Dict]: |
| data = json.loads(output_text) |
| validator = Draft7Validator(schema) |
| errors = sorted(validator.iter_errors(data), key=lambda e: list(e.path)) |
| return len(errors) == 0, data |
|
|
|
|
| def check_budget(input_payload: Dict, scene_payload: Dict) -> bool: |
| max_objects = input_payload.get("max_objects") |
| if max_objects is None: |
| return True |
| return len(scene_payload.get("objects", [])) <= max_objects |
|
|
|
|
| def check_anchor_rules(input_payload: Dict, scene_payload: Dict) -> bool: |
| objects = {obj["class"]: obj for obj in scene_payload.get("objects", [])} |
| for rule in input_payload.get("rules", []): |
| if rule.get("type") != "anchor": |
| continue |
| klass = rule.get("track_class") |
| obj = objects.get(klass) |
| if obj is None: |
| return False |
| for field in ["az_deg", "el_deg", "dist_m"]: |
| if float(obj[field]) != float(rule[field]): |
| return False |
| return True |
|
|
|
|
| def generate_scene(model, tokenizer, prompt_text: str, max_new_tokens: int, temperature: float, top_p: float) -> str: |
| inputs = tokenizer(prompt_text, return_tensors="pt") |
| if torch.cuda.is_available(): |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=max_new_tokens, |
| do_sample=True, |
| temperature=temperature, |
| top_p=top_p, |
| eos_token_id=tokenizer.eos_token_id, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
|
|
| decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| prompt_prefix = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) |
| raw_completion = decoded[len(prompt_prefix):].strip() |
| return extract_first_json(raw_completion) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| schema = json.loads(args.schema_path.read_text(encoding="utf-8")) |
| ds = load_dataset("json", data_files=args.data_file, split="train") |
| if args.limit > 0: |
| ds = ds.select(range(min(args.limit, len(ds)))) |
|
|
| model, tokenizer = load_model_and_tokenizer(args.model_dir) |
|
|
| total = len(ds) |
| parse_ok = 0 |
| schema_ok = 0 |
| budget_ok = 0 |
| anchor_ok = 0 |
| samples = [] |
|
|
| for row in ds: |
| prompt_text = format_prompt(row["prompt"]) |
| generated = generate_scene(model, tokenizer, prompt_text, args.max_new_tokens, args.temperature, args.top_p) |
|
|
| sample_report = {"prompt": row["prompt"], "generated": generated} |
| try: |
| gen_data = json.loads(generated) |
| parse_ok += 1 |
| valid, gen_scene = validate_schema(schema, generated) |
| if valid: |
| schema_ok += 1 |
| |
| prompt_payload_text = row["prompt"].split("INPUT:\n", 1)[1] |
| input_payload = json.loads(prompt_payload_text) |
| if check_budget(input_payload, gen_scene): |
| budget_ok += 1 |
| if check_anchor_rules(input_payload, gen_scene): |
| anchor_ok += 1 |
| sample_report["schema_valid"] = True |
| sample_report["budget_pass"] = check_budget(input_payload, gen_scene) |
| sample_report["anchor_pass"] = check_anchor_rules(input_payload, gen_scene) |
| else: |
| sample_report["schema_valid"] = False |
| except Exception as exc: |
| sample_report["error"] = str(exc) |
|
|
| samples.append(sample_report) |
|
|
| report = { |
| "examples": total, |
| "json_parse_rate": round(parse_ok / total, 4) if total else 0.0, |
| "schema_valid_rate": round(schema_ok / total, 4) if total else 0.0, |
| "budget_pass_rate": round(budget_ok / total, 4) if total else 0.0, |
| "anchor_pass_rate": round(anchor_ok / total, 4) if total else 0.0, |
| "samples": samples[:10], |
| } |
|
|
| args.report_path.parent.mkdir(parents=True, exist_ok=True) |
| args.report_path.write_text(json.dumps(report, indent=2), encoding="utf-8") |
| print(json.dumps(report, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|