Spaces:
Sleeping
Sleeping
Upload 18 files
Browse files- .gitignore +157 -0
- Dockerfile +38 -0
- LICENSE +21 -0
- README.md +214 -12
- client.py +168 -0
- docker-compose.yml +42 -0
- inference.py +454 -0
- models.py +81 -0
- pyproject.toml +80 -0
- server/Dockerfile +43 -0
- server/__init__.py +21 -0
- server/app.py +46 -0
- server/environment.py +436 -0
- server/medical_data.py +345 -0
- server/requirements.txt +9 -0
- tests/test_environment.py +183 -0
- training_wrapper.py +64 -0
- validate.py +269 -0
.gitignore
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
pip-wheel-metadata/
|
| 24 |
+
share/python-wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
MANIFEST
|
| 29 |
+
|
| 30 |
+
# PyInstaller
|
| 31 |
+
*.manifest
|
| 32 |
+
*.spec
|
| 33 |
+
|
| 34 |
+
# Unit test / coverage reports
|
| 35 |
+
htmlcov/
|
| 36 |
+
.tox/
|
| 37 |
+
.nox/
|
| 38 |
+
.coverage
|
| 39 |
+
.coverage.*
|
| 40 |
+
.cache
|
| 41 |
+
nosetests.xml
|
| 42 |
+
coverage.xml
|
| 43 |
+
*.cover
|
| 44 |
+
*.py,cover
|
| 45 |
+
.hypothesis/
|
| 46 |
+
.pytest_cache/
|
| 47 |
+
|
| 48 |
+
# Translations
|
| 49 |
+
*.mo
|
| 50 |
+
*.pot
|
| 51 |
+
|
| 52 |
+
# Django stuff:
|
| 53 |
+
*.log
|
| 54 |
+
local_settings.py
|
| 55 |
+
db.sqlite3
|
| 56 |
+
db.sqlite3-journal
|
| 57 |
+
|
| 58 |
+
# Flask stuff:
|
| 59 |
+
instance/
|
| 60 |
+
.webassets-cache
|
| 61 |
+
|
| 62 |
+
# Scrapy stuff:
|
| 63 |
+
.scrapy
|
| 64 |
+
|
| 65 |
+
# Sphinx documentation
|
| 66 |
+
docs/_build/
|
| 67 |
+
|
| 68 |
+
# PyBuilder
|
| 69 |
+
target/
|
| 70 |
+
|
| 71 |
+
# Jupyter Notebook
|
| 72 |
+
.ipynb_checkpoints
|
| 73 |
+
|
| 74 |
+
# IPython
|
| 75 |
+
profile_default/
|
| 76 |
+
ipython_config.py
|
| 77 |
+
|
| 78 |
+
# pyenv
|
| 79 |
+
.python-version
|
| 80 |
+
|
| 81 |
+
# pipenv
|
| 82 |
+
Pipfile.lock
|
| 83 |
+
|
| 84 |
+
# PEP 582
|
| 85 |
+
__pypackages__/
|
| 86 |
+
|
| 87 |
+
# Celery stuff
|
| 88 |
+
celerybeat-schedule
|
| 89 |
+
celerybeat.pid
|
| 90 |
+
|
| 91 |
+
# SageMath parsed files
|
| 92 |
+
*.sage.py
|
| 93 |
+
|
| 94 |
+
# Environments
|
| 95 |
+
.env
|
| 96 |
+
.venv
|
| 97 |
+
env/
|
| 98 |
+
venv/
|
| 99 |
+
ENV/
|
| 100 |
+
env.bak/
|
| 101 |
+
venv.bak/
|
| 102 |
+
|
| 103 |
+
# Spyder project settings
|
| 104 |
+
.spyderproject
|
| 105 |
+
.spyproject
|
| 106 |
+
|
| 107 |
+
# Rope project settings
|
| 108 |
+
.ropeproject
|
| 109 |
+
|
| 110 |
+
# mkdocs documentation
|
| 111 |
+
/site
|
| 112 |
+
|
| 113 |
+
# mypy
|
| 114 |
+
.mypy_cache/
|
| 115 |
+
.dmypy.json
|
| 116 |
+
dmypy.json
|
| 117 |
+
|
| 118 |
+
# Pyre type checker
|
| 119 |
+
.pyre/
|
| 120 |
+
|
| 121 |
+
# IDE
|
| 122 |
+
.vscode/
|
| 123 |
+
.idea/
|
| 124 |
+
*.swp
|
| 125 |
+
*.swo
|
| 126 |
+
*~
|
| 127 |
+
.DS_Store
|
| 128 |
+
|
| 129 |
+
# Project specific
|
| 130 |
+
venv/
|
| 131 |
+
.env.local
|
| 132 |
+
.env.*.local
|
| 133 |
+
instance/
|
| 134 |
+
.webassets-cache
|
| 135 |
+
docker-compose.override.yml
|
| 136 |
+
|
| 137 |
+
# API Keys
|
| 138 |
+
.openai_api_key
|
| 139 |
+
.hf_token
|
| 140 |
+
|
| 141 |
+
# Model checkpoints
|
| 142 |
+
models/
|
| 143 |
+
checkpoints/
|
| 144 |
+
|
| 145 |
+
# Logs
|
| 146 |
+
logs/
|
| 147 |
+
*.log
|
| 148 |
+
log/
|
| 149 |
+
|
| 150 |
+
# Temporary files
|
| 151 |
+
tmp/
|
| 152 |
+
temp/
|
| 153 |
+
*.tmp
|
| 154 |
+
|
| 155 |
+
# OS
|
| 156 |
+
.DS_Store
|
| 157 |
+
Thumbs.db
|
Dockerfile
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Set environment variables
|
| 6 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 7 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 8 |
+
PIP_NO_CACHE_DIR=1 \
|
| 9 |
+
PIP_DISABLE_PIP_VERSION_CHECK=1
|
| 10 |
+
|
| 11 |
+
# Install system dependencies (minimal)
|
| 12 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 13 |
+
curl \
|
| 14 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 15 |
+
|
| 16 |
+
# Copy and install Python dependencies first (for layer caching)
|
| 17 |
+
COPY server/requirements.txt /tmp/requirements.txt
|
| 18 |
+
RUN pip install --no-cache-dir -r /tmp/requirements.txt && \
|
| 19 |
+
rm /tmp/requirements.txt
|
| 20 |
+
|
| 21 |
+
# Copy application code
|
| 22 |
+
COPY models.py .
|
| 23 |
+
COPY client.py .
|
| 24 |
+
COPY server/ ./server/
|
| 25 |
+
|
| 26 |
+
# Create __init__ files for Python packages
|
| 27 |
+
RUN touch __init__.py && \
|
| 28 |
+
touch server/__init__.py
|
| 29 |
+
|
| 30 |
+
# Health check - validates the server is running and responsive
|
| 31 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 32 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 33 |
+
|
| 34 |
+
# Expose port
|
| 35 |
+
EXPOSE 8000
|
| 36 |
+
|
| 37 |
+
# Run the FastAPI server with uvicorn
|
| 38 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Vicky-220
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -1,12 +1,214 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Medical Diagnostic Environment
|
| 2 |
+
|
| 3 |
+
A lightweight OpenEnv environment for training agents to diagnose patients using clinical reasoning. The agent interacts through a turn-based dialog, orders tests, and submits a final diagnosis.
|
| 4 |
+
|
| 5 |
+
## What this project does
|
| 6 |
+
|
| 7 |
+
The environment simulates a real medical workflow:
|
| 8 |
+
- Present a patient with symptoms and context
|
| 9 |
+
- Let the agent ask clinical questions
|
| 10 |
+
- Allow the agent to order diagnostic tests
|
| 11 |
+
- Score the agent on diagnosis accuracy and process quality
|
| 12 |
+
|
| 13 |
+
It is designed to be used for training or evaluation with reinforcement learning systems.
|
| 14 |
+
|
| 15 |
+
## Why this environment is useful
|
| 16 |
+
|
| 17 |
+
This is not a toy problem. It is a small clinical reasoning task with:
|
| 18 |
+
- real clinical cases and realistic feedback
|
| 19 |
+
- multi-step decisions
|
| 20 |
+
- partial reward signals for progress
|
| 21 |
+
- a clear end goal: accurate final diagnosis
|
| 22 |
+
|
| 23 |
+
## Tasks included
|
| 24 |
+
|
| 25 |
+
There are three difficulty tiers built into the environment:
|
| 26 |
+
|
| 27 |
+
### Easy
|
| 28 |
+
- Seasonal Influenza
|
| 29 |
+
- Urinary Tract Infection
|
| 30 |
+
|
| 31 |
+
### Medium
|
| 32 |
+
- Community-Acquired Pneumonia
|
| 33 |
+
- Acute Appendicitis
|
| 34 |
+
|
| 35 |
+
### Hard
|
| 36 |
+
- Infective Endocarditis
|
| 37 |
+
- Bacterial Meningitis
|
| 38 |
+
|
| 39 |
+
Each case is graded from 0.0 to 1.0 based on the agent's final diagnosis and stepwise decisions.
|
| 40 |
+
|
| 41 |
+
## Action and observation interface
|
| 42 |
+
|
| 43 |
+
### Actions
|
| 44 |
+
The agent sends one of three actions:
|
| 45 |
+
|
| 46 |
+
```python
|
| 47 |
+
class DiagnosticAction(Action):
|
| 48 |
+
action_type: str # ask_question | order_test | submit_diagnosis
|
| 49 |
+
question: Optional[str] = None
|
| 50 |
+
test_name: Optional[str] = None
|
| 51 |
+
diagnosis: Optional[str] = None
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
### Observations
|
| 55 |
+
Each step returns a structured observation:
|
| 56 |
+
|
| 57 |
+
```python
|
| 58 |
+
class PatientObservation(Observation):
|
| 59 |
+
done: bool
|
| 60 |
+
reward: Optional[float]
|
| 61 |
+
message: str
|
| 62 |
+
patient_response: Optional[Dict]
|
| 63 |
+
test_result: Optional[Dict]
|
| 64 |
+
questions_asked: List[str]
|
| 65 |
+
tests_completed: List[str]
|
| 66 |
+
patient_data_revealed: Dict
|
| 67 |
+
steps_taken: int
|
| 68 |
+
max_steps: int
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
## Setup
|
| 72 |
+
|
| 73 |
+
### Requirements
|
| 74 |
+
- Python 3.10+
|
| 75 |
+
- Docker for containerized deployment
|
| 76 |
+
|
| 77 |
+
### Local setup
|
| 78 |
+
|
| 79 |
+
```bash
|
| 80 |
+
git clone <repository-url>
|
| 81 |
+
cd meta_synapse_hackathon
|
| 82 |
+
python -m venv venv
|
| 83 |
+
source venv/bin/activate
|
| 84 |
+
pip install -r server/requirements.txt
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### Run validation
|
| 88 |
+
|
| 89 |
+
```bash
|
| 90 |
+
python validate.py
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
## Running the environment
|
| 94 |
+
|
| 95 |
+
### Start the server
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
cd server
|
| 99 |
+
python app.py
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
Then the environment is available at:
|
| 103 |
+
- WebSocket: `ws://localhost:8000/ws`
|
| 104 |
+
- Health: `http://localhost:8000/health`
|
| 105 |
+
- Swagger: `http://localhost:8000/docs`
|
| 106 |
+
|
| 107 |
+
### Use the client
|
| 108 |
+
|
| 109 |
+
```python
|
| 110 |
+
from client import DiagnosticEnv
|
| 111 |
+
|
| 112 |
+
async with DiagnosticEnv(base_url="ws://localhost:8000/ws") as env:
|
| 113 |
+
obs = await env.reset(difficulty="easy")
|
| 114 |
+
print(obs.message)
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
## Training-ready wrapper
|
| 118 |
+
|
| 119 |
+
A simple, training-ready wrapper is available in `training_wrapper.py`. It provides a minimal async interface for use in training loops.
|
| 120 |
+
|
| 121 |
+
```bash
|
| 122 |
+
python training_wrapper.py
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
Use it in your own code like this:
|
| 126 |
+
|
| 127 |
+
```python
|
| 128 |
+
from training_wrapper import TrainingEnv
|
| 129 |
+
|
| 130 |
+
async with TrainingEnv() as env:
|
| 131 |
+
obs = await env.reset(difficulty="easy")
|
| 132 |
+
step = await env.step(action_type="ask_question", question="Do you have a fever?")
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
## Baseline inference
|
| 136 |
+
|
| 137 |
+
Set the required environment variables then run the baseline script:
|
| 138 |
+
|
| 139 |
+
```bash
|
| 140 |
+
export API_BASE_URL="https://router.huggingface.co/v1"
|
| 141 |
+
export MODEL_NAME="Qwen/Qwen2.5-72B-Instruct"
|
| 142 |
+
export HF_TOKEN="your-huggingface-token"
|
| 143 |
+
export ENV_URL="ws://localhost:8000/ws"
|
| 144 |
+
python inference.py
|
| 145 |
+
```
|
| 146 |
+
|
| 147 |
+
## Optional dataset support
|
| 148 |
+
|
| 149 |
+
The environment always includes core static cases and can optionally load Hugging Face datasets when enabled.
|
| 150 |
+
|
| 151 |
+
To use Hugging Face dataset generation:
|
| 152 |
+
|
| 153 |
+
```bash
|
| 154 |
+
export OPENENV_USE_HF_DATASETS=true
|
| 155 |
+
export OPENENV_DATASET_SEED=42
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
If dataset loading is disabled or unavailable, the environment still works with the built-in cases.
|
| 159 |
+
|
| 160 |
+
## Docker deployment
|
| 161 |
+
|
| 162 |
+
### Build locally
|
| 163 |
+
|
| 164 |
+
```bash
|
| 165 |
+
docker build -t medical-diagnostic-env ./server
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
### Run locally
|
| 169 |
+
|
| 170 |
+
```bash
|
| 171 |
+
docker run -p 8000:8000 medical-diagnostic-env
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
### Deploy to Hugging Face Spaces
|
| 175 |
+
|
| 176 |
+
1. Create a new Space using Docker.
|
| 177 |
+
2. Upload the repository files.
|
| 178 |
+
3. The Space should build and expose the server automatically.
|
| 179 |
+
|
| 180 |
+
## Notes for judges and trainers
|
| 181 |
+
|
| 182 |
+
- The environment exposes standard reset/step/state semantics.
|
| 183 |
+
- It supports concurrent sessions and WebSocket interaction.
|
| 184 |
+
- The training wrapper is intentionally minimal so any agent loop can be added on top.
|
| 185 |
+
|
| 186 |
+
## Project structure
|
| 187 |
+
|
| 188 |
+
```
|
| 189 |
+
├── models.py
|
| 190 |
+
├── client.py
|
| 191 |
+
├── training_wrapper.py
|
| 192 |
+
├── inference.py
|
| 193 |
+
├── validate.py
|
| 194 |
+
├── openenv.yaml
|
| 195 |
+
├── server/
|
| 196 |
+
│ ├── app.py
|
| 197 |
+
│ ├── environment.py
|
| 198 |
+
�� ├── medical_data.py
|
| 199 |
+
│ ├── requirements.txt
|
| 200 |
+
│ └── Dockerfile
|
| 201 |
+
└── tests/
|
| 202 |
+
└── test_environment.py
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
## Testing
|
| 206 |
+
|
| 207 |
+
```bash
|
| 208 |
+
python -m pytest tests/
|
| 209 |
+
python validate.py
|
| 210 |
+
```
|
| 211 |
+
|
| 212 |
+
## License
|
| 213 |
+
|
| 214 |
+
MIT License - see LICENSE file for details.
|
client.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
client.py — OpenEnv client for the Medical Diagnostic Environment.
|
| 3 |
+
|
| 4 |
+
This client enables training code to interact with the environment via WebSocket.
|
| 5 |
+
Provides both async and sync interfaces for flexibility.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from typing import Optional
|
| 9 |
+
import asyncio
|
| 10 |
+
import json
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import websockets
|
| 14 |
+
except ImportError:
|
| 15 |
+
websockets = None
|
| 16 |
+
|
| 17 |
+
from openenv.core.env_client import EnvClient
|
| 18 |
+
from openenv.core.client_types import StepResult
|
| 19 |
+
import sys
|
| 20 |
+
import os
|
| 21 |
+
|
| 22 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 23 |
+
|
| 24 |
+
from models import DiagnosticAction, PatientObservation, ClinicalState
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DiagnosticEnv(EnvClient[DiagnosticAction, PatientObservation, ClinicalState]):
|
| 28 |
+
"""
|
| 29 |
+
Client for interacting with the Medical Diagnostic Environment.
|
| 30 |
+
|
| 31 |
+
Supports both async and sync usage:
|
| 32 |
+
|
| 33 |
+
Async (recommended for training):
|
| 34 |
+
async with DiagnosticEnv(base_url="...") as env:
|
| 35 |
+
obs = await env.reset()
|
| 36 |
+
obs = await env.step(DiagnosticAction(...))
|
| 37 |
+
|
| 38 |
+
Sync (for notebooks/simple scripts):
|
| 39 |
+
with DiagnosticEnv(base_url="...").sync() as env:
|
| 40 |
+
obs = env.reset()
|
| 41 |
+
obs = env.step(DiagnosticAction(...))
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
@classmethod
|
| 45 |
+
async def from_docker_image(cls, image_name: Optional[str] = None, base_url: Optional[str] = None, **kwargs):
|
| 46 |
+
"""Create client connected to a running OpenEnv environment URL."""
|
| 47 |
+
if base_url is None:
|
| 48 |
+
base_url = os.getenv("ENV_URL", "ws://localhost:8000/ws")
|
| 49 |
+
return cls(base_url=base_url, **kwargs)
|
| 50 |
+
|
| 51 |
+
def _step_payload(self, action: DiagnosticAction) -> dict:
|
| 52 |
+
"""Convert action to JSON payload for server."""
|
| 53 |
+
return {
|
| 54 |
+
"action_type": action.action_type,
|
| 55 |
+
"question": action.question,
|
| 56 |
+
"test_name": action.test_name,
|
| 57 |
+
"diagnosis": action.diagnosis,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
def _parse_result(self, payload: dict) -> StepResult:
|
| 61 |
+
"""Parse server response into StepResult."""
|
| 62 |
+
obs_data = payload.get("observation", {})
|
| 63 |
+
|
| 64 |
+
# Parse nested dictionaries if present
|
| 65 |
+
patient_data_revealed = obs_data.get("patient_data_revealed", {})
|
| 66 |
+
if isinstance(patient_data_revealed, str):
|
| 67 |
+
try:
|
| 68 |
+
patient_data_revealed = json.loads(patient_data_revealed)
|
| 69 |
+
except:
|
| 70 |
+
patient_data_revealed = {}
|
| 71 |
+
|
| 72 |
+
test_result = obs_data.get("test_result")
|
| 73 |
+
if isinstance(test_result, str):
|
| 74 |
+
try:
|
| 75 |
+
test_result = json.loads(test_result)
|
| 76 |
+
except:
|
| 77 |
+
test_result = None
|
| 78 |
+
|
| 79 |
+
observation = PatientObservation(
|
| 80 |
+
done=payload.get("done", False),
|
| 81 |
+
reward=payload.get("reward"),
|
| 82 |
+
message=obs_data.get("message", ""),
|
| 83 |
+
patient_response=obs_data.get("patient_response"),
|
| 84 |
+
test_result=test_result,
|
| 85 |
+
questions_asked=obs_data.get("questions_asked", []),
|
| 86 |
+
tests_completed=obs_data.get("tests_completed", []),
|
| 87 |
+
patient_data_revealed=patient_data_revealed,
|
| 88 |
+
steps_taken=obs_data.get("steps_taken", 0),
|
| 89 |
+
max_steps=obs_data.get("max_steps", 15),
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
return StepResult(
|
| 93 |
+
observation=observation,
|
| 94 |
+
reward=payload.get("reward"),
|
| 95 |
+
done=payload.get("done", False),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def _parse_state(self, payload: dict) -> ClinicalState:
|
| 99 |
+
"""Parse state response."""
|
| 100 |
+
patient_details = payload.get("patient_details", {})
|
| 101 |
+
if isinstance(patient_details, str):
|
| 102 |
+
try:
|
| 103 |
+
patient_details = json.loads(patient_details)
|
| 104 |
+
except:
|
| 105 |
+
patient_details = {}
|
| 106 |
+
|
| 107 |
+
return ClinicalState(
|
| 108 |
+
episode_id=payload.get("episode_id", ""),
|
| 109 |
+
step_count=payload.get("step_count", 0),
|
| 110 |
+
true_diagnosis=payload.get("true_diagnosis", ""),
|
| 111 |
+
patient_case=payload.get("patient_case", ""),
|
| 112 |
+
patient_details=patient_details,
|
| 113 |
+
difficulty=payload.get("difficulty", "easy"),
|
| 114 |
+
questions_asked=payload.get("questions_asked", []),
|
| 115 |
+
tests_completed=payload.get("tests_completed", []),
|
| 116 |
+
final_diagnosis_submitted=payload.get("final_diagnosis_submitted"),
|
| 117 |
+
final_accuracy=payload.get("final_accuracy", 0.0),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# ─────────────────────────────────────────────────────────────────────
|
| 121 |
+
# Sync wrapper for convenience
|
| 122 |
+
# ─────────────────────────────────────────────────────────────────────
|
| 123 |
+
|
| 124 |
+
def sync(self) -> "SyncDiagnosticEnv":
|
| 125 |
+
"""
|
| 126 |
+
Return synchronous wrapper for use in notebooks/simple scripts.
|
| 127 |
+
|
| 128 |
+
Usage:
|
| 129 |
+
with DiagnosticEnv(url).sync() as env:
|
| 130 |
+
obs = env.reset()
|
| 131 |
+
obs = env.step(action)
|
| 132 |
+
"""
|
| 133 |
+
return SyncDiagnosticEnv(self.base_url)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class SyncDiagnosticEnv:
|
| 137 |
+
"""Synchronous wrapper around async client."""
|
| 138 |
+
|
| 139 |
+
def __init__(self, base_url: str):
|
| 140 |
+
self.base_url = base_url
|
| 141 |
+
self._loop = asyncio.new_event_loop()
|
| 142 |
+
self._async_client = DiagnosticEnv(base_url)
|
| 143 |
+
|
| 144 |
+
def __enter__(self):
|
| 145 |
+
self._loop.run_until_complete(self._async_client.__aenter__())
|
| 146 |
+
return self
|
| 147 |
+
|
| 148 |
+
def __exit__(self, *args):
|
| 149 |
+
self._loop.run_until_complete(self._async_client.__aexit__(*args))
|
| 150 |
+
self._loop.close()
|
| 151 |
+
|
| 152 |
+
def reset(self, difficulty: str = "easy") -> PatientObservation:
|
| 153 |
+
"""Reset environment and start new episode."""
|
| 154 |
+
return self._loop.run_until_complete(
|
| 155 |
+
self._async_client.reset(difficulty=difficulty)
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def step(self, action: DiagnosticAction) -> StepResult:
|
| 159 |
+
"""Take a step in the environment."""
|
| 160 |
+
return self._loop.run_until_complete(
|
| 161 |
+
self._async_client.step(action)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
def state(self) -> ClinicalState:
|
| 165 |
+
"""Get current state (includes hidden information)."""
|
| 166 |
+
return self._loop.run_until_complete(
|
| 167 |
+
self._async_client.state()
|
| 168 |
+
)
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Docker Compose for local development and testing
|
| 2 |
+
#
|
| 3 |
+
# Start with: docker-compose up --build
|
| 4 |
+
# Access server at: http://localhost:8000
|
| 5 |
+
# Docs at: http://localhost:8000/docs
|
| 6 |
+
# WebSocket at: ws://localhost:8000/ws
|
| 7 |
+
|
| 8 |
+
version: "3.9"
|
| 9 |
+
|
| 10 |
+
services:
|
| 11 |
+
medical-diagnostic-env:
|
| 12 |
+
build:
|
| 13 |
+
context: .
|
| 14 |
+
dockerfile: server/Dockerfile
|
| 15 |
+
|
| 16 |
+
container_name: medical-diagnostic-env
|
| 17 |
+
|
| 18 |
+
ports:
|
| 19 |
+
- "8000:8000"
|
| 20 |
+
|
| 21 |
+
environment:
|
| 22 |
+
- PORT=8000
|
| 23 |
+
- WORKERS=2 # Reduce for local development
|
| 24 |
+
|
| 25 |
+
restart: unless-stopped
|
| 26 |
+
|
| 27 |
+
healthcheck:
|
| 28 |
+
test: ["CMD", "curl", "-f", "http://localhost:8000/health"]
|
| 29 |
+
interval: 30s
|
| 30 |
+
timeout: 10s
|
| 31 |
+
retries: 3
|
| 32 |
+
start_period: 5s
|
| 33 |
+
|
| 34 |
+
volumes:
|
| 35 |
+
- .:/app # For live code reloading during development (optional)
|
| 36 |
+
|
| 37 |
+
networks:
|
| 38 |
+
- medical-net
|
| 39 |
+
|
| 40 |
+
networks:
|
| 41 |
+
medical-net:
|
| 42 |
+
driver: bridge
|
inference.py
ADDED
|
@@ -0,0 +1,454 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
inference.py — Baseline inference script for the Medical Diagnostic Environment.
|
| 3 |
+
This script demonstrates how to run the environment with an LLM (using OpenAI client)
|
| 4 |
+
|
| 5 |
+
and logs results in the exact format required by the hackathon.
|
| 6 |
+
|
| 7 |
+
Format requirements:
|
| 8 |
+
[START] task=<task_name> env=<env_name> model=<model_name>
|
| 9 |
+
[STEP] step=<n> action=<action_str> reward=<float> done=<bool> error=<str|null>
|
| 10 |
+
[END] success=<bool> steps=<n> score=<float> rewards=<comma_separated_list>
|
| 11 |
+
|
| 12 |
+
All fields on a single line with NO NEWLINES within a line.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import asyncio
|
| 16 |
+
import os
|
| 17 |
+
import json
|
| 18 |
+
import textwrap
|
| 19 |
+
from typing import List, Optional, Dict
|
| 20 |
+
|
| 21 |
+
from openai import OpenAI
|
| 22 |
+
import sys
|
| 23 |
+
|
| 24 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 25 |
+
|
| 26 |
+
from client import DiagnosticEnv
|
| 27 |
+
from models import DiagnosticAction
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ==============================================================================
|
| 31 |
+
# CONFIGURATION
|
| 32 |
+
# ==============================================================================
|
| 33 |
+
|
| 34 |
+
# Configuration
|
| 35 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY")
|
| 36 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 37 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 38 |
+
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "medical-diagnostic-env:latest")
|
| 39 |
+
ENV_URL = os.getenv("ENV_URL", "ws://localhost:8000/ws")
|
| 40 |
+
BENCHMARK = os.getenv("BENCHMARK", "medical_diagnostic_env")
|
| 41 |
+
|
| 42 |
+
# Inference configuration
|
| 43 |
+
MAX_STEPS = 15 # Maximum steps per episode
|
| 44 |
+
TEMPERATURE = 0.7 # LLM temperature for reasoning
|
| 45 |
+
MAX_TOKENS = 256 # Max tokens per completion
|
| 46 |
+
TASK_NAMES = ["easy_diagnosis", "medium_diagnosis", "hard_diagnosis"]
|
| 47 |
+
DIFFICULTY_LEVELS = ["easy", "medium", "hard"]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# ==============================================================================
|
| 51 |
+
# LOGGING FUNCTIONS
|
| 52 |
+
# ==============================================================================
|
| 53 |
+
|
| 54 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 55 |
+
"""Log episode start in required format."""
|
| 56 |
+
# Clean model name for logging
|
| 57 |
+
model_clean = model.split("/")[-1] if "/" in model else model
|
| 58 |
+
print(f"[START] task={task} env={env} model={model_clean}", flush=True)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 62 |
+
"""Log single step in required format."""
|
| 63 |
+
error_val = f'"{error}"' if error else "null"
|
| 64 |
+
done_val = str(done).lower()
|
| 65 |
+
print(
|
| 66 |
+
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
|
| 67 |
+
flush=True,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 72 |
+
"""Log episode end in required format."""
|
| 73 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 74 |
+
success_val = str(success).lower()
|
| 75 |
+
print(
|
| 76 |
+
f"[END] success={success_val} steps={steps} score={score:.3f} rewards={rewards_str}",
|
| 77 |
+
flush=True,
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# ==============================================================================
|
| 81 |
+
# LLM INTERACTION
|
| 82 |
+
# ==============================================================================
|
| 83 |
+
|
| 84 |
+
def create_system_prompt() -> str:
|
| 85 |
+
"""Create system prompt for medical diagnostic reasoning."""
|
| 86 |
+
return textwrap.dedent("""
|
| 87 |
+
You are an expert medical diagnostic AI assistant. Your role is to:
|
| 88 |
+
|
| 89 |
+
1. GATHER INFORMATION: Ask relevant clinical questions about symptoms,
|
| 90 |
+
history, and presentation.
|
| 91 |
+
|
| 92 |
+
2. ORDER TESTS: Request appropriate diagnostic tests based on the
|
| 93 |
+
differential diagnosis.
|
| 94 |
+
|
| 95 |
+
3. REASON DIAGNOSTICALLY: Consider the patient's presentation,
|
| 96 |
+
synthesize findings, and make a diagnosis.
|
| 97 |
+
|
| 98 |
+
Your reasoning should follow clinical guidelines and prioritize:
|
| 99 |
+
- Life-threatening conditions first (red flags)
|
| 100 |
+
- Most common diagnoses for the presentation
|
| 101 |
+
- Efficiency (minimize unnecessary tests)
|
| 102 |
+
|
| 103 |
+
When responding, use EXACTLY ONE of these actions:
|
| 104 |
+
|
| 105 |
+
ACTION: ask_question
|
| 106 |
+
QUESTION: <your question here>
|
| 107 |
+
|
| 108 |
+
OR
|
| 109 |
+
|
| 110 |
+
ACTION: order_test
|
| 111 |
+
TEST: <test name>
|
| 112 |
+
|
| 113 |
+
OR
|
| 114 |
+
|
| 115 |
+
ACTION: submit_diagnosis
|
| 116 |
+
DIAGNOSIS: <final diagnosis>
|
| 117 |
+
|
| 118 |
+
Be concise. Diagnose within 10-15 steps if possible.
|
| 119 |
+
""").strip()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def extract_action_from_response(response: str) -> Optional[Dict]:
|
| 123 |
+
"""
|
| 124 |
+
Extract structured action from LLM response.
|
| 125 |
+
|
| 126 |
+
Returns dict with:
|
| 127 |
+
{
|
| 128 |
+
"action_type": "ask_question" | "order_test" | "submit_diagnosis",
|
| 129 |
+
"question": str or None,
|
| 130 |
+
"test_name": str or None,
|
| 131 |
+
"diagnosis": str or None,
|
| 132 |
+
}
|
| 133 |
+
"""
|
| 134 |
+
response_lower = response.lower()
|
| 135 |
+
|
| 136 |
+
# Try to find ACTION: directive
|
| 137 |
+
if "action:" in response_lower:
|
| 138 |
+
lines = response.split("\n")
|
| 139 |
+
action_type = None
|
| 140 |
+
question = None
|
| 141 |
+
test_name = None
|
| 142 |
+
diagnosis = None
|
| 143 |
+
|
| 144 |
+
for i, line in enumerate(lines):
|
| 145 |
+
if "action:" in line.lower():
|
| 146 |
+
action_part = line.split(":", 1)[1].strip().lower()
|
| 147 |
+
if "question" in action_part:
|
| 148 |
+
action_type = "ask_question"
|
| 149 |
+
elif "test" in action_part:
|
| 150 |
+
action_type = "order_test"
|
| 151 |
+
elif "diagnosis" in action_part:
|
| 152 |
+
action_type = "submit_diagnosis"
|
| 153 |
+
|
| 154 |
+
if "question:" in line.lower():
|
| 155 |
+
question = line.split(":", 1)[1].strip()
|
| 156 |
+
elif "test:" in line.lower():
|
| 157 |
+
test_name = line.split(":", 1)[1].strip()
|
| 158 |
+
elif "diagnosis:" in line.lower():
|
| 159 |
+
diagnosis = line.split(":", 1)[1].strip()
|
| 160 |
+
|
| 161 |
+
if action_type:
|
| 162 |
+
return {
|
| 163 |
+
"action_type": action_type,
|
| 164 |
+
"question": question,
|
| 165 |
+
"test_name": test_name,
|
| 166 |
+
"diagnosis": diagnosis,
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
# Fallback: try to infer action type
|
| 170 |
+
if "question" in response_lower or "ask" in response_lower:
|
| 171 |
+
# Extract the question
|
| 172 |
+
for line in response.split("\n"):
|
| 173 |
+
if "?" in line:
|
| 174 |
+
return {
|
| 175 |
+
"action_type": "ask_question",
|
| 176 |
+
"question": line.strip(),
|
| 177 |
+
"test_name": None,
|
| 178 |
+
"diagnosis": None,
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
if "test" in response_lower or "order" in response_lower:
|
| 182 |
+
# Try to extract test name
|
| 183 |
+
words = response.split()
|
| 184 |
+
for i, word in enumerate(words):
|
| 185 |
+
if "test" in word.lower() and i + 1 < len(words):
|
| 186 |
+
test = words[i + 1]
|
| 187 |
+
return {
|
| 188 |
+
"action_type": "order_test",
|
| 189 |
+
"question": None,
|
| 190 |
+
"test_name": test,
|
| 191 |
+
"diagnosis": None,
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
if "diagnos" in response_lower:
|
| 195 |
+
# Try to extract diagnosis
|
| 196 |
+
for word in response.split():
|
| 197 |
+
if len(word) > 3: # Filter out small words
|
| 198 |
+
return {
|
| 199 |
+
"action_type": "submit_diagnosis",
|
| 200 |
+
"question": None,
|
| 201 |
+
"test_name": None,
|
| 202 |
+
"diagnosis": response.strip(),
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
return None
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def build_conversation_history(episode_history: List[Dict]) -> List[Dict]:
|
| 209 |
+
"""Build conversation history for multi-turn interaction."""
|
| 210 |
+
conversation = [
|
| 211 |
+
{
|
| 212 |
+
"role": "system",
|
| 213 |
+
"content": create_system_prompt(),
|
| 214 |
+
}
|
| 215 |
+
]
|
| 216 |
+
|
| 217 |
+
for turn in episode_history:
|
| 218 |
+
# Add assistant message
|
| 219 |
+
if turn.get("agent_action"):
|
| 220 |
+
conversation.append({
|
| 221 |
+
"role": "assistant",
|
| 222 |
+
"content": turn["agent_action"],
|
| 223 |
+
})
|
| 224 |
+
|
| 225 |
+
# Add environment feedback
|
| 226 |
+
if turn.get("environment_feedback"):
|
| 227 |
+
conversation.append({
|
| 228 |
+
"role": "user",
|
| 229 |
+
"content": turn["environment_feedback"],
|
| 230 |
+
})
|
| 231 |
+
|
| 232 |
+
return conversation
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# ==============================================================================
|
| 236 |
+
# EPISODE EXECUTION
|
| 237 |
+
# ==============================================================================
|
| 238 |
+
|
| 239 |
+
async def run_episode_async(
|
| 240 |
+
client: OpenAI,
|
| 241 |
+
image_name: str,
|
| 242 |
+
difficulty: str,
|
| 243 |
+
task_name: str,
|
| 244 |
+
) -> Dict:
|
| 245 |
+
"""
|
| 246 |
+
Run a single episode with asyncio.
|
| 247 |
+
|
| 248 |
+
Returns: {
|
| 249 |
+
"task": task_name,
|
| 250 |
+
"success": bool,
|
| 251 |
+
"steps_taken": int,
|
| 252 |
+
"total_reward": float,
|
| 253 |
+
"episode_rewards": [float],
|
| 254 |
+
"final_diagnosis_accuracy": float,
|
| 255 |
+
}
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
log_start(task_name, "medical_diagnostic_env", MODEL_NAME)
|
| 259 |
+
|
| 260 |
+
# Reset environment
|
| 261 |
+
async with DiagnosticEnv.from_docker_image(image_name=image_name, base_url=ENV_URL) as env:
|
| 262 |
+
obs_result = await env.reset(difficulty=difficulty)
|
| 263 |
+
obs = obs_result.observation if hasattr(obs_result, 'observation') else obs_result
|
| 264 |
+
|
| 265 |
+
episode_history = []
|
| 266 |
+
episode_rewards = []
|
| 267 |
+
step_count = 0
|
| 268 |
+
error_occurred = False
|
| 269 |
+
|
| 270 |
+
# Initial environment message
|
| 271 |
+
initial_message = f"Patient presentation: {obs.message}"
|
| 272 |
+
|
| 273 |
+
while step_count < MAX_STEPS and not obs.done:
|
| 274 |
+
step_count += 1
|
| 275 |
+
|
| 276 |
+
# Build conversation with history
|
| 277 |
+
conversation = [
|
| 278 |
+
{
|
| 279 |
+
"role": "system",
|
| 280 |
+
"content": create_system_prompt(),
|
| 281 |
+
}
|
| 282 |
+
]
|
| 283 |
+
|
| 284 |
+
# Add conversation history
|
| 285 |
+
for turn in episode_history:
|
| 286 |
+
if turn.get("agent_thought"):
|
| 287 |
+
conversation.append({
|
| 288 |
+
"role": "assistant",
|
| 289 |
+
"content": f"Thinking: {turn['agent_thought']}\nAction: {turn['agent_action']}",
|
| 290 |
+
})
|
| 291 |
+
if turn.get("environment_response"):
|
| 292 |
+
conversation.append({
|
| 293 |
+
"role": "user",
|
| 294 |
+
"content": turn["environment_response"],
|
| 295 |
+
})
|
| 296 |
+
|
| 297 |
+
# Add current observation if first step
|
| 298 |
+
if step_count == 1:
|
| 299 |
+
conversation.append({
|
| 300 |
+
"role": "user",
|
| 301 |
+
"content": initial_message,
|
| 302 |
+
})
|
| 303 |
+
|
| 304 |
+
try:
|
| 305 |
+
# Get LLM response
|
| 306 |
+
response = client.chat.completions.create(
|
| 307 |
+
model=MODEL_NAME,
|
| 308 |
+
messages=conversation,
|
| 309 |
+
temperature=TEMPERATURE,
|
| 310 |
+
max_tokens=MAX_TOKENS,
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
llm_response = response.choices[0].message.content
|
| 314 |
+
|
| 315 |
+
# Extract action from response
|
| 316 |
+
action_dict = extract_action_from_response(llm_response)
|
| 317 |
+
|
| 318 |
+
if not action_dict:
|
| 319 |
+
error_msg = "Could not parse action from response"
|
| 320 |
+
log_step(step_count, "parse_error", 0.0, False, error_msg)
|
| 321 |
+
error_occurred = True
|
| 322 |
+
break
|
| 323 |
+
|
| 324 |
+
# Create action
|
| 325 |
+
action = DiagnosticAction(
|
| 326 |
+
action_type=action_dict["action_type"],
|
| 327 |
+
question=action_dict.get("question"),
|
| 328 |
+
test_name=action_dict.get("test_name"),
|
| 329 |
+
diagnosis=action_dict.get("diagnosis"),
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
# Execute action with short action string for logging
|
| 333 |
+
action_str = f"{action.action_type}"
|
| 334 |
+
if action.question:
|
| 335 |
+
action_str += f"('{action.question[:30]}...')"
|
| 336 |
+
elif action.test_name:
|
| 337 |
+
action_str += f"('{action.test_name}')"
|
| 338 |
+
elif action.diagnosis:
|
| 339 |
+
action_str += f"('{action.diagnosis[:40]}...')"
|
| 340 |
+
|
| 341 |
+
# Take step in environment
|
| 342 |
+
step_result = await env.step(action)
|
| 343 |
+
obs = step_result.observation if hasattr(step_result, 'observation') else step_result
|
| 344 |
+
|
| 345 |
+
reward = obs.reward or 0.0
|
| 346 |
+
episode_rewards.append(reward)
|
| 347 |
+
|
| 348 |
+
log_step(step_count, action_str, reward, obs.done, None)
|
| 349 |
+
|
| 350 |
+
# Store in history
|
| 351 |
+
episode_history.append({
|
| 352 |
+
"agent_thought": llm_response[:100],
|
| 353 |
+
"agent_action": action_str,
|
| 354 |
+
"environment_response": obs.message[:200],
|
| 355 |
+
})
|
| 356 |
+
|
| 357 |
+
except Exception as e:
|
| 358 |
+
error_msg = str(e)[:100]
|
| 359 |
+
log_step(step_count, "error", 0.0, True, error_msg)
|
| 360 |
+
error_occurred = True
|
| 361 |
+
break
|
| 362 |
+
|
| 363 |
+
# Get final state
|
| 364 |
+
try:
|
| 365 |
+
state = await env.state()
|
| 366 |
+
final_accuracy = state.final_accuracy if hasattr(state, 'final_accuracy') else 0.0
|
| 367 |
+
except:
|
| 368 |
+
final_accuracy = 0.0
|
| 369 |
+
|
| 370 |
+
# Calculate results
|
| 371 |
+
total_reward = sum(episode_rewards)
|
| 372 |
+
success = obs.done and final_accuracy > 0.3
|
| 373 |
+
|
| 374 |
+
log_end(success, step_count, final_accuracy, episode_rewards)
|
| 375 |
+
|
| 376 |
+
return {
|
| 377 |
+
"task": task_name,
|
| 378 |
+
"success": success,
|
| 379 |
+
"steps_taken": step_count,
|
| 380 |
+
"total_reward": total_reward,
|
| 381 |
+
"episode_rewards": episode_rewards,
|
| 382 |
+
"final_diagnosis_accuracy": final_accuracy,
|
| 383 |
+
}
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
# ==============================================================================
|
| 387 |
+
# MAIN ORCHESTRATION
|
| 388 |
+
# ==============================================================================
|
| 389 |
+
|
| 390 |
+
async def run_all_tasks() -> Dict:
|
| 391 |
+
"""Run all 3 difficulty levels and report overall results."""
|
| 392 |
+
|
| 393 |
+
if not API_KEY:
|
| 394 |
+
print("ERROR: API key not found. Set HF_TOKEN, API_KEY, or OPENAI_API_KEY.", flush=True)
|
| 395 |
+
return {}
|
| 396 |
+
|
| 397 |
+
if not ENV_URL:
|
| 398 |
+
print("ERROR: ENV_URL is not set. Set ENV_URL to the environment WebSocket URL.", flush=True)
|
| 399 |
+
return {}
|
| 400 |
+
|
| 401 |
+
# Initialize OpenAI client
|
| 402 |
+
client = OpenAI(
|
| 403 |
+
api_key=API_KEY,
|
| 404 |
+
base_url=API_BASE_URL,
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
results = {
|
| 408 |
+
"timestamp": None,
|
| 409 |
+
"model": MODEL_NAME,
|
| 410 |
+
"environment": "medical_diagnostic_env",
|
| 411 |
+
"tasks_completed": 0,
|
| 412 |
+
"task_results": [],
|
| 413 |
+
"overall_score": 0.0,
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
# Run each task
|
| 417 |
+
for i, (task_name, difficulty) in enumerate(zip(TASK_NAMES, DIFFICULTY_LEVELS)):
|
| 418 |
+
print(f"\n--- Task {i+1}/3: {difficulty} difficulty ---", flush=True)
|
| 419 |
+
|
| 420 |
+
try:
|
| 421 |
+
result = await run_episode_async(
|
| 422 |
+
client,
|
| 423 |
+
LOCAL_IMAGE_NAME,
|
| 424 |
+
difficulty=difficulty,
|
| 425 |
+
task_name=task_name,
|
| 426 |
+
)
|
| 427 |
+
results["task_results"].append(result)
|
| 428 |
+
results["tasks_completed"] += 1
|
| 429 |
+
except Exception as e:
|
| 430 |
+
print(f"ERROR in task {task_name}: {str(e)}", flush=True)
|
| 431 |
+
|
| 432 |
+
# Calculate overall score
|
| 433 |
+
if results["tasks_completed"] > 0:
|
| 434 |
+
accuracies = [r["final_diagnosis_accuracy"] for r in results["task_results"]]
|
| 435 |
+
results["overall_score"] = sum(accuracies) / len(accuracies)
|
| 436 |
+
|
| 437 |
+
# Print summary
|
| 438 |
+
print("\n" + "="*60, flush=True)
|
| 439 |
+
print(f"Baseline Inference Complete", flush=True)
|
| 440 |
+
print(f"Tasks completed: {results['tasks_completed']}/3", flush=True)
|
| 441 |
+
print(f"Overall diagnostic accuracy: {results['overall_score']:.3f}", flush=True)
|
| 442 |
+
print("="*60, flush=True)
|
| 443 |
+
|
| 444 |
+
return results
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def main():
|
| 448 |
+
"""Entry point."""
|
| 449 |
+
# Run async loop
|
| 450 |
+
results = asyncio.run(run_all_tasks())
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
if __name__ == "__main__":
|
| 454 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
models.py — Type-safe contracts for the Medical Diagnostic Environment.
|
| 3 |
+
|
| 4 |
+
These Pydantic models define the interface between the LLM agent and the environment:
|
| 5 |
+
- DiagnosticAction: What the agent sends (questions, tests, diagnoses)
|
| 6 |
+
- PatientObservation: What the agent receives (feedback, test results, progress)
|
| 7 |
+
- ClinicalState: Full episode state (for debugging, not sent to agent)
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from typing import Optional, List, Dict
|
| 11 |
+
from openenv.core.env_server import Action, Observation, State
|
| 12 |
+
from pydantic import Field
|
| 13 |
+
|
| 14 |
+
class DiagnosticAction(Action):
|
| 15 |
+
"""
|
| 16 |
+
Actions the LLM agent can take during diagnosis.
|
| 17 |
+
|
| 18 |
+
The agent must choose one action per step:
|
| 19 |
+
1. ask_question: Gather patient history
|
| 20 |
+
2. order_test: Request diagnostic test results
|
| 21 |
+
3. submit_diagnosis: Make final diagnosis (ends episode)
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
action_type: str # "ask_question", "order_test", "submit_diagnosis"
|
| 25 |
+
question: Optional[str] = None # Used when action_type="ask_question"
|
| 26 |
+
test_name: Optional[str] = None # Used when action_type="order_test"
|
| 27 |
+
diagnosis: Optional[str] = None # Used when action_type="submit_diagnosis"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class PatientObservation(Observation):
|
| 31 |
+
"""
|
| 32 |
+
What the agent observes after taking an action.
|
| 33 |
+
|
| 34 |
+
Inherits from Observation:
|
| 35 |
+
- done: bool → Is the episode over?
|
| 36 |
+
- reward: Optional[float] → Reward signal
|
| 37 |
+
|
| 38 |
+
Adds medical-specific fields:
|
| 39 |
+
- message: Human-readable feedback
|
| 40 |
+
- patient_response: Answer to question (if applicable)
|
| 41 |
+
- test_result: Test outcome with interpretation
|
| 42 |
+
- questions_asked: History of all questions
|
| 43 |
+
- tests_completed: History of all completed tests
|
| 44 |
+
- patient_data_revealed: What the agent has discovered so far
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
message: str # Feedback from environment
|
| 48 |
+
patient_response: Optional[str] = None # Answer to a question asked
|
| 49 |
+
test_result: Optional[Dict] = None # {"test_name": "X", "result": "...", "interpretation": "..."}
|
| 50 |
+
questions_asked: List[str] = Field(default_factory=list)
|
| 51 |
+
tests_completed: List[str] = Field(default_factory=list)
|
| 52 |
+
patient_data_revealed: Dict = Field(default_factory=dict)
|
| 53 |
+
steps_taken: int = 0 # How many actions so far
|
| 54 |
+
max_steps: int = 15 # Maximum steps allowed
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ClinicalState(State):
|
| 58 |
+
"""
|
| 59 |
+
Complete internal state snapshot. Contains hidden information (diagnosis, true findings).
|
| 60 |
+
Use for debugging only - NEVER send to agent.
|
| 61 |
+
|
| 62 |
+
Inherits from State:
|
| 63 |
+
- episode_id: str → Unique episode identifier
|
| 64 |
+
- step_count: int → Current step number
|
| 65 |
+
|
| 66 |
+
Adds clinical fields:
|
| 67 |
+
- true_diagnosis: The correct diagnosis (hidden from agent)
|
| 68 |
+
- patient_case: Case identifier
|
| 69 |
+
- patient_details: Full patient information (hidden)
|
| 70 |
+
- difficulty: ease|medium|hard
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
true_diagnosis: str = ""
|
| 74 |
+
patient_case: str = ""
|
| 75 |
+
patient_id: str = ""
|
| 76 |
+
patient_details: Dict = Field(default_factory=dict)
|
| 77 |
+
difficulty: str = "easy"
|
| 78 |
+
questions_asked: List[str] = Field(default_factory=list)
|
| 79 |
+
tests_completed: List[str] = Field(default_factory=list)
|
| 80 |
+
final_diagnosis_submitted: Optional[str] = None
|
| 81 |
+
final_accuracy: float = 0.0
|
pyproject.toml
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "medical-diagnostic-env"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "OpenEnv environment for medical diagnosis RL training"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
|
| 12 |
+
authors = [
|
| 13 |
+
{name = "Team SYNAPSE", email = "synapse@example.com"}
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
keywords = [
|
| 17 |
+
"reinforcement-learning",
|
| 18 |
+
"medical",
|
| 19 |
+
"diagnosis",
|
| 20 |
+
"healthcare",
|
| 21 |
+
"rl-training",
|
| 22 |
+
"llm",
|
| 23 |
+
"openenv"
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
classifiers = [
|
| 27 |
+
"Development Status :: 4 - Beta",
|
| 28 |
+
"Intended Audience :: Developers",
|
| 29 |
+
"Intended Audience :: Science/Research",
|
| 30 |
+
"License :: OSI Approved :: MIT License",
|
| 31 |
+
"Programming Language :: Python :: 3",
|
| 32 |
+
"Programming Language :: Python :: 3.10",
|
| 33 |
+
"Programming Language :: Python :: 3.11",
|
| 34 |
+
"Programming Language :: Python :: 3.12",
|
| 35 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
dependencies = [
|
| 39 |
+
"openenv-core>=0.2.3",
|
| 40 |
+
"fastapi>=0.104.0",
|
| 41 |
+
"uvicorn[standard]>=0.24.0",
|
| 42 |
+
"websockets>=16.0",
|
| 43 |
+
"pydantic>=2.5.0",
|
| 44 |
+
"pydantic-settings>=2.1.0",
|
| 45 |
+
"openai>=1.3.0",
|
| 46 |
+
"requests>=2.31.0",
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
[project.optional-dependencies]
|
| 50 |
+
dev = [
|
| 51 |
+
"pytest>=7.0",
|
| 52 |
+
"pytest-asyncio>=0.21.0",
|
| 53 |
+
"black>=23.0",
|
| 54 |
+
"ruff>=0.1.0",
|
| 55 |
+
"mypy>=1.0",
|
| 56 |
+
]
|
| 57 |
+
|
| 58 |
+
[project.urls]
|
| 59 |
+
Homepage = "https://github.com/meta-pytorch/OpenEnv"
|
| 60 |
+
Documentation = "https://meta-pytorch.org/OpenEnv/"
|
| 61 |
+
Repository = "https://github.com/meta-pytorch/OpenEnv"
|
| 62 |
+
Issues = "https://github.com/meta-pytorch/OpenEnv/issues"
|
| 63 |
+
|
| 64 |
+
[tool.setuptools]
|
| 65 |
+
packages = ["medical_diagnostic_env"]
|
| 66 |
+
|
| 67 |
+
[tool.black]
|
| 68 |
+
line-length = 100
|
| 69 |
+
target-version = ["py310", "py311", "py312"]
|
| 70 |
+
|
| 71 |
+
[tool.ruff]
|
| 72 |
+
line-length = 100
|
| 73 |
+
target-version = "py310"
|
| 74 |
+
select = ["E", "F", "W", "I"]
|
| 75 |
+
|
| 76 |
+
[tool.mypy]
|
| 77 |
+
python_version = "3.10"
|
| 78 |
+
check_untyped_defs = true
|
| 79 |
+
disallow_untyped_defs = false
|
| 80 |
+
warn_unused_ignores = true
|
server/Dockerfile
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Medical Diagnostic Environment - Production Dockerfile
|
| 2 |
+
#
|
| 3 |
+
# This Dockerfile containerizes the complete Medical Diagnostic Environment.
|
| 4 |
+
# It can be deployed to Docker Hub, GitHub Container Registry, or Hugging Face Spaces.
|
| 5 |
+
|
| 6 |
+
FROM python:3.11-slim
|
| 7 |
+
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
|
| 10 |
+
# Set environment variables
|
| 11 |
+
ENV PYTHONUNBUFFERED=1 \
|
| 12 |
+
PYTHONDONTWRITEBYTECODE=1 \
|
| 13 |
+
PIP_NO_CACHE_DIR=1 \
|
| 14 |
+
PIP_DISABLE_PIP_VERSION_CHECK=1
|
| 15 |
+
|
| 16 |
+
# Install system dependencies (minimal)
|
| 17 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 18 |
+
curl \
|
| 19 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 20 |
+
|
| 21 |
+
# Copy and install Python dependencies first (for layer caching)
|
| 22 |
+
COPY server/requirements.txt /tmp/requirements.txt
|
| 23 |
+
RUN pip install --no-cache-dir -r /tmp/requirements.txt && \
|
| 24 |
+
rm /tmp/requirements.txt
|
| 25 |
+
|
| 26 |
+
# Copy application code
|
| 27 |
+
COPY models.py .
|
| 28 |
+
COPY client.py .
|
| 29 |
+
COPY server/ ./server/
|
| 30 |
+
|
| 31 |
+
# Create __init__ files for Python packages
|
| 32 |
+
RUN touch __init__.py && \
|
| 33 |
+
touch server/__init__.py
|
| 34 |
+
|
| 35 |
+
# Health check - validates the server is running and responsive
|
| 36 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 37 |
+
CMD curl -f http://localhost:8000/health || exit 1
|
| 38 |
+
|
| 39 |
+
# Expose port
|
| 40 |
+
EXPOSE 8000
|
| 41 |
+
|
| 42 |
+
# Run the FastAPI server with uvicorn
|
| 43 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
|
server/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Medical Diagnostic Environment Server - Package initialization
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from .environment import MedicalDiagnosticEnvironment
|
| 6 |
+
from .medical_data import (
|
| 7 |
+
PATIENT_CASES,
|
| 8 |
+
calculate_question_reward,
|
| 9 |
+
calculate_test_reward,
|
| 10 |
+
calculate_diagnosis_accuracy,
|
| 11 |
+
get_patient_response,
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"MedicalDiagnosticEnvironment",
|
| 16 |
+
"PATIENT_CASES",
|
| 17 |
+
"calculate_question_reward",
|
| 18 |
+
"calculate_test_reward",
|
| 19 |
+
"calculate_diagnosis_accuracy",
|
| 20 |
+
"get_patient_response",
|
| 21 |
+
]
|
server/app.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
server/app.py — FastAPI server for the Medical Diagnostic Environment.
|
| 3 |
+
|
| 4 |
+
This exposes the environment over WebSocket and HTTP using OpenEnv's built-in
|
| 5 |
+
create_fastapi_app helper. One line of meaningful code!
|
| 6 |
+
|
| 7 |
+
The helper automatically creates:
|
| 8 |
+
- /ws endpoint for WebSocket connections (stateful, for training)
|
| 9 |
+
- /reset, /step, /state endpoints (stateless, for testing)
|
| 10 |
+
- /health endpoint (for Docker health checks)
|
| 11 |
+
- /docs endpoint (auto-generated OpenAPI documentation)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from openenv.core.env_server import create_fastapi_app
|
| 15 |
+
import sys
|
| 16 |
+
import os
|
| 17 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 18 |
+
|
| 19 |
+
from models import DiagnosticAction, PatientObservation
|
| 20 |
+
from environment import MedicalDiagnosticEnvironment
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# Create the environment instance
|
| 24 |
+
env = MedicalDiagnosticEnvironment()
|
| 25 |
+
|
| 26 |
+
# Create FastAPI app with all endpoints
|
| 27 |
+
app = create_fastapi_app(
|
| 28 |
+
env,
|
| 29 |
+
DiagnosticAction,
|
| 30 |
+
PatientObservation,
|
| 31 |
+
max_concurrent_envs=100, # Support up to 100 parallel training sessions
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Optional: Add custom middleware or endpoints here if needed
|
| 35 |
+
# (Most common use cases are already handled by create_fastapi_app)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if __name__ == "__main__":
|
| 39 |
+
import uvicorn
|
| 40 |
+
uvicorn.run(
|
| 41 |
+
"app:app",
|
| 42 |
+
host="0.0.0.0",
|
| 43 |
+
port=8000,
|
| 44 |
+
workers=4,
|
| 45 |
+
reload=False,
|
| 46 |
+
)
|
server/environment.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
server/environment.py — Core medical diagnostic environment logic.
|
| 3 |
+
|
| 4 |
+
This is the BRAIN of the Medical Diagnostic Environment. It:
|
| 5 |
+
1. Manages patient cases and episode state
|
| 6 |
+
2. Processes agent actions (questions, tests, diagnoses)
|
| 7 |
+
3. Calculates rewards based on diagnostic quality
|
| 8 |
+
4. Provides trajectory-based reward signals (not sparse)
|
| 9 |
+
|
| 10 |
+
Pure Python - no HTTP or WebSocket code here.
|
| 11 |
+
All logic is deterministic and reproducible.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import random
|
| 15 |
+
import uuid
|
| 16 |
+
from typing import Dict, List, Optional
|
| 17 |
+
from openenv.core.env_server import Environment
|
| 18 |
+
|
| 19 |
+
import sys
|
| 20 |
+
import os
|
| 21 |
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
| 22 |
+
|
| 23 |
+
from models import DiagnosticAction, PatientObservation, ClinicalState
|
| 24 |
+
from server.medical_data import (
|
| 25 |
+
PATIENT_CASES,
|
| 26 |
+
calculate_question_reward,
|
| 27 |
+
calculate_test_reward,
|
| 28 |
+
calculate_diagnosis_accuracy,
|
| 29 |
+
get_patient_response,
|
| 30 |
+
normalize_test_name,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class MedicalDiagnosticEnvironment(Environment):
|
| 35 |
+
"""
|
| 36 |
+
Medical Diagnostic Environment for RL Training.
|
| 37 |
+
|
| 38 |
+
Simulates doctor-patient interaction where an LLM agent must:
|
| 39 |
+
1. Ask relevant clinical questions
|
| 40 |
+
2. Order appropriate diagnostic tests
|
| 41 |
+
3. Make accurate diagnoses
|
| 42 |
+
|
| 43 |
+
The environment provides rich reward signals throughout the trajectory:
|
| 44 |
+
- +0.05 per relevant question asked
|
| 45 |
+
- +0.10 per informative test ordered
|
| 46 |
+
- +1.0 for correct final diagnosis
|
| 47 |
+
- Penalizes inefficient or irrelevant actions
|
| 48 |
+
|
| 49 |
+
This is NOT a sparse reward environment - the agent sees meaningful progress
|
| 50 |
+
at each step, which is crucial for learning.
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
SUPPORTS_CONCURRENT_SESSIONS = True # Allow multiple parallel training sessions
|
| 54 |
+
|
| 55 |
+
# Episode configuration
|
| 56 |
+
MAX_STEPS = 15 # Maximum actions per episode
|
| 57 |
+
DIFFICULTY_LEVELS = ["easy", "medium", "hard"]
|
| 58 |
+
|
| 59 |
+
def __init__(self):
|
| 60 |
+
"""Initialize environment state."""
|
| 61 |
+
super().__init__()
|
| 62 |
+
|
| 63 |
+
# Episode state variables
|
| 64 |
+
self._episode_id: str = ""
|
| 65 |
+
self._case_id: str = ""
|
| 66 |
+
self._difficulty: str = ""
|
| 67 |
+
self._step_count: int = 0
|
| 68 |
+
self._total_reward: float = 0.0
|
| 69 |
+
|
| 70 |
+
# Patient interaction tracking
|
| 71 |
+
self._questions_asked: List[str] = []
|
| 72 |
+
self._tests_ordered: List[str] = []
|
| 73 |
+
self._test_results: Dict = {}
|
| 74 |
+
self._diagnosis_submitted: Optional[str] = None
|
| 75 |
+
self._final_accuracy: float = 0.0
|
| 76 |
+
|
| 77 |
+
# Episode status
|
| 78 |
+
self._done: bool = False
|
| 79 |
+
self._episode_reward_breakdown: Dict = {
|
| 80 |
+
"question_rewards": 0.0,
|
| 81 |
+
"test_rewards": 0.0,
|
| 82 |
+
"diagnosis_reward": 0.0,
|
| 83 |
+
"efficiency_penalty": 0.0,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def current_case_id(self) -> str:
|
| 88 |
+
"""Current patient case identifier."""
|
| 89 |
+
return self._case_id
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def current_difficulty(self) -> str:
|
| 93 |
+
"""Current episode difficulty level."""
|
| 94 |
+
return self._difficulty
|
| 95 |
+
|
| 96 |
+
# ─────────────────────────────────────────────────────────────────────
|
| 97 |
+
# Core API Methods
|
| 98 |
+
# ─────────────────────────────────────────────────────────────────────
|
| 99 |
+
|
| 100 |
+
def reset(self, difficulty: str = "easy", **kwargs) -> PatientObservation:
|
| 101 |
+
"""
|
| 102 |
+
Reset the environment for a new diagnostic episode.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
difficulty: "easy", "medium", or "hard" (controls case selection)
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
Initial PatientObservation for the agent to read
|
| 109 |
+
"""
|
| 110 |
+
# Initialize episode
|
| 111 |
+
self._episode_id = str(uuid.uuid4())
|
| 112 |
+
self._difficulty = difficulty if difficulty in self.DIFFICULTY_LEVELS else "easy"
|
| 113 |
+
self._case_id = self._select_case_by_difficulty(self._difficulty)
|
| 114 |
+
self._step_count = 0
|
| 115 |
+
self._total_reward = 0.0
|
| 116 |
+
|
| 117 |
+
# Reset tracking
|
| 118 |
+
self._questions_asked = []
|
| 119 |
+
self._tests_ordered = []
|
| 120 |
+
self._test_results = {}
|
| 121 |
+
self._diagnosis_submitted = None
|
| 122 |
+
self._final_accuracy = 0.0
|
| 123 |
+
self._done = False
|
| 124 |
+
self._episode_reward_breakdown = {
|
| 125 |
+
"question_rewards": 0.0,
|
| 126 |
+
"test_rewards": 0.0,
|
| 127 |
+
"diagnosis_reward": 0.0,
|
| 128 |
+
"efficiency_penalty": 0.0,
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
# Get case information
|
| 132 |
+
case = PATIENT_CASES[self._case_id]
|
| 133 |
+
|
| 134 |
+
# Create initial observation
|
| 135 |
+
initial_message = (
|
| 136 |
+
f"Patient presents with: {case['presentation']}\n"
|
| 137 |
+
f"Age: {case['age']}, Gender: {case['gender']}\n"
|
| 138 |
+
f"You have up to {self.MAX_STEPS} steps to diagnose this patient.\n"
|
| 139 |
+
f"Please start by asking questions or ordering tests."
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
return PatientObservation(
|
| 143 |
+
done=False,
|
| 144 |
+
reward=0.0,
|
| 145 |
+
message=initial_message,
|
| 146 |
+
patient_response=None,
|
| 147 |
+
test_result=None,
|
| 148 |
+
questions_asked=[],
|
| 149 |
+
tests_completed=[],
|
| 150 |
+
patient_data_revealed={
|
| 151 |
+
"age": case["age"],
|
| 152 |
+
"gender": case["gender"],
|
| 153 |
+
"presentation": case["presentation"],
|
| 154 |
+
},
|
| 155 |
+
steps_taken=0,
|
| 156 |
+
max_steps=self.MAX_STEPS,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def step(self, action: DiagnosticAction, **kwargs) -> PatientObservation:
|
| 160 |
+
"""
|
| 161 |
+
Process one diagnostic action (question, test, or diagnosis).
|
| 162 |
+
|
| 163 |
+
Returns immediate reward and next observation.
|
| 164 |
+
"""
|
| 165 |
+
if self._done:
|
| 166 |
+
return self._create_done_observation(
|
| 167 |
+
message="Episode already ended. Call reset() to start a new case."
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self._step_count += 1
|
| 171 |
+
step_reward = 0.0
|
| 172 |
+
message = ""
|
| 173 |
+
patient_response = None
|
| 174 |
+
test_result = None
|
| 175 |
+
|
| 176 |
+
# ── Process action based on type ──
|
| 177 |
+
if action.action_type == "ask_question":
|
| 178 |
+
step_reward, message, patient_response = self._handle_question(action.question)
|
| 179 |
+
|
| 180 |
+
elif action.action_type == "order_test":
|
| 181 |
+
step_reward, message, test_result = self._handle_test(action.test_name)
|
| 182 |
+
|
| 183 |
+
elif action.action_type == "submit_diagnosis":
|
| 184 |
+
step_reward, message = self._handle_diagnosis(action.diagnosis)
|
| 185 |
+
self._done = True
|
| 186 |
+
|
| 187 |
+
else:
|
| 188 |
+
message = f"Unknown action type: {action.action_type}"
|
| 189 |
+
step_reward = -0.05
|
| 190 |
+
|
| 191 |
+
# Accumulate rewards
|
| 192 |
+
self._total_reward += step_reward
|
| 193 |
+
|
| 194 |
+
# Check if episode should end
|
| 195 |
+
if self._step_count >= self.MAX_STEPS and not self._done:
|
| 196 |
+
message += f"\nMax steps reached. Episode ending."
|
| 197 |
+
self._done = True
|
| 198 |
+
|
| 199 |
+
# Get current case for patient data revelation
|
| 200 |
+
case = PATIENT_CASES[self._case_id]
|
| 201 |
+
|
| 202 |
+
return PatientObservation(
|
| 203 |
+
done=self._done,
|
| 204 |
+
reward=step_reward,
|
| 205 |
+
message=message,
|
| 206 |
+
patient_response=patient_response,
|
| 207 |
+
test_result=test_result,
|
| 208 |
+
questions_asked=self._questions_asked.copy(),
|
| 209 |
+
tests_completed=self._tests_ordered.copy(),
|
| 210 |
+
patient_data_revealed=self._build_patient_data_view(case),
|
| 211 |
+
steps_taken=self._step_count,
|
| 212 |
+
max_steps=self.MAX_STEPS,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
def state(self) -> ClinicalState:
|
| 216 |
+
"""
|
| 217 |
+
Return complete internal state (includes hidden information).
|
| 218 |
+
Used for debugging only - NEVER send to agent.
|
| 219 |
+
"""
|
| 220 |
+
case = PATIENT_CASES.get(self._case_id, {})
|
| 221 |
+
|
| 222 |
+
return ClinicalState(
|
| 223 |
+
episode_id=self._episode_id,
|
| 224 |
+
step_count=self._step_count,
|
| 225 |
+
true_diagnosis=case.get("true_diagnosis", ""),
|
| 226 |
+
patient_case=self._case_id,
|
| 227 |
+
patient_id=self._case_id,
|
| 228 |
+
patient_details=case,
|
| 229 |
+
difficulty=self._difficulty,
|
| 230 |
+
questions_asked=self._questions_asked.copy(),
|
| 231 |
+
tests_completed=self._tests_ordered.copy(),
|
| 232 |
+
final_diagnosis_submitted=self._diagnosis_submitted,
|
| 233 |
+
final_accuracy=self._final_accuracy,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
# ─────────────────────────────────────────────────────────────────────
|
| 237 |
+
# Action Processing
|
| 238 |
+
# ─────────────────────────────────────────────────────────────────────
|
| 239 |
+
|
| 240 |
+
def _handle_question(self, question: Optional[str]) -> tuple:
|
| 241 |
+
"""
|
| 242 |
+
Process a question about the patient.
|
| 243 |
+
|
| 244 |
+
Returns:
|
| 245 |
+
(reward, message, patient_response)
|
| 246 |
+
"""
|
| 247 |
+
if not question or not isinstance(question, str) or not question.strip():
|
| 248 |
+
message = "No valid question was provided. Please ask a clinical question."
|
| 249 |
+
return -0.05, message, None
|
| 250 |
+
|
| 251 |
+
# Calculate reward for asking this question
|
| 252 |
+
reward = calculate_question_reward(self._case_id, question)
|
| 253 |
+
|
| 254 |
+
# Record question
|
| 255 |
+
self._questions_asked.append(question)
|
| 256 |
+
self._episode_reward_breakdown["question_rewards"] += reward
|
| 257 |
+
|
| 258 |
+
# Get patient response
|
| 259 |
+
response = get_patient_response(self._case_id, question)
|
| 260 |
+
|
| 261 |
+
message = f"Patient response: {response}"
|
| 262 |
+
if reward == 0.00:
|
| 263 |
+
message += " (Question may not be directly relevant)"
|
| 264 |
+
elif reward == 0.01:
|
| 265 |
+
message += " (Somewhat relevant question)"
|
| 266 |
+
else:
|
| 267 |
+
message += " (Good clinical question!)"
|
| 268 |
+
|
| 269 |
+
return reward, message, response
|
| 270 |
+
|
| 271 |
+
def _handle_test(self, test_name: Optional[str]) -> tuple:
|
| 272 |
+
"""
|
| 273 |
+
Process a test order.
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
(reward, message, test_result_dict)
|
| 277 |
+
"""
|
| 278 |
+
if not test_name or not isinstance(test_name, str) or not test_name.strip():
|
| 279 |
+
message = "No valid test name was provided. Please order a valid diagnostic test."
|
| 280 |
+
return -0.05, message, None
|
| 281 |
+
|
| 282 |
+
# Calculate reward for ordering this test
|
| 283 |
+
reward = calculate_test_reward(self._case_id, test_name)
|
| 284 |
+
|
| 285 |
+
# Get case data
|
| 286 |
+
case = PATIENT_CASES[self._case_id]
|
| 287 |
+
|
| 288 |
+
# Try to find matching test result
|
| 289 |
+
test_result_data = None
|
| 290 |
+
matched_test_key = None
|
| 291 |
+
|
| 292 |
+
test_lower = normalize_test_name(test_name)
|
| 293 |
+
for test_key, result in case.get("test_results", {}).items():
|
| 294 |
+
if test_key.lower() == test_lower or test_key.lower() in test_lower or test_lower in test_key.lower():
|
| 295 |
+
test_result_data = result
|
| 296 |
+
matched_test_key = test_key
|
| 297 |
+
break
|
| 298 |
+
|
| 299 |
+
if test_result_data is None:
|
| 300 |
+
message = f"Test '{test_name}' not available for this patient or unavailable in this setting."
|
| 301 |
+
reward = -0.02
|
| 302 |
+
return reward, message, None
|
| 303 |
+
|
| 304 |
+
# Record test
|
| 305 |
+
self._tests_ordered.append(matched_test_key)
|
| 306 |
+
self._test_results[matched_test_key] = test_result_data
|
| 307 |
+
self._episode_reward_breakdown["test_rewards"] += reward
|
| 308 |
+
|
| 309 |
+
# Format test result for agent
|
| 310 |
+
test_result_dict = {
|
| 311 |
+
"test_name": matched_test_key,
|
| 312 |
+
"result": str(test_result_data),
|
| 313 |
+
"interpretation": test_result_data.get("interpretation", test_result_data.get("finding", ""))
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
message = f"Test result received for {matched_test_key}:\n{test_result_dict['interpretation']}"
|
| 317 |
+
if reward == 0.10:
|
| 318 |
+
message += " (Excellent diagnostic test!)"
|
| 319 |
+
elif reward == 0.05:
|
| 320 |
+
message += " (Useful supporting test)"
|
| 321 |
+
else:
|
| 322 |
+
message += " (Test ordered but may be less relevant)"
|
| 323 |
+
|
| 324 |
+
return reward, message, test_result_dict
|
| 325 |
+
|
| 326 |
+
def _handle_diagnosis(self, diagnosis: str) -> tuple:
|
| 327 |
+
"""
|
| 328 |
+
Process final diagnosis submission.
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
(reward, message)
|
| 332 |
+
"""
|
| 333 |
+
# Calculate diagnostic accuracy
|
| 334 |
+
accuracy = calculate_diagnosis_accuracy(self._case_id, diagnosis)
|
| 335 |
+
self._final_accuracy = accuracy
|
| 336 |
+
self._diagnosis_submitted = diagnosis
|
| 337 |
+
|
| 338 |
+
# Create diagnosis reward (not just accuracy, but also process quality)
|
| 339 |
+
case = PATIENT_CASES[self._case_id]
|
| 340 |
+
true_diagnosis = case["true_diagnosis"]
|
| 341 |
+
|
| 342 |
+
# Change the if/elif chain to use >= comparisons:
|
| 343 |
+
if accuracy >= 0.95:
|
| 344 |
+
reward = 1.0
|
| 345 |
+
message = f"Correct diagnosis: {diagnosis}"
|
| 346 |
+
elif accuracy >= 0.7:
|
| 347 |
+
reward = accuracy
|
| 348 |
+
message = f"Acceptable diagnosis: {diagnosis}. True: {true_diagnosis}"
|
| 349 |
+
elif accuracy >= 0.3:
|
| 350 |
+
reward = accuracy
|
| 351 |
+
message = f"Partially correct. True: {true_diagnosis}"
|
| 352 |
+
else:
|
| 353 |
+
reward = 0.0
|
| 354 |
+
message = f"Incorrect. True: {true_diagnosis}"
|
| 355 |
+
|
| 356 |
+
self._episode_reward_breakdown["diagnosis_reward"] = reward
|
| 357 |
+
|
| 358 |
+
# Add efficiency feedback
|
| 359 |
+
if self._step_count > self.MAX_STEPS * 0.8:
|
| 360 |
+
penalty = 0.1 * (self._step_count / self.MAX_STEPS - 0.8)
|
| 361 |
+
self._episode_reward_breakdown["efficiency_penalty"] = penalty
|
| 362 |
+
message += f"\n(Efficiency penalty: -{penalty:.2f} for taking many steps)"
|
| 363 |
+
|
| 364 |
+
return reward, message
|
| 365 |
+
|
| 366 |
+
# ─────────────────────────────────────────────────────────────────────
|
| 367 |
+
# Helper Methods
|
| 368 |
+
# ─────────────────────────────────────────────────────────────────────
|
| 369 |
+
|
| 370 |
+
def _select_case_by_difficulty(self, difficulty: str) -> str:
|
| 371 |
+
matching_keys = [k for k, v in PATIENT_CASES.items() if v["difficulty"] == difficulty]
|
| 372 |
+
if not matching_keys:
|
| 373 |
+
matching_keys = list(PATIENT_CASES.keys())
|
| 374 |
+
return random.choice(matching_keys)
|
| 375 |
+
|
| 376 |
+
def _build_patient_data_view(self, case: Dict) -> Dict:
|
| 377 |
+
"""
|
| 378 |
+
Build what the agent has learned about the patient so far.
|
| 379 |
+
Only includes information revealed through questions/tests.
|
| 380 |
+
"""
|
| 381 |
+
revealed = {
|
| 382 |
+
"age": case.get("age"),
|
| 383 |
+
"gender": case.get("gender"),
|
| 384 |
+
"presentation": case.get("presentation"),
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
# Add findings based on questions asked
|
| 388 |
+
findings = case.get("hidden_findings", {})
|
| 389 |
+
for question in self._questions_asked:
|
| 390 |
+
q_lower = question.lower()
|
| 391 |
+
for finding, value in findings.items():
|
| 392 |
+
if finding.lower().replace("_", " ") in q_lower:
|
| 393 |
+
revealed[f"finding_{finding}"] = value
|
| 394 |
+
|
| 395 |
+
# Add test results
|
| 396 |
+
if self._test_results:
|
| 397 |
+
revealed["test_results"] = self._test_results
|
| 398 |
+
|
| 399 |
+
return revealed
|
| 400 |
+
|
| 401 |
+
def _create_done_observation(self, message: str) -> PatientObservation:
|
| 402 |
+
"""Create a terminal observation."""
|
| 403 |
+
return PatientObservation(
|
| 404 |
+
done=True,
|
| 405 |
+
reward=0.0,
|
| 406 |
+
message=message,
|
| 407 |
+
patient_response=None,
|
| 408 |
+
test_result=None,
|
| 409 |
+
questions_asked=self._questions_asked.copy(),
|
| 410 |
+
tests_completed=self._tests_ordered.copy(),
|
| 411 |
+
patient_data_revealed={},
|
| 412 |
+
steps_taken=self._step_count,
|
| 413 |
+
max_steps=self.MAX_STEPS,
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
def get_episode_summary(self) -> Dict:
|
| 417 |
+
"""
|
| 418 |
+
Return a summary of the episode for logging/evaluation.
|
| 419 |
+
"""
|
| 420 |
+
case = PATIENT_CASES.get(self._case_id, {})
|
| 421 |
+
return {
|
| 422 |
+
"episode_id": self._episode_id,
|
| 423 |
+
"case_id": self._case_id,
|
| 424 |
+
"difficulty": self._difficulty,
|
| 425 |
+
"true_diagnosis": case.get("true_diagnosis", ""),
|
| 426 |
+
"submitted_diagnosis": self._diagnosis_submitted,
|
| 427 |
+
"accuracy": self._final_accuracy,
|
| 428 |
+
"diagnostic_accuracy": self._final_accuracy,
|
| 429 |
+
"total_reward": self._total_reward,
|
| 430 |
+
"steps": self._step_count,
|
| 431 |
+
"steps_taken": self._step_count,
|
| 432 |
+
"max_steps": self.MAX_STEPS,
|
| 433 |
+
"questions_asked": len(self._questions_asked),
|
| 434 |
+
"tests_ordered": len(self._tests_ordered),
|
| 435 |
+
"reward_breakdown": self._episode_reward_breakdown,
|
| 436 |
+
}
|
server/medical_data.py
ADDED
|
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
from typing import Dict, List, Tuple, Optional
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
try:
|
| 7 |
+
from datasets import load_dataset
|
| 8 |
+
_DATASETS_AVAILABLE = True
|
| 9 |
+
except ImportError:
|
| 10 |
+
_DATASETS_AVAILABLE = False
|
| 11 |
+
|
| 12 |
+
USE_HF_DATASETS = os.getenv("OPENENV_USE_HF_DATASETS", "false").lower() in ("1", "true", "yes")
|
| 13 |
+
DATASET_SEED = os.getenv("OPENENV_DATASET_SEED")
|
| 14 |
+
if DATASET_SEED is not None:
|
| 15 |
+
try:
|
| 16 |
+
DATASET_SEED = int(DATASET_SEED)
|
| 17 |
+
except ValueError:
|
| 18 |
+
DATASET_SEED = None
|
| 19 |
+
# LOAD REAL DATASETS FROM HUGGING FACE
|
| 20 |
+
# ==============================================================================
|
| 21 |
+
|
| 22 |
+
def load_medical_datasets() -> Dict:
|
| 23 |
+
"""Load and format real medical datasets from Hugging Face."""
|
| 24 |
+
if not USE_HF_DATASETS:
|
| 25 |
+
return {}
|
| 26 |
+
|
| 27 |
+
if not _DATASETS_AVAILABLE:
|
| 28 |
+
print("Warning: datasets package not installed; skipping Hugging Face dataset loading.")
|
| 29 |
+
return {}
|
| 30 |
+
|
| 31 |
+
try:
|
| 32 |
+
# Load MedMCQA dataset (Medical Multiple Choice Questions)
|
| 33 |
+
medmcqa = load_dataset("medmcqa", split="train")
|
| 34 |
+
|
| 35 |
+
# Load BigBio MedQA dataset (Medical Question Answering)
|
| 36 |
+
medqa = load_dataset("bigbio/med_qa", split="train")
|
| 37 |
+
|
| 38 |
+
return {
|
| 39 |
+
"medmcqa": medmcqa,
|
| 40 |
+
"medqa": medqa
|
| 41 |
+
}
|
| 42 |
+
except Exception as e:
|
| 43 |
+
print(f"Warning: Could not load datasets: {e}. No dataset cases loaded.")
|
| 44 |
+
return {}
|
| 45 |
+
|
| 46 |
+
# ==============================================================================
|
| 47 |
+
# INNOVATIVE REWARD SYSTEM USING LLM JUDGMENT
|
| 48 |
+
# ==============================================================================
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
RELEVANT_QUESTION_KEYWORDS = {
|
| 52 |
+
"easy_flu": ["fever", "cough", "ache", "fatigue", "onset", "symptom", "contact", "vaccine", "temperature"],
|
| 53 |
+
"easy_uti": ["urination", "burning", "pain", "frequent", "abdominal", "bladder", "kidney"],
|
| 54 |
+
"medium_pneumonia": ["cough", "fever", "breath", "chest", "pain", "smoking", "sputum", "productive"],
|
| 55 |
+
"medium_appendicitis": ["abdominal", "pain", "nausea", "vomiting", "rebound", "right lower quadrant", "appetite", "fever"],
|
| 56 |
+
"hard_endocarditis": ["fever", "murmur", "drug", "iv", "dental", "hemorrhage", "splinter", "heart"],
|
| 57 |
+
"hard_meningitis": ["headache", "neck", "stiff", "fever", "photophobia", "confusion", "vomit", "seizure"],
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
def calculate_question_reward(case_id: str, question: str) -> float:
|
| 61 |
+
keywords = RELEVANT_QUESTION_KEYWORDS.get(case_id, [])
|
| 62 |
+
q_lower = question.lower()
|
| 63 |
+
matches = sum(1 for kw in keywords if kw in q_lower)
|
| 64 |
+
if matches >= 2: return 0.08
|
| 65 |
+
if matches == 1: return 0.05
|
| 66 |
+
return 0.01
|
| 67 |
+
|
| 68 |
+
TEST_NAME_ALIASES = {
|
| 69 |
+
"complete blood count": "cbc",
|
| 70 |
+
"cbc": "cbc",
|
| 71 |
+
"urinalysis": "urinalysis",
|
| 72 |
+
"urine culture": "urine_culture",
|
| 73 |
+
"blood cultures": "blood_cultures",
|
| 74 |
+
"echocardiogram": "echocardiogram",
|
| 75 |
+
"ct head": "ct_head",
|
| 76 |
+
"ct scan": "ct_head",
|
| 77 |
+
"chest xray": "chest_xray",
|
| 78 |
+
"chest radiograph": "chest_xray",
|
| 79 |
+
"sputum culture": "sputum_culture",
|
| 80 |
+
"rapid flu test": "rapid_flu_test",
|
| 81 |
+
"flu test": "rapid_flu_test",
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def normalize_test_name(test_name: Optional[str]) -> str:
|
| 86 |
+
if not test_name or not isinstance(test_name, str):
|
| 87 |
+
return ""
|
| 88 |
+
cleaned = test_name.strip().lower()
|
| 89 |
+
return TEST_NAME_ALIASES.get(cleaned, cleaned)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def calculate_test_reward(case_id: str, test_name: Optional[str]) -> float:
|
| 93 |
+
"""
|
| 94 |
+
Calculate reward for ordering a diagnostic test.
|
| 95 |
+
|
| 96 |
+
Returns higher reward for tests that are more relevant to the case.
|
| 97 |
+
"""
|
| 98 |
+
if not test_name or not isinstance(test_name, str):
|
| 99 |
+
return -0.02
|
| 100 |
+
|
| 101 |
+
case = PATIENT_CASES.get(case_id, {})
|
| 102 |
+
test_results = case.get("test_results", {})
|
| 103 |
+
|
| 104 |
+
test_lower = normalize_test_name(test_name)
|
| 105 |
+
|
| 106 |
+
# Check if test is available and relevant
|
| 107 |
+
for test_key in test_results.keys():
|
| 108 |
+
if test_key.lower() in test_lower or test_lower in test_key.lower():
|
| 109 |
+
# Test is available - give reward based on relevance
|
| 110 |
+
if "cbc" in test_lower or "blood" in test_lower:
|
| 111 |
+
return 0.10 # Common useful test
|
| 112 |
+
elif "flu" in test_lower or "influenza" in test_lower:
|
| 113 |
+
return 0.10 # Specific relevant test
|
| 114 |
+
else:
|
| 115 |
+
return 0.05 # Somewhat useful test
|
| 116 |
+
|
| 117 |
+
# Test not available or irrelevant
|
| 118 |
+
return -0.02
|
| 119 |
+
|
| 120 |
+
def calculate_diagnosis_accuracy(case_id: str, submitted: str) -> float:
|
| 121 |
+
case = PATIENT_CASES.get(case_id, {})
|
| 122 |
+
s = submitted.lower().strip()
|
| 123 |
+
true = case.get("true_diagnosis", "").lower()
|
| 124 |
+
if s == true: return 1.0
|
| 125 |
+
for acceptable in case.get("correct_diagnoses", []):
|
| 126 |
+
if acceptable.lower() in s or s in acceptable.lower(): return 1.0
|
| 127 |
+
# partial credit
|
| 128 |
+
true_words = set(true.split())
|
| 129 |
+
sub_words = set(s.split())
|
| 130 |
+
overlap = len(true_words & sub_words) / max(len(true_words), 1)
|
| 131 |
+
return round(min(overlap, 0.7), 2)
|
| 132 |
+
|
| 133 |
+
# ==============================================================================
|
| 134 |
+
# PATIENT CASES DATABASE (FORMATTED FROM REAL DATASETS)
|
| 135 |
+
# ==============================================================================
|
| 136 |
+
|
| 137 |
+
def format_medmcqa_to_case(entry: Dict) -> Dict:
|
| 138 |
+
"""Format a MedMCQA entry into our case structure."""
|
| 139 |
+
question = entry.get("question", "")
|
| 140 |
+
options = entry.get("options", {})
|
| 141 |
+
correct_answer = entry.get("answer", "")
|
| 142 |
+
subject = entry.get("subject_name", "")
|
| 143 |
+
|
| 144 |
+
# Create presentation from question
|
| 145 |
+
presentation = f"Patient presents with: {question}"
|
| 146 |
+
|
| 147 |
+
# Use options as possible diagnoses
|
| 148 |
+
diagnoses = list(options.values())
|
| 149 |
+
true_diagnosis = options.get(correct_answer, diagnoses[0] if diagnoses else "Unknown")
|
| 150 |
+
|
| 151 |
+
return {
|
| 152 |
+
"case_id": f"medmcqa_{entry.get('id', random.randint(1000,9999))}",
|
| 153 |
+
"difficulty": "medium", # Default to medium
|
| 154 |
+
"true_diagnosis": true_diagnosis,
|
| 155 |
+
"age": random.randint(25, 75),
|
| 156 |
+
"gender": random.choice(["Male", "Female"]),
|
| 157 |
+
"presentation": presentation,
|
| 158 |
+
"hidden_findings": {}, # Would need more processing
|
| 159 |
+
"test_results": {},
|
| 160 |
+
"correct_diagnoses": [true_diagnosis],
|
| 161 |
+
"differential_diagnoses": diagnoses[:3], # First 3 options
|
| 162 |
+
"source": "medmcqa"
|
| 163 |
+
}
|
| 164 |
+
|
| 165 |
+
def format_medqa_to_case(entry: Dict) -> Dict:
|
| 166 |
+
"""Format a BigBio MedQA entry into our case structure."""
|
| 167 |
+
question = entry.get("question", "")
|
| 168 |
+
answer = entry.get("answer", "")
|
| 169 |
+
|
| 170 |
+
presentation = f"Medical question: {question}"
|
| 171 |
+
|
| 172 |
+
return {
|
| 173 |
+
"case_id": f"medqa_{hash(question) % 10000}",
|
| 174 |
+
"difficulty": "hard", # MedQA is more complex
|
| 175 |
+
"true_diagnosis": answer,
|
| 176 |
+
"age": random.randint(30, 80),
|
| 177 |
+
"gender": random.choice(["Male", "Female"]),
|
| 178 |
+
"presentation": presentation,
|
| 179 |
+
"hidden_findings": {},
|
| 180 |
+
"test_results": {},
|
| 181 |
+
"correct_diagnoses": [answer],
|
| 182 |
+
"differential_diagnoses": [],
|
| 183 |
+
"source": "medqa"
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
def generate_patient_cases_from_datasets() -> Dict:
|
| 187 |
+
"""Generate patient cases from real Hugging Face datasets."""
|
| 188 |
+
cases = {}
|
| 189 |
+
datasets = load_medical_datasets()
|
| 190 |
+
if not datasets:
|
| 191 |
+
return cases
|
| 192 |
+
|
| 193 |
+
if DATASET_SEED is not None:
|
| 194 |
+
random.seed(DATASET_SEED)
|
| 195 |
+
|
| 196 |
+
# Generate easy cases from MedMCQA (simpler questions)
|
| 197 |
+
if "medmcqa" in datasets:
|
| 198 |
+
medmcqa_data = datasets["medmcqa"]
|
| 199 |
+
easy_indices = random.sample(range(len(medmcqa_data)), min(3, len(medmcqa_data)))
|
| 200 |
+
for i, idx in enumerate(easy_indices):
|
| 201 |
+
entry = medmcqa_data[idx]
|
| 202 |
+
case = format_medmcqa_to_case(entry)
|
| 203 |
+
case["difficulty"] = "easy"
|
| 204 |
+
cases[f"easy_real_{i}"] = case
|
| 205 |
+
|
| 206 |
+
# Generate medium cases
|
| 207 |
+
if "medmcqa" in datasets:
|
| 208 |
+
medmcqa_data = datasets["medmcqa"]
|
| 209 |
+
medium_indices = random.sample(range(len(medmcqa_data)), min(2, len(medmcqa_data)))
|
| 210 |
+
for i, idx in enumerate(medium_indices):
|
| 211 |
+
entry = medmcqa_data[idx]
|
| 212 |
+
case = format_medmcqa_to_case(entry)
|
| 213 |
+
case["difficulty"] = "medium"
|
| 214 |
+
cases[f"medium_real_{i}"] = case
|
| 215 |
+
|
| 216 |
+
# Generate hard cases from MedQA
|
| 217 |
+
if "medqa" in datasets:
|
| 218 |
+
medqa_data = datasets["medqa"]
|
| 219 |
+
hard_indices = random.sample(range(len(medqa_data)), min(2, len(medqa_data)))
|
| 220 |
+
for i, idx in enumerate(hard_indices):
|
| 221 |
+
entry = medqa_data[idx]
|
| 222 |
+
case = format_medqa_to_case(entry)
|
| 223 |
+
case["difficulty"] = "hard"
|
| 224 |
+
cases[f"hard_real_{i}"] = case
|
| 225 |
+
|
| 226 |
+
return cases
|
| 227 |
+
|
| 228 |
+
STATIC_PATIENT_CASES = {
|
| 229 |
+
"easy_flu": {
|
| 230 |
+
"case_id": "easy_flu",
|
| 231 |
+
"difficulty": "easy",
|
| 232 |
+
"true_diagnosis": "Seasonal Influenza",
|
| 233 |
+
"age": 28, "gender": "Female",
|
| 234 |
+
"presentation": "Patient presents with sudden fever (38.9°C), body aches, headache, fatigue, and dry cough for 2 days. No shortness of breath.",
|
| 235 |
+
"hidden_findings": {"fever": "38.9°C", "duration": "2 days", "onset": "sudden"},
|
| 236 |
+
"test_results": {
|
| 237 |
+
"rapid_flu_test": {"result": "Positive for Influenza A", "interpretation": "Positive Influenza A — confirms influenza diagnosis"},
|
| 238 |
+
"cbc": {"result": "WBC 9.2, lymphocytosis", "interpretation": "Mild lymphocytosis consistent with viral infection"},
|
| 239 |
+
},
|
| 240 |
+
"correct_diagnoses": ["Seasonal Influenza", "Influenza A", "Flu"],
|
| 241 |
+
"differential_diagnoses": ["COVID-19", "Common Cold", "Strep Throat"],
|
| 242 |
+
},
|
| 243 |
+
"easy_uti": {
|
| 244 |
+
"case_id": "easy_uti",
|
| 245 |
+
"difficulty": "easy",
|
| 246 |
+
"true_diagnosis": "Urinary Tract Infection",
|
| 247 |
+
"age": 35, "gender": "Female",
|
| 248 |
+
"presentation": "Patient presents with frequent urination, burning sensation during urination, and lower abdominal pain for 3 days.",
|
| 249 |
+
"hidden_findings": {"frequency": "frequent urination", "pain": "burning during urination", "duration": "3 days"},
|
| 250 |
+
"test_results": {
|
| 251 |
+
"urinalysis": {"result": "Positive for nitrites, leukocytes >10", "interpretation": "Urinalysis shows signs of bacterial infection consistent with UTI"},
|
| 252 |
+
"urine_culture": {"result": "E. coli >100,000 CFU/mL", "interpretation": "Urine culture confirms E. coli urinary tract infection"},
|
| 253 |
+
},
|
| 254 |
+
"correct_diagnoses": ["Urinary Tract Infection", "UTI", "Bladder Infection"],
|
| 255 |
+
"differential_diagnoses": ["Cystitis", "Pyelonephritis", "Vaginitis"],
|
| 256 |
+
},
|
| 257 |
+
"medium_pneumonia": {
|
| 258 |
+
"case_id": "medium_pneumonia",
|
| 259 |
+
"difficulty": "medium",
|
| 260 |
+
"true_diagnosis": "Community-Acquired Pneumonia",
|
| 261 |
+
"age": 45, "gender": "Male",
|
| 262 |
+
"presentation": "Patient presents with productive cough, fever (39.2°C), shortness of breath, and right-sided chest pain for 5 days. Smoker with 20 pack-year history.",
|
| 263 |
+
"hidden_findings": {"cough": "productive", "fever": "39.2°C", "breathing": "shortness of breath", "smoking": "20 pack-years"},
|
| 264 |
+
"test_results": {
|
| 265 |
+
"chest_xray": {"result": "Right lower lobe consolidation", "interpretation": "Chest X-ray shows consolidation in right lower lobe consistent with pneumonia"},
|
| 266 |
+
"cbc": {"result": "WBC 14.5, neutrophilia", "interpretation": "Elevated white blood cell count with neutrophilia suggesting bacterial infection"},
|
| 267 |
+
"sputum_culture": {"result": "Streptococcus pneumoniae", "interpretation": "Sputum culture positive for Streptococcus pneumoniae"},
|
| 268 |
+
},
|
| 269 |
+
"correct_diagnoses": ["Community-Acquired Pneumonia", "Pneumonia", "Bacterial Pneumonia"],
|
| 270 |
+
"differential_diagnoses": ["Bronchitis", "Pulmonary Embolism", "Lung Cancer"],
|
| 271 |
+
},
|
| 272 |
+
"hard_endocarditis": {
|
| 273 |
+
"case_id": "hard_endocarditis",
|
| 274 |
+
"difficulty": "hard",
|
| 275 |
+
"true_diagnosis": "Infective Endocarditis",
|
| 276 |
+
"age": 55, "gender": "Male",
|
| 277 |
+
"presentation": "Patient with history of IV drug use presents with fever (38.8°C), new heart murmur, and splinter hemorrhages. Recent dental procedure 2 weeks ago.",
|
| 278 |
+
"hidden_findings": {"drug_use": "IV drug user", "murmur": "new heart murmur", "hemorrhages": "splinter hemorrhages", "dental": "recent dental procedure"},
|
| 279 |
+
"test_results": {
|
| 280 |
+
"blood_cultures": {"result": "Staphylococcus aureus in 3/3 bottles", "interpretation": "Blood cultures positive for Staphylococcus aureus in multiple bottles"},
|
| 281 |
+
"echocardiogram": {"result": "Vegetation on aortic valve", "interpretation": "Echocardiogram shows vegetation on aortic valve consistent with endocarditis"},
|
| 282 |
+
"cbc": {"result": "WBC 12.8, anemia", "interpretation": "Elevated white blood cells with anemia of chronic disease"},
|
| 283 |
+
},
|
| 284 |
+
"correct_diagnoses": ["Infective Endocarditis", "Endocarditis", "Bacterial Endocarditis"],
|
| 285 |
+
"differential_diagnoses": ["Sepsis", "Acute Rheumatic Fever", "Myocardial Infarction"],
|
| 286 |
+
},
|
| 287 |
+
"medium_appendicitis": {
|
| 288 |
+
"case_id": "medium_appendicitis",
|
| 289 |
+
"difficulty": "medium",
|
| 290 |
+
"true_diagnosis": "Acute Appendicitis",
|
| 291 |
+
"age": 23, "gender": "Female",
|
| 292 |
+
"presentation": "Patient presents with right lower quadrant abdominal pain, nausea, anorexia, and low-grade fever for 24 hours.",
|
| 293 |
+
"hidden_findings": {"pain": "right lower quadrant", "nausea": "yes", "anorexia": "yes", "fever": "low-grade"},
|
| 294 |
+
"test_results": {
|
| 295 |
+
"abdominal_ultrasound": {"result": "Enlarged appendix with periappendiceal fluid", "interpretation": "Findings are consistent with acute appendicitis"},
|
| 296 |
+
"cbc": {"result": "WBC 13.4, neutrophilia", "interpretation": "Elevated white blood cell count with neutrophils suggests acute inflammation"},
|
| 297 |
+
"urinalysis": {"result": "Trace leukocytes", "interpretation": "Urinalysis slightly abnormal but not diagnostic"},
|
| 298 |
+
},
|
| 299 |
+
"correct_diagnoses": ["Acute Appendicitis", "Appendicitis"],
|
| 300 |
+
"differential_diagnoses": ["Ovarian Cyst", "Ectopic Pregnancy", "Gastroenteritis"],
|
| 301 |
+
},
|
| 302 |
+
"hard_meningitis": {
|
| 303 |
+
"case_id": "hard_meningitis",
|
| 304 |
+
"difficulty": "hard",
|
| 305 |
+
"true_diagnosis": "Bacterial Meningitis",
|
| 306 |
+
"age": 34, "gender": "Male",
|
| 307 |
+
"presentation": "Patient presents with severe headache, neck stiffness, fever, photophobia, and confusion over the last 12 hours.",
|
| 308 |
+
"hidden_findings": {"headache": "severe", "neck": "stiff", "fever": "high", "photophobia": "yes"},
|
| 309 |
+
"test_results": {
|
| 310 |
+
"lumbar_puncture": {"result": "Cloudy CSF with neutrophil predominance", "interpretation": "CSF findings are consistent with bacterial meningitis"},
|
| 311 |
+
"blood_cultures": {"result": "Gram-positive cocci in pairs", "interpretation": "Blood cultures positive for likely Streptococcus pneumoniae"},
|
| 312 |
+
"ct_head": {"result": "No mass effect or hemorrhage", "interpretation": "CT head is unremarkable prior to lumbar puncture"},
|
| 313 |
+
},
|
| 314 |
+
"correct_diagnoses": ["Bacterial Meningitis", "Meningitis"],
|
| 315 |
+
"differential_diagnoses": ["Viral Meningitis", "Migraine", "Subarachnoid Hemorrhage"],
|
| 316 |
+
},
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
real_cases = generate_patient_cases_from_datasets() if USE_HF_DATASETS else {}
|
| 320 |
+
# Merge: static cases are always available, real cases supplement
|
| 321 |
+
PATIENT_CASES = {**STATIC_PATIENT_CASES, **real_cases}
|
| 322 |
+
|
| 323 |
+
|
| 324 |
+
# ==============================================================================================================================
|
| 325 |
+
# PATIENT RESPONSE GENERATION
|
| 326 |
+
# ==============================================================================================================================
|
| 327 |
+
|
| 328 |
+
def get_patient_response(case_id: str, question: str) -> str:
|
| 329 |
+
"""
|
| 330 |
+
Generate a patient response to a question.
|
| 331 |
+
|
| 332 |
+
For dataset-driven cases, responses are generic but clinically plausible.
|
| 333 |
+
"""
|
| 334 |
+
question_lower = question.lower()
|
| 335 |
+
if "pain" in question_lower:
|
| 336 |
+
return "Yes, I am experiencing pain in that area."
|
| 337 |
+
if "fever" in question_lower or "temperature" in question_lower:
|
| 338 |
+
return "I feel warm and may have a fever."
|
| 339 |
+
if "nausea" in question_lower or "vomit" in question_lower:
|
| 340 |
+
return "Yes, I am nauseated and may vomit."
|
| 341 |
+
if "cough" in question_lower or "breath" in question_lower:
|
| 342 |
+
return "I have some coughing and breathing discomfort."
|
| 343 |
+
if "symptom" in question_lower or "feel" in question_lower:
|
| 344 |
+
return "I have concerning symptoms right now."
|
| 345 |
+
return "I'm not sure about that. Can you ask in a different way?"
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.104.1
|
| 2 |
+
uvicorn[standard]==0.24.0
|
| 3 |
+
openenv-core==0.2.3
|
| 4 |
+
websockets==16.0
|
| 5 |
+
pydantic==2.5.0
|
| 6 |
+
pydantic-settings==2.1.0
|
| 7 |
+
openai>=2.7.2
|
| 8 |
+
requests==2.31.0
|
| 9 |
+
datasets==2.14.5
|
tests/test_environment.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Unit tests for Medical Diagnostic Environment
|
| 3 |
+
|
| 4 |
+
Run with: python -m pytest tests/test_environment.py -v
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import pytest
|
| 8 |
+
from server.environment import MedicalDiagnosticEnvironment
|
| 9 |
+
from server.medical_data import (
|
| 10 |
+
PATIENT_CASES,
|
| 11 |
+
calculate_question_reward,
|
| 12 |
+
calculate_test_reward,
|
| 13 |
+
calculate_diagnosis_accuracy,
|
| 14 |
+
)
|
| 15 |
+
from models import DiagnosticAction
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class TestMedicalDiagnosticEnvironment:
|
| 19 |
+
"""Test suite for MedicalDiagnosticEnvironment"""
|
| 20 |
+
|
| 21 |
+
@pytest.fixture
|
| 22 |
+
def env(self):
|
| 23 |
+
"""Create a fresh environment for each test"""
|
| 24 |
+
return MedicalDiagnosticEnvironment()
|
| 25 |
+
|
| 26 |
+
def test_environment_initialization(self, env):
|
| 27 |
+
"""Test that environment initializes correctly"""
|
| 28 |
+
assert env is not None
|
| 29 |
+
assert hasattr(env, "reset")
|
| 30 |
+
assert hasattr(env, "step")
|
| 31 |
+
assert hasattr(env, "state")
|
| 32 |
+
|
| 33 |
+
def test_reset_easy(self, env):
|
| 34 |
+
"""Test reset with easy difficulty"""
|
| 35 |
+
observation = env.reset(difficulty="easy")
|
| 36 |
+
assert observation is not None
|
| 37 |
+
assert hasattr(observation, "message")
|
| 38 |
+
assert "presentation" in observation.message.lower() or "patient" in observation.message.lower()
|
| 39 |
+
assert env.current_case_id is not None
|
| 40 |
+
assert env.current_difficulty == "easy"
|
| 41 |
+
|
| 42 |
+
def test_reset_medium(self, env):
|
| 43 |
+
"""Test reset with medium difficulty"""
|
| 44 |
+
observation = env.reset(difficulty="medium")
|
| 45 |
+
assert observation is not None
|
| 46 |
+
assert env.current_difficulty == "medium"
|
| 47 |
+
|
| 48 |
+
def test_reset_hard(self, env):
|
| 49 |
+
"""Test reset with hard difficulty"""
|
| 50 |
+
observation = env.reset(difficulty="hard")
|
| 51 |
+
assert observation is not None
|
| 52 |
+
assert env.current_difficulty == "hard"
|
| 53 |
+
|
| 54 |
+
def test_ask_question_action(self, env):
|
| 55 |
+
"""Test asking a question"""
|
| 56 |
+
env.reset(difficulty="easy")
|
| 57 |
+
action = DiagnosticAction(
|
| 58 |
+
action_type="ask_question",
|
| 59 |
+
question="Does the patient have a fever?"
|
| 60 |
+
)
|
| 61 |
+
result = env.step(action)
|
| 62 |
+
assert result is not None
|
| 63 |
+
assert hasattr(result, "reward")
|
| 64 |
+
assert result.reward >= 0 # Questions give non-negative reward
|
| 65 |
+
|
| 66 |
+
def test_order_test_action(self, env):
|
| 67 |
+
"""Test ordering a test"""
|
| 68 |
+
env.reset(difficulty="easy")
|
| 69 |
+
action = DiagnosticAction(
|
| 70 |
+
action_type="order_test",
|
| 71 |
+
test_name="Complete Blood Count"
|
| 72 |
+
)
|
| 73 |
+
result = env.step(action)
|
| 74 |
+
assert result is not None
|
| 75 |
+
assert hasattr(result, "reward")
|
| 76 |
+
assert result.reward >= 0 # Tests give non-negative reward
|
| 77 |
+
|
| 78 |
+
def test_submit_diagnosis_action(self, env):
|
| 79 |
+
"""Test submitting a diagnosis"""
|
| 80 |
+
env.reset(difficulty="easy")
|
| 81 |
+
action = DiagnosticAction(
|
| 82 |
+
action_type="submit_diagnosis",
|
| 83 |
+
diagnosis="Common Flu"
|
| 84 |
+
)
|
| 85 |
+
result = env.step(action)
|
| 86 |
+
assert result is not None
|
| 87 |
+
assert hasattr(result, "reward")
|
| 88 |
+
assert result.done is True # Episode should end on diagnosis
|
| 89 |
+
|
| 90 |
+
def test_max_steps_enforcement(self, env):
|
| 91 |
+
"""Test that episodes end after max steps"""
|
| 92 |
+
env.reset(difficulty="easy")
|
| 93 |
+
for _ in range(15): # Max 15 steps
|
| 94 |
+
action = DiagnosticAction(
|
| 95 |
+
action_type="ask_question",
|
| 96 |
+
question="Test question"
|
| 97 |
+
)
|
| 98 |
+
result = env.step(action)
|
| 99 |
+
if result.done:
|
| 100 |
+
break
|
| 101 |
+
assert result.done is True
|
| 102 |
+
|
| 103 |
+
def test_episode_summary(self, env):
|
| 104 |
+
"""Test episode summary generation"""
|
| 105 |
+
env.reset(difficulty="easy")
|
| 106 |
+
action = DiagnosticAction(
|
| 107 |
+
action_type="submit_diagnosis",
|
| 108 |
+
diagnosis="Test Diagnosis"
|
| 109 |
+
)
|
| 110 |
+
env.step(action)
|
| 111 |
+
summary = env.get_episode_summary()
|
| 112 |
+
assert summary is not None
|
| 113 |
+
assert "case_id" in summary
|
| 114 |
+
assert "difficulty" in summary
|
| 115 |
+
assert "accuracy" in summary
|
| 116 |
+
assert "total_reward" in summary
|
| 117 |
+
assert "steps" in summary
|
| 118 |
+
|
| 119 |
+
def test_state_property(self, env):
|
| 120 |
+
"""Test the state property"""
|
| 121 |
+
env.reset(difficulty="easy")
|
| 122 |
+
state = env.state
|
| 123 |
+
assert state is not None
|
| 124 |
+
assert hasattr(state, "patient_id")
|
| 125 |
+
assert hasattr(state, "step_count")
|
| 126 |
+
assert hasattr(state, "true_diagnosis")
|
| 127 |
+
|
| 128 |
+
def test_concurrent_sessions(self):
|
| 129 |
+
"""Test that environment supports concurrent sessions"""
|
| 130 |
+
env = MedicalDiagnosticEnvironment()
|
| 131 |
+
assert env.SUPPORTS_CONCURRENT_SESSIONS is True
|
| 132 |
+
|
| 133 |
+
def test_multiple_episodes(self, env):
|
| 134 |
+
"""Test running multiple episodes"""
|
| 135 |
+
for difficulty in ["easy", "medium", "hard"]:
|
| 136 |
+
observation = env.reset(difficulty=difficulty)
|
| 137 |
+
assert observation is not None
|
| 138 |
+
assert env.current_difficulty == difficulty
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class TestMedicalData:
|
| 142 |
+
"""Test suite for medical data functions"""
|
| 143 |
+
|
| 144 |
+
def test_question_reward_calculation(self):
|
| 145 |
+
"""Test question reward calculation"""
|
| 146 |
+
# This is case-specific, so we just verify the function works
|
| 147 |
+
case_id = next(iter(PATIENT_CASES))
|
| 148 |
+
reward = calculate_question_reward(
|
| 149 |
+
case_id=case_id,
|
| 150 |
+
question="Does the patient have a fever?"
|
| 151 |
+
)
|
| 152 |
+
assert 0.0 <= reward <= 1.0
|
| 153 |
+
|
| 154 |
+
def test_test_reward_calculation(self):
|
| 155 |
+
"""Test test reward calculation"""
|
| 156 |
+
case_id = next(iter(PATIENT_CASES))
|
| 157 |
+
reward = calculate_test_reward(
|
| 158 |
+
case_id=case_id,
|
| 159 |
+
test_name="CBC"
|
| 160 |
+
)
|
| 161 |
+
assert 0.0 <= reward <= 1.0
|
| 162 |
+
|
| 163 |
+
def test_diagnosis_accuracy_exact_match(self):
|
| 164 |
+
"""Test exact diagnosis match"""
|
| 165 |
+
case_id = next(iter(PATIENT_CASES))
|
| 166 |
+
accuracy = calculate_diagnosis_accuracy(
|
| 167 |
+
case_id=case_id,
|
| 168 |
+
submitted_diagnosis=PATIENT_CASES[case_id].get("true_diagnosis", "")
|
| 169 |
+
)
|
| 170 |
+
assert accuracy == 1.0
|
| 171 |
+
|
| 172 |
+
def test_diagnosis_accuracy_partial(self):
|
| 173 |
+
"""Test partial diagnosis accuracy"""
|
| 174 |
+
case_id = next(iter(PATIENT_CASES))
|
| 175 |
+
accuracy = calculate_diagnosis_accuracy(
|
| 176 |
+
case_id=case_id,
|
| 177 |
+
submitted_diagnosis="Pneumonia"
|
| 178 |
+
)
|
| 179 |
+
assert 0.0 <= accuracy <= 1.0
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
if __name__ == "__main__":
|
| 183 |
+
pytest.main([__file__, "-v", "-s"])
|
training_wrapper.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""training_wrapper.py — Minimal training-ready wrapper for the Medical Diagnostic Environment.
|
| 2 |
+
|
| 3 |
+
This module exposes a small async/sync wrapper that is easy to plug into a training
|
| 4 |
+
loop or evaluation script. It is not a full RL algorithm, but it makes the
|
| 5 |
+
environment easy to consume for model-based training.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import asyncio
|
| 10 |
+
from typing import Optional
|
| 11 |
+
|
| 12 |
+
from client import DiagnosticEnv
|
| 13 |
+
from models import DiagnosticAction, PatientObservation
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TrainingEnv:
|
| 17 |
+
"""Minimal wrapper exposing a training-friendly environment interface."""
|
| 18 |
+
|
| 19 |
+
def __init__(self, base_url: Optional[str] = None):
|
| 20 |
+
self.base_url = base_url or os.getenv("ENV_URL", "ws://localhost:8000/ws")
|
| 21 |
+
self._env = DiagnosticEnv(base_url=self.base_url)
|
| 22 |
+
|
| 23 |
+
async def __aenter__(self):
|
| 24 |
+
await self._env.__aenter__()
|
| 25 |
+
return self
|
| 26 |
+
|
| 27 |
+
async def __aexit__(self, exc_type, exc, tb):
|
| 28 |
+
await self._env.__aexit__(exc_type, exc, tb)
|
| 29 |
+
|
| 30 |
+
async def reset(self, difficulty: str = "easy") -> PatientObservation:
|
| 31 |
+
result = await self._env.reset(difficulty=difficulty)
|
| 32 |
+
return result.observation if hasattr(result, "observation") else result
|
| 33 |
+
|
| 34 |
+
async def step(
|
| 35 |
+
self,
|
| 36 |
+
action_type: str,
|
| 37 |
+
question: Optional[str] = None,
|
| 38 |
+
test_name: Optional[str] = None,
|
| 39 |
+
diagnosis: Optional[str] = None,
|
| 40 |
+
) -> PatientObservation:
|
| 41 |
+
action = DiagnosticAction(
|
| 42 |
+
action_type=action_type,
|
| 43 |
+
question=question,
|
| 44 |
+
test_name=test_name,
|
| 45 |
+
diagnosis=diagnosis,
|
| 46 |
+
)
|
| 47 |
+
result = await self._env.step(action)
|
| 48 |
+
return result.observation if hasattr(result, "observation") else result
|
| 49 |
+
|
| 50 |
+
async def state(self):
|
| 51 |
+
return await self._env.state()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
async def run_demo():
|
| 55 |
+
"""Example usage of the training wrapper."""
|
| 56 |
+
async with TrainingEnv() as env:
|
| 57 |
+
obs = await env.reset(difficulty="easy")
|
| 58 |
+
print("Reset observation:", obs.message)
|
| 59 |
+
result = await env.step(action_type="ask_question", question="Do you have a fever?")
|
| 60 |
+
print("Step result:", result.message)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
if __name__ == "__main__":
|
| 64 |
+
asyncio.run(run_demo())
|
validate.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Quick validation script for Medical Diagnostic Environment
|
| 4 |
+
|
| 5 |
+
This script validates that the core environment works correctly without
|
| 6 |
+
requiring the server to be running or external dependencies beyond models.
|
| 7 |
+
|
| 8 |
+
Run with: python validate.py
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import sys
|
| 12 |
+
import traceback
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
from typing import Dict, List
|
| 15 |
+
|
| 16 |
+
# Add parent directory to path
|
| 17 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 18 |
+
|
| 19 |
+
from models import DiagnosticAction, PatientObservation, ClinicalState
|
| 20 |
+
from server.environment import MedicalDiagnosticEnvironment
|
| 21 |
+
from server.medical_data import (
|
| 22 |
+
PATIENT_CASES,
|
| 23 |
+
calculate_question_reward,
|
| 24 |
+
calculate_test_reward,
|
| 25 |
+
calculate_diagnosis_accuracy,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ValidationResult:
|
| 30 |
+
"""Result of a validation check"""
|
| 31 |
+
def __init__(self, name: str, passed: bool, error: str = None):
|
| 32 |
+
self.name = name
|
| 33 |
+
self.passed = passed
|
| 34 |
+
self.error = error
|
| 35 |
+
|
| 36 |
+
def __str__(self):
|
| 37 |
+
status = "PASS" if self.passed else "FAIL"
|
| 38 |
+
msg = f"{status}: {self.name}"
|
| 39 |
+
if self.error:
|
| 40 |
+
msg += f"\n Error: {self.error}"
|
| 41 |
+
return msg
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def validate_imports() -> ValidationResult:
|
| 45 |
+
"""Check that all imports work"""
|
| 46 |
+
try:
|
| 47 |
+
from models import DiagnosticAction, PatientObservation, ClinicalState
|
| 48 |
+
from server.environment import MedicalDiagnosticEnvironment
|
| 49 |
+
from server.medical_data import (
|
| 50 |
+
calculate_question_reward,
|
| 51 |
+
calculate_test_reward,
|
| 52 |
+
calculate_diagnosis_accuracy,
|
| 53 |
+
)
|
| 54 |
+
return ValidationResult("Imports", True)
|
| 55 |
+
except Exception as e:
|
| 56 |
+
return ValidationResult("Imports", False, str(e))
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def validate_model_creation() -> ValidationResult:
|
| 60 |
+
"""Check that models can be instantiated"""
|
| 61 |
+
try:
|
| 62 |
+
action = DiagnosticAction(
|
| 63 |
+
action_type="ask_question",
|
| 64 |
+
question="Test question?"
|
| 65 |
+
)
|
| 66 |
+
assert action.action_type == "ask_question"
|
| 67 |
+
assert action.question == "Test question?"
|
| 68 |
+
return ValidationResult("Model Creation", True)
|
| 69 |
+
except Exception as e:
|
| 70 |
+
return ValidationResult("Model Creation", False, str(e))
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def validate_environment_init() -> ValidationResult:
|
| 74 |
+
"""Check that environment initializes"""
|
| 75 |
+
try:
|
| 76 |
+
env = MedicalDiagnosticEnvironment()
|
| 77 |
+
assert env is not None
|
| 78 |
+
assert hasattr(env, "reset")
|
| 79 |
+
assert hasattr(env, "step")
|
| 80 |
+
return ValidationResult("Environment Initialization", True)
|
| 81 |
+
except Exception as e:
|
| 82 |
+
return ValidationResult("Environment Initialization", False, str(e))
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def validate_reset_all_difficulties() -> ValidationResult:
|
| 86 |
+
"""Check that reset works for all difficulties"""
|
| 87 |
+
try:
|
| 88 |
+
env = MedicalDiagnosticEnvironment()
|
| 89 |
+
for difficulty in ["easy", "medium", "hard"]:
|
| 90 |
+
obs = env.reset(difficulty=difficulty)
|
| 91 |
+
assert obs is not None
|
| 92 |
+
assert env.current_difficulty == difficulty
|
| 93 |
+
assert env.current_case_id is not None
|
| 94 |
+
return ValidationResult("Reset All Difficulties", True)
|
| 95 |
+
except Exception as e:
|
| 96 |
+
return ValidationResult("Reset All Difficulties", False, str(e))
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def validate_question_action() -> ValidationResult:
|
| 100 |
+
"""Check that asking questions works"""
|
| 101 |
+
try:
|
| 102 |
+
env = MedicalDiagnosticEnvironment()
|
| 103 |
+
env.reset(difficulty="easy")
|
| 104 |
+
action = DiagnosticAction(
|
| 105 |
+
action_type="ask_question",
|
| 106 |
+
question="Does the patient have symptoms?"
|
| 107 |
+
)
|
| 108 |
+
result = env.step(action)
|
| 109 |
+
assert result is not None
|
| 110 |
+
assert result.reward >= 0
|
| 111 |
+
assert result.done is False # Should not end on question
|
| 112 |
+
return ValidationResult("Question Action", True)
|
| 113 |
+
except Exception as e:
|
| 114 |
+
return ValidationResult("Question Action", False, str(e))
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def validate_test_action() -> ValidationResult:
|
| 118 |
+
"""Check that ordering tests works"""
|
| 119 |
+
try:
|
| 120 |
+
env = MedicalDiagnosticEnvironment()
|
| 121 |
+
env.reset(difficulty="easy")
|
| 122 |
+
action = DiagnosticAction(
|
| 123 |
+
action_type="order_test",
|
| 124 |
+
test_name="Complete Blood Count"
|
| 125 |
+
)
|
| 126 |
+
result = env.step(action)
|
| 127 |
+
assert result is not None
|
| 128 |
+
assert result.reward >= 0
|
| 129 |
+
assert result.done is False # Should not end on test
|
| 130 |
+
return ValidationResult("Test Action", True)
|
| 131 |
+
except Exception as e:
|
| 132 |
+
return ValidationResult("Test Action", False, str(e))
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def validate_diagnosis_action() -> ValidationResult:
|
| 136 |
+
"""Check that diagnosis submission works"""
|
| 137 |
+
try:
|
| 138 |
+
env = MedicalDiagnosticEnvironment()
|
| 139 |
+
env.reset(difficulty="easy")
|
| 140 |
+
action = DiagnosticAction(
|
| 141 |
+
action_type="submit_diagnosis",
|
| 142 |
+
diagnosis="Common Flu"
|
| 143 |
+
)
|
| 144 |
+
result = env.step(action)
|
| 145 |
+
assert result is not None
|
| 146 |
+
assert result.reward is not None
|
| 147 |
+
assert result.done is True # Should end on diagnosis
|
| 148 |
+
return ValidationResult("Diagnosis Action", True)
|
| 149 |
+
except Exception as e:
|
| 150 |
+
return ValidationResult("Diagnosis Action", False, str(e))
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def validate_episode_summary() -> ValidationResult:
|
| 154 |
+
"""Check that episode summaries are generated correctly"""
|
| 155 |
+
try:
|
| 156 |
+
env = MedicalDiagnosticEnvironment()
|
| 157 |
+
env.reset(difficulty="easy")
|
| 158 |
+
action = DiagnosticAction(
|
| 159 |
+
action_type="submit_diagnosis",
|
| 160 |
+
diagnosis="Test"
|
| 161 |
+
)
|
| 162 |
+
env.step(action)
|
| 163 |
+
summary = env.get_episode_summary()
|
| 164 |
+
assert summary is not None
|
| 165 |
+
assert "case_id" in summary
|
| 166 |
+
assert "difficulty" in summary
|
| 167 |
+
assert "accuracy" in summary
|
| 168 |
+
assert "total_reward" in summary
|
| 169 |
+
assert "steps" in summary
|
| 170 |
+
return ValidationResult("Episode Summary", True)
|
| 171 |
+
except Exception as e:
|
| 172 |
+
return ValidationResult("Episode Summary", False, str(e))
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def validate_reward_functions() -> ValidationResult:
|
| 176 |
+
"""Check that reward functions work"""
|
| 177 |
+
try:
|
| 178 |
+
case_id = next(iter(PATIENT_CASES))
|
| 179 |
+
q_reward = calculate_question_reward(case_id, "Test question?")
|
| 180 |
+
assert isinstance(q_reward, float)
|
| 181 |
+
assert 0.0 <= q_reward <= 1.0
|
| 182 |
+
|
| 183 |
+
t_reward = calculate_test_reward(case_id, "CBC")
|
| 184 |
+
assert isinstance(t_reward, float)
|
| 185 |
+
assert 0.0 <= t_reward <= 1.0
|
| 186 |
+
|
| 187 |
+
true_diag = PATIENT_CASES[case_id].get("true_diagnosis", "")
|
| 188 |
+
d_accuracy = calculate_diagnosis_accuracy(case_id, true_diag)
|
| 189 |
+
assert isinstance(d_accuracy, float)
|
| 190 |
+
assert 0.0 <= d_accuracy <= 1.0
|
| 191 |
+
|
| 192 |
+
return ValidationResult("Reward Functions", True)
|
| 193 |
+
except Exception as e:
|
| 194 |
+
return ValidationResult("Reward Functions", False, str(e))
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def validate_state_property() -> ValidationResult:
|
| 198 |
+
"""Check that state property works"""
|
| 199 |
+
try:
|
| 200 |
+
env = MedicalDiagnosticEnvironment()
|
| 201 |
+
env.reset(difficulty="easy")
|
| 202 |
+
state = env.state()
|
| 203 |
+
assert state is not None
|
| 204 |
+
assert hasattr(state, "patient_id")
|
| 205 |
+
assert hasattr(state, "step_count")
|
| 206 |
+
assert hasattr(state, "true_diagnosis")
|
| 207 |
+
assert hasattr(state, "final_accuracy")
|
| 208 |
+
return ValidationResult("State Property", True)
|
| 209 |
+
except Exception as e:
|
| 210 |
+
return ValidationResult("State Property", False, str(e))
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def validate_concurrent_support() -> ValidationResult:
|
| 214 |
+
"""Check that environment supports concurrent sessions"""
|
| 215 |
+
try:
|
| 216 |
+
env = MedicalDiagnosticEnvironment()
|
| 217 |
+
assert hasattr(env, "SUPPORTS_CONCURRENT_SESSIONS")
|
| 218 |
+
assert env.SUPPORTS_CONCURRENT_SESSIONS is True
|
| 219 |
+
return ValidationResult("Concurrent Sessions Support", True)
|
| 220 |
+
except Exception as e:
|
| 221 |
+
return ValidationResult("Concurrent Sessions Support", False, str(e))
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def main():
|
| 225 |
+
"""Run all validation checks"""
|
| 226 |
+
print("=" * 70)
|
| 227 |
+
print("MEDICAL DIAGNOSTIC ENVIRONMENT - VALIDATION SUITE")
|
| 228 |
+
print("=" * 70)
|
| 229 |
+
print()
|
| 230 |
+
|
| 231 |
+
validators = [
|
| 232 |
+
validate_imports,
|
| 233 |
+
validate_model_creation,
|
| 234 |
+
validate_environment_init,
|
| 235 |
+
validate_reset_all_difficulties,
|
| 236 |
+
validate_question_action,
|
| 237 |
+
validate_test_action,
|
| 238 |
+
validate_diagnosis_action,
|
| 239 |
+
validate_episode_summary,
|
| 240 |
+
validate_reward_functions,
|
| 241 |
+
validate_state_property,
|
| 242 |
+
validate_concurrent_support,
|
| 243 |
+
]
|
| 244 |
+
|
| 245 |
+
results: List[ValidationResult] = []
|
| 246 |
+
for validator in validators:
|
| 247 |
+
try:
|
| 248 |
+
result = validator()
|
| 249 |
+
except Exception as e:
|
| 250 |
+
result = ValidationResult(
|
| 251 |
+
validator.__name__,
|
| 252 |
+
False,
|
| 253 |
+
traceback.format_exc()
|
| 254 |
+
)
|
| 255 |
+
results.append(result)
|
| 256 |
+
print(result)
|
| 257 |
+
|
| 258 |
+
print()
|
| 259 |
+
print("=" * 70)
|
| 260 |
+
passed = sum(1 for r in results if r.passed)
|
| 261 |
+
total = len(results)
|
| 262 |
+
print(f"SUMMARY: {passed}/{total} checks passed")
|
| 263 |
+
print("=" * 70)
|
| 264 |
+
|
| 265 |
+
return 0 if passed == total else 1
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
if __name__ == "__main__":
|
| 269 |
+
sys.exit(main())
|