adithya9903's picture
Flatten project to root for OpenEnv submission readiness.
fa51dd9
"""Tests for the FastAPI HTTP server (OpenEnv create_app endpoints).
OpenEnv HTTP endpoints are *stateless*: each /reset and /step creates a
fresh environment instance. Multi-step sessions only work via WebSocket.
These tests validate single-call behaviour and schema contracts.
"""
from __future__ import annotations
import json
import pytest
from fastapi.testclient import TestClient
from polypharmacy_env.api.server import app
@pytest.fixture
def client() -> TestClient:
return TestClient(app)
class TestHealth:
def test_health(self, client: TestClient) -> None:
resp = client.get("/health")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "healthy"
class TestReset:
def test_reset_default(self, client: TestClient) -> None:
resp = client.post("/reset", json={})
assert resp.status_code == 200
data = resp.json()
assert "observation" in data
assert data["done"] is False
def test_reset_with_task(self, client: TestClient) -> None:
resp = client.post("/reset", json={"task_id": "easy_screening"})
assert resp.status_code == 200
obs = resp.json()["observation"]
assert obs["task_id"] == "easy_screening"
def test_reset_observation_has_medications(self, client: TestClient) -> None:
resp = client.post("/reset", json={"task_id": "easy_screening", "seed": 42})
assert resp.status_code == 200
obs = resp.json()["observation"]
assert len(obs["current_medications"]) >= 3
class TestStep:
"""Test /step endpoint – each call is independent (stateless)."""
def test_step_finish(self, client: TestClient) -> None:
resp = client.post(
"/step",
json={"action": {"action_type": "finish_review"}},
)
assert resp.status_code == 200
data = resp.json()
assert "observation" in data
def test_invalid_action_422(self, client: TestClient) -> None:
resp = client.post(
"/step",
json={"action": {"action_type": "invalid_type"}},
)
assert resp.status_code == 422
class TestSchema:
def test_schema(self, client: TestClient) -> None:
resp = client.get("/schema")
assert resp.status_code == 200
data = resp.json()
# OpenEnv schema endpoint returns keys: action, observation, state
assert "action" in data
assert "observation" in data
class TestWebSocketSession:
"""Test multi-step sessions through the /ws WebSocket endpoint.
OpenEnv WS protocol:
Send: {"type": "reset", "data": {"task_id": "...", "seed": ...}}
Recv: {"type": "observation", "data": {"observation": {...}, "reward": ..., "done": ...}}
Send: {"type": "step", "data": {"action_type": "...", ...}}
Recv: {"type": "observation", "data": {"observation": {...}, ...}}
Send: {"type": "state"}
Recv: {"type": "state", "data": {...state fields...}}
"""
def test_ws_reset_step_finish(self, client: TestClient) -> None:
with client.websocket_connect("/ws") as ws:
# Reset
ws.send_json({
"type": "reset",
"data": {"task_id": "easy_screening", "seed": 42},
})
reset_resp = ws.receive_json()
assert reset_resp["type"] == "observation"
reset_data = reset_resp["data"]
assert reset_data["done"] is False
obs = reset_data["observation"]
assert obs["task_id"] == "easy_screening"
meds = obs["current_medications"]
assert len(meds) >= 3
# Step – query DDI
if len(meds) >= 2:
ws.send_json({
"type": "step",
"data": {
"action_type": "query_ddi",
"drug_id_1": meds[0]["drug_id"],
"drug_id_2": meds[1]["drug_id"],
},
})
step_resp = ws.receive_json()
assert step_resp["type"] == "observation"
assert step_resp["data"]["done"] is False
# Finish
ws.send_json({
"type": "step",
"data": {"action_type": "finish_review"},
})
finish_resp = ws.receive_json()
assert finish_resp["type"] == "observation"
assert finish_resp["data"]["done"] is True
def test_ws_state(self, client: TestClient) -> None:
with client.websocket_connect("/ws") as ws:
ws.send_json({
"type": "reset",
"data": {"task_id": "easy_screening", "seed": 0},
})
ws.receive_json() # consume reset response
ws.send_json({"type": "state"})
state_resp = ws.receive_json()
assert state_resp["type"] == "state"
state_data = state_resp["data"]
assert state_data["step_count"] == 0
assert state_data["task_id"] == "easy_screening"