Aman Khare commited on
Commit ·
b3d1ac3
1
Parent(s): ef8291e
this is the first commit to creating the project
Browse files- .gitattributes +21 -1
- Dockerfile +18 -0
- README.md +199 -1
- __pycache__/inference.cpython-314.pyc +0 -0
- data/clarify_answers.json +14 -0
- data/transcripts/easy.txt +25 -0
- data/transcripts/hard.txt +67 -0
- data/transcripts/medium.txt +41 -0
- environment/__init__.py +13 -0
- environment/__pycache__/__init__.cpython-314.pyc +0 -0
- environment/__pycache__/env.cpython-314.pyc +0 -0
- environment/__pycache__/models.cpython-314.pyc +0 -0
- environment/__pycache__/reward.cpython-314.pyc +0 -0
- environment/env.py +383 -0
- environment/models.py +175 -0
- environment/reward.py +197 -0
- environment/tasks/__init__.py +22 -0
- environment/tasks/__pycache__/__init__.cpython-314.pyc +0 -0
- environment/tasks/__pycache__/task_easy.cpython-314.pyc +0 -0
- environment/tasks/__pycache__/task_hard.cpython-314.pyc +0 -0
- environment/tasks/__pycache__/task_medium.cpython-314.pyc +0 -0
- environment/tasks/task_easy.py +57 -0
- environment/tasks/task_hard.py +86 -0
- environment/tasks/task_medium.py +68 -0
- inference.py +324 -0
- openenv.yaml +211 -0
- requirements.txt +4 -0
- server/__init__.py +1 -0
- server/__pycache__/__init__.cpython-314.pyc +0 -0
- server/__pycache__/app.cpython-314.pyc +0 -0
- server/__pycache__/routes.cpython-314.pyc +0 -0
- server/app.py +46 -0
- server/routes.py +154 -0
- test_inference.py +26 -0
- test_output.txt +9 -0
- test_reward.py +75 -0
.gitattributes
CHANGED
|
@@ -1,2 +1,22 @@
|
|
| 1 |
-
# Auto
|
| 2 |
* text=auto
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Auto-detect text files and normalize line endings
|
| 2 |
* text=auto
|
| 3 |
+
|
| 4 |
+
# Enforce LF line endings for source code and config files
|
| 5 |
+
*.py text eol=lf
|
| 6 |
+
*.sh text eol=lf
|
| 7 |
+
*.yaml text eol=lf
|
| 8 |
+
*.yml text eol=lf
|
| 9 |
+
*.json text eol=lf
|
| 10 |
+
*.md text eol=lf
|
| 11 |
+
*.txt text eol=lf
|
| 12 |
+
*.ipynb text eol=lf
|
| 13 |
+
|
| 14 |
+
# Binary assets (do not modify)
|
| 15 |
+
*.png binary
|
| 16 |
+
*.jpg binary
|
| 17 |
+
*.jpeg binary
|
| 18 |
+
*.gif binary
|
| 19 |
+
*.pdf binary
|
| 20 |
+
|
| 21 |
+
# Custom diff driver for Python files
|
| 22 |
+
*.py diff=python
|
Dockerfile
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# Prevent .pyc files and enable unbuffered stdout for structured logging
|
| 4 |
+
ENV PYTHONDONTWRITEBYTECODE=1 \
|
| 5 |
+
PYTHONUNBUFFERED=1
|
| 6 |
+
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
# Install dependencies first (layer caching)
|
| 10 |
+
COPY requirements.txt .
|
| 11 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 12 |
+
|
| 13 |
+
# Copy project source
|
| 14 |
+
COPY . .
|
| 15 |
+
|
| 16 |
+
EXPOSE 7860
|
| 17 |
+
|
| 18 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,2 +1,200 @@
|
|
| 1 |
-
#
|
| 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Clinical Note Scribe
|
| 2 |
|
| 3 |
+
An **OpenEnv-compliant** environment for evaluating AI agents on clinical SOAP-note generation from doctor–patient transcripts.
|
| 4 |
+
|
| 5 |
+
Built for the **Meta × Hugging Face OpenEnv Hackathon**.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## Environment Description
|
| 10 |
+
|
| 11 |
+
A doctor–patient conversation is recorded as a text transcript. The agent's goal is to read the transcript along with structured patient context (demographics, medications, labs) and produce a clinically accurate, concise **SOAP note** (Subjective, Objective, Assessment, Plan).
|
| 12 |
+
|
| 13 |
+
The agent interacts through a standard `reset()` / `step()` / `state()` API. Three action types are available: submit a full note, request clarification, or revise a single section. A multi-signal reward function scores each submission on clinical accuracy, conciseness, safe language, and structural validity, with penalties for excessive steps or invalid actions.
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## Observation Space
|
| 18 |
+
|
| 19 |
+
| Field | Type | Description |
|
| 20 |
+
|---|---|---|
|
| 21 |
+
| `transcript` | `str` | Full doctor–patient transcript for the current task |
|
| 22 |
+
| `task_id` | `str` | Unique identifier for the active task |
|
| 23 |
+
| `patient_context` | `dict[str, Any]` | Structured patient demographics, conditions, medications, allergies, and labs |
|
| 24 |
+
| `current_draft` | `Optional[str]` | The agent's most recent SOAP-note draft (null until first submission or revision) |
|
| 25 |
+
| `errors_so_far` | `list[str]` | Accumulated error/feedback messages from prior invalid actions |
|
| 26 |
+
| `step_count` | `int` | Number of steps taken so far in the current episode (0-indexed at reset) |
|
| 27 |
+
|
| 28 |
+
---
|
| 29 |
+
|
| 30 |
+
## Action Space
|
| 31 |
+
|
| 32 |
+
| Field | Type | Description |
|
| 33 |
+
|---|---|---|
|
| 34 |
+
| `action_type` | `Literal["submit_note", "request_clarify", "revise_section"]` | **Required.** The kind of action the agent is taking |
|
| 35 |
+
| `soap_note` | `Optional[SOAPNote]` | Complete SOAP note — **required** when `action_type == "submit_note"` |
|
| 36 |
+
| `section` | `Optional[Literal["S", "O", "A", "P"]]` | Which SOAP section to revise — **required** when `action_type == "revise_section"` |
|
| 37 |
+
| `revision_text` | `Optional[str]` | Replacement text for the section — **required** when `action_type == "revise_section"` |
|
| 38 |
+
| `clarify_question` | `Optional[str]` | Free-text question — **required** when `action_type == "request_clarify"` |
|
| 39 |
+
|
| 40 |
+
### SOAPNote Schema
|
| 41 |
+
|
| 42 |
+
| Field | Type | Description |
|
| 43 |
+
|---|---|---|
|
| 44 |
+
| `subjective` | `str` | Patient's self-reported symptoms, history, and concerns |
|
| 45 |
+
| `objective` | `str` | Clinician's measurable findings — vitals, exam, labs, imaging |
|
| 46 |
+
| `assessment` | `str` | Differential diagnoses and clinical reasoning |
|
| 47 |
+
| `plan` | `str` | Treatment plan, medications, follow-ups, referrals |
|
| 48 |
+
|
| 49 |
+
---
|
| 50 |
+
|
| 51 |
+
## Tasks
|
| 52 |
+
|
| 53 |
+
### 🟢 Easy — Routine Check-Up
|
| 54 |
+
**Task ID:** `easy_routine_checkup` · **Max steps:** 5
|
| 55 |
+
|
| 56 |
+
A 6-turn dialogue about a common cold and blood pressure screening for a 34-year-old female. Straightforward clinical picture with no complications.
|
| 57 |
+
|
| 58 |
+
### 🟡 Medium — Chronic Disease Follow-Up
|
| 59 |
+
**Task ID:** `medium_chronic_disease_followup` · **Max steps:** 8
|
| 60 |
+
|
| 61 |
+
A 14-turn follow-up visit for a 58-year-old male with Type 2 Diabetes and Hypertension. Includes HbA1c lab review (7.2% → 7.8%), medication adjustments (adding glipizide 5 mg, uptitrating lisinopril 20 → 40 mg), a 2-week statin gap, and dietary counselling around restaurant meals.
|
| 62 |
+
|
| 63 |
+
### 🔴 Hard — Complex ER Visit
|
| 64 |
+
**Task ID:** `hard_complex_er_visit` · **Max steps:** 10
|
| 65 |
+
|
| 66 |
+
A rapid 20-turn emergency-room encounter for a 72-year-old female with CAD, AFib, and CKD Stage 3. Overlapping chest pain and shortness of breath with a dual ACS vs PE differential. Includes a patient self-contradiction (denied then admitted nitroglycerin use at home), contrast dye allergy complicating CT-PA workup (V/Q scan ordered instead), elevated D-dimer (1840 ng/mL), and Cardiac ICU admission.
|
| 67 |
+
|
| 68 |
+
---
|
| 69 |
+
|
| 70 |
+
## Reward Function
|
| 71 |
+
|
| 72 |
+
```
|
| 73 |
+
value = clamp(weighted_sum − deductions, 0.0, 1.0)
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
| Signal | Weight | Criteria |
|
| 77 |
+
|---|---|---|
|
| 78 |
+
| `grader_score` | × 0.60 | Clinical accuracy from task-specific grader |
|
| 79 |
+
| `conciseness_bonus` | × 0.10 | 1.0 if total SOAP note ≤ 400 words |
|
| 80 |
+
| `safe_language_score` | × 0.15 | 1.0 if no unsafe-certainty phrases detected |
|
| 81 |
+
| `format_valid` | × 0.15 | 1.0 if all four SOAP fields are non-empty |
|
| 82 |
+
|
| 83 |
+
| Deduction | Rate | Trigger |
|
| 84 |
+
|---|---|---|
|
| 85 |
+
| Step penalty | −0.05 | Per step beyond 3 (penalises excessive clarification) |
|
| 86 |
+
| Error penalty | −0.10 | Per invalid action in `errors_so_far` |
|
| 87 |
+
|
| 88 |
+
---
|
| 89 |
+
|
| 90 |
+
## Setup Instructions
|
| 91 |
+
|
| 92 |
+
### Prerequisites
|
| 93 |
+
|
| 94 |
+
- Python 3.11+
|
| 95 |
+
- An OpenAI-compatible API key (set as `OPENAI_API_KEY`)
|
| 96 |
+
|
| 97 |
+
### Local Development
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
# Clone the repository
|
| 101 |
+
git clone https://github.com/<your-org>/meta-huggingface-hackathon-team-silver-orca.git
|
| 102 |
+
cd meta-huggingface-hackathon-team-silver-orca
|
| 103 |
+
|
| 104 |
+
# Install dependencies
|
| 105 |
+
pip install -r requirements.txt
|
| 106 |
+
|
| 107 |
+
# Start the environment server
|
| 108 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 109 |
+
|
| 110 |
+
# In another terminal — run the baseline inference
|
| 111 |
+
export OPENAI_API_KEY="sk-..."
|
| 112 |
+
export MODEL_NAME="gpt-4o-mini" # or any OpenAI-compatible model
|
| 113 |
+
export API_BASE_URL="https://api.openai.com/v1"
|
| 114 |
+
python inference.py
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
### Docker
|
| 118 |
+
|
| 119 |
+
```bash
|
| 120 |
+
docker build -t meta-huggingface-hackathon-team-silver-orca .
|
| 121 |
+
docker run -p 7860:7860 meta-huggingface-hackathon-team-silver-orca
|
| 122 |
+
```
|
| 123 |
+
|
| 124 |
+
### API Endpoints
|
| 125 |
+
|
| 126 |
+
| Method | Path | Description |
|
| 127 |
+
|---|---|---|
|
| 128 |
+
| `GET` | `/health` | Liveness probe → `{"status": "ok"}` |
|
| 129 |
+
| `POST` | `/reset` | Start a new episode → `Observation` |
|
| 130 |
+
| `POST` | `/step` | Submit an action → `{observation, reward, done, info}` |
|
| 131 |
+
| `GET` | `/state` | Inspect environment state → `EnvironmentState` |
|
| 132 |
+
|
| 133 |
+
---
|
| 134 |
+
|
| 135 |
+
## Baseline Scores
|
| 136 |
+
|
| 137 |
+
Scores obtained using `gpt-4o-mini` with `temperature=0.2` via `inference.py`:
|
| 138 |
+
|
| 139 |
+
| Task | Difficulty | Score |
|
| 140 |
+
|---|---|---|
|
| 141 |
+
| `easy_routine_checkup` | 🟢 Easy | 0.7000 |
|
| 142 |
+
| `medium_chronic_disease_followup` | 🟡 Medium | 0.7000 |
|
| 143 |
+
| `hard_complex_er_visit` | 🔴 Hard | 0.7000 |
|
| 144 |
+
| **Average** | | **0.7000** |
|
| 145 |
+
|
| 146 |
+
> **Note:** These baseline scores use placeholder graders (returning 0.5). Once task-specific graders are fully implemented, scores will vary by clinical accuracy.
|
| 147 |
+
|
| 148 |
+
---
|
| 149 |
+
|
| 150 |
+
## Structured Logging
|
| 151 |
+
|
| 152 |
+
Every episode emits JSON log lines to stdout, scraped by the OpenEnv validator:
|
| 153 |
+
|
| 154 |
+
```json
|
| 155 |
+
{"event": "START", "task_id": "easy_routine_checkup", "timestamp": 1700000000.0}
|
| 156 |
+
{"event": "STEP", "step": 1, "action_type": "submit_note", "reward": 0.82}
|
| 157 |
+
{"event": "END", "task_id": "easy_routine_checkup", "final_score": 0.82}
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
---
|
| 161 |
+
|
| 162 |
+
## Project Structure
|
| 163 |
+
|
| 164 |
+
```
|
| 165 |
+
meta-huggingface-hackathon-team-silver-orca/
|
| 166 |
+
├── openenv.yaml ← OpenEnv spec metadata + graders
|
| 167 |
+
├── inference.py ← Baseline inference (OpenAI client, all 3 tasks)
|
| 168 |
+
├── Dockerfile ← Containerised server (port 7860)
|
| 169 |
+
├── README.md ← This file
|
| 170 |
+
├── requirements.txt
|
| 171 |
+
│
|
| 172 |
+
├── environment/
|
| 173 |
+
│ ├── __init__.py
|
| 174 |
+
│ ├── models.py ← Pydantic v2 models (Observation, Action, Reward, …)
|
| 175 |
+
│ ├── env.py ← ClinicalNoteScribeEnv (reset/step/state)
|
| 176 |
+
│ ├── reward.py ← Multi-signal reward function
|
| 177 |
+
│ └── tasks/
|
| 178 |
+
│ ├── __init__.py ← Task & grader registries
|
| 179 |
+
│ ├── task_easy.py ← Routine check-up + grader stub
|
| 180 |
+
│ ├── task_medium.py ← Chronic disease follow-up + grader stub
|
| 181 |
+
│ └── task_hard.py ← Complex ER visit + grader stub
|
| 182 |
+
│
|
| 183 |
+
├── server/
|
| 184 |
+
│ ├── __init__.py
|
| 185 |
+
│ ├── app.py ← FastAPI application
|
| 186 |
+
│ └── routes.py ← API route definitions
|
| 187 |
+
│
|
| 188 |
+
└── data/
|
| 189 |
+
├── transcripts/
|
| 190 |
+
│ ├── easy.txt ← 6-turn routine check-up transcript
|
| 191 |
+
│ ├── medium.txt ← 14-turn chronic disease follow-up transcript
|
| 192 |
+
│ └── hard.txt ← 20-turn complex ER visit transcript
|
| 193 |
+
└── clarify_answers.json ← Clarification Q&A lookup (10 entries)
|
| 194 |
+
```
|
| 195 |
+
|
| 196 |
+
---
|
| 197 |
+
|
| 198 |
+
## License
|
| 199 |
+
|
| 200 |
+
MIT
|
__pycache__/inference.cpython-314.pyc
ADDED
|
Binary file (15.4 kB). View file
|
|
|
data/clarify_answers.json
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"did the patient report any fever?": "No. Jane Doe denied fever throughout the visit. She mentioned mild body aches in the first two days that had since resolved.",
|
| 3 |
+
"what over-the-counter medications has the patient tried?": "The patient took generic DayQuil for two days but stopped because she was uncertain about interactions with her penicillin allergy. No other OTC medications were tried.",
|
| 4 |
+
"what was the patient's blood pressure reading?": "Blood pressure was 118/76 mmHg with a heart rate of 72 bpm, recorded during the visit.",
|
| 5 |
+
|
| 6 |
+
"what is the patient's current hba1c and how does it compare to the previous?": "HbA1c is 7.8%, up from 7.2% three months ago, indicating worsening glycemic control over the past quarter.",
|
| 7 |
+
"what dietary changes were discussed with the patient?": "The patient reported eating out at a barbecue restaurant 1-2 times per week with high-carb sides. The physician recommended limiting restaurant meals to once weekly and substituting starchy sides with salads or steamed vegetables.",
|
| 8 |
+
"was a new medication added for diabetes and what are the side effects?": "Glipizide 5 mg once daily with breakfast was added. The main side effect discussed was potential hypoglycemia (shakiness, sweating, lightheadedness) and modest weight gain.",
|
| 9 |
+
"why was lisinopril increased?": "Home blood pressure readings were averaging 135/85 and the in-office reading was 142/88, both above the target of 130/80 for a diabetic patient. The dose was increased from 20 mg to 40 mg daily. The patient reported no cough or dizziness on the current dose.",
|
| 10 |
+
|
| 11 |
+
"did the patient mention any recent travel or immobilization?": "Yes. The patient initially denied travel but then corrected herself — her daughter drove her to Sacramento last weekend, approximately five hours each way, during which she was mostly sitting.",
|
| 12 |
+
"why was ct-pa not ordered for the pe workup?": "The patient has a documented contrast dye allergy and chronic kidney disease with eGFR of 34 mL/min. Both are relative contraindications to IV contrast, so a V/Q scan was ordered instead.",
|
| 13 |
+
"did the patient contradict herself during the encounter?": "Yes. The patient initially stated the chest pain started while watching TV and she came straight in. She later disclosed that she had taken a nitroglycerin tablet at home before calling 911, which provided transient partial relief for about five minutes before the pain returned."
|
| 14 |
+
}
|
data/transcripts/easy.txt
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DR. PATEL: Good morning, Jane. What brings you in today?
|
| 2 |
+
|
| 3 |
+
JANE DOE: Hi, Doctor. I've had this cold for about five days now — runny nose, scratchy throat, and I've been sneezing a lot. I thought it would go away on its own, but it's lingering.
|
| 4 |
+
|
| 5 |
+
DR. PATEL: I see. Any fever, body aches, or chills along with it?
|
| 6 |
+
|
| 7 |
+
JANE DOE: No fever that I've noticed. Maybe some mild body aches the first couple of days, but nothing now. Mostly it's the congestion that's really bothering me — I can barely breathe through my nose at night.
|
| 8 |
+
|
| 9 |
+
DR. PATEL: That sounds uncomfortable. Have you tried any over-the-counter medications — decongestants, antihistamines, anything like that?
|
| 10 |
+
|
| 11 |
+
JANE DOE: I took some generic DayQuil for two days but honestly I couldn't tell if it helped much. I stopped because I wasn't sure if it was okay with my penicillin allergy.
|
| 12 |
+
|
| 13 |
+
DR. PATEL: That's good you're being cautious, but DayQuil doesn't contain penicillin so that's fine to use. Let me take a quick look at your throat and ears. — Okay, throat is mildly erythematous, no exudates. Tympanic membranes are clear bilaterally. Lungs are clear to auscultation. No lymphadenopathy. Let me also check your blood pressure since you're here — we like to keep an eye on that annually. — Blood pressure is 118 over 76, heart rate 72. That's excellent.
|
| 14 |
+
|
| 15 |
+
JANE DOE: Oh good, I was a little worried about that because my mom was just diagnosed with high blood pressure.
|
| 16 |
+
|
| 17 |
+
DR. PATEL: I understand the concern. Your numbers look great right now. Family history is something we'll keep monitoring, but at 34 with these readings, you're in a good spot. As for the cold — this looks like a straightforward upper respiratory infection, viral in nature. I'd recommend continuing the DayQuil during the day, switching to NyQuil at bedtime for the congestion, plenty of fluids, and rest. If it's not improving in another five to seven days, or if you develop a fever above 101, come back and we'll reassess.
|
| 18 |
+
|
| 19 |
+
JANE DOE: Sounds good. Should I worry about it turning into a sinus infection?
|
| 20 |
+
|
| 21 |
+
DR. PATEL: It's possible but unlikely at this stage. The things to watch for would be facial pressure or pain around your cheeks and forehead, thick yellow-green nasal discharge, or symptoms that seem to get better and then suddenly worsen. If any of that happens, give us a call. Otherwise, I think you'll be on the mend soon.
|
| 22 |
+
|
| 23 |
+
JANE DOE: Great, thank you so much, Doctor.
|
| 24 |
+
|
| 25 |
+
DR. PATEL: You're welcome, Jane. Feel better!
|
data/transcripts/hard.txt
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
TRIAGE NURSE: 72-year-old female, Maria Garcia, presenting with acute onset chest pain and shortness of breath. Pain started approximately ninety minutes ago. She's on warfarin. Vitals on arrival: BP 168 over 94, heart rate 112 and irregular, respiratory rate 22, SpO2 91 percent on room air, temp 37.2. She's being moved to Bed 4 in the acute bay.
|
| 2 |
+
|
| 3 |
+
DR. OKONKWO: Mrs. Garcia, I'm Dr. Okonkwo. I can see you're in some distress. Can you tell me about the pain you're feeling?
|
| 4 |
+
|
| 5 |
+
MARIA GARCIA: It's — it's right here in the center of my chest. It feels like a heavy pressure, like someone's sitting on me. And I can't catch my breath. It started when I was watching television, I wasn't doing anything strenuous.
|
| 6 |
+
|
| 7 |
+
DR. OKONKWO: On a scale of zero to ten, how severe is the pain right now?
|
| 8 |
+
|
| 9 |
+
MARIA GARCIA: Maybe a seven. It was worse when it started, maybe an eight or nine.
|
| 10 |
+
|
| 11 |
+
DR. OKONKWO: Does the pain go anywhere — into your arm, your jaw, your back?
|
| 12 |
+
|
| 13 |
+
MARIA GARCIA: A little into my left arm. Not my jaw. Maybe a tiny bit in my back between my shoulder blades.
|
| 14 |
+
|
| 15 |
+
DR. OKONKWO: Okay. Any nausea, vomiting, or dizziness?
|
| 16 |
+
|
| 17 |
+
MARIA GARCIA: I felt nauseous in the ambulance, but I didn't throw up. A little lightheaded, yes.
|
| 18 |
+
|
| 19 |
+
DR. OKONKWO: Have you had chest pain like this before?
|
| 20 |
+
|
| 21 |
+
MARIA GARCIA: I had a similar episode about two years ago and they said it was angina. They gave me nitroglycerin and it went away. But this feels different — stronger, and the breathing is worse this time.
|
| 22 |
+
|
| 23 |
+
DR. OKONKWO: I see from your chart you have coronary artery disease, atrial fibrillation, and chronic kidney disease stage 3. You're on warfarin, metoprolol, aspirin, furosemide, and amlodipine. Have you taken all your medications today?
|
| 24 |
+
|
| 25 |
+
MARIA GARCIA: Yes, I took everything this morning like I always do.
|
| 26 |
+
|
| 27 |
+
DR. OKONKWO: Any recent travel — long car rides, flights?
|
| 28 |
+
|
| 29 |
+
MARIA GARCIA: No, no travel. I've been — actually, wait. My daughter drove me to Sacramento last weekend. It was about a five-hour drive each way and I was mostly sitting in the car.
|
| 30 |
+
|
| 31 |
+
DR. OKONKWO: That's important. Any leg swelling, redness, or calf pain since that trip?
|
| 32 |
+
|
| 33 |
+
MARIA GARCIA: Now that you mention it, my left calf has been a little sore. I thought it was just from sitting so long. It's not swollen though — I don't think. Maybe a little.
|
| 34 |
+
|
| 35 |
+
DR. OKONKWO: Let me examine you. — Chest auscultation: irregular rhythm, no murmurs or gallops appreciated. Diminished breath sounds at the right base with fine crackles. Abdomen soft, non-tender. Left calf is mildly tender to palpation with trace edema compared to the right. No erythema. JVP appears mildly elevated.
|
| 36 |
+
|
| 37 |
+
DR. OKONKWO [to resident]: Alright, we need to move quickly. Differential here is broad — I'm most concerned about acute coronary syndrome versus pulmonary embolism, given the acute onset, the recent prolonged immobilization from the car trip, the calf tenderness, and the hypoxia. The atrial fib and CKD complicate both the workup and management. Let's get a 12-lead ECG stat, troponin I — I see the initial was 0.08 so borderline — CBC, BMP, coags with INR, BNP, and a D-dimer. Start her on two liters nasal cannula and let's get the SpO2 up. Do NOT give heparin yet — she's already on warfarin and her INR was 2.6 on arrival, so she's anticoagulated. Hold off on nitroglycerin too until we see the ECG — I want to rule out a right ventricular infarct first.
|
| 38 |
+
|
| 39 |
+
RESIDENT DR. CHEN: Got it. ECG is printing now. — ECG shows atrial fibrillation with rapid ventricular response at 114, ST-segment depression in leads V4 through V6, no ST elevation. No right-sided changes.
|
| 40 |
+
|
| 41 |
+
DR. OKONKWO: Okay, no STEMI and no RV involvement. The ST depression could be demand ischemia from the tachycardia, or it could be an NSTEMI. With the borderline troponin, I want a repeat troponin in three hours for delta. Go ahead and give sublingual nitro now — 0.4 milligrams — and let's see if the chest pain responds.
|
| 42 |
+
|
| 43 |
+
MARIA GARCIA: Doctor, I should mention — I took one of my old nitroglycerin tablets at home before calling 911. It didn't help much.
|
| 44 |
+
|
| 45 |
+
DR. OKONKWO: Wait — earlier you said the pain started while you were watching TV and you came straight in. Did you take nitroglycerin at home before the ambulance arrived?
|
| 46 |
+
|
| 47 |
+
MARIA GARCIA: Yes, sorry. I'm a little confused with everything happening. I did take one at home. It took the edge off for maybe five minutes and then the pain came right back.
|
| 48 |
+
|
| 49 |
+
DR. OKONKWO: Okay, that's very important information. A partial, transient response to nitro. Let's still give another dose here under monitoring. — And let's get that D-dimer. With the prolonged immobilization and calf tenderness, PE is still on my differential. However, given the CKD with eGFR of 34 and the contrast dye allergy, a CT-PA is going to be problematic. We'll need to consider a V/Q scan instead if the D-dimer comes back elevated.
|
| 50 |
+
|
| 51 |
+
RESIDENT DR. CHEN: D-dimer result just came in — it's 1,840 nanograms per milliliter. Significantly elevated.
|
| 52 |
+
|
| 53 |
+
DR. OKONKWO: That's four times the upper limit of normal. In a 72-year-old post-immobilization patient with hypoxia and calf tenderness, that's very concerning for PE. Let's order a V/Q scan — nuclear medicine should be able to do it within the hour. Also, I want a bilateral lower-extremity venous duplex to evaluate for DVT in that left leg.
|
| 54 |
+
|
| 55 |
+
DR. OKONKWO: Mrs. Garcia, we're running some additional tests. The nitro — has it helped the chest pressure at all?
|
| 56 |
+
|
| 57 |
+
MARIA GARCIA: A little bit, maybe down to a five now. But I still can't breathe well.
|
| 58 |
+
|
| 59 |
+
DR. OKONKWO: We're going to increase your oxygen to four liters. — Chen, let's also give a one-time dose of IV metoprolol, 5 milligrams, to try to rate-control the afib. Her heart rate is driving up oxygen demand. Check potassium first — with the CKD and furosemide I want to make sure we're not hypokalemic before we push a beta-blocker.
|
| 60 |
+
|
| 61 |
+
RESIDENT DR. CHEN: Potassium is 4.1, within normal limits. BNP is 450, creatinine 1.9 consistent with her baseline CKD. Hemoglobin 10.2. Pushing the metoprolol now.
|
| 62 |
+
|
| 63 |
+
DR. OKONKWO: Good. So to frame this for the chart: we're working up a dual-pathway differential — ACS, specifically NSTEMI, being evaluated with serial troponins and cardiology consult, and concurrent PE workup with V/Q scan and lower extremity duplex given the strong Wells score. The contrast allergy and CKD preclude CT-PA. She's therapeutically anticoagulated on warfarin with INR 2.6 so we're not adding heparin. Rate-controlling the afib with IV metoprolol. Monitoring on continuous telemetry. Admit to Cardiac ICU for observation pending results. — Mrs. Garcia, we're going to keep you here tonight while we sort this out. The team upstairs in the ICU is going to take excellent care of you.
|
| 64 |
+
|
| 65 |
+
MARIA GARCIA: Thank you, Doctor. I'm scared but I trust you all.
|
| 66 |
+
|
| 67 |
+
DR. OKONKWO: You're in the right place. We're going to figure out exactly what's going on and take care of it.
|
data/transcripts/medium.txt
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
DR. NAKAMURA: Good afternoon, Robert. It's been about three months since your last visit. How have things been going with the diabetes and the blood pressure?
|
| 2 |
+
|
| 3 |
+
ROBERT SMITH: Afternoon, Doc. Honestly, it's been a mixed bag. I've been pretty good about taking the metformin and the lisinopril, but I ran out of the atorvastatin for about two weeks last month because I forgot to call in the refill.
|
| 4 |
+
|
| 5 |
+
DR. NAKAMURA: Okay, that happens. We'll make a note of that. How about your blood sugars — have you been checking at home?
|
| 6 |
+
|
| 7 |
+
ROBERT SMITH: Yeah, I check fasting most mornings. They've been running between 140 and 170. I had one reading that was 198 after a birthday dinner, but I figured that was a one-off.
|
| 8 |
+
|
| 9 |
+
DR. NAKAMURA: Those fasting numbers are still higher than we'd like. Ideally we want them under 130 fasting. I have your recent lab results here — your HbA1c came back at 7.8 percent, which is up from 7.2 three months ago. That tells me your average blood sugar has been creeping up over the past few months.
|
| 10 |
+
|
| 11 |
+
ROBERT SMITH: That's not great. I was hoping it would hold steady or maybe come down a little.
|
| 12 |
+
|
| 13 |
+
DR. NAKAMURA: I understand. Let's talk about what might be contributing. You mentioned the birthday dinner — how has your diet been overall? Any changes?
|
| 14 |
+
|
| 15 |
+
ROBERT SMITH: Well, my wife and I started eating out more. There's a new barbecue place near us, and we've probably been going once or twice a week. I know the portions are huge and the sides are all carbs — mac and cheese, cornbread, baked beans. At home I've been doing okay, but not as strict as I was.
|
| 16 |
+
|
| 17 |
+
DR. NAKAMURA: That could definitely be a factor. Restaurant meals, especially barbecue with those sides, can have a significant glycemic impact. I'm not going to tell you to never eat out, but I'd like you to try limiting it to once a week, and when you do go, opt for a side salad or steamed vegetables instead of two starchy sides. Would that feel manageable?
|
| 18 |
+
|
| 19 |
+
ROBERT SMITH: Yeah, I think so. My wife's been on me about it too, so she'll be happy to hear that.
|
| 20 |
+
|
| 21 |
+
DR. NAKAMURA: Good, teamwork helps. Now, given that the HbA1c has gone up despite being on metformin 1000 twice daily, I think it's time to consider adding a second medication. I'd like to start you on a low dose of glipizide — 5 milligrams once daily with breakfast. It works differently from metformin; it helps your pancreas release more insulin. The main thing to watch for is low blood sugar, so if you feel shaky, sweaty, or lightheaded, check your sugar and have a snack.
|
| 22 |
+
|
| 23 |
+
ROBERT SMITH: Is that the one that can cause weight gain? I've been trying to keep my weight stable.
|
| 24 |
+
|
| 25 |
+
DR. NAKAMURA: It can cause modest weight gain in some people, yes. We'll monitor that closely. If weight becomes an issue, there are other classes we can consider, like an SGLT2 inhibitor, which can actually promote weight loss. But let's start here since it's well-studied and affordable on your insurance plan.
|
| 26 |
+
|
| 27 |
+
ROBERT SMITH: Okay. What about the blood pressure? My home readings have been around 135 over 85 most days.
|
| 28 |
+
|
| 29 |
+
DR. NAKAMURA: Let me check it here. — I'm getting 142 over 88. That's a bit above target; we want you under 130 over 80 given your diabetes. The lisinopril at 20 milligrams has been tolerable for you — no cough, no dizziness?
|
| 30 |
+
|
| 31 |
+
ROBERT SMITH: No, nothing like that.
|
| 32 |
+
|
| 33 |
+
DR. NAKAMURA: Good. I'd like to increase the lisinopril to 40 milligrams daily and see if that brings us closer to goal. Also, and this ties back to the diet — cutting back on the restaurant meals will help the blood pressure too, since those foods tend to be very high in sodium. Your creatinine is 1.1 and eGFR is 78, so kidney function looks stable, which is reassuring. LDL came back at 102 — not bad, but given the two-week gap on the statin, I want to make sure you stay consistent. Under 100 is our target for you.
|
| 34 |
+
|
| 35 |
+
ROBERT SMITH: Got it. I already set up auto-refills at the pharmacy so that shouldn't happen again.
|
| 36 |
+
|
| 37 |
+
DR. NAKAMURA: Perfect. So to summarize the plan: continue metformin 1000 twice daily, add glipizide 5 milligrams with breakfast, increase lisinopril from 20 to 40 milligrams daily, continue atorvastatin 40 milligrams daily with no gaps, cut restaurant meals to once weekly with healthier side choices, and I'll see you back in three months with a repeat HbA1c and a fasting lipid panel. Sound good?
|
| 38 |
+
|
| 39 |
+
ROBERT SMITH: Sounds like a plan. Thanks, Doc.
|
| 40 |
+
|
| 41 |
+
DR. NAKAMURA: You're doing the right things by showing up and staying engaged, Robert. We'll get these numbers where we want them.
|
environment/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Clinical Note Scribe environment package."""
|
| 2 |
+
|
| 3 |
+
from .models import Observation, Action, SOAPNote, Reward, EnvironmentState
|
| 4 |
+
from .env import ClinicalNoteScribeEnv
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"Observation",
|
| 8 |
+
"Action",
|
| 9 |
+
"SOAPNote",
|
| 10 |
+
"Reward",
|
| 11 |
+
"EnvironmentState",
|
| 12 |
+
"ClinicalNoteScribeEnv",
|
| 13 |
+
]
|
environment/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (443 Bytes). View file
|
|
|
environment/__pycache__/env.cpython-314.pyc
ADDED
|
Binary file (15.3 kB). View file
|
|
|
environment/__pycache__/models.cpython-314.pyc
ADDED
|
Binary file (5.51 kB). View file
|
|
|
environment/__pycache__/reward.cpython-314.pyc
ADDED
|
Binary file (7.6 kB). View file
|
|
|
environment/env.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ClinicalNoteScribeEnv — core environment loop.
|
| 2 |
+
|
| 3 |
+
Implements the ``reset() → Observation``, ``step(Action) → (Observation, Reward, bool, dict)``,
|
| 4 |
+
and ``state() → EnvironmentState`` interface required by the OpenEnv spec.
|
| 5 |
+
|
| 6 |
+
Structured logging
|
| 7 |
+
------------------
|
| 8 |
+
Every episode emits exactly three kinds of JSON log lines to **stdout**:
|
| 9 |
+
|
| 10 |
+
- ``{"event": "START", "task_id": "...", "timestamp": ...}``
|
| 11 |
+
- ``{"event": "STEP", "step": N, "action_type": "...", "reward": R}``
|
| 12 |
+
- ``{"event": "END", "task_id": "...", "final_score": S}``
|
| 13 |
+
|
| 14 |
+
The OpenEnv validator scrapes ``[START]``, ``[STEP]``, ``[END]`` keywords.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import json
|
| 20 |
+
import logging
|
| 21 |
+
import time
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Any
|
| 24 |
+
|
| 25 |
+
from environment.models import Action, EnvironmentState, Observation, Reward, SOAPNote
|
| 26 |
+
from environment.reward import compute_reward
|
| 27 |
+
from environment.tasks import GRADER_REGISTRY, TASK_REGISTRY
|
| 28 |
+
|
| 29 |
+
logger = logging.getLogger("clinical_note_scribe")
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
# Helpers
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
def _load_transcript(transcript_path: str) -> str:
|
| 37 |
+
"""Load a transcript text file from *project-root-relative* path."""
|
| 38 |
+
base = Path(__file__).resolve().parent.parent # clinical-note-scribe/
|
| 39 |
+
full_path = base / transcript_path
|
| 40 |
+
if full_path.exists():
|
| 41 |
+
return full_path.read_text(encoding="utf-8")
|
| 42 |
+
return f"[Transcript file not found: {transcript_path}]"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def _log_event(event: str, **kwargs: Any) -> None:
|
| 46 |
+
"""Emit a structured JSON log line to stdout via the logger."""
|
| 47 |
+
payload: dict[str, Any] = {"event": event, "timestamp": time.time()}
|
| 48 |
+
payload.update(kwargs)
|
| 49 |
+
logger.info(json.dumps(payload))
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _soap_to_text(soap: SOAPNote) -> str:
|
| 53 |
+
"""Flatten a SOAPNote into a readable multi-line string."""
|
| 54 |
+
return (
|
| 55 |
+
f"S: {soap.subjective}\n"
|
| 56 |
+
f"O: {soap.objective}\n"
|
| 57 |
+
f"A: {soap.assessment}\n"
|
| 58 |
+
f"P: {soap.plan}"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ---------------------------------------------------------------------------
|
| 63 |
+
# Main environment class
|
| 64 |
+
# ---------------------------------------------------------------------------
|
| 65 |
+
|
| 66 |
+
class ClinicalNoteScribeEnv:
|
| 67 |
+
"""Open-environment wrapper for the clinical note-scribe tasks.
|
| 68 |
+
|
| 69 |
+
Lifecycle
|
| 70 |
+
---------
|
| 71 |
+
1. ``env.reset(task_id)`` → returns initial ``Observation``
|
| 72 |
+
2. ``env.step(action)`` → returns ``(Observation, Reward, done, info)``
|
| 73 |
+
3. ``env.state()`` → returns full ``EnvironmentState`` snapshot
|
| 74 |
+
|
| 75 |
+
Parameters
|
| 76 |
+
----------
|
| 77 |
+
clarify_answers_path:
|
| 78 |
+
Project-root-relative path to the clarification lookup JSON.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
clarify_answers_path: str = "data/clarify_answers.json",
|
| 84 |
+
) -> None:
|
| 85 |
+
self._clarify_answers: dict[str, str] = {}
|
| 86 |
+
base = Path(__file__).resolve().parent.parent
|
| 87 |
+
ca_path = base / clarify_answers_path
|
| 88 |
+
if ca_path.exists():
|
| 89 |
+
self._clarify_answers = json.loads(ca_path.read_text(encoding="utf-8"))
|
| 90 |
+
|
| 91 |
+
# Episode state (initialised properly in reset())
|
| 92 |
+
self._task: dict[str, Any] = {}
|
| 93 |
+
self._task_id: str = ""
|
| 94 |
+
self._transcript: str = ""
|
| 95 |
+
self._patient_context: dict[str, Any] = {}
|
| 96 |
+
self._max_steps: int = 10
|
| 97 |
+
self._step_count: int = 0
|
| 98 |
+
self._done: bool = True
|
| 99 |
+
self._current_draft: str | None = None
|
| 100 |
+
self._errors_so_far: list[str] = []
|
| 101 |
+
self._last_reward: Reward | None = None
|
| 102 |
+
self._last_observation: Observation | None = None
|
| 103 |
+
|
| 104 |
+
# --------------------------------------------------------------------- #
|
| 105 |
+
# Public API
|
| 106 |
+
# --------------------------------------------------------------------- #
|
| 107 |
+
|
| 108 |
+
def reset(self, task_id: str | None = None) -> Observation:
|
| 109 |
+
"""Start (or restart) an episode for the given *task_id*.
|
| 110 |
+
|
| 111 |
+
Parameters
|
| 112 |
+
----------
|
| 113 |
+
task_id:
|
| 114 |
+
One of the keys in ``TASK_REGISTRY``. When ``None`` the first
|
| 115 |
+
registered task is used.
|
| 116 |
+
|
| 117 |
+
Returns
|
| 118 |
+
-------
|
| 119 |
+
Observation
|
| 120 |
+
The initial observation for the episode.
|
| 121 |
+
|
| 122 |
+
Raises
|
| 123 |
+
------
|
| 124 |
+
ValueError
|
| 125 |
+
If *task_id* is not found in the registry.
|
| 126 |
+
"""
|
| 127 |
+
if task_id is None:
|
| 128 |
+
task_id = next(iter(TASK_REGISTRY))
|
| 129 |
+
|
| 130 |
+
if task_id not in TASK_REGISTRY:
|
| 131 |
+
available = ", ".join(TASK_REGISTRY.keys())
|
| 132 |
+
raise ValueError(
|
| 133 |
+
f"Unknown task_id '{task_id}'. Available: {available}"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
self._task = TASK_REGISTRY[task_id]
|
| 137 |
+
self._task_id = task_id
|
| 138 |
+
self._transcript = _load_transcript(self._task["transcript_file"])
|
| 139 |
+
self._patient_context = self._task.get("patient_context", {})
|
| 140 |
+
self._max_steps = self._task.get("max_steps", 10)
|
| 141 |
+
self._step_count = 0
|
| 142 |
+
self._done = False
|
| 143 |
+
self._current_draft = None
|
| 144 |
+
self._errors_so_far = []
|
| 145 |
+
self._last_reward = None
|
| 146 |
+
|
| 147 |
+
_log_event("START", task_id=self._task_id)
|
| 148 |
+
|
| 149 |
+
obs = self._build_observation()
|
| 150 |
+
self._last_observation = obs
|
| 151 |
+
return obs
|
| 152 |
+
|
| 153 |
+
def step(self, action: Action) -> tuple[Observation, Reward, bool, dict]:
|
| 154 |
+
"""Execute one agent action and return the resulting observation, reward,
|
| 155 |
+
done flag, and info dict.
|
| 156 |
+
|
| 157 |
+
Parameters
|
| 158 |
+
----------
|
| 159 |
+
action:
|
| 160 |
+
The agent's chosen action.
|
| 161 |
+
|
| 162 |
+
Returns
|
| 163 |
+
-------
|
| 164 |
+
tuple[Observation, Reward, bool, dict]
|
| 165 |
+
"""
|
| 166 |
+
if self._done:
|
| 167 |
+
raise RuntimeError(
|
| 168 |
+
"Episode is done. Call reset() before stepping again."
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
self._step_count += 1
|
| 172 |
+
info: dict[str, Any] = {}
|
| 173 |
+
|
| 174 |
+
# ---- dispatch by action type ----
|
| 175 |
+
if action.action_type == "submit_note":
|
| 176 |
+
reward = self._handle_submit(action, info)
|
| 177 |
+
elif action.action_type == "request_clarify":
|
| 178 |
+
reward = self._handle_clarify(action, info)
|
| 179 |
+
elif action.action_type == "revise_section":
|
| 180 |
+
reward = self._handle_revise(action, info)
|
| 181 |
+
else:
|
| 182 |
+
# Should never happen thanks to the Literal type, but be safe
|
| 183 |
+
self._errors_so_far.append(
|
| 184 |
+
f"Unknown action_type: {action.action_type}"
|
| 185 |
+
)
|
| 186 |
+
reward = compute_reward(
|
| 187 |
+
action,
|
| 188 |
+
grader_score=0.0,
|
| 189 |
+
step_count=self._step_count,
|
| 190 |
+
errors_so_far=self._errors_so_far,
|
| 191 |
+
done=False,
|
| 192 |
+
info={"error": "bad_action"},
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# ---- enforce max-step termination ----
|
| 196 |
+
if self._step_count >= self._max_steps and not self._done:
|
| 197 |
+
self._done = True
|
| 198 |
+
reward = Reward(
|
| 199 |
+
value=reward.value,
|
| 200 |
+
signals=reward.signals,
|
| 201 |
+
done=True,
|
| 202 |
+
info={**reward.info, "termination_reason": "max_steps_reached"},
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
self._last_reward = reward
|
| 206 |
+
|
| 207 |
+
_log_event(
|
| 208 |
+
"STEP",
|
| 209 |
+
step=self._step_count,
|
| 210 |
+
action_type=action.action_type,
|
| 211 |
+
reward=reward.value,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
if self._done:
|
| 215 |
+
_log_event(
|
| 216 |
+
"END",
|
| 217 |
+
task_id=self._task_id,
|
| 218 |
+
final_score=reward.value,
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
obs = self._build_observation()
|
| 222 |
+
self._last_observation = obs
|
| 223 |
+
return obs, reward, self._done, info
|
| 224 |
+
|
| 225 |
+
def state(self) -> EnvironmentState:
|
| 226 |
+
"""Return the full internal state snapshot."""
|
| 227 |
+
return EnvironmentState(
|
| 228 |
+
task_id=self._task_id,
|
| 229 |
+
step_count=self._step_count,
|
| 230 |
+
max_steps=self._max_steps,
|
| 231 |
+
done=self._done,
|
| 232 |
+
current_draft=self._current_draft,
|
| 233 |
+
errors_so_far=list(self._errors_so_far),
|
| 234 |
+
last_reward=self._last_reward,
|
| 235 |
+
observation=self._last_observation,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
# --------------------------------------------------------------------- #
|
| 239 |
+
# Action handlers
|
| 240 |
+
# --------------------------------------------------------------------- #
|
| 241 |
+
|
| 242 |
+
def _handle_submit(self, action: Action, info: dict) -> Reward:
|
| 243 |
+
"""Process a ``submit_note`` action."""
|
| 244 |
+
if action.soap_note is None:
|
| 245 |
+
error = "submit_note requires a non-null soap_note."
|
| 246 |
+
self._errors_so_far.append(error)
|
| 247 |
+
return compute_reward(
|
| 248 |
+
action,
|
| 249 |
+
grader_score=0.0,
|
| 250 |
+
step_count=self._step_count,
|
| 251 |
+
errors_so_far=self._errors_so_far,
|
| 252 |
+
done=False,
|
| 253 |
+
info={"error": error},
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
self._current_draft = _soap_to_text(action.soap_note)
|
| 257 |
+
self._done = True
|
| 258 |
+
|
| 259 |
+
# Attempt to grade via the task-specific grader
|
| 260 |
+
grader = GRADER_REGISTRY.get(self._task_id)
|
| 261 |
+
if grader is None:
|
| 262 |
+
info["warning"] = "No grader registered; returning default reward."
|
| 263 |
+
return compute_reward(
|
| 264 |
+
action,
|
| 265 |
+
grader_score=0.5,
|
| 266 |
+
step_count=self._step_count,
|
| 267 |
+
errors_so_far=self._errors_so_far,
|
| 268 |
+
done=True,
|
| 269 |
+
info=info,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
try:
|
| 273 |
+
raw_signals = grader(action.soap_note, self._task)
|
| 274 |
+
# Grader returns a signals dict; extract a single scalar score
|
| 275 |
+
# as the mean of its values for use as grader_score.
|
| 276 |
+
grader_score = (
|
| 277 |
+
sum(raw_signals.values()) / len(raw_signals)
|
| 278 |
+
if raw_signals else 0.0
|
| 279 |
+
)
|
| 280 |
+
info["grader_signals"] = raw_signals
|
| 281 |
+
except NotImplementedError:
|
| 282 |
+
info["warning"] = "Grader not yet implemented; returning placeholder."
|
| 283 |
+
grader_score = 0.5
|
| 284 |
+
|
| 285 |
+
return compute_reward(
|
| 286 |
+
action,
|
| 287 |
+
grader_score=grader_score,
|
| 288 |
+
step_count=self._step_count,
|
| 289 |
+
errors_so_far=self._errors_so_far,
|
| 290 |
+
done=True,
|
| 291 |
+
info=info,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
def _handle_clarify(self, action: Action, info: dict) -> Reward:
|
| 295 |
+
"""Process a ``request_clarify`` action."""
|
| 296 |
+
question = (action.clarify_question or "").strip()
|
| 297 |
+
if not question:
|
| 298 |
+
error = "request_clarify requires a non-empty clarify_question."
|
| 299 |
+
self._errors_so_far.append(error)
|
| 300 |
+
return compute_reward(
|
| 301 |
+
action,
|
| 302 |
+
grader_score=0.0,
|
| 303 |
+
step_count=self._step_count,
|
| 304 |
+
errors_so_far=self._errors_so_far,
|
| 305 |
+
done=False,
|
| 306 |
+
info={"error": error},
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Lookup a canned answer (case-insensitive key match)
|
| 310 |
+
answer = self._clarify_answers.get(question.lower())
|
| 311 |
+
if answer:
|
| 312 |
+
info["clarify_answer"] = answer
|
| 313 |
+
else:
|
| 314 |
+
info["clarify_answer"] = (
|
| 315 |
+
"No additional information available for that question."
|
| 316 |
+
)
|
| 317 |
+
|
| 318 |
+
# Clarification steps earn no grader_score; step_penalty accrues naturally
|
| 319 |
+
return compute_reward(
|
| 320 |
+
action,
|
| 321 |
+
grader_score=0.0,
|
| 322 |
+
step_count=self._step_count,
|
| 323 |
+
errors_so_far=self._errors_so_far,
|
| 324 |
+
done=False,
|
| 325 |
+
info=info,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
def _handle_revise(self, action: Action, info: dict) -> Reward:
|
| 329 |
+
"""Process a ``revise_section`` action."""
|
| 330 |
+
if action.section is None or action.revision_text is None:
|
| 331 |
+
error = "revise_section requires both 'section' and 'revision_text'."
|
| 332 |
+
self._errors_so_far.append(error)
|
| 333 |
+
return compute_reward(
|
| 334 |
+
action,
|
| 335 |
+
grader_score=0.0,
|
| 336 |
+
step_count=self._step_count,
|
| 337 |
+
errors_so_far=self._errors_so_far,
|
| 338 |
+
done=False,
|
| 339 |
+
info={"error": error},
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
# If there is an existing draft, patch the requested section
|
| 343 |
+
if self._current_draft:
|
| 344 |
+
lines = self._current_draft.split("\n")
|
| 345 |
+
prefix = f"{action.section}: "
|
| 346 |
+
patched = False
|
| 347 |
+
for i, line in enumerate(lines):
|
| 348 |
+
if line.startswith(prefix):
|
| 349 |
+
lines[i] = f"{prefix}{action.revision_text}"
|
| 350 |
+
patched = True
|
| 351 |
+
break
|
| 352 |
+
if patched:
|
| 353 |
+
self._current_draft = "\n".join(lines)
|
| 354 |
+
else:
|
| 355 |
+
self._current_draft += f"\n{prefix}{action.revision_text}"
|
| 356 |
+
else:
|
| 357 |
+
self._current_draft = f"{action.section}: {action.revision_text}"
|
| 358 |
+
|
| 359 |
+
info["revised_section"] = action.section
|
| 360 |
+
|
| 361 |
+
# Revision steps earn no grader_score; deductions still apply
|
| 362 |
+
return compute_reward(
|
| 363 |
+
action,
|
| 364 |
+
grader_score=0.0,
|
| 365 |
+
step_count=self._step_count,
|
| 366 |
+
errors_so_far=self._errors_so_far,
|
| 367 |
+
done=False,
|
| 368 |
+
info=info,
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
# --------------------------------------------------------------------- #
|
| 372 |
+
# Internal helpers
|
| 373 |
+
# --------------------------------------------------------------------- #
|
| 374 |
+
|
| 375 |
+
def _build_observation(self) -> Observation:
|
| 376 |
+
return Observation(
|
| 377 |
+
transcript=self._transcript,
|
| 378 |
+
task_id=self._task_id,
|
| 379 |
+
patient_context=self._patient_context,
|
| 380 |
+
current_draft=self._current_draft,
|
| 381 |
+
errors_so_far=list(self._errors_so_far),
|
| 382 |
+
step_count=self._step_count,
|
| 383 |
+
)
|
environment/models.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Pydantic v2 models for the Clinical Note Scribe environment.
|
| 2 |
+
|
| 3 |
+
Defines the typed contracts for observations, actions, rewards,
|
| 4 |
+
and overall environment state used by the OpenEnv spec.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from __future__ import annotations
|
| 8 |
+
|
| 9 |
+
from typing import Any, Literal, Optional
|
| 10 |
+
|
| 11 |
+
from pydantic import BaseModel, Field
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# ---------------------------------------------------------------------------
|
| 15 |
+
# Observation — what the agent sees after each step
|
| 16 |
+
# ---------------------------------------------------------------------------
|
| 17 |
+
|
| 18 |
+
class Observation(BaseModel):
|
| 19 |
+
"""Snapshot of the environment returned to the agent."""
|
| 20 |
+
|
| 21 |
+
transcript: str = Field(
|
| 22 |
+
...,
|
| 23 |
+
description="Full doctor–patient transcript for the current task.",
|
| 24 |
+
)
|
| 25 |
+
task_id: str = Field(
|
| 26 |
+
...,
|
| 27 |
+
description="Unique identifier for the task (e.g. 'easy_routine_checkup').",
|
| 28 |
+
)
|
| 29 |
+
patient_context: dict[str, Any] = Field(
|
| 30 |
+
default_factory=dict,
|
| 31 |
+
description="Structured patient demographics and history.",
|
| 32 |
+
)
|
| 33 |
+
current_draft: Optional[str] = Field(
|
| 34 |
+
default=None,
|
| 35 |
+
description="The agent's most recent SOAP-note draft, if any.",
|
| 36 |
+
)
|
| 37 |
+
errors_so_far: list[str] = Field(
|
| 38 |
+
default_factory=list,
|
| 39 |
+
description="Accumulated error/feedback messages from prior steps.",
|
| 40 |
+
)
|
| 41 |
+
step_count: int = Field(
|
| 42 |
+
default=0,
|
| 43 |
+
ge=0,
|
| 44 |
+
description="Number of steps taken in the current episode.",
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# SOAPNote — structured clinical note
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
|
| 52 |
+
class SOAPNote(BaseModel):
|
| 53 |
+
"""Standard SOAP clinical-note format."""
|
| 54 |
+
|
| 55 |
+
subjective: str = Field(
|
| 56 |
+
...,
|
| 57 |
+
description="Patient's self-reported symptoms and history.",
|
| 58 |
+
)
|
| 59 |
+
objective: str = Field(
|
| 60 |
+
...,
|
| 61 |
+
description="Clinician's measurable findings (vitals, exam, labs).",
|
| 62 |
+
)
|
| 63 |
+
assessment: str = Field(
|
| 64 |
+
...,
|
| 65 |
+
description="Clinician's diagnosis or differential.",
|
| 66 |
+
)
|
| 67 |
+
plan: str = Field(
|
| 68 |
+
...,
|
| 69 |
+
description="Treatment plan, follow-ups, and prescriptions.",
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# ---------------------------------------------------------------------------
|
| 74 |
+
# Action — what the agent can do
|
| 75 |
+
# ---------------------------------------------------------------------------
|
| 76 |
+
|
| 77 |
+
class Action(BaseModel):
|
| 78 |
+
"""An action the agent submits to the environment."""
|
| 79 |
+
|
| 80 |
+
action_type: Literal["submit_note", "request_clarify", "revise_section"] = Field(
|
| 81 |
+
...,
|
| 82 |
+
description="The kind of action the agent is taking.",
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
# --- submit_note fields ---
|
| 86 |
+
soap_note: Optional[SOAPNote] = Field(
|
| 87 |
+
default=None,
|
| 88 |
+
description="Complete SOAP note (required when action_type == 'submit_note').",
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# --- revise_section fields ---
|
| 92 |
+
section: Optional[Literal["S", "O", "A", "P"]] = Field(
|
| 93 |
+
default=None,
|
| 94 |
+
description="Which SOAP section to revise (required when action_type == 'revise_section').",
|
| 95 |
+
)
|
| 96 |
+
revision_text: Optional[str] = Field(
|
| 97 |
+
default=None,
|
| 98 |
+
description="Replacement text for the specified section.",
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# --- request_clarify fields ---
|
| 102 |
+
clarify_question: Optional[str] = Field(
|
| 103 |
+
default=None,
|
| 104 |
+
description="Free-text question the agent asks for clarification.",
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
# Reward — multi-signal feedback
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
|
| 112 |
+
class Reward(BaseModel):
|
| 113 |
+
"""Reward returned after each step."""
|
| 114 |
+
|
| 115 |
+
value: float = Field(
|
| 116 |
+
...,
|
| 117 |
+
ge=0.0,
|
| 118 |
+
le=1.0,
|
| 119 |
+
description="Aggregate reward in the range [0.0, 1.0].",
|
| 120 |
+
)
|
| 121 |
+
signals: dict[str, float] = Field(
|
| 122 |
+
default_factory=dict,
|
| 123 |
+
description="Breakdown of individual reward sub-signals.",
|
| 124 |
+
)
|
| 125 |
+
done: bool = Field(
|
| 126 |
+
...,
|
| 127 |
+
description="Whether the episode has ended.",
|
| 128 |
+
)
|
| 129 |
+
info: dict[str, Any] = Field(
|
| 130 |
+
default_factory=dict,
|
| 131 |
+
description="Auxiliary metadata (e.g. grader diagnostics).",
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# ---------------------------------------------------------------------------
|
| 136 |
+
# EnvironmentState — full internal state exposed by state()
|
| 137 |
+
# ---------------------------------------------------------------------------
|
| 138 |
+
|
| 139 |
+
class EnvironmentState(BaseModel):
|
| 140 |
+
"""Complete snapshot of the environment's internal state."""
|
| 141 |
+
|
| 142 |
+
task_id: str = Field(
|
| 143 |
+
...,
|
| 144 |
+
description="Active task identifier.",
|
| 145 |
+
)
|
| 146 |
+
step_count: int = Field(
|
| 147 |
+
default=0,
|
| 148 |
+
ge=0,
|
| 149 |
+
description="Steps taken so far in this episode.",
|
| 150 |
+
)
|
| 151 |
+
max_steps: int = Field(
|
| 152 |
+
default=10,
|
| 153 |
+
ge=1,
|
| 154 |
+
description="Maximum steps allowed per episode.",
|
| 155 |
+
)
|
| 156 |
+
done: bool = Field(
|
| 157 |
+
default=False,
|
| 158 |
+
description="Whether the current episode has terminated.",
|
| 159 |
+
)
|
| 160 |
+
current_draft: Optional[str] = Field(
|
| 161 |
+
default=None,
|
| 162 |
+
description="Latest SOAP-note draft text, if any.",
|
| 163 |
+
)
|
| 164 |
+
errors_so_far: list[str] = Field(
|
| 165 |
+
default_factory=list,
|
| 166 |
+
description="Accumulated feedback/error messages.",
|
| 167 |
+
)
|
| 168 |
+
last_reward: Optional[Reward] = Field(
|
| 169 |
+
default=None,
|
| 170 |
+
description="Most recent reward object, if a step has been taken.",
|
| 171 |
+
)
|
| 172 |
+
observation: Optional[Observation] = Field(
|
| 173 |
+
default=None,
|
| 174 |
+
description="Most recent observation returned to the agent.",
|
| 175 |
+
)
|
environment/reward.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Multi-signal reward computation for the Clinical Note Scribe environment.
|
| 2 |
+
|
| 3 |
+
Reward formula (all weights sum to 1.0 before penalties):
|
| 4 |
+
|
| 5 |
+
weighted_sum = grader_score × 0.60 (clinical accuracy from task grader)
|
| 6 |
+
+ conciseness_bonus × 0.10 (1.0 if note ≤ 400 words, else 0.0)
|
| 7 |
+
+ safe_language_score× 0.15 (1.0 if no unsafe-certainty phrases)
|
| 8 |
+
+ format_valid × 0.15 (1.0 if SOAP JSON is well-formed)
|
| 9 |
+
|
| 10 |
+
Deductions (applied after weighted sum):
|
| 11 |
+
- 0.05 × max(0, step_count - 3) (penalty for excessive clarification steps)
|
| 12 |
+
- 0.10 × len(errors_so_far) (penalty for each invalid action)
|
| 13 |
+
|
| 14 |
+
Final value is clamped to [0.0, 1.0].
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import re
|
| 20 |
+
from typing import Any, Optional
|
| 21 |
+
|
| 22 |
+
from environment.models import Action, Reward, SOAPNote
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ---------------------------------------------------------------------------
|
| 26 |
+
# Weights
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
|
| 29 |
+
W_GRADER = 0.60
|
| 30 |
+
W_CONCISE = 0.10
|
| 31 |
+
W_SAFE_LANG = 0.15
|
| 32 |
+
W_FORMAT = 0.15
|
| 33 |
+
|
| 34 |
+
# Deduction constants
|
| 35 |
+
STEP_PENALTY_RATE = 0.05 # per step beyond FREE_STEPS
|
| 36 |
+
FREE_STEPS = 3
|
| 37 |
+
ERROR_PENALTY_RATE = 0.10 # per item in errors_so_far
|
| 38 |
+
|
| 39 |
+
# Conciseness threshold
|
| 40 |
+
WORD_LIMIT = 400
|
| 41 |
+
|
| 42 |
+
# Phrases that indicate unsafe clinical certainty
|
| 43 |
+
# (over-confident language that a scribe should avoid in a note)
|
| 44 |
+
_UNSAFE_PATTERNS: list[re.Pattern[str]] = [
|
| 45 |
+
re.compile(p, re.IGNORECASE)
|
| 46 |
+
for p in [
|
| 47 |
+
r"\bpatient definitely has\b",
|
| 48 |
+
r"\bdiagnosis is certain\b",
|
| 49 |
+
r"\bno doubt\b",
|
| 50 |
+
r"\babsolutely confirmed\b",
|
| 51 |
+
r"\b100%\s+certain\b",
|
| 52 |
+
r"\bwill definitely\b",
|
| 53 |
+
r"\bguaranteed to\b",
|
| 54 |
+
r"\bcannot be\s+\w+\s+else\b",
|
| 55 |
+
r"\bwithout question\b",
|
| 56 |
+
r"\bthis is clearly\b",
|
| 57 |
+
]
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ---------------------------------------------------------------------------
|
| 62 |
+
# Sub-signal helpers
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
|
| 65 |
+
def _conciseness_bonus(soap_note: Optional[SOAPNote]) -> float:
|
| 66 |
+
"""Return 1.0 if the total SOAP note word count is at or below WORD_LIMIT."""
|
| 67 |
+
if soap_note is None:
|
| 68 |
+
return 0.0
|
| 69 |
+
text = " ".join([
|
| 70 |
+
soap_note.subjective,
|
| 71 |
+
soap_note.objective,
|
| 72 |
+
soap_note.assessment,
|
| 73 |
+
soap_note.plan,
|
| 74 |
+
])
|
| 75 |
+
word_count = len(text.split())
|
| 76 |
+
return 1.0 if word_count <= WORD_LIMIT else 0.0
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _safe_language_score(soap_note: Optional[SOAPNote]) -> float:
|
| 80 |
+
"""Return 1.0 if no unsafe-certainty phrases are found in the SOAP note."""
|
| 81 |
+
if soap_note is None:
|
| 82 |
+
return 1.0 # no note submitted → no unsafe language
|
| 83 |
+
text = " ".join([
|
| 84 |
+
soap_note.subjective,
|
| 85 |
+
soap_note.objective,
|
| 86 |
+
soap_note.assessment,
|
| 87 |
+
soap_note.plan,
|
| 88 |
+
])
|
| 89 |
+
for pattern in _UNSAFE_PATTERNS:
|
| 90 |
+
if pattern.search(text):
|
| 91 |
+
return 0.0
|
| 92 |
+
return 1.0
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _format_valid(action: Action) -> float:
|
| 96 |
+
"""Return 1.0 if the submitted note has all required non-empty SOAP fields.
|
| 97 |
+
|
| 98 |
+
This acts as a lightweight structural / «JSON well-formed» check:
|
| 99 |
+
each of S, O, A, P must be a non-empty string, and the action_type
|
| 100 |
+
must be ``submit_note``.
|
| 101 |
+
"""
|
| 102 |
+
if action.action_type != "submit_note":
|
| 103 |
+
return 1.0 # non-submission actions are not graded on format
|
| 104 |
+
if action.soap_note is None:
|
| 105 |
+
return 0.0
|
| 106 |
+
soap = action.soap_note
|
| 107 |
+
fields = [soap.subjective, soap.objective, soap.assessment, soap.plan]
|
| 108 |
+
if all(isinstance(f, str) and f.strip() for f in fields):
|
| 109 |
+
return 1.0
|
| 110 |
+
return 0.0
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# ---------------------------------------------------------------------------
|
| 114 |
+
# Public API
|
| 115 |
+
# ---------------------------------------------------------------------------
|
| 116 |
+
|
| 117 |
+
def compute_reward(
|
| 118 |
+
action: Action,
|
| 119 |
+
grader_score: float,
|
| 120 |
+
step_count: int,
|
| 121 |
+
errors_so_far: list[str],
|
| 122 |
+
*,
|
| 123 |
+
done: bool = False,
|
| 124 |
+
info: Optional[dict[str, Any]] = None,
|
| 125 |
+
) -> Reward:
|
| 126 |
+
"""Compute the multi-signal reward for a completed step.
|
| 127 |
+
|
| 128 |
+
Parameters
|
| 129 |
+
----------
|
| 130 |
+
action:
|
| 131 |
+
The action that was just executed.
|
| 132 |
+
grader_score:
|
| 133 |
+
Clinical-accuracy score returned by the task-specific grader (0.0–1.0).
|
| 134 |
+
Use 0.0 for non-submission actions.
|
| 135 |
+
step_count:
|
| 136 |
+
Total number of steps taken so far in the episode (including this one).
|
| 137 |
+
errors_so_far:
|
| 138 |
+
List of error messages accumulated during the episode.
|
| 139 |
+
done:
|
| 140 |
+
Whether the episode ended with this step.
|
| 141 |
+
info:
|
| 142 |
+
Optional auxiliary metadata dict to include in the Reward.
|
| 143 |
+
|
| 144 |
+
Returns
|
| 145 |
+
-------
|
| 146 |
+
Reward
|
| 147 |
+
Fully populated Reward with ``value`` and ``signals`` breakdown.
|
| 148 |
+
"""
|
| 149 |
+
grader_score = max(0.0, min(1.0, grader_score))
|
| 150 |
+
|
| 151 |
+
# ---- per-signal scores ----
|
| 152 |
+
conciseness = _conciseness_bonus(action.soap_note)
|
| 153 |
+
safe_lang = _safe_language_score(action.soap_note)
|
| 154 |
+
fmt = _format_valid(action)
|
| 155 |
+
|
| 156 |
+
# ---- weighted sum ----
|
| 157 |
+
weighted = (
|
| 158 |
+
grader_score * W_GRADER
|
| 159 |
+
+ conciseness * W_CONCISE
|
| 160 |
+
+ safe_lang * W_SAFE_LANG
|
| 161 |
+
+ fmt * W_FORMAT
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# ---- deductions ----
|
| 165 |
+
extra_steps = max(0, step_count - FREE_STEPS)
|
| 166 |
+
step_penalty = extra_steps * STEP_PENALTY_RATE
|
| 167 |
+
error_penalty = len(errors_so_far) * ERROR_PENALTY_RATE
|
| 168 |
+
|
| 169 |
+
raw = weighted - step_penalty - error_penalty
|
| 170 |
+
|
| 171 |
+
# ---- clamp ----
|
| 172 |
+
value = max(0.0, min(1.0, raw))
|
| 173 |
+
|
| 174 |
+
signals: dict[str, float] = {
|
| 175 |
+
# positive contributions
|
| 176 |
+
"grader_score": round(grader_score * W_GRADER, 4),
|
| 177 |
+
"conciseness_bonus": round(conciseness * W_CONCISE, 4),
|
| 178 |
+
"safe_language_score": round(safe_lang * W_SAFE_LANG, 4),
|
| 179 |
+
"format_valid": round(fmt * W_FORMAT, 4),
|
| 180 |
+
# deductions (stored as negative numbers for clarity)
|
| 181 |
+
"step_penalty": round(-step_penalty, 4),
|
| 182 |
+
"error_penalty": round(-error_penalty, 4),
|
| 183 |
+
# raw sub-signal values (unweighted, for introspection)
|
| 184 |
+
"_grader_score_raw": round(grader_score, 4),
|
| 185 |
+
"_conciseness_raw": round(conciseness, 4),
|
| 186 |
+
"_safe_language_raw": round(safe_lang, 4),
|
| 187 |
+
"_format_valid_raw": round(fmt, 4),
|
| 188 |
+
"_extra_steps": float(extra_steps),
|
| 189 |
+
"_error_count": float(len(errors_so_far)),
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
return Reward(
|
| 193 |
+
value=round(value, 4),
|
| 194 |
+
signals=signals,
|
| 195 |
+
done=done,
|
| 196 |
+
info=info or {},
|
| 197 |
+
)
|
environment/tasks/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Task definitions for easy, medium, and hard scenarios."""
|
| 2 |
+
|
| 3 |
+
from .task_easy import EASY_TASK, grade_easy
|
| 4 |
+
from .task_medium import MEDIUM_TASK, grade_medium
|
| 5 |
+
from .task_hard import HARD_TASK, grade_hard
|
| 6 |
+
|
| 7 |
+
TASK_REGISTRY: dict[str, dict] = {
|
| 8 |
+
EASY_TASK["task_id"]: EASY_TASK,
|
| 9 |
+
MEDIUM_TASK["task_id"]: MEDIUM_TASK,
|
| 10 |
+
HARD_TASK["task_id"]: HARD_TASK,
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
GRADER_REGISTRY: dict[str, callable] = {
|
| 14 |
+
EASY_TASK["task_id"]: grade_easy,
|
| 15 |
+
MEDIUM_TASK["task_id"]: grade_medium,
|
| 16 |
+
HARD_TASK["task_id"]: grade_hard,
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"TASK_REGISTRY",
|
| 21 |
+
"GRADER_REGISTRY",
|
| 22 |
+
]
|
environment/tasks/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (1.15 kB). View file
|
|
|
environment/tasks/__pycache__/task_easy.cpython-314.pyc
ADDED
|
Binary file (1.69 kB). View file
|
|
|
environment/tasks/__pycache__/task_hard.cpython-314.pyc
ADDED
|
Binary file (2.4 kB). View file
|
|
|
environment/tasks/__pycache__/task_medium.cpython-314.pyc
ADDED
|
Binary file (2.02 kB). View file
|
|
|
environment/tasks/task_easy.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Easy task — routine check-up.
|
| 2 |
+
|
| 3 |
+
Grader is intentionally left unimplemented.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from environment.models import SOAPNote
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
# Task definition
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
EASY_TASK: dict[str, Any] = {
|
| 18 |
+
"task_id": "easy_routine_checkup",
|
| 19 |
+
"description": "Generate a SOAP note for a routine annual check-up visit.",
|
| 20 |
+
"transcript_file": "data/transcripts/easy.txt",
|
| 21 |
+
"patient_context": {
|
| 22 |
+
"patient_id": "P-1001",
|
| 23 |
+
"name": "Jane Doe",
|
| 24 |
+
"age": 34,
|
| 25 |
+
"sex": "F",
|
| 26 |
+
"known_conditions": [],
|
| 27 |
+
"current_medications": [],
|
| 28 |
+
"allergies": ["Penicillin"],
|
| 29 |
+
},
|
| 30 |
+
"max_steps": 5,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
# Grader (not yet implemented)
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
|
| 38 |
+
def grade_easy(soap_note: SOAPNote, task: dict[str, Any]) -> dict[str, float]:
|
| 39 |
+
"""Score a submitted SOAP note against the easy-task rubric.
|
| 40 |
+
|
| 41 |
+
Parameters
|
| 42 |
+
----------
|
| 43 |
+
soap_note:
|
| 44 |
+
The agent's submitted clinical note.
|
| 45 |
+
task:
|
| 46 |
+
The task definition dict (``EASY_TASK``).
|
| 47 |
+
|
| 48 |
+
Returns
|
| 49 |
+
-------
|
| 50 |
+
dict mapping signal names → float scores in [0, 1].
|
| 51 |
+
|
| 52 |
+
Raises
|
| 53 |
+
------
|
| 54 |
+
NotImplementedError
|
| 55 |
+
Grader has not been implemented yet.
|
| 56 |
+
"""
|
| 57 |
+
raise NotImplementedError("Easy-task grader is not yet implemented.")
|
environment/tasks/task_hard.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hard task — complex ER visit.
|
| 2 |
+
|
| 3 |
+
Grader is intentionally left unimplemented.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from environment.models import SOAPNote
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
# Task definition
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
HARD_TASK: dict[str, Any] = {
|
| 18 |
+
"task_id": "hard_complex_er_visit",
|
| 19 |
+
"description": (
|
| 20 |
+
"Generate a SOAP note for a complex emergency-room visit involving "
|
| 21 |
+
"chest pain, polytrauma assessment, and multiple co-morbidities."
|
| 22 |
+
),
|
| 23 |
+
"transcript_file": "data/transcripts/hard.txt",
|
| 24 |
+
"patient_context": {
|
| 25 |
+
"patient_id": "P-3782",
|
| 26 |
+
"name": "Maria Garcia",
|
| 27 |
+
"age": 72,
|
| 28 |
+
"sex": "F",
|
| 29 |
+
"known_conditions": [
|
| 30 |
+
"Coronary Artery Disease",
|
| 31 |
+
"Atrial Fibrillation",
|
| 32 |
+
"Chronic Kidney Disease Stage 3",
|
| 33 |
+
"Osteoarthritis",
|
| 34 |
+
],
|
| 35 |
+
"current_medications": [
|
| 36 |
+
"Aspirin 81 mg daily",
|
| 37 |
+
"Warfarin 5 mg daily",
|
| 38 |
+
"Metoprolol 50 mg BID",
|
| 39 |
+
"Furosemide 40 mg daily",
|
| 40 |
+
"Amlodipine 5 mg daily",
|
| 41 |
+
],
|
| 42 |
+
"allergies": ["Sulfa drugs", "Contrast dye"],
|
| 43 |
+
"recent_labs": {
|
| 44 |
+
"troponin_I": "0.08 ng/mL",
|
| 45 |
+
"BNP": "450 pg/mL",
|
| 46 |
+
"creatinine": "1.9 mg/dL",
|
| 47 |
+
"eGFR": "34 mL/min",
|
| 48 |
+
"INR": "2.6",
|
| 49 |
+
"hemoglobin": "10.2 g/dL",
|
| 50 |
+
},
|
| 51 |
+
"vitals_on_arrival": {
|
| 52 |
+
"BP": "168/94 mmHg",
|
| 53 |
+
"HR": "112 bpm (irregular)",
|
| 54 |
+
"RR": "22 breaths/min",
|
| 55 |
+
"SpO2": "91% on room air",
|
| 56 |
+
"Temp": "37.2°C",
|
| 57 |
+
},
|
| 58 |
+
},
|
| 59 |
+
"max_steps": 10,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
# ---------------------------------------------------------------------------
|
| 64 |
+
# Grader (not yet implemented)
|
| 65 |
+
# ---------------------------------------------------------------------------
|
| 66 |
+
|
| 67 |
+
def grade_hard(soap_note: SOAPNote, task: dict[str, Any]) -> dict[str, float]:
|
| 68 |
+
"""Score a submitted SOAP note against the hard-task rubric.
|
| 69 |
+
|
| 70 |
+
Parameters
|
| 71 |
+
----------
|
| 72 |
+
soap_note:
|
| 73 |
+
The agent's submitted clinical note.
|
| 74 |
+
task:
|
| 75 |
+
The task definition dict (``HARD_TASK``).
|
| 76 |
+
|
| 77 |
+
Returns
|
| 78 |
+
-------
|
| 79 |
+
dict mapping signal names → float scores in [0, 1].
|
| 80 |
+
|
| 81 |
+
Raises
|
| 82 |
+
------
|
| 83 |
+
NotImplementedError
|
| 84 |
+
Grader has not been implemented yet.
|
| 85 |
+
"""
|
| 86 |
+
raise NotImplementedError("Hard-task grader is not yet implemented.")
|
environment/tasks/task_medium.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Medium task — chronic disease follow-up.
|
| 2 |
+
|
| 3 |
+
Grader is intentionally left unimplemented.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
from typing import Any
|
| 9 |
+
|
| 10 |
+
from environment.models import SOAPNote
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
# Task definition
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
|
| 17 |
+
MEDIUM_TASK: dict[str, Any] = {
|
| 18 |
+
"task_id": "medium_chronic_disease_followup",
|
| 19 |
+
"description": "Generate a SOAP note for a Type 2 Diabetes follow-up visit.",
|
| 20 |
+
"transcript_file": "data/transcripts/medium.txt",
|
| 21 |
+
"patient_context": {
|
| 22 |
+
"patient_id": "P-2045",
|
| 23 |
+
"name": "Robert Smith",
|
| 24 |
+
"age": 58,
|
| 25 |
+
"sex": "M",
|
| 26 |
+
"known_conditions": ["Type 2 Diabetes Mellitus", "Hypertension"],
|
| 27 |
+
"current_medications": [
|
| 28 |
+
"Metformin 1000 mg BID",
|
| 29 |
+
"Lisinopril 20 mg daily",
|
| 30 |
+
"Atorvastatin 40 mg daily",
|
| 31 |
+
],
|
| 32 |
+
"allergies": [],
|
| 33 |
+
"recent_labs": {
|
| 34 |
+
"HbA1c": "7.8%",
|
| 35 |
+
"fasting_glucose": "156 mg/dL",
|
| 36 |
+
"creatinine": "1.1 mg/dL",
|
| 37 |
+
"eGFR": "78 mL/min",
|
| 38 |
+
"LDL": "102 mg/dL",
|
| 39 |
+
},
|
| 40 |
+
},
|
| 41 |
+
"max_steps": 8,
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
# Grader (not yet implemented)
|
| 47 |
+
# ---------------------------------------------------------------------------
|
| 48 |
+
|
| 49 |
+
def grade_medium(soap_note: SOAPNote, task: dict[str, Any]) -> dict[str, float]:
|
| 50 |
+
"""Score a submitted SOAP note against the medium-task rubric.
|
| 51 |
+
|
| 52 |
+
Parameters
|
| 53 |
+
----------
|
| 54 |
+
soap_note:
|
| 55 |
+
The agent's submitted clinical note.
|
| 56 |
+
task:
|
| 57 |
+
The task definition dict (``MEDIUM_TASK``).
|
| 58 |
+
|
| 59 |
+
Returns
|
| 60 |
+
-------
|
| 61 |
+
dict mapping signal names → float scores in [0, 1].
|
| 62 |
+
|
| 63 |
+
Raises
|
| 64 |
+
------
|
| 65 |
+
NotImplementedError
|
| 66 |
+
Grader has not been implemented yet.
|
| 67 |
+
"""
|
| 68 |
+
raise NotImplementedError("Medium-task grader is not yet implemented.")
|
inference.py
ADDED
|
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Baseline inference script for the Clinical Note Scribe environment.
|
| 2 |
+
|
| 3 |
+
Runs all three tasks (easy → medium → hard) sequentially using an
|
| 4 |
+
OpenAI-compatible API to generate SOAP notes from doctor–patient transcripts.
|
| 5 |
+
|
| 6 |
+
Environment variables
|
| 7 |
+
---------------------
|
| 8 |
+
OPENAI_API_KEY – API key for the model provider
|
| 9 |
+
API_BASE_URL – Base URL for the OpenAI-compatible endpoint (default: https://api.openai.com/v1)
|
| 10 |
+
MODEL_NAME – Model identifier to use (default: gpt-4o-mini)
|
| 11 |
+
|
| 12 |
+
Usage::
|
| 13 |
+
|
| 14 |
+
python inference.py
|
| 15 |
+
|
| 16 |
+
Designed to complete in under 20 minutes on 2 vCPU / 8 GB RAM.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import json
|
| 22 |
+
import logging
|
| 23 |
+
import os
|
| 24 |
+
import sys
|
| 25 |
+
import time
|
| 26 |
+
from typing import Any
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Bootstrap logging BEFORE importing environment modules so the root logger
|
| 30 |
+
# is configured and child loggers (clinical_note_scribe.*) propagate cleanly.
|
| 31 |
+
# ---------------------------------------------------------------------------
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(
|
| 34 |
+
level=logging.INFO,
|
| 35 |
+
format="%(message)s",
|
| 36 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
| 37 |
+
)
|
| 38 |
+
logger = logging.getLogger("inference")
|
| 39 |
+
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# Environment imports (after logging is configured)
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 45 |
+
|
| 46 |
+
from environment import ClinicalNoteScribeEnv, Action, SOAPNote # noqa: E402
|
| 47 |
+
from environment.tasks import TASK_REGISTRY # noqa: E402
|
| 48 |
+
|
| 49 |
+
# ---------------------------------------------------------------------------
|
| 50 |
+
# Config
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
|
| 53 |
+
API_KEY = os.environ.get("OPENAI_API_KEY", "")
|
| 54 |
+
API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 55 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 56 |
+
|
| 57 |
+
TASK_IDS = list(TASK_REGISTRY.keys()) # deterministic order
|
| 58 |
+
|
| 59 |
+
# Maximum tokens for the model response — keeps latency low
|
| 60 |
+
MAX_TOKENS = 1024
|
| 61 |
+
|
| 62 |
+
# ---------------------------------------------------------------------------
|
| 63 |
+
# System prompt
|
| 64 |
+
# ---------------------------------------------------------------------------
|
| 65 |
+
|
| 66 |
+
SYSTEM_PROMPT = """\
|
| 67 |
+
You are a clinical documentation assistant. Given a doctor–patient transcript \
|
| 68 |
+
and patient context, generate a concise, clinically accurate SOAP note.
|
| 69 |
+
|
| 70 |
+
RULES:
|
| 71 |
+
1. Use professional medical language. Avoid over-certain phrasing such as \
|
| 72 |
+
"patient definitely has", "diagnosis is certain", or "100% certain".
|
| 73 |
+
2. Keep the note concise — aim for under 400 words total across all four sections.
|
| 74 |
+
3. Return your output as a **single valid JSON object** matching this schema exactly:
|
| 75 |
+
|
| 76 |
+
{
|
| 77 |
+
"action_type": "submit_note",
|
| 78 |
+
"soap_note": {
|
| 79 |
+
"subjective": "<patient's reported symptoms, history, and concerns>",
|
| 80 |
+
"objective": "<exam findings, vitals, lab results, imaging>",
|
| 81 |
+
"assessment": "<differential diagnoses and clinical reasoning>",
|
| 82 |
+
"plan": "<treatment plan, medications, follow-up, referrals>"
|
| 83 |
+
}
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
Return ONLY the JSON object. No markdown fences, no commentary, no extra keys.
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
# ---------------------------------------------------------------------------
|
| 90 |
+
# Helpers
|
| 91 |
+
# ---------------------------------------------------------------------------
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _build_user_prompt(transcript: str, patient_context: dict[str, Any]) -> str:
|
| 95 |
+
"""Build the user message containing the transcript and context."""
|
| 96 |
+
ctx_str = json.dumps(patient_context, indent=2, default=str)
|
| 97 |
+
return (
|
| 98 |
+
f"## Patient Context\n```json\n{ctx_str}\n```\n\n"
|
| 99 |
+
f"## Doctor–Patient Transcript\n```\n{transcript}\n```\n\n"
|
| 100 |
+
"Generate the SOAP note as a JSON Action object."
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def _call_model(user_prompt: str) -> dict[str, Any]:
|
| 105 |
+
"""Call the OpenAI-compatible API and return the parsed JSON action dict.
|
| 106 |
+
|
| 107 |
+
Uses ``urllib`` so there is zero dependency on ``openai`` package —
|
| 108 |
+
this keeps the Docker image small and avoids version conflicts.
|
| 109 |
+
Falls back to the ``openai`` package if installed.
|
| 110 |
+
"""
|
| 111 |
+
try:
|
| 112 |
+
return _call_model_sdk(user_prompt)
|
| 113 |
+
except ImportError:
|
| 114 |
+
return _call_model_urllib(user_prompt)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _call_model_sdk(user_prompt: str) -> dict[str, Any]:
|
| 118 |
+
"""Call via the ``openai`` Python SDK."""
|
| 119 |
+
from openai import OpenAI # noqa: F811
|
| 120 |
+
|
| 121 |
+
client = OpenAI(
|
| 122 |
+
api_key=API_KEY,
|
| 123 |
+
base_url=API_BASE_URL,
|
| 124 |
+
)
|
| 125 |
+
response = client.chat.completions.create(
|
| 126 |
+
model=MODEL_NAME,
|
| 127 |
+
messages=[
|
| 128 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 129 |
+
{"role": "user", "content": user_prompt},
|
| 130 |
+
],
|
| 131 |
+
max_tokens=MAX_TOKENS,
|
| 132 |
+
temperature=0.2,
|
| 133 |
+
)
|
| 134 |
+
raw = response.choices[0].message.content.strip()
|
| 135 |
+
return _parse_json(raw)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def _call_model_urllib(user_prompt: str) -> dict[str, Any]:
|
| 139 |
+
"""Fallback: call the API with ``urllib`` (no extra dependencies)."""
|
| 140 |
+
import urllib.request
|
| 141 |
+
|
| 142 |
+
url = f"{API_BASE_URL.rstrip('/')}/chat/completions"
|
| 143 |
+
payload = json.dumps({
|
| 144 |
+
"model": MODEL_NAME,
|
| 145 |
+
"messages": [
|
| 146 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 147 |
+
{"role": "user", "content": user_prompt},
|
| 148 |
+
],
|
| 149 |
+
"max_tokens": MAX_TOKENS,
|
| 150 |
+
"temperature": 0.2,
|
| 151 |
+
}).encode()
|
| 152 |
+
|
| 153 |
+
req = urllib.request.Request(
|
| 154 |
+
url,
|
| 155 |
+
data=payload,
|
| 156 |
+
headers={
|
| 157 |
+
"Content-Type": "application/json",
|
| 158 |
+
"Authorization": f"Bearer {API_KEY}",
|
| 159 |
+
},
|
| 160 |
+
)
|
| 161 |
+
with urllib.request.urlopen(req, timeout=120) as resp:
|
| 162 |
+
body = json.loads(resp.read())
|
| 163 |
+
|
| 164 |
+
raw = body["choices"][0]["message"]["content"].strip()
|
| 165 |
+
return _parse_json(raw)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def _parse_json(raw: str) -> dict[str, Any]:
|
| 169 |
+
"""Parse the model's raw text output into a dict, tolerating markdown fences."""
|
| 170 |
+
# Strip markdown code fences if present
|
| 171 |
+
cleaned = raw
|
| 172 |
+
if cleaned.startswith("```"):
|
| 173 |
+
# remove opening fence (possibly ```json)
|
| 174 |
+
first_newline = cleaned.index("\n")
|
| 175 |
+
cleaned = cleaned[first_newline + 1:]
|
| 176 |
+
if cleaned.endswith("```"):
|
| 177 |
+
cleaned = cleaned[: -3]
|
| 178 |
+
cleaned = cleaned.strip()
|
| 179 |
+
|
| 180 |
+
try:
|
| 181 |
+
return json.loads(cleaned)
|
| 182 |
+
except json.JSONDecodeError as exc:
|
| 183 |
+
logger.error("Failed to parse model output as JSON: %s", exc)
|
| 184 |
+
logger.error("Raw output:\n%s", raw)
|
| 185 |
+
raise
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def _log_event(event: str, **kwargs: Any) -> None:
|
| 189 |
+
"""Emit a structured JSON log line."""
|
| 190 |
+
payload: dict[str, Any] = {"event": event, "timestamp": time.time()}
|
| 191 |
+
payload.update(kwargs)
|
| 192 |
+
logger.info(json.dumps(payload, default=str))
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# ---------------------------------------------------------------------------
|
| 196 |
+
# Main loop
|
| 197 |
+
# ---------------------------------------------------------------------------
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def run_all_tasks() -> list[dict[str, Any]]:
|
| 201 |
+
"""Run every registered task and return a list of result dicts."""
|
| 202 |
+
env = ClinicalNoteScribeEnv()
|
| 203 |
+
results: list[dict[str, Any]] = []
|
| 204 |
+
|
| 205 |
+
for task_id in TASK_IDS:
|
| 206 |
+
logger.info("")
|
| 207 |
+
logger.info("=" * 60)
|
| 208 |
+
logger.info(" TASK: %s", task_id)
|
| 209 |
+
logger.info("=" * 60)
|
| 210 |
+
|
| 211 |
+
t0 = time.time()
|
| 212 |
+
_log_event("INFERENCE_START", task_id=task_id)
|
| 213 |
+
|
| 214 |
+
# ---- reset ----
|
| 215 |
+
obs = env.reset(task_id)
|
| 216 |
+
logger.info(" Transcript length : %d chars", len(obs.transcript))
|
| 217 |
+
logger.info(" Patient context keys: %s", list(obs.patient_context.keys()))
|
| 218 |
+
|
| 219 |
+
# ---- generate SOAP note via LLM ----
|
| 220 |
+
user_prompt = _build_user_prompt(obs.transcript, obs.patient_context)
|
| 221 |
+
logger.info(" Calling model (%s) ...", MODEL_NAME)
|
| 222 |
+
|
| 223 |
+
try:
|
| 224 |
+
action_dict = _call_model(user_prompt)
|
| 225 |
+
except Exception as exc:
|
| 226 |
+
logger.error(" Model call failed: %s", exc)
|
| 227 |
+
results.append({
|
| 228 |
+
"task_id": task_id,
|
| 229 |
+
"score": 0.0,
|
| 230 |
+
"error": str(exc),
|
| 231 |
+
"elapsed_s": round(time.time() - t0, 2),
|
| 232 |
+
})
|
| 233 |
+
_log_event("INFERENCE_ERROR", task_id=task_id, error=str(exc))
|
| 234 |
+
continue
|
| 235 |
+
|
| 236 |
+
# ---- validate and create Action ----
|
| 237 |
+
try:
|
| 238 |
+
action = Action(**action_dict)
|
| 239 |
+
except Exception as exc:
|
| 240 |
+
logger.error(" Invalid action schema: %s", exc)
|
| 241 |
+
logger.error(" Model returned: %s", json.dumps(action_dict, indent=2))
|
| 242 |
+
results.append({
|
| 243 |
+
"task_id": task_id,
|
| 244 |
+
"score": 0.0,
|
| 245 |
+
"error": f"schema_error: {exc}",
|
| 246 |
+
"elapsed_s": round(time.time() - t0, 2),
|
| 247 |
+
})
|
| 248 |
+
_log_event("INFERENCE_ERROR", task_id=task_id, error=str(exc))
|
| 249 |
+
continue
|
| 250 |
+
|
| 251 |
+
# ---- step (submit) ----
|
| 252 |
+
obs2, reward, done, info = env.step(action)
|
| 253 |
+
elapsed = round(time.time() - t0, 2)
|
| 254 |
+
|
| 255 |
+
logger.info(" Done: %s | Reward: %.4f | Elapsed: %.1fs", done, reward.value, elapsed)
|
| 256 |
+
logger.info(" Signals: %s",
|
| 257 |
+
{k: v for k, v in reward.signals.items() if not k.startswith("_")})
|
| 258 |
+
|
| 259 |
+
_log_event("INFERENCE_END", task_id=task_id, score=reward.value, elapsed_s=elapsed)
|
| 260 |
+
|
| 261 |
+
results.append({
|
| 262 |
+
"task_id": task_id,
|
| 263 |
+
"score": reward.value,
|
| 264 |
+
"elapsed_s": elapsed,
|
| 265 |
+
})
|
| 266 |
+
|
| 267 |
+
return results
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def _print_summary(results: list[dict[str, Any]]) -> None:
|
| 271 |
+
"""Print a formatted summary table."""
|
| 272 |
+
logger.info("")
|
| 273 |
+
logger.info("=" * 60)
|
| 274 |
+
logger.info(" SUMMARY")
|
| 275 |
+
logger.info("=" * 60)
|
| 276 |
+
|
| 277 |
+
col_task = max(len("Task"), *(len(r["task_id"]) for r in results))
|
| 278 |
+
col_score = 7 # "Score" + padding
|
| 279 |
+
col_time = 9 # "Time (s)"
|
| 280 |
+
|
| 281 |
+
header = f" {'Task':<{col_task}} {'Score':>{col_score}} {'Time (s)':>{col_time}}"
|
| 282 |
+
sep = f" {'-' * col_task} {'-' * col_score} {'-' * col_time}"
|
| 283 |
+
logger.info(header)
|
| 284 |
+
logger.info(sep)
|
| 285 |
+
|
| 286 |
+
total_score = 0.0
|
| 287 |
+
for r in results:
|
| 288 |
+
score_str = f"{r['score']:.4f}" if "error" not in r else "ERROR"
|
| 289 |
+
time_str = f"{r['elapsed_s']:.1f}"
|
| 290 |
+
logger.info(f" {r['task_id']:<{col_task}} {score_str:>{col_score}} {time_str:>{col_time}}")
|
| 291 |
+
total_score += r["score"]
|
| 292 |
+
|
| 293 |
+
logger.info(sep)
|
| 294 |
+
avg = total_score / len(results) if results else 0.0
|
| 295 |
+
logger.info(f" {'AVERAGE':<{col_task}} {avg:>{col_score}.4f}")
|
| 296 |
+
logger.info("")
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
# ---------------------------------------------------------------------------
|
| 300 |
+
# Entry point
|
| 301 |
+
# ---------------------------------------------------------------------------
|
| 302 |
+
|
| 303 |
+
if __name__ == "__main__":
|
| 304 |
+
if not API_KEY:
|
| 305 |
+
logger.warning(
|
| 306 |
+
"OPENAI_API_KEY is not set. The model calls will fail unless "
|
| 307 |
+
"the API endpoint does not require authentication."
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
logger.info("Clinical Note Scribe — Baseline Inference")
|
| 311 |
+
logger.info(" Model : %s", MODEL_NAME)
|
| 312 |
+
logger.info(" API Base : %s", API_BASE_URL)
|
| 313 |
+
logger.info(" Tasks : %s", TASK_IDS)
|
| 314 |
+
logger.info("")
|
| 315 |
+
|
| 316 |
+
start = time.time()
|
| 317 |
+
results = run_all_tasks()
|
| 318 |
+
total_elapsed = round(time.time() - start, 2)
|
| 319 |
+
|
| 320 |
+
_print_summary(results)
|
| 321 |
+
logger.info(" Total wall-clock time: %.1fs", total_elapsed)
|
| 322 |
+
|
| 323 |
+
_log_event("INFERENCE_COMPLETE", total_elapsed_s=total_elapsed,
|
| 324 |
+
scores={r["task_id"]: r["score"] for r in results})
|
openenv.yaml
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OpenEnv Environment Specification
|
| 2 |
+
# Clinical Note Scribe — Meta × Hugging Face OpenEnv Hackathon
|
| 3 |
+
|
| 4 |
+
name: meta-huggingface-hackathon-team-silver-orca
|
| 5 |
+
version: 1.0.0
|
| 6 |
+
description: >
|
| 7 |
+
An OpenEnv-compliant environment for evaluating AI agents on clinical
|
| 8 |
+
SOAP-note generation from doctor–patient transcripts. Agents receive a
|
| 9 |
+
transcript and patient context, then must produce a well-structured,
|
| 10 |
+
clinically accurate SOAP note through submit, revise, or clarify actions.
|
| 11 |
+
|
| 12 |
+
# ---------------------------------------------------------------------------
|
| 13 |
+
# Tasks
|
| 14 |
+
# ---------------------------------------------------------------------------
|
| 15 |
+
tasks:
|
| 16 |
+
- id: easy_routine_checkup
|
| 17 |
+
description: >
|
| 18 |
+
Generate a SOAP note for a routine annual check-up visit.
|
| 19 |
+
6-turn dialogue covering a simple upper respiratory infection
|
| 20 |
+
and a blood pressure screening.
|
| 21 |
+
difficulty: easy
|
| 22 |
+
max_steps: 5
|
| 23 |
+
grader: grade_easy
|
| 24 |
+
|
| 25 |
+
- id: medium_chronic_disease_followup
|
| 26 |
+
description: >
|
| 27 |
+
Generate a SOAP note for a Type 2 Diabetes and Hypertension
|
| 28 |
+
follow-up visit. 14-turn dialogue including medication adjustments
|
| 29 |
+
(glipizide addition, lisinopril uptitration), HbA1c lab review,
|
| 30 |
+
and dietary counseling.
|
| 31 |
+
difficulty: medium
|
| 32 |
+
max_steps: 8
|
| 33 |
+
grader: grade_medium
|
| 34 |
+
|
| 35 |
+
- id: hard_complex_er_visit
|
| 36 |
+
description: >
|
| 37 |
+
Generate a SOAP note for a complex emergency-room visit with
|
| 38 |
+
overlapping chest pain, shortness of breath, and a possible
|
| 39 |
+
pulmonary embolism. 20-turn dialogue with differential diagnoses,
|
| 40 |
+
urgent orders, a patient self-contradiction, and contrast-allergy
|
| 41 |
+
complications.
|
| 42 |
+
difficulty: hard
|
| 43 |
+
max_steps: 10
|
| 44 |
+
grader: grade_hard
|
| 45 |
+
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
# API Endpoints
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
api:
|
| 50 |
+
base_url: http://localhost:7860
|
| 51 |
+
endpoints:
|
| 52 |
+
reset:
|
| 53 |
+
method: POST
|
| 54 |
+
path: /reset
|
| 55 |
+
request_schema: ResetRequest
|
| 56 |
+
response_schema: Observation
|
| 57 |
+
description: Start a new episode for the specified task.
|
| 58 |
+
|
| 59 |
+
step:
|
| 60 |
+
method: POST
|
| 61 |
+
path: /step
|
| 62 |
+
request_schema: Action
|
| 63 |
+
response_schema: StepResponse
|
| 64 |
+
description: Submit an action and advance the environment by one step.
|
| 65 |
+
|
| 66 |
+
state:
|
| 67 |
+
method: GET
|
| 68 |
+
path: /state
|
| 69 |
+
response_schema: EnvironmentState
|
| 70 |
+
description: Return the full internal environment state without mutation.
|
| 71 |
+
|
| 72 |
+
health:
|
| 73 |
+
method: GET
|
| 74 |
+
path: /health
|
| 75 |
+
response_schema: HealthResponse
|
| 76 |
+
description: Liveness probe; returns {"status":"ok"}.
|
| 77 |
+
|
| 78 |
+
# ---------------------------------------------------------------------------
|
| 79 |
+
# Schemas (Pydantic v2 models in environment/models.py)
|
| 80 |
+
# ---------------------------------------------------------------------------
|
| 81 |
+
schemas:
|
| 82 |
+
Observation:
|
| 83 |
+
fields:
|
| 84 |
+
- name: transcript
|
| 85 |
+
type: str
|
| 86 |
+
description: Full doctor–patient transcript for the current task.
|
| 87 |
+
- name: task_id
|
| 88 |
+
type: str
|
| 89 |
+
description: Unique identifier for the task.
|
| 90 |
+
- name: patient_context
|
| 91 |
+
type: "dict[str, Any]"
|
| 92 |
+
description: Structured patient demographics and history.
|
| 93 |
+
- name: current_draft
|
| 94 |
+
type: "Optional[str]"
|
| 95 |
+
description: The agent's most recent SOAP-note draft, if any.
|
| 96 |
+
- name: errors_so_far
|
| 97 |
+
type: "list[str]"
|
| 98 |
+
description: Accumulated error/feedback messages from prior steps.
|
| 99 |
+
- name: step_count
|
| 100 |
+
type: int
|
| 101 |
+
description: Number of steps taken in the current episode.
|
| 102 |
+
|
| 103 |
+
Action:
|
| 104 |
+
fields:
|
| 105 |
+
- name: action_type
|
| 106 |
+
type: "Literal['submit_note','request_clarify','revise_section']"
|
| 107 |
+
description: The kind of action the agent is taking.
|
| 108 |
+
- name: soap_note
|
| 109 |
+
type: "Optional[SOAPNote]"
|
| 110 |
+
description: Complete SOAP note (required for submit_note).
|
| 111 |
+
- name: section
|
| 112 |
+
type: "Optional[Literal['S','O','A','P']]"
|
| 113 |
+
description: Section to revise (required for revise_section).
|
| 114 |
+
- name: revision_text
|
| 115 |
+
type: "Optional[str]"
|
| 116 |
+
description: Replacement text for the section.
|
| 117 |
+
- name: clarify_question
|
| 118 |
+
type: "Optional[str]"
|
| 119 |
+
description: Clarification question to ask.
|
| 120 |
+
|
| 121 |
+
SOAPNote:
|
| 122 |
+
fields:
|
| 123 |
+
- name: subjective
|
| 124 |
+
type: str
|
| 125 |
+
- name: objective
|
| 126 |
+
type: str
|
| 127 |
+
- name: assessment
|
| 128 |
+
type: str
|
| 129 |
+
- name: plan
|
| 130 |
+
type: str
|
| 131 |
+
|
| 132 |
+
Reward:
|
| 133 |
+
fields:
|
| 134 |
+
- name: value
|
| 135 |
+
type: float
|
| 136 |
+
description: "Aggregate reward in [0.0, 1.0]."
|
| 137 |
+
- name: signals
|
| 138 |
+
type: "dict[str, float]"
|
| 139 |
+
description: Breakdown of individual reward sub-signals.
|
| 140 |
+
- name: done
|
| 141 |
+
type: bool
|
| 142 |
+
description: Whether the episode has ended.
|
| 143 |
+
- name: info
|
| 144 |
+
type: "dict[str, Any]"
|
| 145 |
+
description: Auxiliary metadata.
|
| 146 |
+
|
| 147 |
+
EnvironmentState:
|
| 148 |
+
fields:
|
| 149 |
+
- name: task_id
|
| 150 |
+
type: str
|
| 151 |
+
- name: step_count
|
| 152 |
+
type: int
|
| 153 |
+
- name: max_steps
|
| 154 |
+
type: int
|
| 155 |
+
- name: done
|
| 156 |
+
type: bool
|
| 157 |
+
- name: current_draft
|
| 158 |
+
type: "Optional[str]"
|
| 159 |
+
- name: errors_so_far
|
| 160 |
+
type: "list[str]"
|
| 161 |
+
- name: last_reward
|
| 162 |
+
type: "Optional[Reward]"
|
| 163 |
+
- name: observation
|
| 164 |
+
type: "Optional[Observation]"
|
| 165 |
+
|
| 166 |
+
# ---------------------------------------------------------------------------
|
| 167 |
+
# Reward function
|
| 168 |
+
# ---------------------------------------------------------------------------
|
| 169 |
+
reward:
|
| 170 |
+
range: [0.0, 1.0]
|
| 171 |
+
formula: >
|
| 172 |
+
weighted_sum = grader_score × 0.60
|
| 173 |
+
+ conciseness_bonus × 0.10
|
| 174 |
+
+ safe_language_score × 0.15
|
| 175 |
+
+ format_valid × 0.15
|
| 176 |
+
deductions = 0.05 × max(0, step_count - 3)
|
| 177 |
+
+ 0.10 × len(errors_so_far)
|
| 178 |
+
value = clamp(weighted_sum - deductions, 0.0, 1.0)
|
| 179 |
+
signals:
|
| 180 |
+
- grader_score: "Clinical accuracy from task-specific grader (0–1)"
|
| 181 |
+
- conciseness_bonus: "1.0 if SOAP note ≤ 400 words, else 0.0"
|
| 182 |
+
- safe_language_score: "1.0 if no unsafe-certainty phraseology, else 0.0"
|
| 183 |
+
- format_valid: "1.0 if all SOAP fields are non-empty strings"
|
| 184 |
+
- step_penalty: "−0.05 per step beyond 3"
|
| 185 |
+
- error_penalty: "−0.10 per invalid action error"
|
| 186 |
+
|
| 187 |
+
# ---------------------------------------------------------------------------
|
| 188 |
+
# Graders
|
| 189 |
+
# ---------------------------------------------------------------------------
|
| 190 |
+
graders:
|
| 191 |
+
- name: grade_easy
|
| 192 |
+
file: environment/tasks/task_easy.py
|
| 193 |
+
function: grade_easy
|
| 194 |
+
|
| 195 |
+
- name: grade_medium
|
| 196 |
+
file: environment/tasks/task_medium.py
|
| 197 |
+
function: grade_medium
|
| 198 |
+
|
| 199 |
+
- name: grade_hard
|
| 200 |
+
file: environment/tasks/task_hard.py
|
| 201 |
+
function: grade_hard
|
| 202 |
+
|
| 203 |
+
# ---------------------------------------------------------------------------
|
| 204 |
+
# Inference
|
| 205 |
+
# ---------------------------------------------------------------------------
|
| 206 |
+
inference:
|
| 207 |
+
script: inference.py
|
| 208 |
+
env_vars:
|
| 209 |
+
- OPENAI_API_KEY
|
| 210 |
+
- API_BASE_URL
|
| 211 |
+
- MODEL_NAME
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi>=0.110.0
|
| 2 |
+
uvicorn>=0.29.0
|
| 3 |
+
pydantic>=2.7.0
|
| 4 |
+
openai>=1.30.0
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Server package for the Clinical Note Scribe environment."""
|
server/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (221 Bytes). View file
|
|
|
server/__pycache__/app.cpython-314.pyc
ADDED
|
Binary file (1.26 kB). View file
|
|
|
server/__pycache__/routes.cpython-314.pyc
ADDED
|
Binary file (6.37 kB). View file
|
|
|
server/app.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI application for the Clinical Note Scribe environment.
|
| 2 |
+
|
| 3 |
+
Run locally::
|
| 4 |
+
|
| 5 |
+
uvicorn server.app:app --host 0.0.0.0 --port 8000 --reload
|
| 6 |
+
|
| 7 |
+
Or via Docker (see ``Dockerfile`` in project root).
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
import sys
|
| 14 |
+
|
| 15 |
+
from fastapi import FastAPI
|
| 16 |
+
|
| 17 |
+
from server.routes import router
|
| 18 |
+
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
# Configure root logging → structured JSON to stdout
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
logging.basicConfig(
|
| 24 |
+
level=logging.INFO,
|
| 25 |
+
format="%(message)s",
|
| 26 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Silence noisy uvicorn access logs so our structured events stay clean
|
| 30 |
+
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
| 31 |
+
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
# Application factory
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
|
| 36 |
+
app = FastAPI(
|
| 37 |
+
title="Clinical Note Scribe – OpenEnv",
|
| 38 |
+
description=(
|
| 39 |
+
"An OpenEnv-compliant environment for evaluating AI agents on "
|
| 40 |
+
"clinical SOAP-note generation from doctor–patient transcripts."
|
| 41 |
+
),
|
| 42 |
+
version="0.1.0",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Mount all routes at root (/)
|
| 46 |
+
app.include_router(router)
|
server/routes.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI route definitions for the Clinical Note Scribe environment.
|
| 2 |
+
|
| 3 |
+
Endpoints
|
| 4 |
+
---------
|
| 5 |
+
POST /reset – start a new episode (takes ``task_id``, returns ``Observation``)
|
| 6 |
+
POST /step – execute an action (takes ``Action``, returns step result)
|
| 7 |
+
GET /state – inspect env state (returns ``EnvironmentState``)
|
| 8 |
+
GET /health – liveness probe (returns ``{"status": "ok"}``)
|
| 9 |
+
|
| 10 |
+
Structured logging
|
| 11 |
+
------------------
|
| 12 |
+
The underlying ``ClinicalNoteScribeEnv`` already emits ``[START]``, ``[STEP]``,
|
| 13 |
+
and ``[END]`` JSON lines to stdout via Python's ``logging`` module. This router
|
| 14 |
+
adds a thin request-level log wrapper so every inbound HTTP call is also
|
| 15 |
+
traceable.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import logging
|
| 21 |
+
import json
|
| 22 |
+
import time
|
| 23 |
+
from typing import Any, Optional
|
| 24 |
+
|
| 25 |
+
from fastapi import APIRouter, HTTPException
|
| 26 |
+
from pydantic import BaseModel, Field
|
| 27 |
+
|
| 28 |
+
from environment.models import (
|
| 29 |
+
Action,
|
| 30 |
+
EnvironmentState,
|
| 31 |
+
Observation,
|
| 32 |
+
Reward,
|
| 33 |
+
)
|
| 34 |
+
from environment.env import ClinicalNoteScribeEnv
|
| 35 |
+
|
| 36 |
+
logger = logging.getLogger("clinical_note_scribe.server")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
# Singleton environment instance
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
_env = ClinicalNoteScribeEnv()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
# Request / response schemas
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
|
| 50 |
+
class ResetRequest(BaseModel):
|
| 51 |
+
task_id: Optional[str] = Field(
|
| 52 |
+
default=None,
|
| 53 |
+
description=(
|
| 54 |
+
"Task to load. One of: easy_routine_checkup, "
|
| 55 |
+
"medium_chronic_disease_followup, hard_complex_er_visit. "
|
| 56 |
+
"Defaults to the first registered task."
|
| 57 |
+
),
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class StepResponse(BaseModel):
|
| 62 |
+
observation: Observation
|
| 63 |
+
reward: Reward
|
| 64 |
+
done: bool
|
| 65 |
+
info: dict[str, Any] = Field(default_factory=dict)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class HealthResponse(BaseModel):
|
| 69 |
+
status: str = "ok"
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
# ---------------------------------------------------------------------------
|
| 73 |
+
# Helpers
|
| 74 |
+
# ---------------------------------------------------------------------------
|
| 75 |
+
|
| 76 |
+
def _log(event: str, **kwargs: Any) -> None:
|
| 77 |
+
"""Emit a structured JSON log line to stdout."""
|
| 78 |
+
payload: dict[str, Any] = {"event": event, "timestamp": time.time()}
|
| 79 |
+
payload.update(kwargs)
|
| 80 |
+
logger.info(json.dumps(payload, default=str))
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
# Router
|
| 85 |
+
# ---------------------------------------------------------------------------
|
| 86 |
+
|
| 87 |
+
router = APIRouter()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@router.post(
|
| 91 |
+
"/reset",
|
| 92 |
+
response_model=Observation,
|
| 93 |
+
summary="Reset the environment and start a new episode",
|
| 94 |
+
)
|
| 95 |
+
async def reset(body: ResetRequest) -> Observation:
|
| 96 |
+
"""Load a task and return the initial ``Observation``.
|
| 97 |
+
|
| 98 |
+
The underlying environment emits a ``[START]`` log event.
|
| 99 |
+
"""
|
| 100 |
+
_log("START", endpoint="/reset", task_id=body.task_id)
|
| 101 |
+
try:
|
| 102 |
+
obs = _env.reset(task_id=body.task_id)
|
| 103 |
+
except ValueError as exc:
|
| 104 |
+
raise HTTPException(status_code=400, detail=str(exc))
|
| 105 |
+
return obs
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@router.post(
|
| 109 |
+
"/step",
|
| 110 |
+
response_model=StepResponse,
|
| 111 |
+
summary="Submit an action and advance the environment by one step",
|
| 112 |
+
)
|
| 113 |
+
async def step(action: Action) -> StepResponse:
|
| 114 |
+
"""Execute *action* in the current episode.
|
| 115 |
+
|
| 116 |
+
The underlying environment emits a ``[STEP]`` log event (and ``[END]``
|
| 117 |
+
when the episode terminates).
|
| 118 |
+
"""
|
| 119 |
+
_log("STEP", endpoint="/step", action_type=action.action_type)
|
| 120 |
+
try:
|
| 121 |
+
obs, reward, done, info = _env.step(action)
|
| 122 |
+
except RuntimeError as exc:
|
| 123 |
+
# e.g. stepping after episode is done without reset
|
| 124 |
+
raise HTTPException(status_code=409, detail=str(exc))
|
| 125 |
+
|
| 126 |
+
if done:
|
| 127 |
+
_log("END", endpoint="/step", final_score=reward.value)
|
| 128 |
+
|
| 129 |
+
return StepResponse(
|
| 130 |
+
observation=obs,
|
| 131 |
+
reward=reward,
|
| 132 |
+
done=done,
|
| 133 |
+
info=info,
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@router.get(
|
| 138 |
+
"/state",
|
| 139 |
+
response_model=EnvironmentState,
|
| 140 |
+
summary="Return the full internal environment state",
|
| 141 |
+
)
|
| 142 |
+
async def state() -> EnvironmentState:
|
| 143 |
+
"""Inspect the environment without mutating it."""
|
| 144 |
+
return _env.state()
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@router.get(
|
| 148 |
+
"/health",
|
| 149 |
+
response_model=HealthResponse,
|
| 150 |
+
summary="Liveness probe",
|
| 151 |
+
)
|
| 152 |
+
async def health() -> HealthResponse:
|
| 153 |
+
"""Returns HTTP 200 with ``{"status": "ok"}``."""
|
| 154 |
+
return HealthResponse()
|
test_inference.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.insert(0, ".")
|
| 3 |
+
from inference import SYSTEM_PROMPT, TASK_IDS, _parse_json, _build_user_prompt
|
| 4 |
+
from environment import Action
|
| 5 |
+
|
| 6 |
+
print("Imports OK")
|
| 7 |
+
print("Tasks:", TASK_IDS)
|
| 8 |
+
|
| 9 |
+
# Test JSON parsing
|
| 10 |
+
j = _parse_json('{"action_type": "submit_note", "soap_note": {"subjective": "S", "objective": "O", "assessment": "A", "plan": "P"}}')
|
| 11 |
+
print("Parse OK:", j["action_type"])
|
| 12 |
+
|
| 13 |
+
# Test markdown fence stripping
|
| 14 |
+
fenced = '```json\n{"action_type": "submit_note", "soap_note": {"subjective": "S", "objective": "O", "assessment": "A", "plan": "P"}}\n```'
|
| 15 |
+
j2 = _parse_json(fenced)
|
| 16 |
+
print("Fence strip OK:", j2["action_type"])
|
| 17 |
+
|
| 18 |
+
# Test Action creation from parsed output
|
| 19 |
+
action = Action(**j2)
|
| 20 |
+
print("Action created:", action.action_type, "/ sections:", list(action.soap_note.model_fields.keys()))
|
| 21 |
+
|
| 22 |
+
# Test prompt building
|
| 23 |
+
p = _build_user_prompt("Hello doctor", {"name": "Test", "age": 30})
|
| 24 |
+
print("Prompt len:", len(p), "chars")
|
| 25 |
+
|
| 26 |
+
print("\nAll checks passed.")
|
test_output.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
--- Sub-signal unit tests ---
|
| 3 |
+
[OK] conciseness(short): got=1.0 want=1.0
|
| 4 |
+
[OK] conciseness(long) : got=0.0 want=0.0
|
| 5 |
+
[OK] safe_lang(clean) : got=1.0 want=1.0
|
| 6 |
+
[OK] safe_lang(unsafe) : got=0.0 want=0.0
|
| 7 |
+
[OK] format_valid(ok) : got=1.0 want=1.0
|
| 8 |
+
[OK] format_valid(bad) : got=0.0 want=0.0
|
| 9 |
+
[OK] format_valid(clfy): got=1.0 want=1.0
|
test_reward.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
sys.path.insert(0, ".")
|
| 3 |
+
|
| 4 |
+
from environment import ClinicalNoteScribeEnv, Action, SOAPNote
|
| 5 |
+
from environment.reward import (
|
| 6 |
+
compute_reward, _conciseness_bonus, _safe_language_score, _format_valid,
|
| 7 |
+
WORD_LIMIT, FREE_STEPS, STEP_PENALTY_RATE, ERROR_PENALTY_RATE,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
def check(label, got, want):
|
| 11 |
+
ok = abs(got - want) < 1e-6
|
| 12 |
+
sym = "OK" if ok else "FAIL"
|
| 13 |
+
print(f" [{sym}] {label}: got={got} want={want}")
|
| 14 |
+
return ok
|
| 15 |
+
|
| 16 |
+
short_note = SOAPNote(
|
| 17 |
+
subjective="Headache and runny nose for 5 days.",
|
| 18 |
+
objective="BP 118/76, HR 72, afebrile, clear lungs.",
|
| 19 |
+
assessment="Viral URI.",
|
| 20 |
+
plan="DayQuil, fluids, rest. Follow up if fever develops.",
|
| 21 |
+
)
|
| 22 |
+
long_note = SOAPNote(subjective=" ".join(["word"] * (WORD_LIMIT + 1)), objective="O", assessment="A", plan="P")
|
| 23 |
+
unsafe_note = SOAPNote(subjective="Patient definitely has pneumonia.", objective="O", assessment="A", plan="P")
|
| 24 |
+
empty_note = SOAPNote(subjective="", objective="O", assessment="A", plan="P")
|
| 25 |
+
|
| 26 |
+
submit_ok = Action(action_type="submit_note", soap_note=short_note)
|
| 27 |
+
submit_bad = Action(action_type="submit_note", soap_note=empty_note)
|
| 28 |
+
clarify = Action(action_type="request_clarify", clarify_question="fever?")
|
| 29 |
+
|
| 30 |
+
print("\n--- Sub-signal unit tests ---")
|
| 31 |
+
check("conciseness(short)", _conciseness_bonus(short_note), 1.0)
|
| 32 |
+
check("conciseness(long) ", _conciseness_bonus(long_note), 0.0)
|
| 33 |
+
check("safe_lang(clean) ", _safe_language_score(short_note), 1.0)
|
| 34 |
+
check("safe_lang(unsafe) ", _safe_language_score(unsafe_note), 0.0)
|
| 35 |
+
check("format_valid(ok) ", _format_valid(submit_ok), 1.0)
|
| 36 |
+
check("format_valid(bad) ", _format_valid(submit_bad), 0.0)
|
| 37 |
+
check("format_valid(clfy)", _format_valid(clarify), 1.0)
|
| 38 |
+
|
| 39 |
+
print("\n--- grader=1.0, steps=2, errors=0 → expect value=1.0 ---")
|
| 40 |
+
r = compute_reward(submit_ok, grader_score=1.0, step_count=2, errors_so_far=[])
|
| 41 |
+
check("value ", r.value, 1.0)
|
| 42 |
+
check("grader_score wt ", r.signals["grader_score"], 0.60)
|
| 43 |
+
check("conciseness wt ", r.signals["conciseness_bonus"], 0.10)
|
| 44 |
+
check("safe_lang wt ", r.signals["safe_language_score"], 0.15)
|
| 45 |
+
check("format_valid wt ", r.signals["format_valid"], 0.15)
|
| 46 |
+
check("step_penalty ", r.signals["step_penalty"], 0.0)
|
| 47 |
+
check("error_penalty ", r.signals["error_penalty"], 0.0)
|
| 48 |
+
|
| 49 |
+
print("\n--- grader=1.0, steps=5 (+2 extra) → expect deduct 0.10 ---")
|
| 50 |
+
r2 = compute_reward(submit_ok, grader_score=1.0, step_count=5, errors_so_far=[])
|
| 51 |
+
check("step_penalty ", r2.signals["step_penalty"], -(2 * STEP_PENALTY_RATE))
|
| 52 |
+
check("value ", r2.value, round(1.0 - 2 * STEP_PENALTY_RATE, 4))
|
| 53 |
+
|
| 54 |
+
print("\n--- grader=1.0, steps=2, errors=2 → expect deduct 0.20 ---")
|
| 55 |
+
r3 = compute_reward(submit_ok, grader_score=1.0, step_count=2, errors_so_far=["e1", "e2"])
|
| 56 |
+
check("error_penalty ", r3.signals["error_penalty"], -(2 * ERROR_PENALTY_RATE))
|
| 57 |
+
check("value ", r3.value, round(1.0 - 2 * ERROR_PENALTY_RATE, 4))
|
| 58 |
+
|
| 59 |
+
print("\n--- all bad signals → expect value clamped to 0.0 ---")
|
| 60 |
+
bad_note = SOAPNote(subjective=" ".join(["word"] * 500) + " Patient definitely has cancer.", objective="", assessment="A", plan="P")
|
| 61 |
+
bad_act = Action(action_type="submit_note", soap_note=bad_note)
|
| 62 |
+
r4 = compute_reward(bad_act, grader_score=0.0, step_count=10, errors_so_far=["e1","e2","e3"])
|
| 63 |
+
check("value clamped ", r4.value, 0.0)
|
| 64 |
+
|
| 65 |
+
print("\n--- end-to-end env: clarify(step1) then submit(step2) ---")
|
| 66 |
+
env = ClinicalNoteScribeEnv()
|
| 67 |
+
env.reset("easy_routine_checkup")
|
| 68 |
+
_, rc, dc, _ = env.step(Action(action_type="request_clarify", clarify_question="did the patient report any fever?"))
|
| 69 |
+
check("clarify done=False", float(dc), 0.0)
|
| 70 |
+
_, rs, ds, _ = env.step(submit_ok)
|
| 71 |
+
check("submit done=True ", float(ds), 1.0)
|
| 72 |
+
assert 0.0 <= rs.value <= 1.0
|
| 73 |
+
print(f" Final value: {rs.value}")
|
| 74 |
+
print(f" Signals: { {k:v for k,v in rs.signals.items() if not k.startswith('_')} }")
|
| 75 |
+
print("\nAll done.")
|