Aman Khare
Optimize codebase + add minimalist frontend
8b7bdb7
"""FastAPI routes for the Clinical Note Scribe environment."""
from __future__ import annotations
import json
import logging
import time
from typing import Any, Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field, ValidationError
from environment.models import Action, EnvironmentState, Observation, Reward
from environment.env import ClinicalNoteScribeEnv
logger = logging.getLogger("clinical_note_scribe.server")
_env = ClinicalNoteScribeEnv()
router = APIRouter()
def _log(event: str, **kw: Any) -> None:
logger.info(json.dumps({"event": event, "timestamp": time.time(), **kw}, default=str))
class ResetRequest(BaseModel):
task_id: Optional[str] = Field(None, description="Task to load. Defaults to first registered task.")
class StepResponse(BaseModel):
observation: Observation
reward: Reward
done: bool
info: dict[str, Any] = Field(default_factory=dict)
class HealthResponse(BaseModel):
status: str = "ok"
@router.post("/reset", response_model=Observation, summary="Reset and start a new episode")
async def reset(body: Optional[ResetRequest] = None) -> Observation:
task_id = body.task_id if body else None
_log("START", endpoint="/reset", task_id=task_id)
try:
return _env.reset(task_id=task_id)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
@router.post("/step", response_model=StepResponse, summary="Submit an action")
async def step(payload: dict[str, Any]) -> StepResponse:
try:
action = Action(**payload)
except (ValidationError, TypeError) as exc:
_log("STEP", endpoint="/step", action_type="invalid", error=str(exc))
error_msg = f"Invalid action payload: {exc}"
_env._errors_so_far.append(error_msg)
_env._step_count += 1
return StepResponse(
observation=_env._obs(),
reward=Reward(value=0.0, signals={"error": 1.0}, done=False, info={"error": error_msg}),
done=False, info={"error": error_msg},
)
_log("STEP", endpoint="/step", action_type=action.action_type)
try:
obs, reward, done, info = _env.step(action)
except RuntimeError as exc:
raise HTTPException(status_code=409, detail=str(exc))
if done:
_log("END", endpoint="/step", final_score=reward.value)
return StepResponse(observation=obs, reward=reward, done=done, info=info)
@router.get("/state", response_model=EnvironmentState, summary="Inspect environment state")
async def state() -> EnvironmentState:
return _env.state()
@router.get("/health", response_model=HealthResponse, summary="Liveness probe")
async def health() -> HealthResponse:
return HealthResponse()