| """FastAPI routes for the Clinical Note Scribe environment.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import logging |
| import time |
| from typing import Any, Optional |
|
|
| from fastapi import APIRouter, HTTPException |
| from pydantic import BaseModel, Field, ValidationError |
|
|
| from environment.models import Action, EnvironmentState, Observation, Reward |
| from environment.env import ClinicalNoteScribeEnv |
|
|
| logger = logging.getLogger("clinical_note_scribe.server") |
| _env = ClinicalNoteScribeEnv() |
| router = APIRouter() |
|
|
|
|
| def _log(event: str, **kw: Any) -> None: |
| logger.info(json.dumps({"event": event, "timestamp": time.time(), **kw}, default=str)) |
|
|
|
|
| class ResetRequest(BaseModel): |
| task_id: Optional[str] = Field(None, description="Task to load. Defaults to first registered task.") |
|
|
|
|
| class StepResponse(BaseModel): |
| observation: Observation |
| reward: Reward |
| done: bool |
| info: dict[str, Any] = Field(default_factory=dict) |
|
|
|
|
| class HealthResponse(BaseModel): |
| status: str = "ok" |
|
|
|
|
| @router.post("/reset", response_model=Observation, summary="Reset and start a new episode") |
| async def reset(body: Optional[ResetRequest] = None) -> Observation: |
| task_id = body.task_id if body else None |
| _log("START", endpoint="/reset", task_id=task_id) |
| try: |
| return _env.reset(task_id=task_id) |
| except ValueError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) |
|
|
|
|
| @router.post("/step", response_model=StepResponse, summary="Submit an action") |
| async def step(payload: dict[str, Any]) -> StepResponse: |
| try: |
| action = Action(**payload) |
| except (ValidationError, TypeError) as exc: |
| _log("STEP", endpoint="/step", action_type="invalid", error=str(exc)) |
| error_msg = f"Invalid action payload: {exc}" |
| _env._errors_so_far.append(error_msg) |
| _env._step_count += 1 |
| return StepResponse( |
| observation=_env._obs(), |
| reward=Reward(value=0.0, signals={"error": 1.0}, done=False, info={"error": error_msg}), |
| done=False, info={"error": error_msg}, |
| ) |
|
|
| _log("STEP", endpoint="/step", action_type=action.action_type) |
| try: |
| obs, reward, done, info = _env.step(action) |
| except RuntimeError as exc: |
| raise HTTPException(status_code=409, detail=str(exc)) |
|
|
| if done: |
| _log("END", endpoint="/step", final_score=reward.value) |
| return StepResponse(observation=obs, reward=reward, done=done, info=info) |
|
|
|
|
| @router.get("/state", response_model=EnvironmentState, summary="Inspect environment state") |
| async def state() -> EnvironmentState: |
| return _env.state() |
|
|
|
|
| @router.get("/health", response_model=HealthResponse, summary="Liveness probe") |
| async def health() -> HealthResponse: |
| return HealthResponse() |
|
|