polyguard-openenv-workbench / polyguard-rl /scripts /test_inference_postsave.py
TheJackBright's picture
Deploy GitHub root master to Space
c296d62
#!/usr/bin/env python3
"""Post-training inference validation for adapter or merged model artifacts."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import re
import time
from typing import Any
import sys
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app.env.env_core import PolyGuardEnv
from app.common.normalization import clamp_reward
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Validate inference from saved adapter/merged artifacts.")
parser.add_argument("--merged-model", default="checkpoints/merged")
parser.add_argument("--adapter-dir", default="checkpoints/sft_adapter")
parser.add_argument("--base-model", default="")
parser.add_argument("--prompts", default="data/processed/training_corpus_grpo_prompts.jsonl")
parser.add_argument("--samples", type=int, default=3)
parser.add_argument("--output", default="outputs/reports/postsave_inference.json")
return parser.parse_args()
def _load_prompt_rows(path: Path, limit: int) -> list[dict[str, Any]]:
if not path.exists():
return []
rows: list[dict[str, Any]] = []
with path.open("r", encoding="utf-8") as handle:
for line in handle:
line = line.strip()
if not line:
continue
try:
payload = json.loads(line)
except json.JSONDecodeError:
continue
if isinstance(payload, dict):
rows.append(payload)
if len(rows) >= limit:
break
return rows
def _prompt_to_text(row: dict[str, Any]) -> str:
prompt = row.get("prompt", {}) if isinstance(row.get("prompt"), dict) else {}
candidates = prompt.get("candidates", prompt.get("candidate_set", []))
candidate_ids = [
str(item.get("candidate_id"))
for item in candidates
if isinstance(item, dict) and item.get("candidate_id")
]
text = {
"instruction": "Choose one candidate_id and justify briefly.",
"patient_id": prompt.get("patient_id", prompt.get("patient_summary", {}).get("patient_id", "unknown")),
"candidate_ids": candidate_ids,
"format": "candidate_id=<cand_xx>; rationale=<text>",
}
return json.dumps(text, ensure_ascii=True)
def _discover_base_model(adapter_dir: Path) -> str:
cfg = adapter_dir / "adapter_config.json"
if not cfg.exists():
return ""
try:
payload = json.loads(cfg.read_text(encoding="utf-8"))
except json.JSONDecodeError:
return ""
value = payload.get("base_model_name_or_path")
return str(value) if isinstance(value, str) else ""
def _load_model(
merged_model: Path,
adapter_dir: Path,
base_model_arg: str,
):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
if merged_model.exists() and (merged_model / "config.json").exists():
tokenizer = AutoTokenizer.from_pretrained(str(merged_model))
model = AutoModelForCausalLM.from_pretrained(
str(merged_model),
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
)
source = "merged"
return model, tokenizer, source
if not adapter_dir.exists():
raise FileNotFoundError(f"adapter_dir_not_found:{adapter_dir}")
from peft import PeftModel
base_model = base_model_arg.strip() or _discover_base_model(adapter_dir)
if not base_model:
raise RuntimeError("missing_base_model_for_adapter")
tokenizer = AutoTokenizer.from_pretrained(base_model)
base = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
)
model = PeftModel.from_pretrained(base, str(adapter_dir))
source = "adapter"
return model, tokenizer, source
def _fallback_completion(row: dict[str, Any]) -> tuple[str, str | None]:
prompt = row.get("prompt", {}) if isinstance(row.get("prompt"), dict) else {}
candidates = prompt.get("candidates", prompt.get("candidate_set", []))
candidate_ids = [
str(item.get("candidate_id"))
for item in candidates
if isinstance(item, dict) and item.get("candidate_id")
]
candidate_id = candidate_ids[0] if candidate_ids else None
completion = (
f"candidate_id={candidate_id}; rationale=fallback_policy_artifact"
if candidate_id
else "candidate_id=cand_01; rationale=fallback_policy_artifact"
)
return completion, candidate_id
def _extract_candidate_id(text: str) -> str | None:
match = re.search(r"cand_\d+", text.lower())
if not match:
return None
return match.group(0)
def main() -> None:
args = parse_args()
root = Path(__file__).resolve().parents[1]
merged_model = (root / args.merged_model).resolve()
adapter_dir = (root / args.adapter_dir).resolve()
prompts_path = (root / args.prompts).resolve()
rows = _load_prompt_rows(prompts_path, limit=max(1, args.samples))
if not rows:
raise SystemExit(f"no_prompts_loaded:{prompts_path}")
fallback_policy_file = (root / "checkpoints" / "sft_policy_fallback.json").resolve()
model = None
tokenizer = None
model_source = "fallback_policy"
model_load_error = ""
try:
model, tokenizer, model_source = _load_model(
merged_model=merged_model,
adapter_dir=adapter_dir,
base_model_arg=args.base_model,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
except Exception as exc: # noqa: BLE001
model_load_error = str(exc)
if not fallback_policy_file.exists():
raise
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
if model is not None:
model = model.to(device)
model.eval()
env = PolyGuardEnv()
results: list[dict[str, Any]] = []
for idx, row in enumerate(rows):
env.reset(seed=17_000 + idx, difficulty="medium")
prompt_text = _prompt_to_text(row)
started = time.perf_counter()
if model is not None and tokenizer is not None:
encoded = tokenizer(prompt_text, return_tensors="pt", truncation=True, max_length=512)
encoded = {key: value.to(device) for key, value in encoded.items()}
with torch.no_grad():
generated = model.generate(
**encoded,
max_new_tokens=80,
do_sample=False,
temperature=0.0,
eos_token_id=tokenizer.eos_token_id,
)
decoded = tokenizer.decode(generated[0], skip_special_tokens=True)
completion = decoded[len(prompt_text) :].strip() if decoded.startswith(prompt_text) else decoded
candidate_id = _extract_candidate_id(completion)
else:
completion, candidate_id = _fallback_completion(row)
latency_seconds = time.perf_counter() - started
all_actions = env.get_candidate_actions()
legal_actions = env.get_legal_actions()
by_id_all = {str(item.get("candidate_id", "")).lower(): item for item in all_actions}
by_id_legal = {str(item.get("candidate_id", "")).lower(): item for item in legal_actions}
action = by_id_legal.get(str(candidate_id or "").lower())
if action is None:
action = by_id_all.get(str(candidate_id or "").lower())
if action is None and legal_actions:
action = legal_actions[0]
if action is None:
results.append(
{
"idx": idx,
"prompt": prompt_text,
"completion": completion,
"candidate_id": candidate_id,
"selected_candidate": None,
"env_reward": 0.001,
"latency_seconds": round(latency_seconds, 3),
"valid": False,
"reason": "no_action_available",
}
)
continue
_, reward, done, info = env.step(action)
results.append(
{
"idx": idx,
"prompt": prompt_text,
"completion": completion,
"candidate_id": candidate_id,
"selected_candidate": action.get("candidate_id"),
"env_reward": clamp_reward(float(reward)),
"latency_seconds": round(latency_seconds, 3),
"done": bool(done),
"valid": bool(info.get("safety_report", {}).get("legal", False)),
"termination_reason": info.get("termination_reason"),
}
)
valid_rate = sum(1.0 for row in results if row.get("valid")) / len(results)
avg_reward = clamp_reward(sum(float(row.get("env_reward", 0.0)) for row in results) / len(results))
avg_latency_seconds = round(
sum(float(row.get("latency_seconds", 0.0)) for row in results) / len(results),
3,
)
payload = {
"status": "ok",
"model_source": model_source,
"model_load_error": model_load_error,
"samples": len(results),
"valid_rate": round(valid_rate, 3),
"avg_env_reward": avg_reward,
"avg_latency_seconds": avg_latency_seconds,
"results": results,
}
output_path = root / args.output
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8")
print("postsave_inference_ok")
if __name__ == "__main__":
main()