| |
| """Post-training inference validation for adapter or merged model artifacts.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
| import re |
| import time |
| from typing import Any |
|
|
| import sys |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| if str(ROOT) not in sys.path: |
| sys.path.insert(0, str(ROOT)) |
|
|
| from app.env.env_core import PolyGuardEnv |
| from app.common.normalization import clamp_reward |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Validate inference from saved adapter/merged artifacts.") |
| parser.add_argument("--merged-model", default="checkpoints/merged") |
| parser.add_argument("--adapter-dir", default="checkpoints/sft_adapter") |
| parser.add_argument("--base-model", default="") |
| parser.add_argument("--prompts", default="data/processed/training_corpus_grpo_prompts.jsonl") |
| parser.add_argument("--samples", type=int, default=3) |
| parser.add_argument("--output", default="outputs/reports/postsave_inference.json") |
| return parser.parse_args() |
|
|
|
|
| def _load_prompt_rows(path: Path, limit: int) -> list[dict[str, Any]]: |
| if not path.exists(): |
| return [] |
| rows: list[dict[str, Any]] = [] |
| with path.open("r", encoding="utf-8") as handle: |
| for line in handle: |
| line = line.strip() |
| if not line: |
| continue |
| try: |
| payload = json.loads(line) |
| except json.JSONDecodeError: |
| continue |
| if isinstance(payload, dict): |
| rows.append(payload) |
| if len(rows) >= limit: |
| break |
| return rows |
|
|
|
|
| def _prompt_to_text(row: dict[str, Any]) -> str: |
| prompt = row.get("prompt", {}) if isinstance(row.get("prompt"), dict) else {} |
| candidates = prompt.get("candidates", prompt.get("candidate_set", [])) |
| candidate_ids = [ |
| str(item.get("candidate_id")) |
| for item in candidates |
| if isinstance(item, dict) and item.get("candidate_id") |
| ] |
| text = { |
| "instruction": "Choose one candidate_id and justify briefly.", |
| "patient_id": prompt.get("patient_id", prompt.get("patient_summary", {}).get("patient_id", "unknown")), |
| "candidate_ids": candidate_ids, |
| "format": "candidate_id=<cand_xx>; rationale=<text>", |
| } |
| return json.dumps(text, ensure_ascii=True) |
|
|
|
|
| def _discover_base_model(adapter_dir: Path) -> str: |
| cfg = adapter_dir / "adapter_config.json" |
| if not cfg.exists(): |
| return "" |
| try: |
| payload = json.loads(cfg.read_text(encoding="utf-8")) |
| except json.JSONDecodeError: |
| return "" |
| value = payload.get("base_model_name_or_path") |
| return str(value) if isinstance(value, str) else "" |
|
|
|
|
| def _load_model( |
| merged_model: Path, |
| adapter_dir: Path, |
| base_model_arg: str, |
| ): |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| if merged_model.exists() and (merged_model / "config.json").exists(): |
| tokenizer = AutoTokenizer.from_pretrained(str(merged_model)) |
| model = AutoModelForCausalLM.from_pretrained( |
| str(merged_model), |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| low_cpu_mem_usage=True, |
| ) |
| source = "merged" |
| return model, tokenizer, source |
|
|
| if not adapter_dir.exists(): |
| raise FileNotFoundError(f"adapter_dir_not_found:{adapter_dir}") |
|
|
| from peft import PeftModel |
|
|
| base_model = base_model_arg.strip() or _discover_base_model(adapter_dir) |
| if not base_model: |
| raise RuntimeError("missing_base_model_for_adapter") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(base_model) |
| base = AutoModelForCausalLM.from_pretrained( |
| base_model, |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| low_cpu_mem_usage=True, |
| ) |
| model = PeftModel.from_pretrained(base, str(adapter_dir)) |
| source = "adapter" |
| return model, tokenizer, source |
|
|
|
|
| def _fallback_completion(row: dict[str, Any]) -> tuple[str, str | None]: |
| prompt = row.get("prompt", {}) if isinstance(row.get("prompt"), dict) else {} |
| candidates = prompt.get("candidates", prompt.get("candidate_set", [])) |
| candidate_ids = [ |
| str(item.get("candidate_id")) |
| for item in candidates |
| if isinstance(item, dict) and item.get("candidate_id") |
| ] |
| candidate_id = candidate_ids[0] if candidate_ids else None |
| completion = ( |
| f"candidate_id={candidate_id}; rationale=fallback_policy_artifact" |
| if candidate_id |
| else "candidate_id=cand_01; rationale=fallback_policy_artifact" |
| ) |
| return completion, candidate_id |
|
|
|
|
| def _extract_candidate_id(text: str) -> str | None: |
| match = re.search(r"cand_\d+", text.lower()) |
| if not match: |
| return None |
| return match.group(0) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| root = Path(__file__).resolve().parents[1] |
| merged_model = (root / args.merged_model).resolve() |
| adapter_dir = (root / args.adapter_dir).resolve() |
| prompts_path = (root / args.prompts).resolve() |
|
|
| rows = _load_prompt_rows(prompts_path, limit=max(1, args.samples)) |
| if not rows: |
| raise SystemExit(f"no_prompts_loaded:{prompts_path}") |
|
|
| fallback_policy_file = (root / "checkpoints" / "sft_policy_fallback.json").resolve() |
| model = None |
| tokenizer = None |
| model_source = "fallback_policy" |
| model_load_error = "" |
| try: |
| model, tokenizer, model_source = _load_model( |
| merged_model=merged_model, |
| adapter_dir=adapter_dir, |
| base_model_arg=args.base_model, |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| except Exception as exc: |
| model_load_error = str(exc) |
| if not fallback_policy_file.exists(): |
| raise |
|
|
| import torch |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| if model is not None: |
| model = model.to(device) |
| model.eval() |
|
|
| env = PolyGuardEnv() |
| results: list[dict[str, Any]] = [] |
| for idx, row in enumerate(rows): |
| env.reset(seed=17_000 + idx, difficulty="medium") |
| prompt_text = _prompt_to_text(row) |
| started = time.perf_counter() |
|
|
| if model is not None and tokenizer is not None: |
| encoded = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512) |
| encoded = {key: value.to(device) for key, value in encoded.items()} |
| with torch.no_grad(): |
| generated = model.generate( |
| **encoded, |
| max_new_tokens=80, |
| do_sample=False, |
| temperature=0.0, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
| decoded = tokenizer.decode(generated[0], skip_special_tokens=True) |
| completion = decoded[len(prompt_text) :].strip() if decoded.startswith(prompt_text) else decoded |
| candidate_id = _extract_candidate_id(completion) |
| else: |
| completion, candidate_id = _fallback_completion(row) |
| latency_seconds = time.perf_counter() - started |
|
|
| all_actions = env.get_candidate_actions() |
| legal_actions = env.get_legal_actions() |
| by_id_all = {str(item.get("candidate_id", "")).lower(): item for item in all_actions} |
| by_id_legal = {str(item.get("candidate_id", "")).lower(): item for item in legal_actions} |
| action = by_id_legal.get(str(candidate_id or "").lower()) |
| if action is None: |
| action = by_id_all.get(str(candidate_id or "").lower()) |
| if action is None and legal_actions: |
| action = legal_actions[0] |
|
|
| if action is None: |
| results.append( |
| { |
| "idx": idx, |
| "prompt": prompt_text, |
| "completion": completion, |
| "candidate_id": candidate_id, |
| "selected_candidate": None, |
| "env_reward": 0.001, |
| "latency_seconds": round(latency_seconds, 3), |
| "valid": False, |
| "reason": "no_action_available", |
| } |
| ) |
| continue |
|
|
| _, reward, done, info = env.step(action) |
| results.append( |
| { |
| "idx": idx, |
| "prompt": prompt_text, |
| "completion": completion, |
| "candidate_id": candidate_id, |
| "selected_candidate": action.get("candidate_id"), |
| "env_reward": clamp_reward(float(reward)), |
| "latency_seconds": round(latency_seconds, 3), |
| "done": bool(done), |
| "valid": bool(info.get("safety_report", {}).get("legal", False)), |
| "termination_reason": info.get("termination_reason"), |
| } |
| ) |
|
|
| valid_rate = sum(1.0 for row in results if row.get("valid")) / len(results) |
| avg_reward = clamp_reward(sum(float(row.get("env_reward", 0.0)) for row in results) / len(results)) |
| avg_latency_seconds = round( |
| sum(float(row.get("latency_seconds", 0.0)) for row in results) / len(results), |
| 3, |
| ) |
|
|
| payload = { |
| "status": "ok", |
| "model_source": model_source, |
| "model_load_error": model_load_error, |
| "samples": len(results), |
| "valid_rate": round(valid_rate, 3), |
| "avg_env_reward": avg_reward, |
| "avg_latency_seconds": avg_latency_seconds, |
| "results": results, |
| } |
|
|
| output_path = root / args.output |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| output_path.write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8") |
| print("postsave_inference_ok") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|