| import argparse |
| import json |
| import re |
| from pathlib import Path |
|
|
| import torch |
| from jsonschema import Draft7Validator |
| from peft import AutoPeftModelForCausalLM |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| SYSTEM_PREFIX = ( |
| "You are GravityLLM, a Spatial9 scene generation model. " |
| "Given music constraints and stem features, output ONLY valid Spatial9Scene JSON. " |
| "Do not return markdown. Do not explain your answer.\n\n" |
| ) |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Run GravityLLM inference on a Spatial9 constraint payload.") |
| parser.add_argument("--model_dir", type=str, required=True, help="Path or Hub repo id for trained model or adapter.") |
| parser.add_argument("--input_json", type=Path, required=True) |
| parser.add_argument("--schema_path", type=Path, default=Path("schemas/scene.schema.json")) |
| parser.add_argument("--output_json", type=Path, default=None) |
| parser.add_argument("--max_new_tokens", type=int, default=900) |
| parser.add_argument("--temperature", type=float, default=0.35) |
| parser.add_argument("--top_p", type=float, default=0.9) |
| parser.add_argument("--validate", action="store_true") |
| return parser.parse_args() |
|
|
|
|
| def load_model_and_tokenizer(model_dir: str): |
| tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=True, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| model = None |
| try: |
| model = AutoPeftModelForCausalLM.from_pretrained( |
| model_dir, |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None, |
| device_map="auto" if torch.cuda.is_available() else None, |
| trust_remote_code=True, |
| ) |
| except Exception: |
| model = AutoModelForCausalLM.from_pretrained( |
| model_dir, |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None, |
| device_map="auto" if torch.cuda.is_available() else None, |
| trust_remote_code=True, |
| ) |
| model.eval() |
| return model, tokenizer |
|
|
|
|
| def extract_first_json(text: str) -> str: |
| match = re.search(r"\{.*\}", text, flags=re.DOTALL) |
| return match.group(0).strip() if match else text.strip() |
|
|
|
|
| def validate_output(schema_path: Path, output_text: str): |
| schema = json.loads(schema_path.read_text(encoding="utf-8")) |
| data = json.loads(output_text) |
| validator = Draft7Validator(schema) |
| errors = sorted(validator.iter_errors(data), key=lambda e: list(e.path)) |
| return data, errors |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| payload = json.loads(args.input_json.read_text(encoding="utf-8")) |
|
|
| model, tokenizer = load_model_and_tokenizer(args.model_dir) |
| prompt = SYSTEM_PREFIX + "INPUT:\n" + json.dumps(payload, ensure_ascii=False, indent=2) + "\n\nOUTPUT:\n" |
|
|
| inputs = tokenizer(prompt, return_tensors="pt") |
| if torch.cuda.is_available(): |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} |
|
|
| with torch.no_grad(): |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=args.max_new_tokens, |
| do_sample=True, |
| temperature=args.temperature, |
| top_p=args.top_p, |
| eos_token_id=tokenizer.eos_token_id, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
|
|
| decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| prompt_prefix = tokenizer.decode(inputs["input_ids"][0], skip_special_tokens=True) |
| raw_completion = decoded[len(prompt_prefix):].strip() |
| json_text = extract_first_json(raw_completion) |
|
|
| if args.validate: |
| try: |
| _, errors = validate_output(args.schema_path, json_text) |
| if errors: |
| print("Validation: INVALID") |
| for err in errors[:20]: |
| path = ".".join(str(p) for p in err.path) |
| print(f"- {path}: {err.message}") |
| else: |
| print("Validation: VALID") |
| except Exception as exc: |
| print(f"Validation failed: {exc}") |
|
|
| if args.output_json: |
| args.output_json.parent.mkdir(parents=True, exist_ok=True) |
| args.output_json.write_text(json_text + "\n", encoding="utf-8") |
| print(f"Wrote output to {args.output_json}") |
|
|
| print(json_text) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|