vicky0406 commited on
Commit
ee117a1
·
verified ·
1 Parent(s): a829d4f

Upload 18 files

Browse files
.gitignore ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ *.manifest
32
+ *.spec
33
+
34
+ # Unit test / coverage reports
35
+ htmlcov/
36
+ .tox/
37
+ .nox/
38
+ .coverage
39
+ .coverage.*
40
+ .cache
41
+ nosetests.xml
42
+ coverage.xml
43
+ *.cover
44
+ *.py,cover
45
+ .hypothesis/
46
+ .pytest_cache/
47
+
48
+ # Translations
49
+ *.mo
50
+ *.pot
51
+
52
+ # Django stuff:
53
+ *.log
54
+ local_settings.py
55
+ db.sqlite3
56
+ db.sqlite3-journal
57
+
58
+ # Flask stuff:
59
+ instance/
60
+ .webassets-cache
61
+
62
+ # Scrapy stuff:
63
+ .scrapy
64
+
65
+ # Sphinx documentation
66
+ docs/_build/
67
+
68
+ # PyBuilder
69
+ target/
70
+
71
+ # Jupyter Notebook
72
+ .ipynb_checkpoints
73
+
74
+ # IPython
75
+ profile_default/
76
+ ipython_config.py
77
+
78
+ # pyenv
79
+ .python-version
80
+
81
+ # pipenv
82
+ Pipfile.lock
83
+
84
+ # PEP 582
85
+ __pypackages__/
86
+
87
+ # Celery stuff
88
+ celerybeat-schedule
89
+ celerybeat.pid
90
+
91
+ # SageMath parsed files
92
+ *.sage.py
93
+
94
+ # Environments
95
+ .env
96
+ .venv
97
+ env/
98
+ venv/
99
+ ENV/
100
+ env.bak/
101
+ venv.bak/
102
+
103
+ # Spyder project settings
104
+ .spyderproject
105
+ .spyproject
106
+
107
+ # Rope project settings
108
+ .ropeproject
109
+
110
+ # mkdocs documentation
111
+ /site
112
+
113
+ # mypy
114
+ .mypy_cache/
115
+ .dmypy.json
116
+ dmypy.json
117
+
118
+ # Pyre type checker
119
+ .pyre/
120
+
121
+ # IDE
122
+ .vscode/
123
+ .idea/
124
+ *.swp
125
+ *.swo
126
+ *~
127
+ .DS_Store
128
+
129
+ # Project specific
130
+ venv/
131
+ .env.local
132
+ .env.*.local
133
+ instance/
134
+ .webassets-cache
135
+ docker-compose.override.yml
136
+
137
+ # API Keys
138
+ .openai_api_key
139
+ .hf_token
140
+
141
+ # Model checkpoints
142
+ models/
143
+ checkpoints/
144
+
145
+ # Logs
146
+ logs/
147
+ *.log
148
+ log/
149
+
150
+ # Temporary files
151
+ tmp/
152
+ temp/
153
+ *.tmp
154
+
155
+ # OS
156
+ .DS_Store
157
+ Thumbs.db
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Set environment variables
6
+ ENV PYTHONUNBUFFERED=1 \
7
+ PYTHONDONTWRITEBYTECODE=1 \
8
+ PIP_NO_CACHE_DIR=1 \
9
+ PIP_DISABLE_PIP_VERSION_CHECK=1
10
+
11
+ # Install system dependencies (minimal)
12
+ RUN apt-get update && apt-get install -y --no-install-recommends \
13
+ curl \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # Copy and install Python dependencies first (for layer caching)
17
+ COPY server/requirements.txt /tmp/requirements.txt
18
+ RUN pip install --no-cache-dir -r /tmp/requirements.txt && \
19
+ rm /tmp/requirements.txt
20
+
21
+ # Copy application code
22
+ COPY models.py .
23
+ COPY client.py .
24
+ COPY server/ ./server/
25
+
26
+ # Create __init__ files for Python packages
27
+ RUN touch __init__.py && \
28
+ touch server/__init__.py
29
+
30
+ # Health check - validates the server is running and responsive
31
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
32
+ CMD curl -f http://localhost:8000/health || exit 1
33
+
34
+ # Expose port
35
+ EXPOSE 8000
36
+
37
+ # Run the FastAPI server with uvicorn
38
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Vicky-220
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,214 @@
1
- ---
2
- title: Synapse Openenv
3
- emoji:
4
- colorFrom: gray
5
- colorTo: pink
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- short_description: Medical diagnostic OpenEnv environment for RL training
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Medical Diagnostic Environment
2
+
3
+ A lightweight OpenEnv environment for training agents to diagnose patients using clinical reasoning. The agent interacts through a turn-based dialog, orders tests, and submits a final diagnosis.
4
+
5
+ ## What this project does
6
+
7
+ The environment simulates a real medical workflow:
8
+ - Present a patient with symptoms and context
9
+ - Let the agent ask clinical questions
10
+ - Allow the agent to order diagnostic tests
11
+ - Score the agent on diagnosis accuracy and process quality
12
+
13
+ It is designed to be used for training or evaluation with reinforcement learning systems.
14
+
15
+ ## Why this environment is useful
16
+
17
+ This is not a toy problem. It is a small clinical reasoning task with:
18
+ - real clinical cases and realistic feedback
19
+ - multi-step decisions
20
+ - partial reward signals for progress
21
+ - a clear end goal: accurate final diagnosis
22
+
23
+ ## Tasks included
24
+
25
+ There are three difficulty tiers built into the environment:
26
+
27
+ ### Easy
28
+ - Seasonal Influenza
29
+ - Urinary Tract Infection
30
+
31
+ ### Medium
32
+ - Community-Acquired Pneumonia
33
+ - Acute Appendicitis
34
+
35
+ ### Hard
36
+ - Infective Endocarditis
37
+ - Bacterial Meningitis
38
+
39
+ Each case is graded from 0.0 to 1.0 based on the agent's final diagnosis and stepwise decisions.
40
+
41
+ ## Action and observation interface
42
+
43
+ ### Actions
44
+ The agent sends one of three actions:
45
+
46
+ ```python
47
+ class DiagnosticAction(Action):
48
+ action_type: str # ask_question | order_test | submit_diagnosis
49
+ question: Optional[str] = None
50
+ test_name: Optional[str] = None
51
+ diagnosis: Optional[str] = None
52
+ ```
53
+
54
+ ### Observations
55
+ Each step returns a structured observation:
56
+
57
+ ```python
58
+ class PatientObservation(Observation):
59
+ done: bool
60
+ reward: Optional[float]
61
+ message: str
62
+ patient_response: Optional[Dict]
63
+ test_result: Optional[Dict]
64
+ questions_asked: List[str]
65
+ tests_completed: List[str]
66
+ patient_data_revealed: Dict
67
+ steps_taken: int
68
+ max_steps: int
69
+ ```
70
+
71
+ ## Setup
72
+
73
+ ### Requirements
74
+ - Python 3.10+
75
+ - Docker for containerized deployment
76
+
77
+ ### Local setup
78
+
79
+ ```bash
80
+ git clone <repository-url>
81
+ cd meta_synapse_hackathon
82
+ python -m venv venv
83
+ source venv/bin/activate
84
+ pip install -r server/requirements.txt
85
+ ```
86
+
87
+ ### Run validation
88
+
89
+ ```bash
90
+ python validate.py
91
+ ```
92
+
93
+ ## Running the environment
94
+
95
+ ### Start the server
96
+
97
+ ```bash
98
+ cd server
99
+ python app.py
100
+ ```
101
+
102
+ Then the environment is available at:
103
+ - WebSocket: `ws://localhost:8000/ws`
104
+ - Health: `http://localhost:8000/health`
105
+ - Swagger: `http://localhost:8000/docs`
106
+
107
+ ### Use the client
108
+
109
+ ```python
110
+ from client import DiagnosticEnv
111
+
112
+ async with DiagnosticEnv(base_url="ws://localhost:8000/ws") as env:
113
+ obs = await env.reset(difficulty="easy")
114
+ print(obs.message)
115
+ ```
116
+
117
+ ## Training-ready wrapper
118
+
119
+ A simple, training-ready wrapper is available in `training_wrapper.py`. It provides a minimal async interface for use in training loops.
120
+
121
+ ```bash
122
+ python training_wrapper.py
123
+ ```
124
+
125
+ Use it in your own code like this:
126
+
127
+ ```python
128
+ from training_wrapper import TrainingEnv
129
+
130
+ async with TrainingEnv() as env:
131
+ obs = await env.reset(difficulty="easy")
132
+ step = await env.step(action_type="ask_question", question="Do you have a fever?")
133
+ ```
134
+
135
+ ## Baseline inference
136
+
137
+ Set the required environment variables then run the baseline script:
138
+
139
+ ```bash
140
+ export API_BASE_URL="https://router.huggingface.co/v1"
141
+ export MODEL_NAME="Qwen/Qwen2.5-72B-Instruct"
142
+ export HF_TOKEN="your-huggingface-token"
143
+ export ENV_URL="ws://localhost:8000/ws"
144
+ python inference.py
145
+ ```
146
+
147
+ ## Optional dataset support
148
+
149
+ The environment always includes core static cases and can optionally load Hugging Face datasets when enabled.
150
+
151
+ To use Hugging Face dataset generation:
152
+
153
+ ```bash
154
+ export OPENENV_USE_HF_DATASETS=true
155
+ export OPENENV_DATASET_SEED=42
156
+ ```
157
+
158
+ If dataset loading is disabled or unavailable, the environment still works with the built-in cases.
159
+
160
+ ## Docker deployment
161
+
162
+ ### Build locally
163
+
164
+ ```bash
165
+ docker build -t medical-diagnostic-env ./server
166
+ ```
167
+
168
+ ### Run locally
169
+
170
+ ```bash
171
+ docker run -p 8000:8000 medical-diagnostic-env
172
+ ```
173
+
174
+ ### Deploy to Hugging Face Spaces
175
+
176
+ 1. Create a new Space using Docker.
177
+ 2. Upload the repository files.
178
+ 3. The Space should build and expose the server automatically.
179
+
180
+ ## Notes for judges and trainers
181
+
182
+ - The environment exposes standard reset/step/state semantics.
183
+ - It supports concurrent sessions and WebSocket interaction.
184
+ - The training wrapper is intentionally minimal so any agent loop can be added on top.
185
+
186
+ ## Project structure
187
+
188
+ ```
189
+ ├── models.py
190
+ ├── client.py
191
+ ├── training_wrapper.py
192
+ ├── inference.py
193
+ ├── validate.py
194
+ ├── openenv.yaml
195
+ ├── server/
196
+ │ ├── app.py
197
+ │ ├── environment.py
198
+ �� ├── medical_data.py
199
+ │ ├── requirements.txt
200
+ │ └── Dockerfile
201
+ └── tests/
202
+ └── test_environment.py
203
+ ```
204
+
205
+ ## Testing
206
+
207
+ ```bash
208
+ python -m pytest tests/
209
+ python validate.py
210
+ ```
211
+
212
+ ## License
213
+
214
+ MIT License - see LICENSE file for details.
client.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ client.py — OpenEnv client for the Medical Diagnostic Environment.
3
+
4
+ This client enables training code to interact with the environment via WebSocket.
5
+ Provides both async and sync interfaces for flexibility.
6
+ """
7
+
8
+ from typing import Optional
9
+ import asyncio
10
+ import json
11
+
12
+ try:
13
+ import websockets
14
+ except ImportError:
15
+ websockets = None
16
+
17
+ from openenv.core.env_client import EnvClient
18
+ from openenv.core.client_types import StepResult
19
+ import sys
20
+ import os
21
+
22
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
23
+
24
+ from models import DiagnosticAction, PatientObservation, ClinicalState
25
+
26
+
27
+ class DiagnosticEnv(EnvClient[DiagnosticAction, PatientObservation, ClinicalState]):
28
+ """
29
+ Client for interacting with the Medical Diagnostic Environment.
30
+
31
+ Supports both async and sync usage:
32
+
33
+ Async (recommended for training):
34
+ async with DiagnosticEnv(base_url="...") as env:
35
+ obs = await env.reset()
36
+ obs = await env.step(DiagnosticAction(...))
37
+
38
+ Sync (for notebooks/simple scripts):
39
+ with DiagnosticEnv(base_url="...").sync() as env:
40
+ obs = env.reset()
41
+ obs = env.step(DiagnosticAction(...))
42
+ """
43
+
44
+ @classmethod
45
+ async def from_docker_image(cls, image_name: Optional[str] = None, base_url: Optional[str] = None, **kwargs):
46
+ """Create client connected to a running OpenEnv environment URL."""
47
+ if base_url is None:
48
+ base_url = os.getenv("ENV_URL", "ws://localhost:8000/ws")
49
+ return cls(base_url=base_url, **kwargs)
50
+
51
+ def _step_payload(self, action: DiagnosticAction) -> dict:
52
+ """Convert action to JSON payload for server."""
53
+ return {
54
+ "action_type": action.action_type,
55
+ "question": action.question,
56
+ "test_name": action.test_name,
57
+ "diagnosis": action.diagnosis,
58
+ }
59
+
60
+ def _parse_result(self, payload: dict) -> StepResult:
61
+ """Parse server response into StepResult."""
62
+ obs_data = payload.get("observation", {})
63
+
64
+ # Parse nested dictionaries if present
65
+ patient_data_revealed = obs_data.get("patient_data_revealed", {})
66
+ if isinstance(patient_data_revealed, str):
67
+ try:
68
+ patient_data_revealed = json.loads(patient_data_revealed)
69
+ except:
70
+ patient_data_revealed = {}
71
+
72
+ test_result = obs_data.get("test_result")
73
+ if isinstance(test_result, str):
74
+ try:
75
+ test_result = json.loads(test_result)
76
+ except:
77
+ test_result = None
78
+
79
+ observation = PatientObservation(
80
+ done=payload.get("done", False),
81
+ reward=payload.get("reward"),
82
+ message=obs_data.get("message", ""),
83
+ patient_response=obs_data.get("patient_response"),
84
+ test_result=test_result,
85
+ questions_asked=obs_data.get("questions_asked", []),
86
+ tests_completed=obs_data.get("tests_completed", []),
87
+ patient_data_revealed=patient_data_revealed,
88
+ steps_taken=obs_data.get("steps_taken", 0),
89
+ max_steps=obs_data.get("max_steps", 15),
90
+ )
91
+
92
+ return StepResult(
93
+ observation=observation,
94
+ reward=payload.get("reward"),
95
+ done=payload.get("done", False),
96
+ )
97
+
98
+ def _parse_state(self, payload: dict) -> ClinicalState:
99
+ """Parse state response."""
100
+ patient_details = payload.get("patient_details", {})
101
+ if isinstance(patient_details, str):
102
+ try:
103
+ patient_details = json.loads(patient_details)
104
+ except:
105
+ patient_details = {}
106
+
107
+ return ClinicalState(
108
+ episode_id=payload.get("episode_id", ""),
109
+ step_count=payload.get("step_count", 0),
110
+ true_diagnosis=payload.get("true_diagnosis", ""),
111
+ patient_case=payload.get("patient_case", ""),
112
+ patient_details=patient_details,
113
+ difficulty=payload.get("difficulty", "easy"),
114
+ questions_asked=payload.get("questions_asked", []),
115
+ tests_completed=payload.get("tests_completed", []),
116
+ final_diagnosis_submitted=payload.get("final_diagnosis_submitted"),
117
+ final_accuracy=payload.get("final_accuracy", 0.0),
118
+ )
119
+
120
+ # ─────────────────────────────────────────────────────────────────────
121
+ # Sync wrapper for convenience
122
+ # ─────────────────────────────────────────────────────────────────────
123
+
124
+ def sync(self) -> "SyncDiagnosticEnv":
125
+ """
126
+ Return synchronous wrapper for use in notebooks/simple scripts.
127
+
128
+ Usage:
129
+ with DiagnosticEnv(url).sync() as env:
130
+ obs = env.reset()
131
+ obs = env.step(action)
132
+ """
133
+ return SyncDiagnosticEnv(self.base_url)
134
+
135
+
136
+ class SyncDiagnosticEnv:
137
+ """Synchronous wrapper around async client."""
138
+
139
+ def __init__(self, base_url: str):
140
+ self.base_url = base_url
141
+ self._loop = asyncio.new_event_loop()
142
+ self._async_client = DiagnosticEnv(base_url)
143
+
144
+ def __enter__(self):
145
+ self._loop.run_until_complete(self._async_client.__aenter__())
146
+ return self
147
+
148
+ def __exit__(self, *args):
149
+ self._loop.run_until_complete(self._async_client.__aexit__(*args))
150
+ self._loop.close()
151
+
152
+ def reset(self, difficulty: str = "easy") -> PatientObservation:
153
+ """Reset environment and start new episode."""
154
+ return self._loop.run_until_complete(
155
+ self._async_client.reset(difficulty=difficulty)
156
+ )
157
+
158
+ def step(self, action: DiagnosticAction) -> StepResult:
159
+ """Take a step in the environment."""
160
+ return self._loop.run_until_complete(
161
+ self._async_client.step(action)
162
+ )
163
+
164
+ def state(self) -> ClinicalState:
165
+ """Get current state (includes hidden information)."""
166
+ return self._loop.run_until_complete(
167
+ self._async_client.state()
168
+ )
docker-compose.yml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Docker Compose for local development and testing
2
+ #
3
+ # Start with: docker-compose up --build
4
+ # Access server at: http://localhost:8000
5
+ # Docs at: http://localhost:8000/docs
6
+ # WebSocket at: ws://localhost:8000/ws
7
+
8
+ version: "3.9"
9
+
10
+ services:
11
+ medical-diagnostic-env:
12
+ build:
13
+ context: .
14
+ dockerfile: server/Dockerfile
15
+
16
+ container_name: medical-diagnostic-env
17
+
18
+ ports:
19
+ - "8000:8000"
20
+
21
+ environment:
22
+ - PORT=8000
23
+ - WORKERS=2 # Reduce for local development
24
+
25
+ restart: unless-stopped
26
+
27
+ healthcheck:
28
+ test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
29
+ interval: 30s
30
+ timeout: 10s
31
+ retries: 3
32
+ start_period: 5s
33
+
34
+ volumes:
35
+ - .:/app # For live code reloading during development (optional)
36
+
37
+ networks:
38
+ - medical-net
39
+
40
+ networks:
41
+ medical-net:
42
+ driver: bridge
inference.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py — Baseline inference script for the Medical Diagnostic Environment.
3
+ This script demonstrates how to run the environment with an LLM (using OpenAI client)
4
+
5
+ and logs results in the exact format required by the hackathon.
6
+
7
+ Format requirements:
8
+ [START] task=<task_name> env=<env_name> model=<model_name>
9
+ [STEP] step=<n> action=<action_str> reward=<float> done=<bool> error=<str|null>
10
+ [END] success=<bool> steps=<n> score=<float> rewards=<comma_separated_list>
11
+
12
+ All fields on a single line with NO NEWLINES within a line.
13
+ """
14
+
15
+ import asyncio
16
+ import os
17
+ import json
18
+ import textwrap
19
+ from typing import List, Optional, Dict
20
+
21
+ from openai import OpenAI
22
+ import sys
23
+
24
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
25
+
26
+ from client import DiagnosticEnv
27
+ from models import DiagnosticAction
28
+
29
+
30
+ # ==============================================================================
31
+ # CONFIGURATION
32
+ # ==============================================================================
33
+
34
+ # Configuration
35
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY")
36
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
37
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
38
+ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "medical-diagnostic-env:latest")
39
+ ENV_URL = os.getenv("ENV_URL", "ws://localhost:8000/ws")
40
+ BENCHMARK = os.getenv("BENCHMARK", "medical_diagnostic_env")
41
+
42
+ # Inference configuration
43
+ MAX_STEPS = 15 # Maximum steps per episode
44
+ TEMPERATURE = 0.7 # LLM temperature for reasoning
45
+ MAX_TOKENS = 256 # Max tokens per completion
46
+ TASK_NAMES = ["easy_diagnosis", "medium_diagnosis", "hard_diagnosis"]
47
+ DIFFICULTY_LEVELS = ["easy", "medium", "hard"]
48
+
49
+
50
+ # ==============================================================================
51
+ # LOGGING FUNCTIONS
52
+ # ==============================================================================
53
+
54
+ def log_start(task: str, env: str, model: str) -> None:
55
+ """Log episode start in required format."""
56
+ # Clean model name for logging
57
+ model_clean = model.split("/")[-1] if "/" in model else model
58
+ print(f"[START] task={task} env={env} model={model_clean}", flush=True)
59
+
60
+
61
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
62
+ """Log single step in required format."""
63
+ error_val = f'"{error}"' if error else "null"
64
+ done_val = str(done).lower()
65
+ print(
66
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
67
+ flush=True,
68
+ )
69
+
70
+
71
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
72
+ """Log episode end in required format."""
73
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
74
+ success_val = str(success).lower()
75
+ print(
76
+ f"[END] success={success_val} steps={steps} score={score:.3f} rewards={rewards_str}",
77
+ flush=True,
78
+ )
79
+
80
+ # ==============================================================================
81
+ # LLM INTERACTION
82
+ # ==============================================================================
83
+
84
+ def create_system_prompt() -> str:
85
+ """Create system prompt for medical diagnostic reasoning."""
86
+ return textwrap.dedent("""
87
+ You are an expert medical diagnostic AI assistant. Your role is to:
88
+
89
+ 1. GATHER INFORMATION: Ask relevant clinical questions about symptoms,
90
+ history, and presentation.
91
+
92
+ 2. ORDER TESTS: Request appropriate diagnostic tests based on the
93
+ differential diagnosis.
94
+
95
+ 3. REASON DIAGNOSTICALLY: Consider the patient's presentation,
96
+ synthesize findings, and make a diagnosis.
97
+
98
+ Your reasoning should follow clinical guidelines and prioritize:
99
+ - Life-threatening conditions first (red flags)
100
+ - Most common diagnoses for the presentation
101
+ - Efficiency (minimize unnecessary tests)
102
+
103
+ When responding, use EXACTLY ONE of these actions:
104
+
105
+ ACTION: ask_question
106
+ QUESTION: <your question here>
107
+
108
+ OR
109
+
110
+ ACTION: order_test
111
+ TEST: <test name>
112
+
113
+ OR
114
+
115
+ ACTION: submit_diagnosis
116
+ DIAGNOSIS: <final diagnosis>
117
+
118
+ Be concise. Diagnose within 10-15 steps if possible.
119
+ """).strip()
120
+
121
+
122
+ def extract_action_from_response(response: str) -> Optional[Dict]:
123
+ """
124
+ Extract structured action from LLM response.
125
+
126
+ Returns dict with:
127
+ {
128
+ "action_type": "ask_question" | "order_test" | "submit_diagnosis",
129
+ "question": str or None,
130
+ "test_name": str or None,
131
+ "diagnosis": str or None,
132
+ }
133
+ """
134
+ response_lower = response.lower()
135
+
136
+ # Try to find ACTION: directive
137
+ if "action:" in response_lower:
138
+ lines = response.split("\n")
139
+ action_type = None
140
+ question = None
141
+ test_name = None
142
+ diagnosis = None
143
+
144
+ for i, line in enumerate(lines):
145
+ if "action:" in line.lower():
146
+ action_part = line.split(":", 1)[1].strip().lower()
147
+ if "question" in action_part:
148
+ action_type = "ask_question"
149
+ elif "test" in action_part:
150
+ action_type = "order_test"
151
+ elif "diagnosis" in action_part:
152
+ action_type = "submit_diagnosis"
153
+
154
+ if "question:" in line.lower():
155
+ question = line.split(":", 1)[1].strip()
156
+ elif "test:" in line.lower():
157
+ test_name = line.split(":", 1)[1].strip()
158
+ elif "diagnosis:" in line.lower():
159
+ diagnosis = line.split(":", 1)[1].strip()
160
+
161
+ if action_type:
162
+ return {
163
+ "action_type": action_type,
164
+ "question": question,
165
+ "test_name": test_name,
166
+ "diagnosis": diagnosis,
167
+ }
168
+
169
+ # Fallback: try to infer action type
170
+ if "question" in response_lower or "ask" in response_lower:
171
+ # Extract the question
172
+ for line in response.split("\n"):
173
+ if "?" in line:
174
+ return {
175
+ "action_type": "ask_question",
176
+ "question": line.strip(),
177
+ "test_name": None,
178
+ "diagnosis": None,
179
+ }
180
+
181
+ if "test" in response_lower or "order" in response_lower:
182
+ # Try to extract test name
183
+ words = response.split()
184
+ for i, word in enumerate(words):
185
+ if "test" in word.lower() and i + 1 < len(words):
186
+ test = words[i + 1]
187
+ return {
188
+ "action_type": "order_test",
189
+ "question": None,
190
+ "test_name": test,
191
+ "diagnosis": None,
192
+ }
193
+
194
+ if "diagnos" in response_lower:
195
+ # Try to extract diagnosis
196
+ for word in response.split():
197
+ if len(word) > 3: # Filter out small words
198
+ return {
199
+ "action_type": "submit_diagnosis",
200
+ "question": None,
201
+ "test_name": None,
202
+ "diagnosis": response.strip(),
203
+ }
204
+
205
+ return None
206
+
207
+
208
+ def build_conversation_history(episode_history: List[Dict]) -> List[Dict]:
209
+ """Build conversation history for multi-turn interaction."""
210
+ conversation = [
211
+ {
212
+ "role": "system",
213
+ "content": create_system_prompt(),
214
+ }
215
+ ]
216
+
217
+ for turn in episode_history:
218
+ # Add assistant message
219
+ if turn.get("agent_action"):
220
+ conversation.append({
221
+ "role": "assistant",
222
+ "content": turn["agent_action"],
223
+ })
224
+
225
+ # Add environment feedback
226
+ if turn.get("environment_feedback"):
227
+ conversation.append({
228
+ "role": "user",
229
+ "content": turn["environment_feedback"],
230
+ })
231
+
232
+ return conversation
233
+
234
+
235
+ # ==============================================================================
236
+ # EPISODE EXECUTION
237
+ # ==============================================================================
238
+
239
+ async def run_episode_async(
240
+ client: OpenAI,
241
+ image_name: str,
242
+ difficulty: str,
243
+ task_name: str,
244
+ ) -> Dict:
245
+ """
246
+ Run a single episode with asyncio.
247
+
248
+ Returns: {
249
+ "task": task_name,
250
+ "success": bool,
251
+ "steps_taken": int,
252
+ "total_reward": float,
253
+ "episode_rewards": [float],
254
+ "final_diagnosis_accuracy": float,
255
+ }
256
+ """
257
+
258
+ log_start(task_name, "medical_diagnostic_env", MODEL_NAME)
259
+
260
+ # Reset environment
261
+ async with DiagnosticEnv.from_docker_image(image_name=image_name, base_url=ENV_URL) as env:
262
+ obs_result = await env.reset(difficulty=difficulty)
263
+ obs = obs_result.observation if hasattr(obs_result, 'observation') else obs_result
264
+
265
+ episode_history = []
266
+ episode_rewards = []
267
+ step_count = 0
268
+ error_occurred = False
269
+
270
+ # Initial environment message
271
+ initial_message = f"Patient presentation: {obs.message}"
272
+
273
+ while step_count < MAX_STEPS and not obs.done:
274
+ step_count += 1
275
+
276
+ # Build conversation with history
277
+ conversation = [
278
+ {
279
+ "role": "system",
280
+ "content": create_system_prompt(),
281
+ }
282
+ ]
283
+
284
+ # Add conversation history
285
+ for turn in episode_history:
286
+ if turn.get("agent_thought"):
287
+ conversation.append({
288
+ "role": "assistant",
289
+ "content": f"Thinking: {turn['agent_thought']}\nAction: {turn['agent_action']}",
290
+ })
291
+ if turn.get("environment_response"):
292
+ conversation.append({
293
+ "role": "user",
294
+ "content": turn["environment_response"],
295
+ })
296
+
297
+ # Add current observation if first step
298
+ if step_count == 1:
299
+ conversation.append({
300
+ "role": "user",
301
+ "content": initial_message,
302
+ })
303
+
304
+ try:
305
+ # Get LLM response
306
+ response = client.chat.completions.create(
307
+ model=MODEL_NAME,
308
+ messages=conversation,
309
+ temperature=TEMPERATURE,
310
+ max_tokens=MAX_TOKENS,
311
+ )
312
+
313
+ llm_response = response.choices[0].message.content
314
+
315
+ # Extract action from response
316
+ action_dict = extract_action_from_response(llm_response)
317
+
318
+ if not action_dict:
319
+ error_msg = "Could not parse action from response"
320
+ log_step(step_count, "parse_error", 0.0, False, error_msg)
321
+ error_occurred = True
322
+ break
323
+
324
+ # Create action
325
+ action = DiagnosticAction(
326
+ action_type=action_dict["action_type"],
327
+ question=action_dict.get("question"),
328
+ test_name=action_dict.get("test_name"),
329
+ diagnosis=action_dict.get("diagnosis"),
330
+ )
331
+
332
+ # Execute action with short action string for logging
333
+ action_str = f"{action.action_type}"
334
+ if action.question:
335
+ action_str += f"('{action.question[:30]}...')"
336
+ elif action.test_name:
337
+ action_str += f"('{action.test_name}')"
338
+ elif action.diagnosis:
339
+ action_str += f"('{action.diagnosis[:40]}...')"
340
+
341
+ # Take step in environment
342
+ step_result = await env.step(action)
343
+ obs = step_result.observation if hasattr(step_result, 'observation') else step_result
344
+
345
+ reward = obs.reward or 0.0
346
+ episode_rewards.append(reward)
347
+
348
+ log_step(step_count, action_str, reward, obs.done, None)
349
+
350
+ # Store in history
351
+ episode_history.append({
352
+ "agent_thought": llm_response[:100],
353
+ "agent_action": action_str,
354
+ "environment_response": obs.message[:200],
355
+ })
356
+
357
+ except Exception as e:
358
+ error_msg = str(e)[:100]
359
+ log_step(step_count, "error", 0.0, True, error_msg)
360
+ error_occurred = True
361
+ break
362
+
363
+ # Get final state
364
+ try:
365
+ state = await env.state()
366
+ final_accuracy = state.final_accuracy if hasattr(state, 'final_accuracy') else 0.0
367
+ except:
368
+ final_accuracy = 0.0
369
+
370
+ # Calculate results
371
+ total_reward = sum(episode_rewards)
372
+ success = obs.done and final_accuracy > 0.3
373
+
374
+ log_end(success, step_count, final_accuracy, episode_rewards)
375
+
376
+ return {
377
+ "task": task_name,
378
+ "success": success,
379
+ "steps_taken": step_count,
380
+ "total_reward": total_reward,
381
+ "episode_rewards": episode_rewards,
382
+ "final_diagnosis_accuracy": final_accuracy,
383
+ }
384
+
385
+
386
+ # ==============================================================================
387
+ # MAIN ORCHESTRATION
388
+ # ==============================================================================
389
+
390
+ async def run_all_tasks() -> Dict:
391
+ """Run all 3 difficulty levels and report overall results."""
392
+
393
+ if not API_KEY:
394
+ print("ERROR: API key not found. Set HF_TOKEN, API_KEY, or OPENAI_API_KEY.", flush=True)
395
+ return {}
396
+
397
+ if not ENV_URL:
398
+ print("ERROR: ENV_URL is not set. Set ENV_URL to the environment WebSocket URL.", flush=True)
399
+ return {}
400
+
401
+ # Initialize OpenAI client
402
+ client = OpenAI(
403
+ api_key=API_KEY,
404
+ base_url=API_BASE_URL,
405
+ )
406
+
407
+ results = {
408
+ "timestamp": None,
409
+ "model": MODEL_NAME,
410
+ "environment": "medical_diagnostic_env",
411
+ "tasks_completed": 0,
412
+ "task_results": [],
413
+ "overall_score": 0.0,
414
+ }
415
+
416
+ # Run each task
417
+ for i, (task_name, difficulty) in enumerate(zip(TASK_NAMES, DIFFICULTY_LEVELS)):
418
+ print(f"\n--- Task {i+1}/3: {difficulty} difficulty ---", flush=True)
419
+
420
+ try:
421
+ result = await run_episode_async(
422
+ client,
423
+ LOCAL_IMAGE_NAME,
424
+ difficulty=difficulty,
425
+ task_name=task_name,
426
+ )
427
+ results["task_results"].append(result)
428
+ results["tasks_completed"] += 1
429
+ except Exception as e:
430
+ print(f"ERROR in task {task_name}: {str(e)}", flush=True)
431
+
432
+ # Calculate overall score
433
+ if results["tasks_completed"] > 0:
434
+ accuracies = [r["final_diagnosis_accuracy"] for r in results["task_results"]]
435
+ results["overall_score"] = sum(accuracies) / len(accuracies)
436
+
437
+ # Print summary
438
+ print("\n" + "="*60, flush=True)
439
+ print(f"Baseline Inference Complete", flush=True)
440
+ print(f"Tasks completed: {results['tasks_completed']}/3", flush=True)
441
+ print(f"Overall diagnostic accuracy: {results['overall_score']:.3f}", flush=True)
442
+ print("="*60, flush=True)
443
+
444
+ return results
445
+
446
+
447
+ def main():
448
+ """Entry point."""
449
+ # Run async loop
450
+ results = asyncio.run(run_all_tasks())
451
+
452
+
453
+ if __name__ == "__main__":
454
+ main()
models.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ models.py — Type-safe contracts for the Medical Diagnostic Environment.
3
+
4
+ These Pydantic models define the interface between the LLM agent and the environment:
5
+ - DiagnosticAction: What the agent sends (questions, tests, diagnoses)
6
+ - PatientObservation: What the agent receives (feedback, test results, progress)
7
+ - ClinicalState: Full episode state (for debugging, not sent to agent)
8
+ """
9
+
10
+ from typing import Optional, List, Dict
11
+ from openenv.core.env_server import Action, Observation, State
12
+ from pydantic import Field
13
+
14
+ class DiagnosticAction(Action):
15
+ """
16
+ Actions the LLM agent can take during diagnosis.
17
+
18
+ The agent must choose one action per step:
19
+ 1. ask_question: Gather patient history
20
+ 2. order_test: Request diagnostic test results
21
+ 3. submit_diagnosis: Make final diagnosis (ends episode)
22
+ """
23
+
24
+ action_type: str # "ask_question", "order_test", "submit_diagnosis"
25
+ question: Optional[str] = None # Used when action_type="ask_question"
26
+ test_name: Optional[str] = None # Used when action_type="order_test"
27
+ diagnosis: Optional[str] = None # Used when action_type="submit_diagnosis"
28
+
29
+
30
+ class PatientObservation(Observation):
31
+ """
32
+ What the agent observes after taking an action.
33
+
34
+ Inherits from Observation:
35
+ - done: bool → Is the episode over?
36
+ - reward: Optional[float] → Reward signal
37
+
38
+ Adds medical-specific fields:
39
+ - message: Human-readable feedback
40
+ - patient_response: Answer to question (if applicable)
41
+ - test_result: Test outcome with interpretation
42
+ - questions_asked: History of all questions
43
+ - tests_completed: History of all completed tests
44
+ - patient_data_revealed: What the agent has discovered so far
45
+ """
46
+
47
+ message: str # Feedback from environment
48
+ patient_response: Optional[str] = None # Answer to a question asked
49
+ test_result: Optional[Dict] = None # {"test_name": "X", "result": "...", "interpretation": "..."}
50
+ questions_asked: List[str] = Field(default_factory=list)
51
+ tests_completed: List[str] = Field(default_factory=list)
52
+ patient_data_revealed: Dict = Field(default_factory=dict)
53
+ steps_taken: int = 0 # How many actions so far
54
+ max_steps: int = 15 # Maximum steps allowed
55
+
56
+
57
+ class ClinicalState(State):
58
+ """
59
+ Complete internal state snapshot. Contains hidden information (diagnosis, true findings).
60
+ Use for debugging only - NEVER send to agent.
61
+
62
+ Inherits from State:
63
+ - episode_id: str → Unique episode identifier
64
+ - step_count: int → Current step number
65
+
66
+ Adds clinical fields:
67
+ - true_diagnosis: The correct diagnosis (hidden from agent)
68
+ - patient_case: Case identifier
69
+ - patient_details: Full patient information (hidden)
70
+ - difficulty: ease|medium|hard
71
+ """
72
+
73
+ true_diagnosis: str = ""
74
+ patient_case: str = ""
75
+ patient_id: str = ""
76
+ patient_details: Dict = Field(default_factory=dict)
77
+ difficulty: str = "easy"
78
+ questions_asked: List[str] = Field(default_factory=list)
79
+ tests_completed: List[str] = Field(default_factory=list)
80
+ final_diagnosis_submitted: Optional[str] = None
81
+ final_accuracy: float = 0.0
pyproject.toml ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=68.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "medical-diagnostic-env"
7
+ version = "1.0.0"
8
+ description = "OpenEnv environment for medical diagnosis RL training"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+
12
+ authors = [
13
+ {name = "Team SYNAPSE", email = "synapse@example.com"}
14
+ ]
15
+
16
+ keywords = [
17
+ "reinforcement-learning",
18
+ "medical",
19
+ "diagnosis",
20
+ "healthcare",
21
+ "rl-training",
22
+ "llm",
23
+ "openenv"
24
+ ]
25
+
26
+ classifiers = [
27
+ "Development Status :: 4 - Beta",
28
+ "Intended Audience :: Developers",
29
+ "Intended Audience :: Science/Research",
30
+ "License :: OSI Approved :: MIT License",
31
+ "Programming Language :: Python :: 3",
32
+ "Programming Language :: Python :: 3.10",
33
+ "Programming Language :: Python :: 3.11",
34
+ "Programming Language :: Python :: 3.12",
35
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
36
+ ]
37
+
38
+ dependencies = [
39
+ "openenv-core>=0.2.3",
40
+ "fastapi>=0.104.0",
41
+ "uvicorn[standard]>=0.24.0",
42
+ "websockets>=16.0",
43
+ "pydantic>=2.5.0",
44
+ "pydantic-settings>=2.1.0",
45
+ "openai>=1.3.0",
46
+ "requests>=2.31.0",
47
+ ]
48
+
49
+ [project.optional-dependencies]
50
+ dev = [
51
+ "pytest>=7.0",
52
+ "pytest-asyncio>=0.21.0",
53
+ "black>=23.0",
54
+ "ruff>=0.1.0",
55
+ "mypy>=1.0",
56
+ ]
57
+
58
+ [project.urls]
59
+ Homepage = "https://github.com/meta-pytorch/OpenEnv"
60
+ Documentation = "https://meta-pytorch.org/OpenEnv/"
61
+ Repository = "https://github.com/meta-pytorch/OpenEnv"
62
+ Issues = "https://github.com/meta-pytorch/OpenEnv/issues"
63
+
64
+ [tool.setuptools]
65
+ packages = ["medical_diagnostic_env"]
66
+
67
+ [tool.black]
68
+ line-length = 100
69
+ target-version = ["py310", "py311", "py312"]
70
+
71
+ [tool.ruff]
72
+ line-length = 100
73
+ target-version = "py310"
74
+ select = ["E", "F", "W", "I"]
75
+
76
+ [tool.mypy]
77
+ python_version = "3.10"
78
+ check_untyped_defs = true
79
+ disallow_untyped_defs = false
80
+ warn_unused_ignores = true
server/Dockerfile ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Medical Diagnostic Environment - Production Dockerfile
2
+ #
3
+ # This Dockerfile containerizes the complete Medical Diagnostic Environment.
4
+ # It can be deployed to Docker Hub, GitHub Container Registry, or Hugging Face Spaces.
5
+
6
+ FROM python:3.11-slim
7
+
8
+ WORKDIR /app
9
+
10
+ # Set environment variables
11
+ ENV PYTHONUNBUFFERED=1 \
12
+ PYTHONDONTWRITEBYTECODE=1 \
13
+ PIP_NO_CACHE_DIR=1 \
14
+ PIP_DISABLE_PIP_VERSION_CHECK=1
15
+
16
+ # Install system dependencies (minimal)
17
+ RUN apt-get update && apt-get install -y --no-install-recommends \
18
+ curl \
19
+ && rm -rf /var/lib/apt/lists/*
20
+
21
+ # Copy and install Python dependencies first (for layer caching)
22
+ COPY server/requirements.txt /tmp/requirements.txt
23
+ RUN pip install --no-cache-dir -r /tmp/requirements.txt && \
24
+ rm /tmp/requirements.txt
25
+
26
+ # Copy application code
27
+ COPY models.py .
28
+ COPY client.py .
29
+ COPY server/ ./server/
30
+
31
+ # Create __init__ files for Python packages
32
+ RUN touch __init__.py && \
33
+ touch server/__init__.py
34
+
35
+ # Health check - validates the server is running and responsive
36
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
37
+ CMD curl -f http://localhost:8000/health || exit 1
38
+
39
+ # Expose port
40
+ EXPOSE 8000
41
+
42
+ # Run the FastAPI server with uvicorn
43
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
server/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Medical Diagnostic Environment Server - Package initialization
3
+ """
4
+
5
+ from .environment import MedicalDiagnosticEnvironment
6
+ from .medical_data import (
7
+ PATIENT_CASES,
8
+ calculate_question_reward,
9
+ calculate_test_reward,
10
+ calculate_diagnosis_accuracy,
11
+ get_patient_response,
12
+ )
13
+
14
+ __all__ = [
15
+ "MedicalDiagnosticEnvironment",
16
+ "PATIENT_CASES",
17
+ "calculate_question_reward",
18
+ "calculate_test_reward",
19
+ "calculate_diagnosis_accuracy",
20
+ "get_patient_response",
21
+ ]
server/app.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ server/app.py — FastAPI server for the Medical Diagnostic Environment.
3
+
4
+ This exposes the environment over WebSocket and HTTP using OpenEnv's built-in
5
+ create_fastapi_app helper. One line of meaningful code!
6
+
7
+ The helper automatically creates:
8
+ - /ws endpoint for WebSocket connections (stateful, for training)
9
+ - /reset, /step, /state endpoints (stateless, for testing)
10
+ - /health endpoint (for Docker health checks)
11
+ - /docs endpoint (auto-generated OpenAPI documentation)
12
+ """
13
+
14
+ from openenv.core.env_server import create_fastapi_app
15
+ import sys
16
+ import os
17
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
18
+
19
+ from models import DiagnosticAction, PatientObservation
20
+ from environment import MedicalDiagnosticEnvironment
21
+
22
+
23
+ # Create the environment instance
24
+ env = MedicalDiagnosticEnvironment()
25
+
26
+ # Create FastAPI app with all endpoints
27
+ app = create_fastapi_app(
28
+ env,
29
+ DiagnosticAction,
30
+ PatientObservation,
31
+ max_concurrent_envs=100, # Support up to 100 parallel training sessions
32
+ )
33
+
34
+ # Optional: Add custom middleware or endpoints here if needed
35
+ # (Most common use cases are already handled by create_fastapi_app)
36
+
37
+
38
+ if __name__ == "__main__":
39
+ import uvicorn
40
+ uvicorn.run(
41
+ "app:app",
42
+ host="0.0.0.0",
43
+ port=8000,
44
+ workers=4,
45
+ reload=False,
46
+ )
server/environment.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ server/environment.py — Core medical diagnostic environment logic.
3
+
4
+ This is the BRAIN of the Medical Diagnostic Environment. It:
5
+ 1. Manages patient cases and episode state
6
+ 2. Processes agent actions (questions, tests, diagnoses)
7
+ 3. Calculates rewards based on diagnostic quality
8
+ 4. Provides trajectory-based reward signals (not sparse)
9
+
10
+ Pure Python - no HTTP or WebSocket code here.
11
+ All logic is deterministic and reproducible.
12
+ """
13
+
14
+ import random
15
+ import uuid
16
+ from typing import Dict, List, Optional
17
+ from openenv.core.env_server import Environment
18
+
19
+ import sys
20
+ import os
21
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
22
+
23
+ from models import DiagnosticAction, PatientObservation, ClinicalState
24
+ from server.medical_data import (
25
+ PATIENT_CASES,
26
+ calculate_question_reward,
27
+ calculate_test_reward,
28
+ calculate_diagnosis_accuracy,
29
+ get_patient_response,
30
+ normalize_test_name,
31
+ )
32
+
33
+
34
+ class MedicalDiagnosticEnvironment(Environment):
35
+ """
36
+ Medical Diagnostic Environment for RL Training.
37
+
38
+ Simulates doctor-patient interaction where an LLM agent must:
39
+ 1. Ask relevant clinical questions
40
+ 2. Order appropriate diagnostic tests
41
+ 3. Make accurate diagnoses
42
+
43
+ The environment provides rich reward signals throughout the trajectory:
44
+ - +0.05 per relevant question asked
45
+ - +0.10 per informative test ordered
46
+ - +1.0 for correct final diagnosis
47
+ - Penalizes inefficient or irrelevant actions
48
+
49
+ This is NOT a sparse reward environment - the agent sees meaningful progress
50
+ at each step, which is crucial for learning.
51
+ """
52
+
53
+ SUPPORTS_CONCURRENT_SESSIONS = True # Allow multiple parallel training sessions
54
+
55
+ # Episode configuration
56
+ MAX_STEPS = 15 # Maximum actions per episode
57
+ DIFFICULTY_LEVELS = ["easy", "medium", "hard"]
58
+
59
+ def __init__(self):
60
+ """Initialize environment state."""
61
+ super().__init__()
62
+
63
+ # Episode state variables
64
+ self._episode_id: str = ""
65
+ self._case_id: str = ""
66
+ self._difficulty: str = ""
67
+ self._step_count: int = 0
68
+ self._total_reward: float = 0.0
69
+
70
+ # Patient interaction tracking
71
+ self._questions_asked: List[str] = []
72
+ self._tests_ordered: List[str] = []
73
+ self._test_results: Dict = {}
74
+ self._diagnosis_submitted: Optional[str] = None
75
+ self._final_accuracy: float = 0.0
76
+
77
+ # Episode status
78
+ self._done: bool = False
79
+ self._episode_reward_breakdown: Dict = {
80
+ "question_rewards": 0.0,
81
+ "test_rewards": 0.0,
82
+ "diagnosis_reward": 0.0,
83
+ "efficiency_penalty": 0.0,
84
+ }
85
+
86
+ @property
87
+ def current_case_id(self) -> str:
88
+ """Current patient case identifier."""
89
+ return self._case_id
90
+
91
+ @property
92
+ def current_difficulty(self) -> str:
93
+ """Current episode difficulty level."""
94
+ return self._difficulty
95
+
96
+ # ─────────────────────────────────────────────────────────────────────
97
+ # Core API Methods
98
+ # ─────────────────────────────────────────────────────────────────────
99
+
100
+ def reset(self, difficulty: str = "easy", **kwargs) -> PatientObservation:
101
+ """
102
+ Reset the environment for a new diagnostic episode.
103
+
104
+ Args:
105
+ difficulty: "easy", "medium", or "hard" (controls case selection)
106
+
107
+ Returns:
108
+ Initial PatientObservation for the agent to read
109
+ """
110
+ # Initialize episode
111
+ self._episode_id = str(uuid.uuid4())
112
+ self._difficulty = difficulty if difficulty in self.DIFFICULTY_LEVELS else "easy"
113
+ self._case_id = self._select_case_by_difficulty(self._difficulty)
114
+ self._step_count = 0
115
+ self._total_reward = 0.0
116
+
117
+ # Reset tracking
118
+ self._questions_asked = []
119
+ self._tests_ordered = []
120
+ self._test_results = {}
121
+ self._diagnosis_submitted = None
122
+ self._final_accuracy = 0.0
123
+ self._done = False
124
+ self._episode_reward_breakdown = {
125
+ "question_rewards": 0.0,
126
+ "test_rewards": 0.0,
127
+ "diagnosis_reward": 0.0,
128
+ "efficiency_penalty": 0.0,
129
+ }
130
+
131
+ # Get case information
132
+ case = PATIENT_CASES[self._case_id]
133
+
134
+ # Create initial observation
135
+ initial_message = (
136
+ f"Patient presents with: {case['presentation']}\n"
137
+ f"Age: {case['age']}, Gender: {case['gender']}\n"
138
+ f"You have up to {self.MAX_STEPS} steps to diagnose this patient.\n"
139
+ f"Please start by asking questions or ordering tests."
140
+ )
141
+
142
+ return PatientObservation(
143
+ done=False,
144
+ reward=0.0,
145
+ message=initial_message,
146
+ patient_response=None,
147
+ test_result=None,
148
+ questions_asked=[],
149
+ tests_completed=[],
150
+ patient_data_revealed={
151
+ "age": case["age"],
152
+ "gender": case["gender"],
153
+ "presentation": case["presentation"],
154
+ },
155
+ steps_taken=0,
156
+ max_steps=self.MAX_STEPS,
157
+ )
158
+
159
+ def step(self, action: DiagnosticAction, **kwargs) -> PatientObservation:
160
+ """
161
+ Process one diagnostic action (question, test, or diagnosis).
162
+
163
+ Returns immediate reward and next observation.
164
+ """
165
+ if self._done:
166
+ return self._create_done_observation(
167
+ message="Episode already ended. Call reset() to start a new case."
168
+ )
169
+
170
+ self._step_count += 1
171
+ step_reward = 0.0
172
+ message = ""
173
+ patient_response = None
174
+ test_result = None
175
+
176
+ # ── Process action based on type ──
177
+ if action.action_type == "ask_question":
178
+ step_reward, message, patient_response = self._handle_question(action.question)
179
+
180
+ elif action.action_type == "order_test":
181
+ step_reward, message, test_result = self._handle_test(action.test_name)
182
+
183
+ elif action.action_type == "submit_diagnosis":
184
+ step_reward, message = self._handle_diagnosis(action.diagnosis)
185
+ self._done = True
186
+
187
+ else:
188
+ message = f"Unknown action type: {action.action_type}"
189
+ step_reward = -0.05
190
+
191
+ # Accumulate rewards
192
+ self._total_reward += step_reward
193
+
194
+ # Check if episode should end
195
+ if self._step_count >= self.MAX_STEPS and not self._done:
196
+ message += f"\nMax steps reached. Episode ending."
197
+ self._done = True
198
+
199
+ # Get current case for patient data revelation
200
+ case = PATIENT_CASES[self._case_id]
201
+
202
+ return PatientObservation(
203
+ done=self._done,
204
+ reward=step_reward,
205
+ message=message,
206
+ patient_response=patient_response,
207
+ test_result=test_result,
208
+ questions_asked=self._questions_asked.copy(),
209
+ tests_completed=self._tests_ordered.copy(),
210
+ patient_data_revealed=self._build_patient_data_view(case),
211
+ steps_taken=self._step_count,
212
+ max_steps=self.MAX_STEPS,
213
+ )
214
+
215
+ def state(self) -> ClinicalState:
216
+ """
217
+ Return complete internal state (includes hidden information).
218
+ Used for debugging only - NEVER send to agent.
219
+ """
220
+ case = PATIENT_CASES.get(self._case_id, {})
221
+
222
+ return ClinicalState(
223
+ episode_id=self._episode_id,
224
+ step_count=self._step_count,
225
+ true_diagnosis=case.get("true_diagnosis", ""),
226
+ patient_case=self._case_id,
227
+ patient_id=self._case_id,
228
+ patient_details=case,
229
+ difficulty=self._difficulty,
230
+ questions_asked=self._questions_asked.copy(),
231
+ tests_completed=self._tests_ordered.copy(),
232
+ final_diagnosis_submitted=self._diagnosis_submitted,
233
+ final_accuracy=self._final_accuracy,
234
+ )
235
+
236
+ # ─────────────────────────────────────────────────────────────────────
237
+ # Action Processing
238
+ # ─────────────────────────────────────────────────────────────────────
239
+
240
+ def _handle_question(self, question: Optional[str]) -> tuple:
241
+ """
242
+ Process a question about the patient.
243
+
244
+ Returns:
245
+ (reward, message, patient_response)
246
+ """
247
+ if not question or not isinstance(question, str) or not question.strip():
248
+ message = "No valid question was provided. Please ask a clinical question."
249
+ return -0.05, message, None
250
+
251
+ # Calculate reward for asking this question
252
+ reward = calculate_question_reward(self._case_id, question)
253
+
254
+ # Record question
255
+ self._questions_asked.append(question)
256
+ self._episode_reward_breakdown["question_rewards"] += reward
257
+
258
+ # Get patient response
259
+ response = get_patient_response(self._case_id, question)
260
+
261
+ message = f"Patient response: {response}"
262
+ if reward == 0.00:
263
+ message += " (Question may not be directly relevant)"
264
+ elif reward == 0.01:
265
+ message += " (Somewhat relevant question)"
266
+ else:
267
+ message += " (Good clinical question!)"
268
+
269
+ return reward, message, response
270
+
271
+ def _handle_test(self, test_name: Optional[str]) -> tuple:
272
+ """
273
+ Process a test order.
274
+
275
+ Returns:
276
+ (reward, message, test_result_dict)
277
+ """
278
+ if not test_name or not isinstance(test_name, str) or not test_name.strip():
279
+ message = "No valid test name was provided. Please order a valid diagnostic test."
280
+ return -0.05, message, None
281
+
282
+ # Calculate reward for ordering this test
283
+ reward = calculate_test_reward(self._case_id, test_name)
284
+
285
+ # Get case data
286
+ case = PATIENT_CASES[self._case_id]
287
+
288
+ # Try to find matching test result
289
+ test_result_data = None
290
+ matched_test_key = None
291
+
292
+ test_lower = normalize_test_name(test_name)
293
+ for test_key, result in case.get("test_results", {}).items():
294
+ if test_key.lower() == test_lower or test_key.lower() in test_lower or test_lower in test_key.lower():
295
+ test_result_data = result
296
+ matched_test_key = test_key
297
+ break
298
+
299
+ if test_result_data is None:
300
+ message = f"Test '{test_name}' not available for this patient or unavailable in this setting."
301
+ reward = -0.02
302
+ return reward, message, None
303
+
304
+ # Record test
305
+ self._tests_ordered.append(matched_test_key)
306
+ self._test_results[matched_test_key] = test_result_data
307
+ self._episode_reward_breakdown["test_rewards"] += reward
308
+
309
+ # Format test result for agent
310
+ test_result_dict = {
311
+ "test_name": matched_test_key,
312
+ "result": str(test_result_data),
313
+ "interpretation": test_result_data.get("interpretation", test_result_data.get("finding", ""))
314
+ }
315
+
316
+ message = f"Test result received for {matched_test_key}:\n{test_result_dict['interpretation']}"
317
+ if reward == 0.10:
318
+ message += " (Excellent diagnostic test!)"
319
+ elif reward == 0.05:
320
+ message += " (Useful supporting test)"
321
+ else:
322
+ message += " (Test ordered but may be less relevant)"
323
+
324
+ return reward, message, test_result_dict
325
+
326
+ def _handle_diagnosis(self, diagnosis: str) -> tuple:
327
+ """
328
+ Process final diagnosis submission.
329
+
330
+ Returns:
331
+ (reward, message)
332
+ """
333
+ # Calculate diagnostic accuracy
334
+ accuracy = calculate_diagnosis_accuracy(self._case_id, diagnosis)
335
+ self._final_accuracy = accuracy
336
+ self._diagnosis_submitted = diagnosis
337
+
338
+ # Create diagnosis reward (not just accuracy, but also process quality)
339
+ case = PATIENT_CASES[self._case_id]
340
+ true_diagnosis = case["true_diagnosis"]
341
+
342
+ # Change the if/elif chain to use >= comparisons:
343
+ if accuracy >= 0.95:
344
+ reward = 1.0
345
+ message = f"Correct diagnosis: {diagnosis}"
346
+ elif accuracy >= 0.7:
347
+ reward = accuracy
348
+ message = f"Acceptable diagnosis: {diagnosis}. True: {true_diagnosis}"
349
+ elif accuracy >= 0.3:
350
+ reward = accuracy
351
+ message = f"Partially correct. True: {true_diagnosis}"
352
+ else:
353
+ reward = 0.0
354
+ message = f"Incorrect. True: {true_diagnosis}"
355
+
356
+ self._episode_reward_breakdown["diagnosis_reward"] = reward
357
+
358
+ # Add efficiency feedback
359
+ if self._step_count > self.MAX_STEPS * 0.8:
360
+ penalty = 0.1 * (self._step_count / self.MAX_STEPS - 0.8)
361
+ self._episode_reward_breakdown["efficiency_penalty"] = penalty
362
+ message += f"\n(Efficiency penalty: -{penalty:.2f} for taking many steps)"
363
+
364
+ return reward, message
365
+
366
+ # ─────────────────────────────────────────────────────────────────────
367
+ # Helper Methods
368
+ # ─────────────────────────────────────────────────────────────────────
369
+
370
+ def _select_case_by_difficulty(self, difficulty: str) -> str:
371
+ matching_keys = [k for k, v in PATIENT_CASES.items() if v["difficulty"] == difficulty]
372
+ if not matching_keys:
373
+ matching_keys = list(PATIENT_CASES.keys())
374
+ return random.choice(matching_keys)
375
+
376
+ def _build_patient_data_view(self, case: Dict) -> Dict:
377
+ """
378
+ Build what the agent has learned about the patient so far.
379
+ Only includes information revealed through questions/tests.
380
+ """
381
+ revealed = {
382
+ "age": case.get("age"),
383
+ "gender": case.get("gender"),
384
+ "presentation": case.get("presentation"),
385
+ }
386
+
387
+ # Add findings based on questions asked
388
+ findings = case.get("hidden_findings", {})
389
+ for question in self._questions_asked:
390
+ q_lower = question.lower()
391
+ for finding, value in findings.items():
392
+ if finding.lower().replace("_", " ") in q_lower:
393
+ revealed[f"finding_{finding}"] = value
394
+
395
+ # Add test results
396
+ if self._test_results:
397
+ revealed["test_results"] = self._test_results
398
+
399
+ return revealed
400
+
401
+ def _create_done_observation(self, message: str) -> PatientObservation:
402
+ """Create a terminal observation."""
403
+ return PatientObservation(
404
+ done=True,
405
+ reward=0.0,
406
+ message=message,
407
+ patient_response=None,
408
+ test_result=None,
409
+ questions_asked=self._questions_asked.copy(),
410
+ tests_completed=self._tests_ordered.copy(),
411
+ patient_data_revealed={},
412
+ steps_taken=self._step_count,
413
+ max_steps=self.MAX_STEPS,
414
+ )
415
+
416
+ def get_episode_summary(self) -> Dict:
417
+ """
418
+ Return a summary of the episode for logging/evaluation.
419
+ """
420
+ case = PATIENT_CASES.get(self._case_id, {})
421
+ return {
422
+ "episode_id": self._episode_id,
423
+ "case_id": self._case_id,
424
+ "difficulty": self._difficulty,
425
+ "true_diagnosis": case.get("true_diagnosis", ""),
426
+ "submitted_diagnosis": self._diagnosis_submitted,
427
+ "accuracy": self._final_accuracy,
428
+ "diagnostic_accuracy": self._final_accuracy,
429
+ "total_reward": self._total_reward,
430
+ "steps": self._step_count,
431
+ "steps_taken": self._step_count,
432
+ "max_steps": self.MAX_STEPS,
433
+ "questions_asked": len(self._questions_asked),
434
+ "tests_ordered": len(self._tests_ordered),
435
+ "reward_breakdown": self._episode_reward_breakdown,
436
+ }
server/medical_data.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from typing import Dict, List, Tuple, Optional
4
+ from functools import lru_cache
5
+
6
+ try:
7
+ from datasets import load_dataset
8
+ _DATASETS_AVAILABLE = True
9
+ except ImportError:
10
+ _DATASETS_AVAILABLE = False
11
+
12
+ USE_HF_DATASETS = os.getenv("OPENENV_USE_HF_DATASETS", "false").lower() in ("1", "true", "yes")
13
+ DATASET_SEED = os.getenv("OPENENV_DATASET_SEED")
14
+ if DATASET_SEED is not None:
15
+ try:
16
+ DATASET_SEED = int(DATASET_SEED)
17
+ except ValueError:
18
+ DATASET_SEED = None
19
+ # LOAD REAL DATASETS FROM HUGGING FACE
20
+ # ==============================================================================
21
+
22
+ def load_medical_datasets() -> Dict:
23
+ """Load and format real medical datasets from Hugging Face."""
24
+ if not USE_HF_DATASETS:
25
+ return {}
26
+
27
+ if not _DATASETS_AVAILABLE:
28
+ print("Warning: datasets package not installed; skipping Hugging Face dataset loading.")
29
+ return {}
30
+
31
+ try:
32
+ # Load MedMCQA dataset (Medical Multiple Choice Questions)
33
+ medmcqa = load_dataset("medmcqa", split="train")
34
+
35
+ # Load BigBio MedQA dataset (Medical Question Answering)
36
+ medqa = load_dataset("bigbio/med_qa", split="train")
37
+
38
+ return {
39
+ "medmcqa": medmcqa,
40
+ "medqa": medqa
41
+ }
42
+ except Exception as e:
43
+ print(f"Warning: Could not load datasets: {e}. No dataset cases loaded.")
44
+ return {}
45
+
46
+ # ==============================================================================
47
+ # INNOVATIVE REWARD SYSTEM USING LLM JUDGMENT
48
+ # ==============================================================================
49
+
50
+
51
+ RELEVANT_QUESTION_KEYWORDS = {
52
+ "easy_flu": ["fever", "cough", "ache", "fatigue", "onset", "symptom", "contact", "vaccine", "temperature"],
53
+ "easy_uti": ["urination", "burning", "pain", "frequent", "abdominal", "bladder", "kidney"],
54
+ "medium_pneumonia": ["cough", "fever", "breath", "chest", "pain", "smoking", "sputum", "productive"],
55
+ "medium_appendicitis": ["abdominal", "pain", "nausea", "vomiting", "rebound", "right lower quadrant", "appetite", "fever"],
56
+ "hard_endocarditis": ["fever", "murmur", "drug", "iv", "dental", "hemorrhage", "splinter", "heart"],
57
+ "hard_meningitis": ["headache", "neck", "stiff", "fever", "photophobia", "confusion", "vomit", "seizure"],
58
+ }
59
+
60
+ def calculate_question_reward(case_id: str, question: str) -> float:
61
+ keywords = RELEVANT_QUESTION_KEYWORDS.get(case_id, [])
62
+ q_lower = question.lower()
63
+ matches = sum(1 for kw in keywords if kw in q_lower)
64
+ if matches >= 2: return 0.08
65
+ if matches == 1: return 0.05
66
+ return 0.01
67
+
68
+ TEST_NAME_ALIASES = {
69
+ "complete blood count": "cbc",
70
+ "cbc": "cbc",
71
+ "urinalysis": "urinalysis",
72
+ "urine culture": "urine_culture",
73
+ "blood cultures": "blood_cultures",
74
+ "echocardiogram": "echocardiogram",
75
+ "ct head": "ct_head",
76
+ "ct scan": "ct_head",
77
+ "chest xray": "chest_xray",
78
+ "chest radiograph": "chest_xray",
79
+ "sputum culture": "sputum_culture",
80
+ "rapid flu test": "rapid_flu_test",
81
+ "flu test": "rapid_flu_test",
82
+ }
83
+
84
+
85
+ def normalize_test_name(test_name: Optional[str]) -> str:
86
+ if not test_name or not isinstance(test_name, str):
87
+ return ""
88
+ cleaned = test_name.strip().lower()
89
+ return TEST_NAME_ALIASES.get(cleaned, cleaned)
90
+
91
+
92
+ def calculate_test_reward(case_id: str, test_name: Optional[str]) -> float:
93
+ """
94
+ Calculate reward for ordering a diagnostic test.
95
+
96
+ Returns higher reward for tests that are more relevant to the case.
97
+ """
98
+ if not test_name or not isinstance(test_name, str):
99
+ return -0.02
100
+
101
+ case = PATIENT_CASES.get(case_id, {})
102
+ test_results = case.get("test_results", {})
103
+
104
+ test_lower = normalize_test_name(test_name)
105
+
106
+ # Check if test is available and relevant
107
+ for test_key in test_results.keys():
108
+ if test_key.lower() in test_lower or test_lower in test_key.lower():
109
+ # Test is available - give reward based on relevance
110
+ if "cbc" in test_lower or "blood" in test_lower:
111
+ return 0.10 # Common useful test
112
+ elif "flu" in test_lower or "influenza" in test_lower:
113
+ return 0.10 # Specific relevant test
114
+ else:
115
+ return 0.05 # Somewhat useful test
116
+
117
+ # Test not available or irrelevant
118
+ return -0.02
119
+
120
+ def calculate_diagnosis_accuracy(case_id: str, submitted: str) -> float:
121
+ case = PATIENT_CASES.get(case_id, {})
122
+ s = submitted.lower().strip()
123
+ true = case.get("true_diagnosis", "").lower()
124
+ if s == true: return 1.0
125
+ for acceptable in case.get("correct_diagnoses", []):
126
+ if acceptable.lower() in s or s in acceptable.lower(): return 1.0
127
+ # partial credit
128
+ true_words = set(true.split())
129
+ sub_words = set(s.split())
130
+ overlap = len(true_words & sub_words) / max(len(true_words), 1)
131
+ return round(min(overlap, 0.7), 2)
132
+
133
+ # ==============================================================================
134
+ # PATIENT CASES DATABASE (FORMATTED FROM REAL DATASETS)
135
+ # ==============================================================================
136
+
137
+ def format_medmcqa_to_case(entry: Dict) -> Dict:
138
+ """Format a MedMCQA entry into our case structure."""
139
+ question = entry.get("question", "")
140
+ options = entry.get("options", {})
141
+ correct_answer = entry.get("answer", "")
142
+ subject = entry.get("subject_name", "")
143
+
144
+ # Create presentation from question
145
+ presentation = f"Patient presents with: {question}"
146
+
147
+ # Use options as possible diagnoses
148
+ diagnoses = list(options.values())
149
+ true_diagnosis = options.get(correct_answer, diagnoses[0] if diagnoses else "Unknown")
150
+
151
+ return {
152
+ "case_id": f"medmcqa_{entry.get('id', random.randint(1000,9999))}",
153
+ "difficulty": "medium", # Default to medium
154
+ "true_diagnosis": true_diagnosis,
155
+ "age": random.randint(25, 75),
156
+ "gender": random.choice(["Male", "Female"]),
157
+ "presentation": presentation,
158
+ "hidden_findings": {}, # Would need more processing
159
+ "test_results": {},
160
+ "correct_diagnoses": [true_diagnosis],
161
+ "differential_diagnoses": diagnoses[:3], # First 3 options
162
+ "source": "medmcqa"
163
+ }
164
+
165
+ def format_medqa_to_case(entry: Dict) -> Dict:
166
+ """Format a BigBio MedQA entry into our case structure."""
167
+ question = entry.get("question", "")
168
+ answer = entry.get("answer", "")
169
+
170
+ presentation = f"Medical question: {question}"
171
+
172
+ return {
173
+ "case_id": f"medqa_{hash(question) % 10000}",
174
+ "difficulty": "hard", # MedQA is more complex
175
+ "true_diagnosis": answer,
176
+ "age": random.randint(30, 80),
177
+ "gender": random.choice(["Male", "Female"]),
178
+ "presentation": presentation,
179
+ "hidden_findings": {},
180
+ "test_results": {},
181
+ "correct_diagnoses": [answer],
182
+ "differential_diagnoses": [],
183
+ "source": "medqa"
184
+ }
185
+
186
+ def generate_patient_cases_from_datasets() -> Dict:
187
+ """Generate patient cases from real Hugging Face datasets."""
188
+ cases = {}
189
+ datasets = load_medical_datasets()
190
+ if not datasets:
191
+ return cases
192
+
193
+ if DATASET_SEED is not None:
194
+ random.seed(DATASET_SEED)
195
+
196
+ # Generate easy cases from MedMCQA (simpler questions)
197
+ if "medmcqa" in datasets:
198
+ medmcqa_data = datasets["medmcqa"]
199
+ easy_indices = random.sample(range(len(medmcqa_data)), min(3, len(medmcqa_data)))
200
+ for i, idx in enumerate(easy_indices):
201
+ entry = medmcqa_data[idx]
202
+ case = format_medmcqa_to_case(entry)
203
+ case["difficulty"] = "easy"
204
+ cases[f"easy_real_{i}"] = case
205
+
206
+ # Generate medium cases
207
+ if "medmcqa" in datasets:
208
+ medmcqa_data = datasets["medmcqa"]
209
+ medium_indices = random.sample(range(len(medmcqa_data)), min(2, len(medmcqa_data)))
210
+ for i, idx in enumerate(medium_indices):
211
+ entry = medmcqa_data[idx]
212
+ case = format_medmcqa_to_case(entry)
213
+ case["difficulty"] = "medium"
214
+ cases[f"medium_real_{i}"] = case
215
+
216
+ # Generate hard cases from MedQA
217
+ if "medqa" in datasets:
218
+ medqa_data = datasets["medqa"]
219
+ hard_indices = random.sample(range(len(medqa_data)), min(2, len(medqa_data)))
220
+ for i, idx in enumerate(hard_indices):
221
+ entry = medqa_data[idx]
222
+ case = format_medqa_to_case(entry)
223
+ case["difficulty"] = "hard"
224
+ cases[f"hard_real_{i}"] = case
225
+
226
+ return cases
227
+
228
+ STATIC_PATIENT_CASES = {
229
+ "easy_flu": {
230
+ "case_id": "easy_flu",
231
+ "difficulty": "easy",
232
+ "true_diagnosis": "Seasonal Influenza",
233
+ "age": 28, "gender": "Female",
234
+ "presentation": "Patient presents with sudden fever (38.9°C), body aches, headache, fatigue, and dry cough for 2 days. No shortness of breath.",
235
+ "hidden_findings": {"fever": "38.9°C", "duration": "2 days", "onset": "sudden"},
236
+ "test_results": {
237
+ "rapid_flu_test": {"result": "Positive for Influenza A", "interpretation": "Positive Influenza A — confirms influenza diagnosis"},
238
+ "cbc": {"result": "WBC 9.2, lymphocytosis", "interpretation": "Mild lymphocytosis consistent with viral infection"},
239
+ },
240
+ "correct_diagnoses": ["Seasonal Influenza", "Influenza A", "Flu"],
241
+ "differential_diagnoses": ["COVID-19", "Common Cold", "Strep Throat"],
242
+ },
243
+ "easy_uti": {
244
+ "case_id": "easy_uti",
245
+ "difficulty": "easy",
246
+ "true_diagnosis": "Urinary Tract Infection",
247
+ "age": 35, "gender": "Female",
248
+ "presentation": "Patient presents with frequent urination, burning sensation during urination, and lower abdominal pain for 3 days.",
249
+ "hidden_findings": {"frequency": "frequent urination", "pain": "burning during urination", "duration": "3 days"},
250
+ "test_results": {
251
+ "urinalysis": {"result": "Positive for nitrites, leukocytes >10", "interpretation": "Urinalysis shows signs of bacterial infection consistent with UTI"},
252
+ "urine_culture": {"result": "E. coli >100,000 CFU/mL", "interpretation": "Urine culture confirms E. coli urinary tract infection"},
253
+ },
254
+ "correct_diagnoses": ["Urinary Tract Infection", "UTI", "Bladder Infection"],
255
+ "differential_diagnoses": ["Cystitis", "Pyelonephritis", "Vaginitis"],
256
+ },
257
+ "medium_pneumonia": {
258
+ "case_id": "medium_pneumonia",
259
+ "difficulty": "medium",
260
+ "true_diagnosis": "Community-Acquired Pneumonia",
261
+ "age": 45, "gender": "Male",
262
+ "presentation": "Patient presents with productive cough, fever (39.2°C), shortness of breath, and right-sided chest pain for 5 days. Smoker with 20 pack-year history.",
263
+ "hidden_findings": {"cough": "productive", "fever": "39.2°C", "breathing": "shortness of breath", "smoking": "20 pack-years"},
264
+ "test_results": {
265
+ "chest_xray": {"result": "Right lower lobe consolidation", "interpretation": "Chest X-ray shows consolidation in right lower lobe consistent with pneumonia"},
266
+ "cbc": {"result": "WBC 14.5, neutrophilia", "interpretation": "Elevated white blood cell count with neutrophilia suggesting bacterial infection"},
267
+ "sputum_culture": {"result": "Streptococcus pneumoniae", "interpretation": "Sputum culture positive for Streptococcus pneumoniae"},
268
+ },
269
+ "correct_diagnoses": ["Community-Acquired Pneumonia", "Pneumonia", "Bacterial Pneumonia"],
270
+ "differential_diagnoses": ["Bronchitis", "Pulmonary Embolism", "Lung Cancer"],
271
+ },
272
+ "hard_endocarditis": {
273
+ "case_id": "hard_endocarditis",
274
+ "difficulty": "hard",
275
+ "true_diagnosis": "Infective Endocarditis",
276
+ "age": 55, "gender": "Male",
277
+ "presentation": "Patient with history of IV drug use presents with fever (38.8°C), new heart murmur, and splinter hemorrhages. Recent dental procedure 2 weeks ago.",
278
+ "hidden_findings": {"drug_use": "IV drug user", "murmur": "new heart murmur", "hemorrhages": "splinter hemorrhages", "dental": "recent dental procedure"},
279
+ "test_results": {
280
+ "blood_cultures": {"result": "Staphylococcus aureus in 3/3 bottles", "interpretation": "Blood cultures positive for Staphylococcus aureus in multiple bottles"},
281
+ "echocardiogram": {"result": "Vegetation on aortic valve", "interpretation": "Echocardiogram shows vegetation on aortic valve consistent with endocarditis"},
282
+ "cbc": {"result": "WBC 12.8, anemia", "interpretation": "Elevated white blood cells with anemia of chronic disease"},
283
+ },
284
+ "correct_diagnoses": ["Infective Endocarditis", "Endocarditis", "Bacterial Endocarditis"],
285
+ "differential_diagnoses": ["Sepsis", "Acute Rheumatic Fever", "Myocardial Infarction"],
286
+ },
287
+ "medium_appendicitis": {
288
+ "case_id": "medium_appendicitis",
289
+ "difficulty": "medium",
290
+ "true_diagnosis": "Acute Appendicitis",
291
+ "age": 23, "gender": "Female",
292
+ "presentation": "Patient presents with right lower quadrant abdominal pain, nausea, anorexia, and low-grade fever for 24 hours.",
293
+ "hidden_findings": {"pain": "right lower quadrant", "nausea": "yes", "anorexia": "yes", "fever": "low-grade"},
294
+ "test_results": {
295
+ "abdominal_ultrasound": {"result": "Enlarged appendix with periappendiceal fluid", "interpretation": "Findings are consistent with acute appendicitis"},
296
+ "cbc": {"result": "WBC 13.4, neutrophilia", "interpretation": "Elevated white blood cell count with neutrophils suggests acute inflammation"},
297
+ "urinalysis": {"result": "Trace leukocytes", "interpretation": "Urinalysis slightly abnormal but not diagnostic"},
298
+ },
299
+ "correct_diagnoses": ["Acute Appendicitis", "Appendicitis"],
300
+ "differential_diagnoses": ["Ovarian Cyst", "Ectopic Pregnancy", "Gastroenteritis"],
301
+ },
302
+ "hard_meningitis": {
303
+ "case_id": "hard_meningitis",
304
+ "difficulty": "hard",
305
+ "true_diagnosis": "Bacterial Meningitis",
306
+ "age": 34, "gender": "Male",
307
+ "presentation": "Patient presents with severe headache, neck stiffness, fever, photophobia, and confusion over the last 12 hours.",
308
+ "hidden_findings": {"headache": "severe", "neck": "stiff", "fever": "high", "photophobia": "yes"},
309
+ "test_results": {
310
+ "lumbar_puncture": {"result": "Cloudy CSF with neutrophil predominance", "interpretation": "CSF findings are consistent with bacterial meningitis"},
311
+ "blood_cultures": {"result": "Gram-positive cocci in pairs", "interpretation": "Blood cultures positive for likely Streptococcus pneumoniae"},
312
+ "ct_head": {"result": "No mass effect or hemorrhage", "interpretation": "CT head is unremarkable prior to lumbar puncture"},
313
+ },
314
+ "correct_diagnoses": ["Bacterial Meningitis", "Meningitis"],
315
+ "differential_diagnoses": ["Viral Meningitis", "Migraine", "Subarachnoid Hemorrhage"],
316
+ },
317
+ }
318
+
319
+ real_cases = generate_patient_cases_from_datasets() if USE_HF_DATASETS else {}
320
+ # Merge: static cases are always available, real cases supplement
321
+ PATIENT_CASES = {**STATIC_PATIENT_CASES, **real_cases}
322
+
323
+
324
+ # ==============================================================================================================================
325
+ # PATIENT RESPONSE GENERATION
326
+ # ==============================================================================================================================
327
+
328
+ def get_patient_response(case_id: str, question: str) -> str:
329
+ """
330
+ Generate a patient response to a question.
331
+
332
+ For dataset-driven cases, responses are generic but clinically plausible.
333
+ """
334
+ question_lower = question.lower()
335
+ if "pain" in question_lower:
336
+ return "Yes, I am experiencing pain in that area."
337
+ if "fever" in question_lower or "temperature" in question_lower:
338
+ return "I feel warm and may have a fever."
339
+ if "nausea" in question_lower or "vomit" in question_lower:
340
+ return "Yes, I am nauseated and may vomit."
341
+ if "cough" in question_lower or "breath" in question_lower:
342
+ return "I have some coughing and breathing discomfort."
343
+ if "symptom" in question_lower or "feel" in question_lower:
344
+ return "I have concerning symptoms right now."
345
+ return "I'm not sure about that. Can you ask in a different way?"
server/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.104.1
2
+ uvicorn[standard]==0.24.0
3
+ openenv-core==0.2.3
4
+ websockets==16.0
5
+ pydantic==2.5.0
6
+ pydantic-settings==2.1.0
7
+ openai>=2.7.2
8
+ requests==2.31.0
9
+ datasets==2.14.5
tests/test_environment.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unit tests for Medical Diagnostic Environment
3
+
4
+ Run with: python -m pytest tests/test_environment.py -v
5
+ """
6
+
7
+ import pytest
8
+ from server.environment import MedicalDiagnosticEnvironment
9
+ from server.medical_data import (
10
+ PATIENT_CASES,
11
+ calculate_question_reward,
12
+ calculate_test_reward,
13
+ calculate_diagnosis_accuracy,
14
+ )
15
+ from models import DiagnosticAction
16
+
17
+
18
+ class TestMedicalDiagnosticEnvironment:
19
+ """Test suite for MedicalDiagnosticEnvironment"""
20
+
21
+ @pytest.fixture
22
+ def env(self):
23
+ """Create a fresh environment for each test"""
24
+ return MedicalDiagnosticEnvironment()
25
+
26
+ def test_environment_initialization(self, env):
27
+ """Test that environment initializes correctly"""
28
+ assert env is not None
29
+ assert hasattr(env, "reset")
30
+ assert hasattr(env, "step")
31
+ assert hasattr(env, "state")
32
+
33
+ def test_reset_easy(self, env):
34
+ """Test reset with easy difficulty"""
35
+ observation = env.reset(difficulty="easy")
36
+ assert observation is not None
37
+ assert hasattr(observation, "message")
38
+ assert "presentation" in observation.message.lower() or "patient" in observation.message.lower()
39
+ assert env.current_case_id is not None
40
+ assert env.current_difficulty == "easy"
41
+
42
+ def test_reset_medium(self, env):
43
+ """Test reset with medium difficulty"""
44
+ observation = env.reset(difficulty="medium")
45
+ assert observation is not None
46
+ assert env.current_difficulty == "medium"
47
+
48
+ def test_reset_hard(self, env):
49
+ """Test reset with hard difficulty"""
50
+ observation = env.reset(difficulty="hard")
51
+ assert observation is not None
52
+ assert env.current_difficulty == "hard"
53
+
54
+ def test_ask_question_action(self, env):
55
+ """Test asking a question"""
56
+ env.reset(difficulty="easy")
57
+ action = DiagnosticAction(
58
+ action_type="ask_question",
59
+ question="Does the patient have a fever?"
60
+ )
61
+ result = env.step(action)
62
+ assert result is not None
63
+ assert hasattr(result, "reward")
64
+ assert result.reward >= 0 # Questions give non-negative reward
65
+
66
+ def test_order_test_action(self, env):
67
+ """Test ordering a test"""
68
+ env.reset(difficulty="easy")
69
+ action = DiagnosticAction(
70
+ action_type="order_test",
71
+ test_name="Complete Blood Count"
72
+ )
73
+ result = env.step(action)
74
+ assert result is not None
75
+ assert hasattr(result, "reward")
76
+ assert result.reward >= 0 # Tests give non-negative reward
77
+
78
+ def test_submit_diagnosis_action(self, env):
79
+ """Test submitting a diagnosis"""
80
+ env.reset(difficulty="easy")
81
+ action = DiagnosticAction(
82
+ action_type="submit_diagnosis",
83
+ diagnosis="Common Flu"
84
+ )
85
+ result = env.step(action)
86
+ assert result is not None
87
+ assert hasattr(result, "reward")
88
+ assert result.done is True # Episode should end on diagnosis
89
+
90
+ def test_max_steps_enforcement(self, env):
91
+ """Test that episodes end after max steps"""
92
+ env.reset(difficulty="easy")
93
+ for _ in range(15): # Max 15 steps
94
+ action = DiagnosticAction(
95
+ action_type="ask_question",
96
+ question="Test question"
97
+ )
98
+ result = env.step(action)
99
+ if result.done:
100
+ break
101
+ assert result.done is True
102
+
103
+ def test_episode_summary(self, env):
104
+ """Test episode summary generation"""
105
+ env.reset(difficulty="easy")
106
+ action = DiagnosticAction(
107
+ action_type="submit_diagnosis",
108
+ diagnosis="Test Diagnosis"
109
+ )
110
+ env.step(action)
111
+ summary = env.get_episode_summary()
112
+ assert summary is not None
113
+ assert "case_id" in summary
114
+ assert "difficulty" in summary
115
+ assert "accuracy" in summary
116
+ assert "total_reward" in summary
117
+ assert "steps" in summary
118
+
119
+ def test_state_property(self, env):
120
+ """Test the state property"""
121
+ env.reset(difficulty="easy")
122
+ state = env.state
123
+ assert state is not None
124
+ assert hasattr(state, "patient_id")
125
+ assert hasattr(state, "step_count")
126
+ assert hasattr(state, "true_diagnosis")
127
+
128
+ def test_concurrent_sessions(self):
129
+ """Test that environment supports concurrent sessions"""
130
+ env = MedicalDiagnosticEnvironment()
131
+ assert env.SUPPORTS_CONCURRENT_SESSIONS is True
132
+
133
+ def test_multiple_episodes(self, env):
134
+ """Test running multiple episodes"""
135
+ for difficulty in ["easy", "medium", "hard"]:
136
+ observation = env.reset(difficulty=difficulty)
137
+ assert observation is not None
138
+ assert env.current_difficulty == difficulty
139
+
140
+
141
+ class TestMedicalData:
142
+ """Test suite for medical data functions"""
143
+
144
+ def test_question_reward_calculation(self):
145
+ """Test question reward calculation"""
146
+ # This is case-specific, so we just verify the function works
147
+ case_id = next(iter(PATIENT_CASES))
148
+ reward = calculate_question_reward(
149
+ case_id=case_id,
150
+ question="Does the patient have a fever?"
151
+ )
152
+ assert 0.0 <= reward <= 1.0
153
+
154
+ def test_test_reward_calculation(self):
155
+ """Test test reward calculation"""
156
+ case_id = next(iter(PATIENT_CASES))
157
+ reward = calculate_test_reward(
158
+ case_id=case_id,
159
+ test_name="CBC"
160
+ )
161
+ assert 0.0 <= reward <= 1.0
162
+
163
+ def test_diagnosis_accuracy_exact_match(self):
164
+ """Test exact diagnosis match"""
165
+ case_id = next(iter(PATIENT_CASES))
166
+ accuracy = calculate_diagnosis_accuracy(
167
+ case_id=case_id,
168
+ submitted_diagnosis=PATIENT_CASES[case_id].get("true_diagnosis", "")
169
+ )
170
+ assert accuracy == 1.0
171
+
172
+ def test_diagnosis_accuracy_partial(self):
173
+ """Test partial diagnosis accuracy"""
174
+ case_id = next(iter(PATIENT_CASES))
175
+ accuracy = calculate_diagnosis_accuracy(
176
+ case_id=case_id,
177
+ submitted_diagnosis="Pneumonia"
178
+ )
179
+ assert 0.0 <= accuracy <= 1.0
180
+
181
+
182
+ if __name__ == "__main__":
183
+ pytest.main([__file__, "-v", "-s"])
training_wrapper.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """training_wrapper.py — Minimal training-ready wrapper for the Medical Diagnostic Environment.
2
+
3
+ This module exposes a small async/sync wrapper that is easy to plug into a training
4
+ loop or evaluation script. It is not a full RL algorithm, but it makes the
5
+ environment easy to consume for model-based training.
6
+ """
7
+
8
+ import os
9
+ import asyncio
10
+ from typing import Optional
11
+
12
+ from client import DiagnosticEnv
13
+ from models import DiagnosticAction, PatientObservation
14
+
15
+
16
+ class TrainingEnv:
17
+ """Minimal wrapper exposing a training-friendly environment interface."""
18
+
19
+ def __init__(self, base_url: Optional[str] = None):
20
+ self.base_url = base_url or os.getenv("ENV_URL", "ws://localhost:8000/ws")
21
+ self._env = DiagnosticEnv(base_url=self.base_url)
22
+
23
+ async def __aenter__(self):
24
+ await self._env.__aenter__()
25
+ return self
26
+
27
+ async def __aexit__(self, exc_type, exc, tb):
28
+ await self._env.__aexit__(exc_type, exc, tb)
29
+
30
+ async def reset(self, difficulty: str = "easy") -> PatientObservation:
31
+ result = await self._env.reset(difficulty=difficulty)
32
+ return result.observation if hasattr(result, "observation") else result
33
+
34
+ async def step(
35
+ self,
36
+ action_type: str,
37
+ question: Optional[str] = None,
38
+ test_name: Optional[str] = None,
39
+ diagnosis: Optional[str] = None,
40
+ ) -> PatientObservation:
41
+ action = DiagnosticAction(
42
+ action_type=action_type,
43
+ question=question,
44
+ test_name=test_name,
45
+ diagnosis=diagnosis,
46
+ )
47
+ result = await self._env.step(action)
48
+ return result.observation if hasattr(result, "observation") else result
49
+
50
+ async def state(self):
51
+ return await self._env.state()
52
+
53
+
54
+ async def run_demo():
55
+ """Example usage of the training wrapper."""
56
+ async with TrainingEnv() as env:
57
+ obs = await env.reset(difficulty="easy")
58
+ print("Reset observation:", obs.message)
59
+ result = await env.step(action_type="ask_question", question="Do you have a fever?")
60
+ print("Step result:", result.message)
61
+
62
+
63
+ if __name__ == "__main__":
64
+ asyncio.run(run_demo())
validate.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Quick validation script for Medical Diagnostic Environment
4
+
5
+ This script validates that the core environment works correctly without
6
+ requiring the server to be running or external dependencies beyond models.
7
+
8
+ Run with: python validate.py
9
+ """
10
+
11
+ import sys
12
+ import traceback
13
+ from pathlib import Path
14
+ from typing import Dict, List
15
+
16
+ # Add parent directory to path
17
+ sys.path.insert(0, str(Path(__file__).parent))
18
+
19
+ from models import DiagnosticAction, PatientObservation, ClinicalState
20
+ from server.environment import MedicalDiagnosticEnvironment
21
+ from server.medical_data import (
22
+ PATIENT_CASES,
23
+ calculate_question_reward,
24
+ calculate_test_reward,
25
+ calculate_diagnosis_accuracy,
26
+ )
27
+
28
+
29
+ class ValidationResult:
30
+ """Result of a validation check"""
31
+ def __init__(self, name: str, passed: bool, error: str = None):
32
+ self.name = name
33
+ self.passed = passed
34
+ self.error = error
35
+
36
+ def __str__(self):
37
+ status = "PASS" if self.passed else "FAIL"
38
+ msg = f"{status}: {self.name}"
39
+ if self.error:
40
+ msg += f"\n Error: {self.error}"
41
+ return msg
42
+
43
+
44
+ def validate_imports() -> ValidationResult:
45
+ """Check that all imports work"""
46
+ try:
47
+ from models import DiagnosticAction, PatientObservation, ClinicalState
48
+ from server.environment import MedicalDiagnosticEnvironment
49
+ from server.medical_data import (
50
+ calculate_question_reward,
51
+ calculate_test_reward,
52
+ calculate_diagnosis_accuracy,
53
+ )
54
+ return ValidationResult("Imports", True)
55
+ except Exception as e:
56
+ return ValidationResult("Imports", False, str(e))
57
+
58
+
59
+ def validate_model_creation() -> ValidationResult:
60
+ """Check that models can be instantiated"""
61
+ try:
62
+ action = DiagnosticAction(
63
+ action_type="ask_question",
64
+ question="Test question?"
65
+ )
66
+ assert action.action_type == "ask_question"
67
+ assert action.question == "Test question?"
68
+ return ValidationResult("Model Creation", True)
69
+ except Exception as e:
70
+ return ValidationResult("Model Creation", False, str(e))
71
+
72
+
73
+ def validate_environment_init() -> ValidationResult:
74
+ """Check that environment initializes"""
75
+ try:
76
+ env = MedicalDiagnosticEnvironment()
77
+ assert env is not None
78
+ assert hasattr(env, "reset")
79
+ assert hasattr(env, "step")
80
+ return ValidationResult("Environment Initialization", True)
81
+ except Exception as e:
82
+ return ValidationResult("Environment Initialization", False, str(e))
83
+
84
+
85
+ def validate_reset_all_difficulties() -> ValidationResult:
86
+ """Check that reset works for all difficulties"""
87
+ try:
88
+ env = MedicalDiagnosticEnvironment()
89
+ for difficulty in ["easy", "medium", "hard"]:
90
+ obs = env.reset(difficulty=difficulty)
91
+ assert obs is not None
92
+ assert env.current_difficulty == difficulty
93
+ assert env.current_case_id is not None
94
+ return ValidationResult("Reset All Difficulties", True)
95
+ except Exception as e:
96
+ return ValidationResult("Reset All Difficulties", False, str(e))
97
+
98
+
99
+ def validate_question_action() -> ValidationResult:
100
+ """Check that asking questions works"""
101
+ try:
102
+ env = MedicalDiagnosticEnvironment()
103
+ env.reset(difficulty="easy")
104
+ action = DiagnosticAction(
105
+ action_type="ask_question",
106
+ question="Does the patient have symptoms?"
107
+ )
108
+ result = env.step(action)
109
+ assert result is not None
110
+ assert result.reward >= 0
111
+ assert result.done is False # Should not end on question
112
+ return ValidationResult("Question Action", True)
113
+ except Exception as e:
114
+ return ValidationResult("Question Action", False, str(e))
115
+
116
+
117
+ def validate_test_action() -> ValidationResult:
118
+ """Check that ordering tests works"""
119
+ try:
120
+ env = MedicalDiagnosticEnvironment()
121
+ env.reset(difficulty="easy")
122
+ action = DiagnosticAction(
123
+ action_type="order_test",
124
+ test_name="Complete Blood Count"
125
+ )
126
+ result = env.step(action)
127
+ assert result is not None
128
+ assert result.reward >= 0
129
+ assert result.done is False # Should not end on test
130
+ return ValidationResult("Test Action", True)
131
+ except Exception as e:
132
+ return ValidationResult("Test Action", False, str(e))
133
+
134
+
135
+ def validate_diagnosis_action() -> ValidationResult:
136
+ """Check that diagnosis submission works"""
137
+ try:
138
+ env = MedicalDiagnosticEnvironment()
139
+ env.reset(difficulty="easy")
140
+ action = DiagnosticAction(
141
+ action_type="submit_diagnosis",
142
+ diagnosis="Common Flu"
143
+ )
144
+ result = env.step(action)
145
+ assert result is not None
146
+ assert result.reward is not None
147
+ assert result.done is True # Should end on diagnosis
148
+ return ValidationResult("Diagnosis Action", True)
149
+ except Exception as e:
150
+ return ValidationResult("Diagnosis Action", False, str(e))
151
+
152
+
153
+ def validate_episode_summary() -> ValidationResult:
154
+ """Check that episode summaries are generated correctly"""
155
+ try:
156
+ env = MedicalDiagnosticEnvironment()
157
+ env.reset(difficulty="easy")
158
+ action = DiagnosticAction(
159
+ action_type="submit_diagnosis",
160
+ diagnosis="Test"
161
+ )
162
+ env.step(action)
163
+ summary = env.get_episode_summary()
164
+ assert summary is not None
165
+ assert "case_id" in summary
166
+ assert "difficulty" in summary
167
+ assert "accuracy" in summary
168
+ assert "total_reward" in summary
169
+ assert "steps" in summary
170
+ return ValidationResult("Episode Summary", True)
171
+ except Exception as e:
172
+ return ValidationResult("Episode Summary", False, str(e))
173
+
174
+
175
+ def validate_reward_functions() -> ValidationResult:
176
+ """Check that reward functions work"""
177
+ try:
178
+ case_id = next(iter(PATIENT_CASES))
179
+ q_reward = calculate_question_reward(case_id, "Test question?")
180
+ assert isinstance(q_reward, float)
181
+ assert 0.0 <= q_reward <= 1.0
182
+
183
+ t_reward = calculate_test_reward(case_id, "CBC")
184
+ assert isinstance(t_reward, float)
185
+ assert 0.0 <= t_reward <= 1.0
186
+
187
+ true_diag = PATIENT_CASES[case_id].get("true_diagnosis", "")
188
+ d_accuracy = calculate_diagnosis_accuracy(case_id, true_diag)
189
+ assert isinstance(d_accuracy, float)
190
+ assert 0.0 <= d_accuracy <= 1.0
191
+
192
+ return ValidationResult("Reward Functions", True)
193
+ except Exception as e:
194
+ return ValidationResult("Reward Functions", False, str(e))
195
+
196
+
197
+ def validate_state_property() -> ValidationResult:
198
+ """Check that state property works"""
199
+ try:
200
+ env = MedicalDiagnosticEnvironment()
201
+ env.reset(difficulty="easy")
202
+ state = env.state()
203
+ assert state is not None
204
+ assert hasattr(state, "patient_id")
205
+ assert hasattr(state, "step_count")
206
+ assert hasattr(state, "true_diagnosis")
207
+ assert hasattr(state, "final_accuracy")
208
+ return ValidationResult("State Property", True)
209
+ except Exception as e:
210
+ return ValidationResult("State Property", False, str(e))
211
+
212
+
213
+ def validate_concurrent_support() -> ValidationResult:
214
+ """Check that environment supports concurrent sessions"""
215
+ try:
216
+ env = MedicalDiagnosticEnvironment()
217
+ assert hasattr(env, "SUPPORTS_CONCURRENT_SESSIONS")
218
+ assert env.SUPPORTS_CONCURRENT_SESSIONS is True
219
+ return ValidationResult("Concurrent Sessions Support", True)
220
+ except Exception as e:
221
+ return ValidationResult("Concurrent Sessions Support", False, str(e))
222
+
223
+
224
+ def main():
225
+ """Run all validation checks"""
226
+ print("=" * 70)
227
+ print("MEDICAL DIAGNOSTIC ENVIRONMENT - VALIDATION SUITE")
228
+ print("=" * 70)
229
+ print()
230
+
231
+ validators = [
232
+ validate_imports,
233
+ validate_model_creation,
234
+ validate_environment_init,
235
+ validate_reset_all_difficulties,
236
+ validate_question_action,
237
+ validate_test_action,
238
+ validate_diagnosis_action,
239
+ validate_episode_summary,
240
+ validate_reward_functions,
241
+ validate_state_property,
242
+ validate_concurrent_support,
243
+ ]
244
+
245
+ results: List[ValidationResult] = []
246
+ for validator in validators:
247
+ try:
248
+ result = validator()
249
+ except Exception as e:
250
+ result = ValidationResult(
251
+ validator.__name__,
252
+ False,
253
+ traceback.format_exc()
254
+ )
255
+ results.append(result)
256
+ print(result)
257
+
258
+ print()
259
+ print("=" * 70)
260
+ passed = sum(1 for r in results if r.passed)
261
+ total = len(results)
262
+ print(f"SUMMARY: {passed}/{total} checks passed")
263
+ print("=" * 70)
264
+
265
+ return 0 if passed == total else 1
266
+
267
+
268
+ if __name__ == "__main__":
269
+ sys.exit(main())