Spaces:
Running
Running
| """Contract tests for judge audit payloads (TST 11). | |
| Verifies that terminal StepInfo and EpisodeLog contain all required | |
| audit fields with correct types and semantics. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import pytest | |
| from fastapi.testclient import TestClient | |
| from replicalab.models import EpisodeLog, RewardBreakdown, StepInfo | |
| from server.app import app | |
| def client(): | |
| return TestClient(app) | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _good_action_payload(client: TestClient) -> dict: | |
| from replicalab.scenarios import generate_scenario | |
| scenario = generate_scenario(seed=42, template="math_reasoning", difficulty="easy") | |
| lab = scenario.lab_manager_observation | |
| spec = scenario.hidden_reference_spec | |
| return { | |
| "action_type": "propose_protocol", | |
| "sample_size": 10, | |
| "controls": ["baseline", "ablation"], | |
| "technique": spec.summary[:60] if spec.summary else "replication_plan", | |
| "duration_days": max(1, min(2, lab.time_limit_days)), | |
| "required_equipment": list(lab.equipment_available[:1]) if lab.equipment_available else [], | |
| "required_reagents": list(lab.reagents_in_stock[:1]) if lab.reagents_in_stock else [], | |
| "questions": [], | |
| "rationale": ( | |
| f"Plan addresses: {', '.join(spec.required_elements[:2])}. " | |
| f"Target metric: {spec.target_metric}. " | |
| f"Target value: {spec.target_value}. " | |
| "Stay within budget and schedule." | |
| ), | |
| } | |
| def _accept_action() -> dict: | |
| return { | |
| "action_type": "accept", | |
| "sample_size": 0, | |
| "controls": [], | |
| "technique": "", | |
| "duration_days": 0, | |
| "required_equipment": [], | |
| "required_reagents": [], | |
| "questions": [], | |
| "rationale": "", | |
| } | |
| def _run_to_terminal(client: TestClient, *, accept: bool = True) -> tuple[str, dict]: | |
| """Reset, propose, and optionally accept. Returns (episode_id, terminal_response_json).""" | |
| reset = client.post("/reset", json={"seed": 42, "scenario": "math_reasoning", "difficulty": "easy"}).json() | |
| session_id = reset["session_id"] | |
| episode_id = reset["episode_id"] | |
| action = _good_action_payload(client) | |
| client.post("/step", json={"session_id": session_id, "action": action}) | |
| if accept: | |
| resp = client.post("/step", json={"session_id": session_id, "action": _accept_action()}) | |
| else: | |
| # Run to timeout | |
| for _ in range(10): | |
| resp = client.post("/step", json={"session_id": session_id, "action": action}) | |
| if resp.json()["done"]: | |
| break | |
| return episode_id, resp.json() | |
| # --------------------------------------------------------------------------- | |
| # StepInfo audit field contracts | |
| # --------------------------------------------------------------------------- | |
| class TestStepInfoAuditContract: | |
| """Terminal StepInfo must contain all audit fields.""" | |
| def test_terminal_info_has_verdict(self, client: TestClient) -> None: | |
| _, data = _run_to_terminal(client) | |
| assert data["info"]["verdict"] in ("accept", "revise", "timeout") | |
| def test_terminal_info_has_judge_notes(self, client: TestClient) -> None: | |
| _, data = _run_to_terminal(client) | |
| assert isinstance(data["info"]["judge_notes"], str) | |
| assert len(data["info"]["judge_notes"]) > 0 | |
| def test_terminal_info_has_reward_breakdown(self, client: TestClient) -> None: | |
| _, data = _run_to_terminal(client) | |
| rb = data["info"]["reward_breakdown"] | |
| assert rb is not None | |
| for key in ("rigor", "feasibility", "fidelity", "parsimony"): | |
| assert key in rb | |
| assert 0.0 <= rb[key] <= 1.0 | |
| def test_terminal_info_has_top_failure_reasons(self, client: TestClient) -> None: | |
| _, data = _run_to_terminal(client) | |
| reasons = data["info"]["top_failure_reasons"] | |
| assert isinstance(reasons, list) | |
| def test_terminal_info_has_agreement_reached(self, client: TestClient) -> None: | |
| _, data = _run_to_terminal(client) | |
| assert isinstance(data["info"]["agreement_reached"], bool) | |
| assert data["info"]["agreement_reached"] is True | |
| def test_non_terminal_info_has_no_verdict(self, client: TestClient) -> None: | |
| reset = client.post("/reset", json={"seed": 42}).json() | |
| action = _good_action_payload(client) | |
| resp = client.post("/step", json={"session_id": reset["session_id"], "action": action}) | |
| data = resp.json() | |
| assert data["done"] is False | |
| assert data["info"]["verdict"] is None | |
| assert data["info"]["reward_breakdown"] is None | |
| # --------------------------------------------------------------------------- | |
| # EpisodeLog audit field contracts | |
| # --------------------------------------------------------------------------- | |
| class TestEpisodeLogAuditContract: | |
| """GET /replay/{episode_id} must include full audit metadata.""" | |
| def test_replay_has_verdict_and_judge_notes(self, client: TestClient) -> None: | |
| episode_id, _ = _run_to_terminal(client) | |
| replay = client.get(f"/replay/{episode_id}").json() | |
| assert replay["verdict"] in ("accept", "revise", "timeout") | |
| assert isinstance(replay["judge_notes"], str) | |
| assert len(replay["judge_notes"]) > 0 | |
| def test_replay_has_reward_breakdown(self, client: TestClient) -> None: | |
| episode_id, _ = _run_to_terminal(client) | |
| replay = client.get(f"/replay/{episode_id}").json() | |
| rb = replay["reward_breakdown"] | |
| assert rb is not None | |
| for key in ("rigor", "feasibility", "fidelity", "parsimony"): | |
| assert key in rb | |
| def test_replay_has_top_failure_reasons(self, client: TestClient) -> None: | |
| episode_id, _ = _run_to_terminal(client) | |
| replay = client.get(f"/replay/{episode_id}").json() | |
| assert isinstance(replay["top_failure_reasons"], list) | |
| def test_replay_has_transcript(self, client: TestClient) -> None: | |
| episode_id, _ = _run_to_terminal(client) | |
| replay = client.get(f"/replay/{episode_id}").json() | |
| assert isinstance(replay["transcript"], list) | |
| assert len(replay["transcript"]) > 0 | |
| def test_replay_has_invalid_action_fields(self, client: TestClient) -> None: | |
| episode_id, _ = _run_to_terminal(client) | |
| replay = client.get(f"/replay/{episode_id}").json() | |
| assert "invalid_action_count" in replay | |
| assert "invalid_action_rate" in replay | |
| assert isinstance(replay["invalid_action_count"], int) | |
| assert isinstance(replay["invalid_action_rate"], (int, float)) | |
| assert replay["invalid_action_count"] >= 0 | |
| assert 0.0 <= replay["invalid_action_rate"] <= 1.0 | |
| def test_replay_has_total_reward(self, client: TestClient) -> None: | |
| episode_id, _ = _run_to_terminal(client) | |
| replay = client.get(f"/replay/{episode_id}").json() | |
| assert isinstance(replay["total_reward"], (int, float)) | |
| # --------------------------------------------------------------------------- | |
| # Pydantic model contracts (unit-level) | |
| # --------------------------------------------------------------------------- | |
| class TestAuditModelContracts: | |
| """Pydantic model round-trip for audit fields.""" | |
| def test_step_info_default_audit_fields(self) -> None: | |
| info = StepInfo() | |
| assert info.agreement_reached is False | |
| assert info.verdict is None | |
| assert info.judge_notes is None | |
| assert info.reward_breakdown is None | |
| assert info.top_failure_reasons == [] | |
| def test_step_info_with_audit_fields(self) -> None: | |
| info = StepInfo( | |
| agreement_reached=True, | |
| verdict="accept", | |
| judge_notes="All clear.", | |
| reward_breakdown=RewardBreakdown(rigor=0.9, feasibility=0.8, fidelity=0.7), | |
| top_failure_reasons=["minor issue"], | |
| ) | |
| assert info.verdict == "accept" | |
| assert info.reward_breakdown.rigor == 0.9 | |
| assert info.top_failure_reasons == ["minor issue"] | |
| def test_episode_log_invalid_action_fields_default(self) -> None: | |
| log = EpisodeLog() | |
| assert log.invalid_action_count == 0 | |
| assert log.invalid_action_rate == 0.0 | |
| def test_episode_log_invalid_action_fields_set(self) -> None: | |
| log = EpisodeLog(invalid_action_count=3, invalid_action_rate=0.25) | |
| assert log.invalid_action_count == 3 | |
| assert log.invalid_action_rate == 0.25 | |
| def test_episode_log_json_round_trip_preserves_audit(self) -> None: | |
| log = EpisodeLog( | |
| episode_id="audit-test", | |
| verdict="accept", | |
| judge_notes="Solid work.", | |
| top_failure_reasons=["none"], | |
| invalid_action_count=1, | |
| invalid_action_rate=0.1, | |
| reward_breakdown=RewardBreakdown( | |
| rigor=0.9, | |
| feasibility=0.85, | |
| fidelity=0.8, | |
| penalties={"stalling": 0.05}, | |
| ), | |
| ) | |
| raw = log.model_dump_json() | |
| restored = EpisodeLog.model_validate_json(raw) | |
| assert restored.verdict == "accept" | |
| assert restored.invalid_action_count == 1 | |
| assert restored.invalid_action_rate == 0.1 | |
| assert restored.reward_breakdown.penalties == {"stalling": 0.05} | |