Spaces:
Running
Running
| """ | |
| Stress test simulating large-scale agentic RL with AWM environments. | |
| Simulates one RL step: 2000 environments reset in parallel, each runs a | |
| multi-turn episode (random tool calls with LLM-like latency), then closes. | |
| Phases per session: | |
| 1. connect + reset — env startup | |
| 2. list_tools — tool discovery | |
| 3. N turns of tool calls — simulate multi-turn agent interaction | |
| (random tool, empty args, random "thinking" delay between turns) | |
| 4. done + close — episode end | |
| Usage: | |
| # Terminal 1: Start server | |
| PYTHONPATH=src:envs uv run uvicorn \ | |
| envs.agent_world_model_env.server.app:app \ | |
| --host 0.0.0.0 --port 8899 | |
| # Terminal 2: Run RL simulation (default 2000 envs) | |
| PYTHONPATH=src:envs uv run python \ | |
| envs/agent_world_model_env/example_stress_test.py | |
| # Custom scale | |
| PYTHONPATH=src:envs uv run python \ | |
| envs/agent_world_model_env/example_stress_test.py \ | |
| --scale 500 --concurrency 100 --min-turns 1 --max-turns 3 | |
| """ | |
| import argparse | |
| import asyncio | |
| import json | |
| import logging | |
| import os | |
| import random | |
| import statistics | |
| import sys | |
| import time | |
| from dataclasses import dataclass, field | |
| import httpx | |
| import psutil | |
| from openenv.core.env_server.mcp_types import CallToolAction, ListToolsAction | |
| from agent_world_model_env import AWMEnv | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| datefmt="%H:%M:%S", | |
| ) | |
| log = logging.getLogger("rl_stress") | |
| BASE_URL = "http://localhost:8899" | |
| CLIENT_TIMEOUT: float = 600.0 | |
| # Scenarios to cycle through | |
| SCENARIOS = [ | |
| "e_commerce_33", | |
| "inventory_management_7", | |
| "document_management_5", | |
| "billing_payments_3", | |
| "hris_employee_management_1", | |
| ] | |
| # RL simulation defaults | |
| MIN_TURNS = 3 | |
| MAX_TURNS = 20 | |
| # Simulate LLM rollout time: uniform [min, max] seconds | |
| LLM_THINK_MIN = 1.0 | |
| LLM_THINK_MAX = 20.0 | |
| # --------------------------------------------------------------------------- | |
| # Data classes | |
| # --------------------------------------------------------------------------- | |
| class SessionResult: | |
| session_id: int | |
| scenario: str | |
| task_idx: int | |
| num_turns: int # planned turns | |
| turns_completed: int = 0 | |
| connect_s: float = 0.0 | |
| reset_s: float = 0.0 | |
| list_tools_s: float = 0.0 | |
| tool_call_latencies: list[float] = field(default_factory=list) | |
| done_s: float = 0.0 | |
| total_s: float = 0.0 | |
| success: bool = False | |
| error: str | None = None | |
| num_tools: int = 0 | |
| tools_discovered: list[str] = field(default_factory=list) | |
| # --------------------------------------------------------------------------- | |
| # System resource monitor | |
| # --------------------------------------------------------------------------- | |
| class ResourceMonitor: | |
| """Periodically samples CPU and memory in a background task.""" | |
| def __init__(self, interval: float = 2.0): | |
| self._interval = interval | |
| self._samples: list[dict] = [] | |
| self._task: asyncio.Task | None = None | |
| self._process = psutil.Process(os.getpid()) | |
| self._server_pid: int | None = None | |
| def start(self, server_pid: int | None = None): | |
| self._server_pid = server_pid | |
| self._task = asyncio.create_task(self._loop()) | |
| async def stop(self): | |
| if self._task: | |
| self._task.cancel() | |
| try: | |
| await self._task | |
| except asyncio.CancelledError: | |
| pass | |
| async def _loop(self): | |
| while True: | |
| sample = { | |
| "time": time.monotonic(), | |
| "system_cpu_pct": psutil.cpu_percent(interval=0), | |
| "system_mem_pct": psutil.virtual_memory().percent, | |
| "system_mem_used_gb": round( | |
| psutil.virtual_memory().used / (1024**3), 2 | |
| ), | |
| "client_mem_mb": round(self._process.memory_info().rss / (1024**2), 1), | |
| } | |
| if self._server_pid: | |
| try: | |
| server_proc = psutil.Process(self._server_pid) | |
| children = server_proc.children(recursive=True) | |
| server_mem = server_proc.memory_info().rss | |
| for child in children: | |
| try: | |
| server_mem += child.memory_info().rss | |
| except (psutil.NoSuchProcess, psutil.AccessDenied): | |
| pass | |
| sample["server_tree_mem_mb"] = round(server_mem / (1024**2), 1) | |
| sample["server_children"] = len(children) | |
| except (psutil.NoSuchProcess, psutil.AccessDenied): | |
| pass | |
| self._samples.append(sample) | |
| await asyncio.sleep(self._interval) | |
| def summary(self) -> dict: | |
| if not self._samples: | |
| return {} | |
| cpu_vals = [s["system_cpu_pct"] for s in self._samples] | |
| mem_vals = [s["system_mem_used_gb"] for s in self._samples] | |
| client_mem = [s["client_mem_mb"] for s in self._samples] | |
| result = { | |
| "samples": len(self._samples), | |
| "cpu_pct": { | |
| "mean": round(statistics.mean(cpu_vals), 1), | |
| "max": round(max(cpu_vals), 1), | |
| }, | |
| "system_mem_gb": { | |
| "min": round(min(mem_vals), 2), | |
| "max": round(max(mem_vals), 2), | |
| }, | |
| "client_mem_mb": { | |
| "min": round(min(client_mem), 1), | |
| "max": round(max(client_mem), 1), | |
| }, | |
| } | |
| server_mem = [ | |
| s["server_tree_mem_mb"] for s in self._samples if "server_tree_mem_mb" in s | |
| ] | |
| if server_mem: | |
| result["server_tree_mem_mb"] = { | |
| "min": round(min(server_mem), 1), | |
| "max": round(max(server_mem), 1), | |
| } | |
| server_children = [ | |
| s["server_children"] for s in self._samples if "server_children" in s | |
| ] | |
| if server_children: | |
| result["server_subprocess_peak"] = max(server_children) | |
| return result | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def latency_stats(values: list[float]) -> dict: | |
| if not values: | |
| return {} | |
| s = sorted(values) | |
| return { | |
| "count": len(s), | |
| "min": round(min(s), 3), | |
| "p50": round(s[len(s) // 2], 3), | |
| "p90": round(s[int(len(s) * 0.9)], 3), | |
| "p99": round(s[int(len(s) * 0.99)], 3), | |
| "max": round(max(s), 3), | |
| "mean": round(statistics.mean(s), 3), | |
| } | |
| async def check_server(url: str) -> int | None: | |
| """Check server is up, return server PID if available.""" | |
| async with httpx.AsyncClient() as client: | |
| resp = await client.get(f"{url}/docs", timeout=10) | |
| resp.raise_for_status() | |
| # Try to find server PID by matching the port in the URL | |
| from urllib.parse import urlparse | |
| port = str(urlparse(url).port or "8899") | |
| for proc in psutil.process_iter(["pid", "cmdline"]): | |
| try: | |
| cmdline = " ".join(proc.info["cmdline"] or []) | |
| if "uvicorn" in cmdline and port in cmdline: | |
| return proc.info["pid"] | |
| except (psutil.NoSuchProcess, psutil.AccessDenied): | |
| pass | |
| return None | |
| async def fetch_server_stats(url: str) -> dict | None: | |
| try: | |
| async with httpx.AsyncClient() as client: | |
| resp = await client.get(f"{url}/stats", timeout=5) | |
| return resp.json() | |
| except Exception: | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Single session: full RL episode | |
| # --------------------------------------------------------------------------- | |
| class ProgressCounters: | |
| """Shared counters updated by each session for live progress reporting.""" | |
| done: int = 0 | |
| ok: int = 0 | |
| fail: int = 0 | |
| resets_done: int = 0 | |
| turns_done: int = 0 | |
| async def run_rl_episode( | |
| session_id: int, | |
| scenario: str, | |
| task_idx: int, | |
| num_turns: int, | |
| reset_semaphore: asyncio.Semaphore, | |
| interact_semaphore: asyncio.Semaphore, | |
| counters: ProgressCounters, | |
| ) -> SessionResult: | |
| """Simulate a full RL episode: reset -> list_tools -> N tool calls -> done.""" | |
| r = SessionResult( | |
| session_id=session_id, | |
| scenario=scenario, | |
| task_idx=task_idx, | |
| num_turns=num_turns, | |
| ) | |
| session_start = time.monotonic() | |
| phase = "init" | |
| env = AWMEnv( | |
| base_url=BASE_URL, message_timeout_s=CLIENT_TIMEOUT, connect_timeout_s=60.0 | |
| ) | |
| try: | |
| # -- Phase 1: connect + reset (rate-limited to avoid thundering herd) -- | |
| async with reset_semaphore: | |
| phase = "connect" | |
| t0 = time.monotonic() | |
| await env.connect() | |
| r.connect_s = time.monotonic() - t0 | |
| phase = "reset" | |
| t0 = time.monotonic() | |
| result = await env.reset(scenario=scenario, task_idx=task_idx) | |
| r.reset_s = time.monotonic() - t0 | |
| if result.observation.reward_type not in ("reset_ok", "reset_warning"): | |
| r.error = f"reset failed: {result.observation.error}" | |
| counters.done += 1 | |
| counters.fail += 1 | |
| return r | |
| r.num_tools = result.observation.num_tools or 0 | |
| counters.resets_done += 1 | |
| # -- Phase 2: list_tools -- | |
| phase = "list_tools" | |
| t0 = time.monotonic() | |
| result = await env.step(ListToolsAction()) | |
| r.list_tools_s = time.monotonic() - t0 | |
| # Collect tool names for random calling | |
| obs = result.observation | |
| if hasattr(obs, "tools") and obs.tools: | |
| r.tools_discovered = [ | |
| t.get("name", t.get("tool_name", "")) | |
| for t in obs.tools | |
| if isinstance(t, dict) | |
| ] | |
| if not r.tools_discovered: | |
| r.tools_discovered = ["unknown_tool"] | |
| # -- Phase 3: multi-turn tool calling (simulate agent interaction) -- | |
| async with interact_semaphore: | |
| for turn in range(num_turns): | |
| phase = f"turn_{turn}" | |
| # Simulate LLM thinking time (async sleep = non-blocking) | |
| think_time = random.uniform(LLM_THINK_MIN, LLM_THINK_MAX) | |
| await asyncio.sleep(think_time) | |
| # Pick a random tool and call with empty args (will fail, that's fine) | |
| tool_name = random.choice(r.tools_discovered) | |
| t0 = time.monotonic() | |
| try: | |
| result = await env.step( | |
| CallToolAction(tool_name=tool_name, arguments={}) | |
| ) | |
| except Exception: | |
| # Tool call failure is expected (no args), just measure latency | |
| pass | |
| r.tool_call_latencies.append(time.monotonic() - t0) | |
| r.turns_completed += 1 | |
| counters.turns_done += 1 | |
| # -- Phase 4: done + close -- | |
| phase = "done" | |
| t0 = time.monotonic() | |
| await env.step( | |
| CallToolAction(tool_name="done", arguments={"keep_session": False}) | |
| ) | |
| r.done_s = time.monotonic() - t0 | |
| r.success = True | |
| counters.ok += 1 | |
| except Exception as e: | |
| r.error = f"[{phase}] {type(e).__name__}: {str(e)[:200]}" | |
| counters.fail += 1 | |
| finally: | |
| r.total_s = time.monotonic() - session_start | |
| counters.done += 1 | |
| try: | |
| await env.close() | |
| except Exception: | |
| pass | |
| return r | |
| # --------------------------------------------------------------------------- | |
| # Progress reporter | |
| # --------------------------------------------------------------------------- | |
| async def progress_reporter( | |
| counters: ProgressCounters, | |
| total: int, | |
| total_turns: int, | |
| monitor: ResourceMonitor, | |
| interval: float = 10.0, | |
| ): | |
| """Periodically log progress while the test runs.""" | |
| start = time.monotonic() | |
| while True: | |
| await asyncio.sleep(interval) | |
| elapsed = time.monotonic() - start | |
| in_flight = total - counters.done | |
| stats = await fetch_server_stats(BASE_URL) | |
| server_sessions = stats.get("total_sessions", "?") if stats else "?" | |
| # Current resource snapshot | |
| samples = monitor._samples | |
| last = samples[-1] if samples else {} | |
| cpu = last.get("system_cpu_pct", "?") | |
| mem = last.get("system_mem_used_gb", "?") | |
| server_mem = last.get("server_tree_mem_mb", "?") | |
| children = last.get("server_children", "?") | |
| log.info( | |
| f"[{elapsed:.0f}s] episodes={counters.done}/{total} " | |
| f"ok={counters.ok} fail={counters.fail} " | |
| f"resets={counters.resets_done} " | |
| f"turns={counters.turns_done}/{total_turns} " | |
| f"in_flight={in_flight} | " | |
| f"server={server_sessions} subprocs={children} | " | |
| f"cpu={cpu}% mem={mem}GB server={server_mem}MB" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Main test | |
| # --------------------------------------------------------------------------- | |
| async def run_rl_step( | |
| scale: int, | |
| concurrency: int, | |
| min_turns: int, | |
| max_turns: int, | |
| ) -> tuple[list[SessionResult], dict]: | |
| """Run one RL step: launch `scale` episodes in parallel.""" | |
| log.info("=" * 78) | |
| log.info( | |
| f"RL STEP SIMULATION: {scale} envs, concurrency={concurrency}, " | |
| f"turns={min_turns}-{max_turns}, timeout={CLIENT_TIMEOUT}s" | |
| ) | |
| log.info("=" * 78) | |
| # Discover server PID for resource monitoring | |
| server_pid = await check_server(BASE_URL) | |
| log.info(f"Server reachable (pid={server_pid})") | |
| monitor = ResourceMonitor(interval=2.0) | |
| monitor.start(server_pid) | |
| # Two semaphores: | |
| # - reset_semaphore: limits concurrent resets (heavy: subprocess spawn) | |
| # - interact_semaphore: limits concurrent multi-turn interaction | |
| reset_semaphore = asyncio.Semaphore(concurrency) | |
| interact_semaphore = asyncio.Semaphore(scale) # no limit on interaction | |
| # Pre-assign turns per session | |
| turn_counts = [random.randint(min_turns, max_turns) for _ in range(scale)] | |
| total_planned_turns = sum(turn_counts) | |
| counters = ProgressCounters() | |
| # Launch progress reporter | |
| progress_task = asyncio.create_task( | |
| progress_reporter(counters, scale, total_planned_turns, monitor) | |
| ) | |
| wall_start = time.monotonic() | |
| tasks = [] | |
| for i in range(scale): | |
| scenario = SCENARIOS[i % len(SCENARIOS)] | |
| task_idx = i % 10 | |
| tasks.append( | |
| run_rl_episode( | |
| session_id=i, | |
| scenario=scenario, | |
| task_idx=task_idx, | |
| num_turns=turn_counts[i], | |
| reset_semaphore=reset_semaphore, | |
| interact_semaphore=interact_semaphore, | |
| counters=counters, | |
| ) | |
| ) | |
| completed = await asyncio.gather(*tasks) | |
| wall_s = time.monotonic() - wall_start | |
| progress_task.cancel() | |
| try: | |
| await progress_task | |
| except asyncio.CancelledError: | |
| pass | |
| await monitor.stop() | |
| resource_summary = monitor.summary() | |
| # --------------- Report --------------- | |
| ok = [r for r in completed if r.success] | |
| failed = [r for r in completed if not r.success] | |
| total_turns = sum(r.turns_completed for r in completed) | |
| total_planned = sum(r.num_turns for r in completed) | |
| log.info("") | |
| log.info(f"{'=' * 78}") | |
| log.info( | |
| f"RESULTS: {len(ok)}/{scale} succeeded, {len(failed)} failed, wall={wall_s:.1f}s" | |
| ) | |
| log.info(f"Total turns: {total_turns}/{total_planned} completed") | |
| log.info(f"{'=' * 78}") | |
| # Latency distributions | |
| for label, values in [ | |
| ("connect", [r.connect_s for r in ok]), | |
| ("reset", [r.reset_s for r in ok]), | |
| ("list_tools", [r.list_tools_s for r in ok]), | |
| ("tool_call", [lat for r in ok for lat in r.tool_call_latencies]), | |
| ("done", [r.done_s for r in ok]), | |
| ("episode_total", [r.total_s for r in ok]), | |
| ]: | |
| stats = latency_stats(values) | |
| if stats: | |
| log.info(f" {label:>14s}: {json.dumps(stats)}") | |
| # Resource summary | |
| log.info("") | |
| log.info(f" {'RESOURCES':>14s}: {json.dumps(resource_summary)}") | |
| # Turn distribution | |
| if ok: | |
| turn_dist = [r.num_turns for r in ok] | |
| log.info( | |
| f" {'turns/episode':>14s}: min={min(turn_dist)} max={max(turn_dist)} " | |
| f"mean={statistics.mean(turn_dist):.1f}" | |
| ) | |
| # Failures | |
| if failed: | |
| log.warning("") | |
| log.warning(f" {len(failed)} failures:") | |
| for r in failed[:20]: | |
| log.warning( | |
| f" session {r.session_id} ({r.scenario}/{r.task_idx}, " | |
| f"turns={r.turns_completed}/{r.num_turns}): {r.error}" | |
| ) | |
| if len(failed) > 20: | |
| log.warning(f" ... and {len(failed) - 20} more") | |
| return list(completed), resource_summary | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def parse_args(): | |
| p = argparse.ArgumentParser( | |
| description="AWM stress test — simulates large-scale agentic RL" | |
| ) | |
| p.add_argument( | |
| "--scale", | |
| type=int, | |
| default=2000, | |
| help="Number of parallel environments per RL step (default: 2000)", | |
| ) | |
| p.add_argument( | |
| "--concurrency", | |
| type=int, | |
| default=256, | |
| help="Max concurrent resets (default: 256)", | |
| ) | |
| p.add_argument( | |
| "--min-turns", | |
| type=int, | |
| default=3, | |
| help="Min tool-call turns per episode (default: 3)", | |
| ) | |
| p.add_argument( | |
| "--max-turns", | |
| type=int, | |
| default=20, | |
| help="Max tool-call turns per episode (default: 20)", | |
| ) | |
| p.add_argument( | |
| "--think-min", | |
| type=float, | |
| default=1.0, | |
| help="Min LLM rollout time per turn in seconds (default: 1.0)", | |
| ) | |
| p.add_argument( | |
| "--think-max", | |
| type=float, | |
| default=20.0, | |
| help="Max LLM rollout time per turn in seconds (default: 20.0)", | |
| ) | |
| p.add_argument( | |
| "--url", | |
| default="http://localhost:8899", | |
| help="Server base URL (default: http://localhost:8899)", | |
| ) | |
| p.add_argument( | |
| "--client-timeout", | |
| type=float, | |
| default=600.0, | |
| help="Client message timeout in seconds (default: 600)", | |
| ) | |
| return p.parse_args() | |
| async def main(): | |
| args = parse_args() | |
| global BASE_URL, CLIENT_TIMEOUT, MIN_TURNS, MAX_TURNS, LLM_THINK_MIN, LLM_THINK_MAX | |
| BASE_URL = args.url | |
| CLIENT_TIMEOUT = args.client_timeout | |
| MIN_TURNS = args.min_turns | |
| MAX_TURNS = args.max_turns | |
| LLM_THINK_MIN = args.think_min | |
| LLM_THINK_MAX = args.think_max | |
| log.info(f"AWM RL Stress Test — server: {BASE_URL}") | |
| try: | |
| await check_server(BASE_URL) | |
| except Exception as e: | |
| log.error(f"Cannot reach server at {BASE_URL}: {e}") | |
| sys.exit(1) | |
| results, resources = await run_rl_step( | |
| args.scale, args.concurrency, args.min_turns, args.max_turns | |
| ) | |
| ok = sum(1 for r in results if r.success) | |
| fail = len(results) - ok | |
| log.info("") | |
| log.info("=" * 78) | |
| log.info("FINAL SUMMARY") | |
| log.info("=" * 78) | |
| log.info( | |
| f" scale={args.scale} concurrency={args.concurrency} ok={ok} fail={fail}" | |
| ) | |
| log.info( | |
| f" turns_range=[{args.min_turns},{args.max_turns}] " | |
| f"total_turns={sum(r.turns_completed for r in results)}" | |
| ) | |
| if resources: | |
| log.info(f" resources={json.dumps(resources)}") | |
| if fail > 0: | |
| log.error("SOME EPISODES FAILED") | |
| sys.exit(1) | |
| else: | |
| log.info("ALL EPISODES PASSED") | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |