| from __future__ import annotations |
|
|
| from contextlib import asynccontextmanager |
| import os |
| from typing import Any |
|
|
| import uvicorn |
| from fastapi import FastAPI, HTTPException |
| from fastapi.middleware.cors import CORSMiddleware |
|
|
| from trenches_env.env import FogOfWarDiplomacyEnv |
| from trenches_env.model_runtime import build_entity_model_bindings |
| from trenches_env.models import ( |
| BenchmarkRunRequest, |
| BenchmarkRunResponse, |
| CreateSessionRequest, |
| IngestNewsRequest, |
| IngestNewsResponse, |
| LiveControlRequest, |
| ProviderDiagnosticsResponse, |
| ReactionLogEntry, |
| ResetEnvRequest, |
| ResetEnvResponse, |
| ResetSessionRequest, |
| ScenarioSummary, |
| SessionState, |
| SourceMonitorReport, |
| StepEnvRequest, |
| StepEnvResponse, |
| StepSessionRequest, |
| StepSessionResponse, |
| ) |
| from trenches_env.openenv_adapter import ( |
| OPENENV_CORE_AVAILABLE, |
| OpenEnvAdapter, |
| TrenchesOpenEnvEnvironment, |
| create_openenv_fastapi_app, |
| ) |
| from trenches_env.session_manager import SessionManager |
| from trenches_env.source_ingestion import SourceHarvester |
|
|
| DEFAULT_LOCAL_DEV_CORS_ORIGIN_REGEX = r"https?://(localhost|127\.0\.0\.1)(:\d+)?$" |
|
|
|
|
| def _parse_csv_env(raw_value: str | None) -> list[str]: |
| if not raw_value: |
| return [] |
| return [item.strip() for item in raw_value.split(",") if item.strip()] |
|
|
|
|
| def _resolve_cors_settings() -> dict[str, Any]: |
| allow_origins = _parse_csv_env(os.getenv("TRENCHES_CORS_ALLOW_ORIGINS")) |
| allow_origin_regex = os.getenv("TRENCHES_CORS_ALLOW_ORIGIN_REGEX") |
|
|
| if "*" in allow_origins: |
| return { |
| "allow_origins": ["*"], |
| "allow_origin_regex": None, |
| |
| "allow_credentials": False, |
| "allow_methods": ["*"], |
| "allow_headers": ["*"], |
| } |
|
|
| if not allow_origins and not allow_origin_regex: |
| allow_origin_regex = DEFAULT_LOCAL_DEV_CORS_ORIGIN_REGEX |
|
|
| allow_credentials = os.getenv("TRENCHES_CORS_ALLOW_CREDENTIALS", "true").strip().lower() not in { |
| "0", |
| "false", |
| "no", |
| "off", |
| } |
| return { |
| "allow_origins": allow_origins, |
| "allow_origin_regex": allow_origin_regex, |
| "allow_credentials": allow_credentials, |
| "allow_methods": ["*"], |
| "allow_headers": ["*"], |
| } |
|
|
|
|
| def create_app(session_manager: SessionManager | None = None) -> FastAPI: |
| manager = session_manager or SessionManager( |
| env=FogOfWarDiplomacyEnv( |
| source_harvester=SourceHarvester(auto_start=True), |
| ).enable_source_warm_start() |
| ) |
|
|
| @asynccontextmanager |
| async def lifespan(_: FastAPI): |
| try: |
| manager.start_background_runner() |
| yield |
| finally: |
| manager.shutdown() |
|
|
| app = FastAPI(title="Trenches OpenEnv Backend", version="0.1.0", lifespan=lifespan) |
| app.add_middleware(CORSMiddleware, **_resolve_cors_settings()) |
| runtime = OpenEnvAdapter(session_manager=manager) |
| openenv_app = create_openenv_fastapi_app( |
| lambda: TrenchesOpenEnvEnvironment( |
| env=FogOfWarDiplomacyEnv( |
| source_harvester=SourceHarvester(auto_start=False), |
| ).enable_source_warm_start() |
| ) |
| ) |
| if openenv_app is not None: |
| app.mount("/openenv", openenv_app) |
|
|
| @app.get("/healthz") |
| async def healthz() -> dict[str, str]: |
| return {"status": "ok"} |
|
|
| @app.get("/capabilities") |
| async def capabilities() -> dict[str, Any]: |
| cors_settings = _resolve_cors_settings() |
| return { |
| "model_bindings": { |
| agent_id: binding.model_dump(mode="json") |
| for agent_id, binding in build_entity_model_bindings().items() |
| }, |
| "session_api": True, |
| "legacy_openenv_tuple_api": True, |
| "native_openenv_api": OPENENV_CORE_AVAILABLE, |
| "native_openenv_base_path": "/openenv" if OPENENV_CORE_AVAILABLE else None, |
| "cors": { |
| "allow_origins": cors_settings["allow_origins"], |
| "allow_origin_regex": cors_settings["allow_origin_regex"], |
| "allow_credentials": cors_settings["allow_credentials"], |
| }, |
| } |
|
|
| @app.post("/sessions", response_model=SessionState) |
| async def create_session(request: CreateSessionRequest) -> SessionState: |
| return manager.create_session( |
| seed=request.seed, |
| training_agent=request.training_agent, |
| training_stage=request.training_stage, |
| max_turns=request.max_turns, |
| scenario_id=request.scenario_id, |
| replay_id=request.replay_id, |
| replay_start_index=request.replay_start_index, |
| ) |
|
|
| @app.post("/sessions/{session_id}/reset", response_model=SessionState) |
| async def reset_session(session_id: str, request: ResetSessionRequest) -> SessionState: |
| try: |
| return manager.reset_session( |
| session_id=session_id, |
| seed=request.seed, |
| training_agent=request.training_agent, |
| training_stage=request.training_stage, |
| max_turns=request.max_turns, |
| scenario_id=request.scenario_id, |
| replay_id=request.replay_id, |
| replay_start_index=request.replay_start_index, |
| ) |
| except KeyError as exc: |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc |
| except ValueError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) from exc |
|
|
| @app.get("/scenarios", response_model=list[ScenarioSummary]) |
| async def list_scenarios() -> list[ScenarioSummary]: |
| return manager.list_scenarios() |
|
|
| @app.post("/benchmarks/run", response_model=BenchmarkRunResponse) |
| async def run_benchmark(request: BenchmarkRunRequest) -> BenchmarkRunResponse: |
| try: |
| return manager.run_benchmark(request) |
| except ValueError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) from exc |
|
|
| @app.get("/sessions/{session_id}", response_model=SessionState) |
| async def get_session(session_id: str) -> SessionState: |
| try: |
| return manager.get_session(session_id) |
| except KeyError as exc: |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc |
|
|
| @app.post("/sessions/{session_id}/sources/refresh", response_model=SessionState) |
| async def refresh_session_sources(session_id: str) -> SessionState: |
| try: |
| return manager.refresh_session_sources(session_id=session_id, force=True) |
| except KeyError as exc: |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc |
|
|
| @app.get("/sessions/{session_id}/sources/monitor", response_model=SourceMonitorReport) |
| async def source_monitor(session_id: str) -> SourceMonitorReport: |
| try: |
| return manager.source_monitor(session_id=session_id) |
| except KeyError as exc: |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc |
|
|
| @app.get("/sessions/{session_id}/reactions", response_model=list[ReactionLogEntry]) |
| async def reaction_log(session_id: str) -> list[ReactionLogEntry]: |
| try: |
| return manager.reaction_log(session_id=session_id) |
| except KeyError as exc: |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc |
|
|
| @app.get("/sessions/{session_id}/providers/diagnostics", response_model=ProviderDiagnosticsResponse) |
| async def provider_diagnostics(session_id: str) -> ProviderDiagnosticsResponse: |
| try: |
| return manager.provider_diagnostics(session_id=session_id) |
| except KeyError as exc: |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc |
|
|
| @app.post("/sessions/{session_id}/news", response_model=IngestNewsResponse) |
| async def ingest_news(session_id: str, request: IngestNewsRequest) -> IngestNewsResponse: |
| try: |
| return manager.ingest_news(session_id=session_id, request=request) |
| except KeyError as exc: |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc |
| except ValueError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) from exc |
|
|
| @app.post("/sessions/{session_id}/live", response_model=SessionState) |
| async def set_live_mode(session_id: str, request: LiveControlRequest) -> SessionState: |
| try: |
| return manager.set_live_mode(session_id=session_id, request=request) |
| except KeyError as exc: |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc |
| except ValueError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) from exc |
|
|
| @app.post("/sessions/{session_id}/step", response_model=StepSessionResponse) |
| async def step_session(session_id: str, request: StepSessionRequest) -> StepSessionResponse: |
| try: |
| return manager.step_session(session_id=session_id, request=request) |
| except KeyError as exc: |
| raise HTTPException(status_code=404, detail=f"Unknown session: {session_id}") from exc |
| except ValueError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) from exc |
|
|
| @app.post("/reset", response_model=ResetEnvResponse) |
| async def reset_env(request: ResetEnvRequest) -> ResetEnvResponse: |
| observations, info = runtime.reset( |
| seed=request.seed, |
| training_stage=request.training_stage, |
| max_turns=request.max_turns, |
| scenario_id=request.scenario_id, |
| replay_id=request.replay_id, |
| replay_start_index=request.replay_start_index, |
| ) |
| return ResetEnvResponse(observations=observations, info=info) |
|
|
| @app.post("/step", response_model=StepEnvResponse) |
| async def step_env(request: StepEnvRequest) -> StepEnvResponse: |
| try: |
| observations, rewards, terminated, truncated, info = runtime.step( |
| actions=request.actions, |
| predictions=request.predictions, |
| external_signals=request.external_signals, |
| ) |
| except ValueError as exc: |
| raise HTTPException(status_code=400, detail=str(exc)) from exc |
| return StepEnvResponse( |
| observations=observations, |
| rewards=rewards, |
| terminated=terminated, |
| truncated=truncated, |
| info=info, |
| ) |
|
|
| @app.get("/state", response_model=SessionState) |
| async def state_env() -> SessionState: |
| session = runtime.state() |
| if session is None: |
| raise HTTPException(status_code=404, detail="No active OpenEnv runtime session.") |
| return session |
|
|
| return app |
|
|
|
|
| app = create_app() |
|
|
|
|
| def run() -> None: |
| uvicorn.run("trenches_env.server:app", host="0.0.0.0", port=8000, reload=False) |
|
|
|
|
| if __name__ == "__main__": |
| run() |
|
|