| """FastAPI server exposing the OpenSleuth environment over HTTP. |
| |
| Two HTTP surfaces are served from this app: |
| |
| * The legacy OpenSleuth contract (``/health``, ``/functions``, ``/tasks``, |
| ``/reset``, ``/step``, ``/state/{episode_id}``, ``/probe_once``) used by the |
| in-flight trainer and eval harness. |
| * The OpenEnv-conformant sub-app mounted at ``/openenv/*`` (added in v0.5.0 |
| for hackathon conformance) -- exposes ``/openenv/reset``, ``/openenv/step``, |
| ``/openenv/state``, ``/openenv/health``, ``/openenv/metadata``, |
| ``/openenv/schema``, and the canonical ``/openenv/ws`` WebSocket. See |
| :mod:`opensleuth_env.openenv_adapter` and |
| https://github.com/meta-pytorch/OpenEnv (v0.2.3). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| import random |
| from typing import Optional |
|
|
| from fastapi import FastAPI, HTTPException, Query |
|
|
| from opensleuth_env import ( |
| BLACK_BOX_FUNCTIONS, |
| OpenSleuthEnv, |
| ProbeAction, |
| ResetRequest, |
| StepRequest, |
| StepResponse, |
| SubmitAction, |
| TaskCatalog, |
| ) |
| from opensleuth_env.task_catalog import TaskResolutionError |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") |
| log = logging.getLogger("opensleuth.server") |
|
|
| app = FastAPI(title="OpenSleuth Env", version="0.5.0") |
| env = OpenSleuthEnv() |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| try: |
| from openenv.core.env_server.http_server import HTTPEnvServer |
|
|
| from opensleuth_env.openenv_adapter import ( |
| OPENENV_AVAILABLE, |
| OpenSleuthAction, |
| OpenSleuthEnvironment, |
| OpenSleuthObservation, |
| ) |
|
|
| if OPENENV_AVAILABLE: |
| openenv_app = FastAPI( |
| title="OpenSleuth (OpenEnv-conformant)", |
| version="0.5.0", |
| description=( |
| "OpenEnv 0.2.x conformant surface for the OpenSleuth environment.\n\n" |
| "See https://github.com/meta-pytorch/OpenEnv -- this sub-app implements" |
| " the canonical reset/step/state/health/metadata/schema HTTP routes plus" |
| " the /ws WebSocket session protocol." |
| ), |
| ) |
| _openenv_server = HTTPEnvServer( |
| env=OpenSleuthEnvironment, |
| action_cls=OpenSleuthAction, |
| observation_cls=OpenSleuthObservation, |
| max_concurrent_envs=8, |
| ) |
| _openenv_server.register_routes(openenv_app) |
| app.mount("/openenv", openenv_app) |
| log.info("Mounted OpenEnv-conformant sub-app at /openenv (openenv-core %s)", |
| _openenv_server.__class__.__module__) |
| else: |
| log.warning("openenv-core not importable; /openenv/* will be unavailable.") |
| except Exception as e: |
| log.warning("Could not register OpenEnv sub-app: %s: %s", type(e).__name__, e) |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return { |
| "status": "ok", |
| "episodes_tracked": len(env._states), |
| "hub": env.catalog.hub_status(), |
| } |
|
|
|
|
| @app.get("/functions") |
| def list_functions( |
| difficulty: Optional[str] = Query( |
| None, |
| description="Optional filter: easy / medium / hard. Used by the trainer for curriculum scheduling.", |
| ), |
| ): |
| |
| |
| |
| |
| items = [] |
| for s in BLACK_BOX_FUNCTIONS.values(): |
| if difficulty is not None and getattr(s, "difficulty", None) != difficulty: |
| continue |
| items.append( |
| { |
| "name": s.name, |
| "signature": s.signature, |
| "description": s.description, |
| "difficulty": getattr(s, "difficulty", None), |
| "edge_case_count": len(getattr(s, "edge_cases", []) or []), |
| "source": "builtin", |
| } |
| ) |
| return {"functions": items} |
|
|
|
|
| @app.get("/tasks") |
| def list_tasks( |
| source: str = Query( |
| "all", |
| description="Filter by source: 'builtin', 'hub', or 'all' (default).", |
| ), |
| difficulty: Optional[str] = Query(None, description="Optional curriculum filter."), |
| ): |
| src = source.lower() |
| if src == "builtin": |
| tasks = env.catalog.list_builtin() |
| elif src == "hub": |
| tasks = env.catalog.list_hub() |
| elif src == "all": |
| tasks = env.catalog.list_all() |
| else: |
| raise HTTPException( |
| status_code=400, detail="source must be one of: builtin, hub, all" |
| ) |
| if difficulty is not None: |
| tasks = [t for t in tasks if t.get("difficulty") == difficulty] |
| return { |
| "tasks": tasks, |
| "count": len(tasks), |
| "hub": env.catalog.hub_status(), |
| } |
|
|
|
|
| @app.post("/reset") |
| def reset(req: ResetRequest): |
| |
| |
| |
| if not req.target_name and not req.target_code: |
| raise HTTPException( |
| status_code=400, |
| detail="Either 'target_name' or ('target_code' + 'target_function_name') must be set.", |
| ) |
| if req.target_code and not req.target_function_name: |
| raise HTTPException( |
| status_code=400, |
| detail="'target_function_name' is required when 'target_code' is provided.", |
| ) |
| try: |
| obs = env.reset( |
| target_name=req.target_name, |
| seed=req.seed, |
| max_steps=req.max_steps, |
| target_code=req.target_code, |
| target_function_name=req.target_function_name, |
| edge_cases=req.edge_cases, |
| fuzz_spec=req.fuzz_spec, |
| ) |
| except ValueError as e: |
| raise HTTPException(status_code=400, detail=str(e)) from e |
| return obs |
|
|
|
|
| @app.post("/step", response_model=StepResponse) |
| def step(req: StepRequest): |
| try: |
| return env.step(req.episode_id, req.action) |
| except KeyError as e: |
| raise HTTPException(status_code=404, detail=str(e)) from e |
|
|
|
|
| @app.get("/state/{episode_id}") |
| def get_state(episode_id: str): |
| state = env.get_state(episode_id) |
| if not state: |
| raise HTTPException(status_code=404, detail=f"Unknown episode_id {episode_id!r}") |
| return state |
|
|
|
|
| @app.post("/probe_once") |
| def probe_once(target_name: str, input_repr: str): |
| obs = env.reset(target_name=target_name) |
| resp = env.step(obs.episode_id, ProbeAction(input_repr=input_repr)) |
| return resp |
|
|
|
|
| @app.get("/tasks/{name}/sample_inputs") |
| def sample_inputs( |
| name: str, |
| n: int = Query(8, ge=1, le=64, description="How many inputs to draw."), |
| seed: int = Query(0, description="Deterministic seed for the fuzzer."), |
| ): |
| """Return ``n`` Python-literal `repr` strings drawn from the task's |
| auto-fuzzer (or hand-written fuzzer for builtins). |
| |
| Used by the trainer to build in-context probe pools without having to |
| duplicate the auto-fuzzer logic on the trainer side. Each returned |
| string is `ast.literal_eval`-safe and can be POSTed straight back to |
| `/step` as a `ProbeAction.input_repr`. |
| """ |
| try: |
| spec = env.catalog.resolve(target_name=name) |
| except TaskResolutionError as e: |
| raise HTTPException(status_code=404, detail=str(e)) from e |
| rng = random.Random(seed) |
| try: |
| raw_inputs = spec.fuzzer(rng, n) |
| except Exception as e: |
| raise HTTPException( |
| status_code=500, |
| detail=f"fuzzer for {name!r} failed: {type(e).__name__}: {e}", |
| ) from e |
| return { |
| "name": name, |
| "n": n, |
| "seed": seed, |
| "unpack_args": bool(getattr(spec, "unpack_args", False)), |
| "inputs": [repr(x) for x in raw_inputs], |
| } |
|
|