pyre_env / examples /random_agent.py
Krooz's picture
Upload folder using huggingface_hub
2d3e659 verified
"""Random-action baseline agent for Pyre (single-agent).
Runs N episodes using the PyreEnv client and prints per-episode stats.
Use this to smoke-test the server and verify the reward distribution
spans a meaningful range.
Usage:
# Server must be running first:
# cd pyre_env && uv run server
#
python examples/random_agent.py --episodes 5
python examples/random_agent.py --episodes 5 --verbose
python examples/random_agent.py --url http://localhost:8000 --episodes 10
"""
import argparse
import random
import sys
from typing import List
import requests
from pyre_env import PyreEnv, PyreAction
# ---------------------------------------------------------------------------
# Action sampling
# ---------------------------------------------------------------------------
def _parse_hint(hint: str) -> PyreAction:
"""Parse a hint string from available_actions_hint into a PyreAction."""
try:
h = hint.strip()
if h.startswith("move("):
return PyreAction(action="move", direction=h.split("'")[1])
elif h.startswith("door("):
parts = h.split("'")
# parts: ["door(target_id=", did, ", door_state=", state, ")"]
target_id = parts[1]
door_state = parts[3]
return PyreAction(action="door", target_id=target_id, door_state=door_state)
elif h == "wait()":
return PyreAction(action="wait")
except (IndexError, ValueError):
pass
return PyreAction(action="wait")
def random_action(hints: List[str], rng: random.Random) -> PyreAction:
"""Pick a random action, biasing toward available hints 70% of the time."""
if hints and rng.random() < 0.7:
return _parse_hint(rng.choice(hints))
# Fallback: random move
return PyreAction(action="move", direction=rng.choice(["north", "south", "east", "west"]))
# ---------------------------------------------------------------------------
# Episode runner
# ---------------------------------------------------------------------------
def run_episode(env, max_steps: int, rng: random.Random, verbose: bool) -> dict:
result = env.reset()
obs = result.observation
episode_reward = 0.0
steps = 0
done = result.done
while not done and steps < max_steps:
action = random_action(obs.available_actions_hint, rng)
result = env.step(action)
obs = result.observation
reward = result.reward or 0.0
done = result.done
episode_reward += reward
steps += 1
if verbose:
first_line = obs.narrative.split("\n")[0] if obs.narrative else ""
print(
f" step {steps:3d} | hp={obs.agent_health:5.1f}"
f" | r={reward:+.3f} | done={done} | {first_line[:70]}"
)
meta = obs.metadata or {}
return {
"steps": steps,
"total_reward": round(episode_reward, 3),
"done": done,
"evacuated": obs.agent_evacuated,
"final_health": obs.agent_health,
"wind_dir": obs.wind_dir,
"fire_sources": meta.get("fire_sources", "?"),
"fire_spread": meta.get("fire_spread_rate", "?"),
"last_narrative": obs.narrative[:120] if obs.narrative else "",
}
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(description="Pyre random-agent baseline")
parser.add_argument("--url", default="http://localhost:8000", help="Server base URL")
parser.add_argument("--episodes", type=int, default=5)
parser.add_argument("--max-steps", type=int, default=100)
parser.add_argument("--seed", type=int, default=7)
parser.add_argument("--verbose", action="store_true")
args = parser.parse_args()
# Health check
try:
r = requests.get(f"{args.url}/health", timeout=5)
r.raise_for_status()
print(f"Server healthy: {args.url}")
except Exception as e:
print(f"Server not reachable at {args.url}: {e}")
sys.exit(1)
rng = random.Random(args.seed)
results: List[dict] = []
with PyreEnv(base_url=args.url).sync() as env:
for ep in range(args.episodes):
print(f"\n=== Episode {ep + 1}/{args.episodes} ===")
stats = run_episode(env, args.max_steps, rng, args.verbose)
results.append(stats)
print(
f" DONE steps={stats['steps']} reward={stats['total_reward']:+.3f}"
f" health={stats['final_health']:.1f}"
f" wind={stats['wind_dir']} sources={stats['fire_sources']}"
f" spread={stats['fire_spread']}"
)
print("\n=== Summary ===")
rewards = [r["total_reward"] for r in results]
print(f"Episodes: {len(results)}")
print(f"Reward min/max: {min(rewards):.3f} / {max(rewards):.3f}")
print(f"Reward mean: {sum(rewards)/len(rewards):.3f}")
print(f"Avg steps: {sum(r['steps'] for r in results) / len(results):.1f}")
if __name__ == "__main__":
main()