pharma-vigilance / tests /test_env.py
modelbuilderhq's picture
Upload folder using huggingface_hub
9ab33d8 verified
import sys
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from env import Action, PharmaVigilanceEnv
from tasks import (
cluster_signal_medium_action_grader,
cluster_signal_medium_grader,
confounded_hard_action_grader,
confounded_hard_grader,
get_task,
get_tasks,
known_signal_easy_action_grader,
known_signal_easy_grader,
)
def test_reset_loads_easy_task():
env = PharmaVigilanceEnv()
obs = env.reset("known_signal_easy")
assert obs.task_id == "known_signal_easy"
assert obs.step_number == 0
assert obs.max_steps == 2
assert len(obs.reports) == 1
def test_known_signal_grader_full_credit():
reward = known_signal_easy_action_grader(
Action(
classification="known_side_effect",
suspect_drug="Lisinopril",
severity_assessment="mild",
recommended_action="log_and_monitor",
reasoning="Known reaction pattern.",
confidence=91,
)
)
assert reward.total == 1.0
def test_medium_cluster_grader_partial_credit():
reward = cluster_signal_medium_action_grader(
Action(
classification="new_signal",
suspect_drug="Cardiovexa",
severity_assessment="moderate",
recommended_action="escalate",
reasoning="A cluster is forming.",
)
)
assert reward.total == 0.75
def test_hard_grader_reasoning_bonus():
reward = confounded_hard_action_grader(
Action(
classification="new_signal",
suspect_drug="Tacrolimus+Voriconazole",
severity_assessment="critical",
recommended_action="escalate",
reasoning="This looks like a tacrolimus-voriconazole drug interaction with toxic levels.",
)
)
assert reward.total == 1.0
assert reward.breakdown["reasoning_bonus"] == 0.05
def test_hard_grader_substring_suspect_match():
reward = confounded_hard_action_grader(
Action(
classification="new_signal",
suspect_drug="Tacrolimus",
severity_assessment="critical",
recommended_action="escalate",
reasoning="Voriconazole likely raised tacrolimus exposure.",
)
)
assert reward.breakdown["suspect_drug"] == 0.25
def test_env_step_returns_done():
env = PharmaVigilanceEnv()
env.reset("confounded_hard")
obs, reward, done, info = env.step(
Action(
classification="new_signal",
suspect_drug="Tacrolimus+Voriconazole",
severity_assessment="critical",
recommended_action="escalate",
reasoning="Tacrolimus toxicity from an azole interaction.",
)
)
assert done is False
assert obs.step_number == 1
assert "reward_breakdown" in info
assert reward.total >= 0.20
obs, reward, done, info = env.step(
Action(
classification="new_signal",
suspect_drug="Tacrolimus+Voriconazole",
severity_assessment="critical",
recommended_action="escalate",
reasoning="Tacrolimus toxicity from an azole interaction.",
)
)
assert done is True
assert obs.step_number == 2
assert reward.total >= 0.85
def test_first_step_returns_partial_reward_and_review_feedback():
env = PharmaVigilanceEnv()
obs = env.reset("cluster_signal_medium")
obs, reward, done, info = env.step(
Action(
classification="new_signal",
suspect_drug="Cardiovexa",
severity_assessment="severe",
recommended_action="escalate",
reasoning="Clustered bradycardia on a newer therapy.",
confidence=88,
)
)
assert done is False
assert obs.step_number == 1
assert reward.total > 0.0
assert info["phase"] == "initial_triage"
assert "Senior review note" in obs.feedback
def test_final_step_awards_revision_bonus_when_agent_improves():
env = PharmaVigilanceEnv()
env.reset("cluster_signal_medium")
env.step(
Action(
classification="noise",
suspect_drug="Unknown",
severity_assessment="mild",
recommended_action="dismiss",
reasoning="Weak initial guess.",
confidence=90,
)
)
_, reward, done, info = env.step(
Action(
classification="new_signal",
suspect_drug="Cardiovexa",
severity_assessment="severe",
recommended_action="escalate",
reasoning="Follow-up reports confirm a coherent bradycardia cluster.",
confidence=82,
)
)
assert done is True
assert reward.breakdown["revision_bonus"] == 0.05
assert info["phase"] == "final_review"
def test_final_step_applies_stubborn_penalty_for_repeating_weak_answer():
env = PharmaVigilanceEnv()
env.reset("confounded_hard")
weak = Action(
classification="noise",
suspect_drug="Trimethoprim-sulfamethoxazole",
severity_assessment="mild",
recommended_action="dismiss",
reasoning="Reporter probably overcalled it.",
confidence=85,
)
env.step(weak)
_, reward, done, _ = env.step(weak)
assert done is True
assert reward.breakdown["stubborn_penalty"] == -0.05
def test_initial_step_can_return_negative_reward_for_unsafe_triage():
env = PharmaVigilanceEnv()
env.reset("cluster_signal_medium")
_, reward, done, info = env.step(
Action(
classification="noise",
suspect_drug="Unknown",
severity_assessment="mild",
recommended_action="dismiss",
reasoning="No obvious concern.",
confidence=95,
)
)
assert done is False
assert info["phase"] == "initial_triage"
assert reward.total < 0.0
def test_single_step_action_grader_can_return_negative_total():
reward = cluster_signal_medium_action_grader(
Action(
classification="noise",
suspect_drug="Unknown",
severity_assessment="mild",
recommended_action="dismiss",
reasoning="Probably unrelated.",
confidence=95,
)
)
assert reward.total < 0.0
def test_overconfidence_penalty_applies_on_weak_single_step_grading():
reward = cluster_signal_medium_action_grader(
Action(
classification="noise",
suspect_drug="Unknown",
severity_assessment="mild",
recommended_action="dismiss",
reasoning="This is probably nothing.",
confidence=95,
)
)
assert reward.breakdown["confidence_adjustment"] == -0.10
def test_low_confidence_penalty_applies_on_strong_answer():
reward = known_signal_easy_action_grader(
Action(
classification="known_side_effect",
suspect_drug="Lisinopril",
severity_assessment="mild",
recommended_action="log_and_monitor",
reasoning="Known labeled ACE-inhibitor cough.",
confidence=20,
)
)
assert reward.breakdown["confidence_adjustment"] == -0.03
def test_episode_rejects_third_step_after_completion():
env = PharmaVigilanceEnv()
env.reset("known_signal_easy")
good = Action(
classification="known_side_effect",
suspect_drug="Lisinopril",
severity_assessment="mild",
recommended_action="log_and_monitor",
reasoning="Known ACE-inhibitor cough.",
confidence=90,
)
env.step(good)
env.step(good)
with pytest.raises(RuntimeError, match="Episode already complete"):
env.step(good)
def test_state_tracks_last_action():
env = PharmaVigilanceEnv()
env.reset("known_signal_easy")
env.step(
Action(
classification="known_side_effect",
suspect_drug="Lisinopril",
severity_assessment="mild",
recommended_action="log_and_monitor",
reasoning="Known adverse effect.",
confidence=90,
)
)
env.step(
Action(
classification="known_side_effect",
suspect_drug="Lisinopril",
severity_assessment="mild",
recommended_action="log_and_monitor",
reasoning="Known adverse effect.",
confidence=90,
)
)
state = env.state()
assert state["step_number"] == 2
assert state["last_action"]["classification"] == "known_side_effect"
def test_all_tasks_available():
tasks = get_tasks()
assert set(tasks.keys()) == {
"known_signal_easy",
"cluster_signal_medium",
"confounded_hard",
}
def test_grouped_tasks_expose_easy_medium_hard_pools():
grouped = get_tasks(grouped=True)
assert set(grouped.keys()) == {"easy", "medium", "hard"}
assert grouped["easy"][0].task_id == "known_signal_easy"
assert grouped["medium"][0].task_id == "cluster_signal_medium"
assert grouped["hard"][0].task_id == "confounded_hard"
def test_get_task_returns_hard_truth():
task = get_task("confounded_hard")
assert task.ground_truth.suspect_drug == "Tacrolimus+Voriconazole"
def test_public_graders_are_strictly_bounded():
assert known_signal_easy_grader({"rewards": [1.0]}) == 0.99
assert cluster_signal_medium_grader({"rewards": [0.0]}) == 0.01
assert confounded_hard_grader({"score": 1.5}) == 0.99
def test_inference_final_score_uses_public_task_grader():
pytest.importorskip("openenv")
from inference import final_score
rewards = [0.4, 1.0]
assert final_score("known_signal_easy", rewards) == known_signal_easy_grader({"rewards": rewards})
assert final_score("cluster_signal_medium", rewards) == cluster_signal_medium_grader({"rewards": rewards})
assert final_score("confounded_hard", rewards) == confounded_hard_grader({"rewards": rewards})
def test_http_reset_then_step_roundtrip():
pytest.importorskip("openenv")
from fastapi.testclient import TestClient
from server.app import app
client = TestClient(app)
reset_response = client.post("/reset", json={})
assert reset_response.status_code == 200
first_step = client.post(
"/step",
json={
"action": {
"classification": "known_side_effect",
"suspect_drug": "Lisinopril",
"severity_assessment": "mild",
"recommended_action": "log_and_monitor",
"reasoning": "Known ACE inhibitor cough.",
"confidence": 90,
}
},
)
assert first_step.status_code == 200
first_payload = first_step.json()
assert first_payload["done"] is False
step_response = client.post(
"/step",
json={
"action": {
"classification": "known_side_effect",
"suspect_drug": "Lisinopril",
"severity_assessment": "mild",
"recommended_action": "log_and_monitor",
"reasoning": "Known ACE inhibitor cough.",
"confidence": 90,
}
},
)
assert step_response.status_code == 200
payload = step_response.json()
assert payload["done"] is True
assert payload["reward"] == 1.0