Spaces:
Running
Running
| """OpenEnv-compatible wrapper around local env service. | |
| The wrapper intentionally exposes meaningful clinician-facing tool methods for | |
| LLM policy training instead of a single opaque ``step(action)`` interface. | |
| """ | |
| from __future__ import annotations | |
| from typing import Any, Literal | |
| from app.env.client import PolyGuardEnvClient | |
| try: | |
| from openenv import GenericEnvClient | |
| except Exception: # noqa: BLE001 | |
| GenericEnvClient = None # type: ignore[assignment] | |
| class LocalOpenEnvWrapper: | |
| def __init__(self, base_url: str = "http://127.0.0.1:8100") -> None: | |
| self.http_client = PolyGuardEnvClient(base_url=base_url) | |
| self.base_url = base_url | |
| self._sync_client: Any = None | |
| if GenericEnvClient is not None: | |
| try: | |
| self._sync_client = GenericEnvClient(base_url=base_url).sync() | |
| self._sync_client.connect() | |
| except Exception: # noqa: BLE001 | |
| self._sync_client = None | |
| def reset(self, **kwargs: Any) -> dict[str, Any]: | |
| if self._sync_client is not None: | |
| result = self._sync_client.reset(**kwargs) | |
| return { | |
| "observation": result.observation, | |
| "reward": result.reward, | |
| "done": result.done, | |
| } | |
| return self.http_client.reset(**kwargs) | |
| def step(self, action: dict[str, Any]) -> dict[str, Any]: | |
| if self._sync_client is not None: | |
| result = self._sync_client.step(action) | |
| return { | |
| "observation": result.observation, | |
| "reward": result.reward, | |
| "done": result.done, | |
| } | |
| return self.http_client.step(action) | |
| def state(self) -> dict[str, Any]: | |
| if self._sync_client is not None: | |
| return self._sync_client.state() | |
| return self.http_client.state() | |
| def trace(self) -> list[dict[str, Any]]: | |
| return self.http_client.trace() | |
| def legal_actions(self) -> list[dict[str, Any]]: | |
| return self.http_client.legal_actions() | |
| def reward_breakdown(self) -> dict[str, Any]: | |
| return self.http_client.reward_breakdown() | |
| def uncertainty(self) -> dict[str, Any]: | |
| return self.http_client.uncertainty() | |
| def inspect_regimen(self) -> dict[str, Any]: | |
| """Return a compact clinical snapshot of the active case.""" | |
| state = self.state() | |
| patient = state.get("patient", {}) | |
| risk_summary = state.get("risk_summary", {}) | |
| meds = patient.get("medications", []) | |
| return { | |
| "patient_id": patient.get("patient_id"), | |
| "age": patient.get("age"), | |
| "comorbidities": patient.get("comorbidities", []), | |
| "medication_count": len(meds), | |
| "medications": meds, | |
| "risk_summary": risk_summary, | |
| "burden_score": state.get("burden_score"), | |
| "step_count": state.get("step_count"), | |
| "max_steps": state.get("max_steps"), | |
| } | |
| def evaluate_candidate(self, candidate_id: str) -> dict[str, Any]: | |
| """Lookup a legal candidate action by candidate id.""" | |
| candidates = self.legal_actions() | |
| for candidate in candidates: | |
| if candidate.get("candidate_id") == candidate_id: | |
| return candidate | |
| return {"candidate_id": candidate_id, "found": False} | |
| def _execute_action( | |
| self, | |
| mode: str, | |
| action_type: str, | |
| target_drug: str | None = None, | |
| replacement_drug: str | None = None, | |
| dose_bucket: str = "NA", | |
| taper_days: int | None = None, | |
| monitoring_plan: str | None = None, | |
| candidate_id: str = "cand_manual", | |
| confidence: float = 0.65, | |
| rationale_brief: str = "tool_action", | |
| ) -> dict[str, Any]: | |
| payload = { | |
| "mode": mode, | |
| "action_type": action_type, | |
| "target_drug": target_drug, | |
| "replacement_drug": replacement_drug, | |
| "dose_bucket": dose_bucket, | |
| "taper_days": taper_days, | |
| "monitoring_plan": monitoring_plan, | |
| "candidate_id": candidate_id, | |
| "confidence": confidence, | |
| "rationale_brief": rationale_brief, | |
| } | |
| return self.step(payload) | |
| def stop_drug(self, target_drug: str, taper_days: int | None = None, candidate_id: str = "cand_stop_tool") -> dict[str, Any]: | |
| """Issue STOP_DRUG action for a single medication.""" | |
| return self._execute_action( | |
| mode="REGIMEN_OPT", | |
| action_type="STOP_DRUG", | |
| target_drug=target_drug, | |
| taper_days=taper_days, | |
| candidate_id=candidate_id, | |
| rationale_brief=f"stop_drug:{target_drug}", | |
| ) | |
| def substitute_drug( | |
| self, | |
| target_drug: str, | |
| replacement_drug: str, | |
| candidate_id: str = "cand_substitute_tool", | |
| ) -> dict[str, Any]: | |
| """Issue SUBSTITUTE_WITHIN_CLASS action.""" | |
| return self._execute_action( | |
| mode="REGIMEN_OPT", | |
| action_type="SUBSTITUTE_WITHIN_CLASS", | |
| target_drug=target_drug, | |
| replacement_drug=replacement_drug, | |
| candidate_id=candidate_id, | |
| rationale_brief=f"substitute:{target_drug}->{replacement_drug}", | |
| ) | |
| def start_taper(self, target_drug: str, taper_days: int = 14, candidate_id: str = "cand_taper_start_tool") -> dict[str, Any]: | |
| """Issue TAPER_INITIATE action.""" | |
| return self._execute_action( | |
| mode="REGIMEN_OPT", | |
| action_type="TAPER_INITIATE", | |
| target_drug=target_drug, | |
| taper_days=taper_days, | |
| candidate_id=candidate_id, | |
| rationale_brief=f"taper_start:{target_drug}", | |
| ) | |
| def continue_taper(self, target_drug: str, taper_days: int = 7, candidate_id: str = "cand_taper_continue_tool") -> dict[str, Any]: | |
| """Issue TAPER_CONTINUE action.""" | |
| return self._execute_action( | |
| mode="REGIMEN_OPT", | |
| action_type="TAPER_CONTINUE", | |
| target_drug=target_drug, | |
| taper_days=taper_days, | |
| candidate_id=candidate_id, | |
| rationale_brief=f"taper_continue:{target_drug}", | |
| ) | |
| def adjust_dose( | |
| self, | |
| target_drug: str, | |
| direction: Literal["increase", "reduce", "hold"], | |
| candidate_id: str = "cand_adjust_dose_tool", | |
| ) -> dict[str, Any]: | |
| """Adjust dose bucket with an explicit direction.""" | |
| if direction == "increase": | |
| action_type = "INCREASE_DOSE_BUCKET" | |
| dose_bucket = "HIGH" | |
| elif direction == "reduce": | |
| action_type = "REDUCE_DOSE_BUCKET" | |
| dose_bucket = "LOW" | |
| else: | |
| action_type = "DOSE_HOLD" | |
| dose_bucket = "HOLD" | |
| return self._execute_action( | |
| mode="DOSE_OPT", | |
| action_type=action_type, | |
| target_drug=target_drug, | |
| dose_bucket=dose_bucket, | |
| candidate_id=candidate_id, | |
| rationale_brief=f"adjust_dose:{direction}:{target_drug}", | |
| ) | |
| def request_review( | |
| self, | |
| review_type: Literal["pharmacist", "specialist"] = "specialist", | |
| candidate_id: str = "cand_review_tool", | |
| ) -> dict[str, Any]: | |
| """Request human review when uncertainty or legality concerns are high.""" | |
| action_type = "REQUEST_PHARMACIST_REVIEW" if review_type == "pharmacist" else "REQUEST_SPECIALIST_REVIEW" | |
| return self._execute_action( | |
| mode="ABSTAIN_REVIEW", | |
| action_type=action_type, | |
| candidate_id=candidate_id, | |
| rationale_brief=f"request_review:{review_type}", | |
| ) | |
| def finish_case(self, candidate_id: str = "cand_finish_tool") -> dict[str, Any]: | |
| """Close the episode with a conservative keep action.""" | |
| return self._execute_action( | |
| mode="REGIMEN_OPT", | |
| action_type="KEEP_REGIMEN", | |
| candidate_id=candidate_id, | |
| rationale_brief="finish_case", | |
| ) | |
| def close(self) -> None: | |
| if self._sync_client is not None: | |
| try: | |
| self._sync_client.close() | |
| except Exception: # noqa: BLE001 | |
| pass | |