| """ |
| SepsisPilot β Pre-Submission Validation Script |
| Run this before submitting to verify OpenEnv spec compliance. |
| |
| Usage: python validate.py [--url http://localhost:7860] |
| """ |
|
|
| from __future__ import annotations |
| import argparse |
| import sys |
| import requests |
|
|
| PASS = "\033[92m[PASS]\033[0m" |
| FAIL = "\033[91m[FAIL]\033[0m" |
| WARN = "\033[93m[WARN]\033[0m" |
| INFO = "\033[94m[INFO]\033[0m" |
|
|
| errors = 0 |
|
|
|
|
| def check(label: str, condition: bool, msg: str = ""): |
| global errors |
| if condition: |
| print(f" {PASS} {label}") |
| else: |
| print(f" {FAIL} {label} {msg}") |
| errors += 1 |
|
|
|
|
| def section(title: str): |
| print(f"\n{'β'*50}\n {title}\n{'β'*50}") |
|
|
|
|
| def validate(base_url: str): |
| global errors |
| print(f"\n㪠SepsisPilot OpenEnv Validation\n Target: {base_url}\n") |
|
|
| |
| section("1. Health Check") |
| try: |
| r = requests.get(f"{base_url}/health", timeout=10) |
| check("GET /health returns 200", r.status_code == 200) |
| data = r.json() |
| check("Response contains 'status'", "status" in data) |
| check("Status is 'ok'", data.get("status") == "ok") |
| except Exception as e: |
| check("Server reachable", False, str(e)) |
| print("\n [ABORT] Server not reachable. Start the server first.\n") |
| sys.exit(1) |
|
|
| |
| section("2. Task Listing") |
| try: |
| r = requests.get(f"{base_url}/tasks", timeout=10) |
| check("GET /tasks returns 200", r.status_code == 200) |
| tasks = r.json() |
| check("Returns a list", isinstance(tasks, list)) |
| check("At least 3 tasks", len(tasks) >= 3) |
| task_names = [t["name"] for t in tasks] |
| check("mild_sepsis present", "mild_sepsis" in task_names) |
| check("septic_shock present", "septic_shock" in task_names) |
| check("severe_mods present", "severe_mods" in task_names) |
| for t in tasks: |
| check(f" Task '{t['name']}' has difficulty", "difficulty" in t) |
| check(f" Task '{t['name']}' has description", "description" in t) |
| check(f" Task '{t['name']}' has max_steps", "max_steps" in t) |
| except Exception as e: |
| check("Tasks endpoint works", False, str(e)) |
|
|
| |
| section("3. Episode Flow β mild_sepsis (Easy)") |
| _validate_episode(base_url, "mild_sepsis", max_steps=24) |
|
|
| |
| section("4. Episode Flow β septic_shock (Medium)") |
| _validate_episode(base_url, "septic_shock", max_steps=48) |
|
|
| |
| section("5. Episode Flow β severe_mods (Hard)") |
| _validate_episode(base_url, "severe_mods", max_steps=72) |
|
|
| |
| section("6. Grader Score Variance (anti-trivial check)") |
| scores = [] |
| actions_list = [ |
| [5, 5, 5, 1, 1, 1], |
| [0, 0, 0, 0, 0, 0], |
| [4, 4, 4, 4, 4, 4], |
| ] |
| for i, actions in enumerate(actions_list): |
| try: |
| r = requests.post(f"{base_url}/reset", json={"task": "mild_sepsis", "seed": 42}, timeout=10) |
| for a in actions: |
| r = requests.post(f"{base_url}/step", json={"action": a}, timeout=10) |
| if r.json().get("done"): |
| break |
| |
| while not r.json().get("done"): |
| r = requests.post(f"{base_url}/step", json={"action": 0}, timeout=10) |
| grade = requests.get(f"{base_url}/grade", timeout=10).json() |
| scores.append(grade["score"]) |
| except Exception as e: |
| scores.append(None) |
| print(f" {WARN} Strategy {i} failed: {e}") |
|
|
| valid_scores = [s for s in scores if s is not None] |
| check("Grader returns different scores for different strategies", |
| len(set(round(s, 2) for s in valid_scores)) > 1, |
| f"(scores: {[round(s,4) for s in valid_scores]})") |
| check("All scores in [0.0, 1.0]", |
| all(0.0 <= s <= 1.0 for s in valid_scores)) |
|
|
| |
| section("7. Reproducibility (same seed = same result)") |
| try: |
| scores_run1, scores_run2 = [], [] |
| for run_scores in (scores_run1, scores_run2): |
| requests.post(f"{base_url}/reset", json={"task": "mild_sepsis", "seed": 99}, timeout=10) |
| for _ in range(5): |
| r = requests.post(f"{base_url}/step", json={"action": 5}, timeout=10) |
| run_scores.append(round(r.json()["reward"], 4)) |
| if r.json()["done"]: |
| break |
| check("Reward sequences are identical across runs", scores_run1 == scores_run2, |
| f"\n run1={scores_run1}\n run2={scores_run2}") |
| except Exception as e: |
| check("Reproducibility check", False, str(e)) |
|
|
| |
| section("8. Error Handling") |
| try: |
| r = requests.post(f"{base_url}/step", json={"action": 99}, timeout=10) |
| check("Invalid action returns 4xx", r.status_code in (400, 422)) |
| except Exception as e: |
| check("Invalid action error handling", False, str(e)) |
|
|
| |
| print(f"\n{'β'*50}") |
| if errors == 0: |
| print(f" β
All checks passed. Ready for submission!") |
| else: |
| print(f" β {errors} check(s) failed. Fix before submitting.") |
| print(f"{'β'*50}\n") |
| sys.exit(0 if errors == 0 else 1) |
|
|
|
|
| def _validate_episode(base_url: str, task: str, max_steps: int): |
| """Run a short episode and verify all OpenEnv contracts.""" |
| try: |
| |
| r = requests.post(f"{base_url}/reset", json={"task": task, "seed": 42}, timeout=10) |
| check(f"POST /reset 200", r.status_code == 200) |
| state = r.json() |
| check("Reset returns vitals", "vitals" in state) |
| check("Reset returns step=0", state.get("step") == 0) |
| check("Reset returns done=False", state.get("done") == False) |
| check("Reset returns alive=True", state.get("alive") == True) |
|
|
| |
| r = requests.get(f"{base_url}/state", timeout=10) |
| check("GET /state 200", r.status_code == 200) |
|
|
| |
| r = requests.post(f"{base_url}/step", json={"action": 5}, timeout=10) |
| check("POST /step 200", r.status_code == 200) |
| result = r.json() |
| check("Step returns state", "state" in result) |
| check("Step returns reward (float)", isinstance(result.get("reward"), (int, float))) |
| check("Step returns done (bool)", isinstance(result.get("done"), bool)) |
| check("Step returns info (dict)", isinstance(result.get("info"), dict)) |
| check("Step increments step counter", result["state"]["step"] == 1) |
|
|
| |
| reward = result["reward"] |
| check("Reward is finite and in expected range", |
| -15.0 <= reward <= 10.0, f"(got {reward})") |
|
|
| |
| done = result["done"] |
| for _ in range(max_steps): |
| if done: |
| break |
| r = requests.post(f"{base_url}/step", json={"action": 5}, timeout=10) |
| done = r.json()["done"] |
|
|
| |
| r = requests.get(f"{base_url}/grade", timeout=10) |
| check("GET /grade 200 after episode", r.status_code == 200) |
| grade = r.json() |
| check("Grade has score in [0,1]", |
| isinstance(grade.get("score"), (int, float)) and 0.0 <= grade["score"] <= 1.0, |
| f"(got {grade.get('score')})") |
| check("Grade has reason string", isinstance(grade.get("reason"), str)) |
| check("Grade has metrics dict", isinstance(grade.get("metrics"), dict)) |
| check("Grade has passed bool", isinstance(grade.get("passed"), bool)) |
|
|
| except Exception as e: |
| check(f"Episode for {task} completed without error", False, str(e)) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--url", default="http://localhost:7860") |
| args = parser.parse_args() |
| validate(args.url) |
|
|