| """ |
| test_env.py — Simulation Runner & Sanity Tests |
| ================================================ |
| |
| Provides two entry-points: |
| |
| run_simulation(mode) – Run one full episode and print a formatted report. |
| run_all() – Run all three difficulty modes and compare. |
| run_sanity_checks() – Fast correctness assertions (no pytest needed). |
| |
| Usage |
| ----- |
| python test_env.py # runs all modes + sanity checks |
| python test_env.py easy # run a single mode |
| """ |
|
|
| from __future__ import annotations |
|
|
| import sys |
| import builtins |
| from typing import Dict, Any |
|
|
| from env import TrafficEnv |
| from tasks import get_config |
| from baseline_agent import RuleBasedAgent |
|
|
|
|
| |
| |
| |
|
|
| _COL = 80 |
|
|
|
|
| def _separator(char: str = "─") -> str: |
| return char * _COL |
|
|
|
|
| _ASCII_FALLBACKS = ( |
| ("\u2550", "="), |
| ("\u2500", "-"), |
| ("\u2502", "|"), |
| ("\u00b7", "-"), |
| ("\U0001F6A8", "EV"), |
| ("\u2713", "PASS"), |
| ("\u2717", "FAIL"), |
| ("\u26a0\ufe0f", "WARNING"), |
| ("\u2705", "PASS"), |
| ("\u2014", "-"), |
| ("\u2265", ">="), |
| ("\u2264", "<="), |
| ("\u2208", "in"), |
| ) |
|
|
|
|
| def _safe_text(text: str) -> str: |
| encoding = getattr(sys.stdout, "encoding", None) or "utf-8" |
| try: |
| text.encode(encoding) |
| return text |
| except UnicodeEncodeError: |
| for src, dest in _ASCII_FALLBACKS: |
| text = text.replace(src, dest) |
| return text |
|
|
|
|
| def print(*args, **kwargs) -> None: |
| """ |
| Safe local print wrapper: |
| - keeps rich Unicode output when supported |
| - falls back to ASCII-safe glyphs on limited encodings (e.g. cp1252) |
| """ |
| file = kwargs.get("file", sys.stdout) |
| if file is not sys.stdout: |
| builtins.print(*args, **kwargs) |
| return |
|
|
| sep = kwargs.get("sep", " ") |
| end = kwargs.get("end", "\n") |
| flush = kwargs.get("flush", False) |
| text = sep.join(str(arg) for arg in args) |
| builtins.print(_safe_text(text), end=end, flush=flush, file=file) |
|
|
|
|
| def _fmt_metric(key: str, value: Any) -> str: |
| label = key.replace("_", " ").title() |
| if isinstance(value, float): |
| return f" {label:<30} {value:.4f}" |
| return f" {label:<30} {value}" |
|
|
|
|
| |
| |
| |
|
|
| def run_simulation(mode: str = "medium", verbose: bool = True) -> Dict[str, Any]: |
| """ |
| Run one complete episode in the specified difficulty mode. |
| |
| Parameters |
| ---------- |
| mode : str |
| "easy", "medium", or "hard" |
| verbose : bool |
| Print step-by-step output if True. |
| |
| Returns |
| ------- |
| dict |
| Final info metrics plus 'cumulative_reward' and 'mode'. |
| """ |
| config = get_config(mode) |
| env = TrafficEnv(config) |
| agent = RuleBasedAgent( |
| min_green_time=5, |
| imbalance_threshold=5, |
| max_green_time=15, |
| emergency_min_green=2, |
| ) |
|
|
| state = env.reset() |
| agent.reset() |
| done = False |
| total_reward = 0.0 |
| step_rewards = [] |
|
|
| if verbose: |
| print() |
| print(_separator("═")) |
| print(f" TRAFFIC SIGNAL SIMULATION · Mode: {mode.upper()}") |
| print(_separator("═")) |
| header = ( |
| f"{'Step':<6} │ {'Phase':<4} │ " |
| f"{'N':>4} {'S':>4} {'E':>4} {'W':>4} │ " |
| f"{'NS':>4} {'EW':>4} │ " |
| f"{'Reward':>8} │ EV" |
| ) |
| print(header) |
| print(_separator()) |
|
|
| while not done: |
| action = agent.select_action(state) |
| next_state, reward, done, info = env.step(action) |
| total_reward += reward |
| step_rewards.append(reward) |
|
|
| if verbose: |
| phase_str = "NS" if next_state["phase"] == 0 else "EW" |
| ns_q = next_state["north_cars"] + next_state["south_cars"] |
| ew_q = next_state["east_cars"] + next_state["west_cars"] |
| ev_flags = next_state["emergency_flags"] |
| ev_active = "🚨" if any(ev_flags.values()) else " " |
|
|
| |
| if env.step_count % 5 == 0 or any(ev_flags.values()): |
| print( |
| f"{env.step_count:<6} │ {phase_str:<4} │ " |
| f"{next_state['north_cars']:>4} " |
| f"{next_state['south_cars']:>4} " |
| f"{next_state['east_cars']:>4} " |
| f"{next_state['west_cars']:>4} │ " |
| f"{ns_q:>4} {ew_q:>4} │ " |
| f"{reward:>8.3f} │ {ev_active}" |
| ) |
|
|
| state = next_state |
|
|
| if verbose: |
| print(_separator()) |
| print(f"\n FINAL METRICS ({mode.upper()})") |
| print(_separator()) |
| for k, v in info.items(): |
| print(_fmt_metric(k, v)) |
| print(_fmt_metric("cumulative_reward", total_reward)) |
| if step_rewards: |
| print(_fmt_metric("min_step_reward", min(step_rewards))) |
| print(_fmt_metric("max_step_reward", max(step_rewards))) |
| print() |
|
|
| result = dict(info) |
| result["cumulative_reward"] = total_reward |
| result["mode"] = mode |
| return result |
|
|
|
|
| |
| |
| |
|
|
| def run_all() -> None: |
| """Run easy, medium and hard in sequence; print a comparison table.""" |
| results = {} |
| for mode in ("easy", "medium", "hard"): |
| results[mode] = run_simulation(mode, verbose=True) |
|
|
| print() |
| print(_separator("═")) |
| print(" CROSS-MODE COMPARISON") |
| print(_separator("═")) |
| metrics = [ |
| "total_cleared", "avg_waiting_time", |
| "max_queue_length", "signal_switch_count", |
| "congestion_score", "avg_ev_clear_time", |
| "fairness_score", "cumulative_reward", |
| ] |
| col_w = 18 |
| header = f" {'Metric':<30}" + "".join(f"{m.upper():>{col_w}}" for m in ("easy", "medium", "hard")) |
| print(header) |
| print(_separator()) |
| for m in metrics: |
| row = f" {m.replace('_',' ').title():<30}" |
| for mode in ("easy", "medium", "hard"): |
| val = results[mode].get(m, "—") |
| if isinstance(val, float): |
| row += f"{val:>{col_w}.3f}" |
| else: |
| row += f"{val:>{col_w}}" |
| print(row) |
| print(_separator("═")) |
| print() |
|
|
|
|
| |
| |
| |
|
|
| def run_sanity_checks() -> None: |
| """Assert basic correctness invariants for all difficulty modes.""" |
| print() |
| print(_separator("═")) |
| print(" SANITY CHECKS") |
| print(_separator("═")) |
|
|
| passed = 0 |
| failed = 0 |
|
|
| def check(name: str, condition: bool) -> None: |
| nonlocal passed, failed |
| status = "✓ PASS" if condition else "✗ FAIL" |
| print(f" [{status}] {name}") |
| if condition: |
| passed += 1 |
| else: |
| failed += 1 |
|
|
| for mode in ("easy", "medium", "hard"): |
| cfg = get_config(mode) |
| env = TrafficEnv(cfg) |
| agent = RuleBasedAgent() |
|
|
| |
| state = env.reset() |
| agent.reset() |
| check( |
| f"[{mode}] reset() returns all-zero queues", |
| all(state[f"{d}_cars"] == 0 for d in ("north", "south", "east", "west")), |
| ) |
|
|
| |
| action = agent.select_action(state) |
| result = env.step(action) |
| check(f"[{mode}] step() returns 4-tuple", len(result) == 4) |
|
|
| ns, reward, done, info = result |
|
|
| |
| check(f"[{mode}] reward in [-1, 1]", -1.0 <= reward <= 1.0) |
|
|
| |
| required_keys = { |
| "north_cars", "south_cars", "east_cars", "west_cars", |
| "waiting_times", "phase", "emergency_flags", "step_count", |
| } |
| check(f"[{mode}] state has required keys", required_keys.issubset(ns.keys())) |
|
|
| |
| required_info = { |
| "total_cleared", "avg_waiting_time", |
| "max_queue_length", "signal_switch_count", |
| "congestion_score", "avg_ev_clear_time", |
| "fairness_score", |
| } |
| check(f"[{mode}] info has required keys", required_info.issubset(info.keys())) |
|
|
| |
| for _ in range(cfg["max_steps"]): |
| a = agent.select_action(ns) |
| ns, _, done, _ = env.step(a) |
| if done: |
| break |
| all_non_neg = all(v >= 0 for v in env.queues.values()) |
| check(f"[{mode}] queues never go negative (full episode)", all_non_neg) |
|
|
| |
| check( |
| f"[{mode}] queues never exceed max_queue ({cfg['max_queue']})", |
| all(v <= cfg["max_queue"] for v in env.queues.values()), |
| ) |
|
|
| |
| check(f"[{mode}] phase is always 0 or 1", env.phase in (0, 1)) |
|
|
| |
| check(f"[{mode}] total_cleared ≥ 0", env.total_cleared >= 0) |
|
|
| |
| score = info["congestion_score"] |
| check(f"[{mode}] congestion_score ∈ [0, 1]", 0.0 <= score <= 1.0) |
|
|
| print() |
|
|
| print(_separator()) |
| print(f" Results: {passed} passed, {failed} failed") |
| print(_separator("═")) |
| if failed: |
| print(" ⚠️ Some checks failed — review the environment logic.") |
| else: |
| print(" ✅ All sanity checks passed.") |
| print() |
|
|
|
|
| |
| |
| |
|
|
| if __name__ == "__main__": |
| if len(sys.argv) == 2 and sys.argv[1].lower() in ("easy", "medium", "hard"): |
| run_simulation(sys.argv[1].lower(), verbose=True) |
| else: |
| run_all() |
| run_sanity_checks() |
|
|