Spaces:
Sleeping
Sleeping
| """Evaluate base or MLX-adapted Qwen models on the local vulnops environment.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import re | |
| import sys | |
| from pathlib import Path | |
| from typing import Dict, List | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from mlx_lm import generate, load | |
| from mlx_lm.sample_utils import make_sampler | |
| from models import VulnTriageAction | |
| from server.cases import TASK_ORDER | |
| from server.vuln_triage_env_environment import VulnTriageEnvironment | |
| from training_utils import render_prompt | |
| THINK_BLOCK_RE = re.compile(r"<think>.*?</think>", re.DOTALL | re.IGNORECASE) | |
| def extract_last_json_object(text: str) -> str | None: | |
| cleaned = THINK_BLOCK_RE.sub("", text).strip() | |
| start = cleaned.find("{") | |
| if start == -1: | |
| return None | |
| depth = 0 | |
| in_string = False | |
| escape = False | |
| last_candidate = None | |
| candidate_start = None | |
| for index, ch in enumerate(cleaned): | |
| if ch == "\\" and in_string and not escape: | |
| escape = True | |
| continue | |
| if ch == '"' and not escape: | |
| in_string = not in_string | |
| escape = False | |
| if in_string: | |
| continue | |
| if ch == "{": | |
| if depth == 0: | |
| candidate_start = index | |
| depth += 1 | |
| elif ch == "}": | |
| depth -= 1 | |
| if depth == 0 and candidate_start is not None: | |
| last_candidate = cleaned[candidate_start : index + 1] | |
| return last_candidate | |
| def parse_action_output(text: str) -> Dict[str, object] | None: | |
| candidate = extract_last_json_object(text) | |
| if candidate is None: | |
| return None | |
| try: | |
| payload = json.loads(candidate) | |
| action = VulnTriageAction.model_validate(payload) | |
| except Exception: | |
| return None | |
| return action.model_dump(exclude_none=True) | |
| def next_action(model, tokenizer, observation: Dict[str, object]) -> Dict[str, object]: | |
| prompt = render_prompt( | |
| observation=observation, | |
| prompt_variant="Return only the best next action in JSON.", | |
| ) | |
| output = generate( | |
| model, | |
| tokenizer, | |
| prompt=prompt, | |
| verbose=False, | |
| max_tokens=192, | |
| sampler=make_sampler(temp=0.0), | |
| ) | |
| payload = parse_action_output(output) | |
| if payload is None: | |
| return { | |
| "action_type": "submit_triage", | |
| "rationale": f"Fallback because model output could not be parsed: {output[:120]}", | |
| } | |
| return payload | |
| def run_episode(model, tokenizer, task_id: str) -> Dict[str, object]: | |
| env = VulnTriageEnvironment() | |
| observation = env.reset(task_id=task_id).model_dump() | |
| actions: List[Dict[str, object]] = [] | |
| while not observation["done"]: | |
| action_payload = next_action(model, tokenizer, observation) | |
| action = VulnTriageAction.model_validate(action_payload) | |
| actions.append(action.model_dump(exclude_none=True)) | |
| observation = env.step(action).model_dump() | |
| return { | |
| "task_id": task_id, | |
| "difficulty": observation["difficulty"], | |
| "final_score": float(observation.get("final_score") or 0.0), | |
| "score_breakdown": observation["score_breakdown"], | |
| "steps_used": len(actions), | |
| "actions": actions, | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--model", default="Qwen/Qwen3.5-4B") | |
| parser.add_argument("--adapter-path") | |
| parser.add_argument("--output-json") | |
| args = parser.parse_args() | |
| model, tokenizer = load(args.model, adapter_path=args.adapter_path) | |
| episodes = [run_episode(model, tokenizer, task_id) for task_id in TASK_ORDER] | |
| average_score = round(sum(item["final_score"] for item in episodes) / len(episodes), 4) | |
| payload = { | |
| "model": args.model, | |
| "adapter_path": args.adapter_path, | |
| "average_score": average_score, | |
| "episodes": episodes, | |
| } | |
| if args.output_json: | |
| out = Path(args.output_json) | |
| if not out.is_absolute(): | |
| out = (ROOT / out).resolve() | |
| out.parent.mkdir(parents=True, exist_ok=True) | |
| out.write_text(json.dumps(payload, indent=2, sort_keys=True) + "\n", encoding="utf-8") | |
| print(json.dumps(payload, indent=2, sort_keys=True)) | |
| if __name__ == "__main__": | |
| main() | |