| | import json |
| | import math |
| | from dataclasses import dataclass, asdict |
| | from typing import Dict, List, Tuple, Optional |
| |
|
| | import numpy as np |
| | from PIL import Image, ImageDraw |
| |
|
| | import gradio as gr |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | GRID_W, GRID_H = 21, 15 |
| | TILE = 22 |
| |
|
| | VIEW_W, VIEW_H = 640, 360 |
| | RAY_W = 320 |
| | FOV_DEG = 78 |
| | MAX_DEPTH = 20 |
| |
|
| | DIRS = [(1, 0), (0, 1), (-1, 0), (0, -1)] |
| | ORI_DEG = [0, 90, 180, 270] |
| |
|
| | EMPTY = 0 |
| | WALL = 1 |
| | FOOD = 2 |
| | NOISE = 3 |
| | DOOR = 4 |
| | TELE = 5 |
| |
|
| | TILE_NAMES = { |
| | EMPTY: "Empty", |
| | WALL: "Wall", |
| | FOOD: "Food", |
| | NOISE: "Noise", |
| | DOOR: "Door", |
| | TELE: "Teleporter", |
| | } |
| |
|
| | AGENT_COLORS = { |
| | "Predator": (255, 120, 90), |
| | "Prey": (120, 255, 160), |
| | "Scout": (120, 190, 255), |
| | } |
| |
|
| | SKY = np.array([14, 16, 26], dtype=np.uint8) |
| | FLOOR_NEAR = np.array([24, 26, 40], dtype=np.uint8) |
| | FLOOR_FAR = np.array([10, 11, 18], dtype=np.uint8) |
| | WALL_BASE = np.array([210, 210, 225], dtype=np.uint8) |
| | WALL_SIDE = np.array([150, 150, 170], dtype=np.uint8) |
| | DOOR_COL = np.array([180, 210, 255], dtype=np.uint8) |
| |
|
| | ACTIONS = ["L", "F", "R"] |
| |
|
| | |
| | |
| | |
| | def rng_for(seed: int, step: int, stream: int = 0) -> np.random.Generator: |
| | mix = (seed * 1_000_003) ^ (step * 9_999_937) ^ (stream * 97_531) |
| | return np.random.default_rng(mix & 0xFFFFFFFFFFFFFFFF) |
| |
|
| | |
| | |
| | |
| | @dataclass |
| | class Agent: |
| | name: str |
| | x: int |
| | y: int |
| | ori: int |
| | energy: int = 100 |
| |
|
| | @dataclass |
| | class TrainConfig: |
| | use_q_pred: bool = True |
| | use_q_prey: bool = True |
| | alpha: float = 0.15 |
| | gamma: float = 0.95 |
| | epsilon: float = 0.10 |
| | epsilon_min: float = 0.02 |
| | epsilon_decay: float = 0.995 |
| |
|
| | |
| | pred_step_penalty: float = -0.02 |
| | pred_dist_coeff: float = 0.03 |
| | pred_catch_reward: float = 3.0 |
| |
|
| | prey_step_penalty: float = -0.02 |
| | prey_food_reward: float = 0.6 |
| | prey_survive_reward: float = 0.02 |
| | prey_caught_penalty: float = -3.0 |
| |
|
| | @dataclass |
| | class Metrics: |
| | episodes: int = 0 |
| | catches: int = 0 |
| | avg_steps_to_catch: float = 0.0 |
| | avg_path_efficiency: float = 0.0 |
| | last_episode_steps: int = 0 |
| | last_episode_eff: float = 0.0 |
| | epsilon: float = 0.10 |
| |
|
| | @dataclass |
| | class WorldState: |
| | seed: int |
| | step: int |
| | grid: List[List[int]] |
| | agents: Dict[str, Agent] |
| | controlled: str |
| | pov: str |
| | overlay: bool |
| |
|
| | caught: bool |
| | branches: Dict[str, int] |
| |
|
| | |
| | event_log: List[str] |
| | trace_log: List[str] |
| |
|
| | |
| | cfg: TrainConfig |
| | q_pred: Dict[str, List[float]] |
| | q_prey: Dict[str, List[float]] |
| | metrics: Metrics |
| |
|
| | @dataclass |
| | class Snapshot: |
| | step: int |
| | agents: Dict[str, Dict] |
| | grid: List[List[int]] |
| | caught: bool |
| | event_log_tail: List[str] |
| | trace_tail: List[str] |
| |
|
| | |
| | |
| | |
| | def default_grid() -> List[List[int]]: |
| | g = [[EMPTY for _ in range(GRID_W)] for _ in range(GRID_H)] |
| | for x in range(GRID_W): |
| | g[0][x] = WALL |
| | g[GRID_H - 1][x] = WALL |
| | for y in range(GRID_H): |
| | g[y][0] = WALL |
| | g[y][GRID_W - 1] = WALL |
| |
|
| | for x in range(4, 17): |
| | g[7][x] = WALL |
| | g[7][10] = DOOR |
| |
|
| | g[3][4] = FOOD |
| | g[11][15] = FOOD |
| | g[4][14] = NOISE |
| | g[12][5] = NOISE |
| | g[2][18] = TELE |
| | g[13][2] = TELE |
| | return g |
| |
|
| | def init_state(seed: int) -> WorldState: |
| | agents = { |
| | "Predator": Agent("Predator", 2, 2, 0, 100), |
| | "Prey": Agent("Prey", 18, 12, 2, 100), |
| | "Scout": Agent("Scout", 10, 3, 1, 100), |
| | } |
| | cfg = TrainConfig() |
| | return WorldState( |
| | seed=seed, |
| | step=0, |
| | grid=default_grid(), |
| | agents=agents, |
| | controlled="Predator", |
| | pov="Predator", |
| | overlay=False, |
| | caught=False, |
| | branches={"main": 0}, |
| | event_log=["Initialized world."], |
| | trace_log=[], |
| | cfg=cfg, |
| | q_pred={}, |
| | q_prey={}, |
| | metrics=Metrics(epsilon=cfg.epsilon), |
| | ) |
| |
|
| | |
| | |
| | |
| | def init_belief() -> Dict[str, np.ndarray]: |
| | b = {} |
| | for nm in ["Predator", "Prey", "Scout"]: |
| | b[nm] = -1 * np.ones((GRID_H, GRID_W), dtype=np.int16) |
| | return b |
| |
|
| | |
| | |
| | |
| | def in_bounds(x: int, y: int) -> bool: |
| | return 0 <= x < GRID_W and 0 <= y < GRID_H |
| |
|
| | def is_blocking(tile: int) -> bool: |
| | return tile == WALL |
| |
|
| | def manhattan(a: Agent, b: Agent) -> int: |
| | return abs(a.x - b.x) + abs(a.y - b.y) |
| |
|
| | def bresenham_los(grid: List[List[int]], x0: int, y0: int, x1: int, y1: int) -> bool: |
| | dx = abs(x1 - x0) |
| | dy = abs(y1 - y0) |
| | sx = 1 if x0 < x1 else -1 |
| | sy = 1 if y0 < y1 else -1 |
| | err = dx - dy |
| | x, y = x0, y0 |
| | while True: |
| | if (x, y) != (x0, y0) and (x, y) != (x1, y1): |
| | if grid[y][x] == WALL: |
| | return False |
| | if x == x1 and y == y1: |
| | return True |
| | e2 = 2 * err |
| | if e2 > -dy: |
| | err -= dy |
| | x += sx |
| | if e2 < dx: |
| | err += dx |
| | y += sy |
| |
|
| | def within_fov(observer: Agent, tx: int, ty: int, fov_deg: float = FOV_DEG) -> bool: |
| | dx = tx - observer.x |
| | dy = ty - observer.y |
| | if dx == 0 and dy == 0: |
| | return True |
| | angle = math.degrees(math.atan2(dy, dx)) % 360 |
| | facing = ORI_DEG[observer.ori] |
| | diff = (angle - facing + 540) % 360 - 180 |
| | return abs(diff) <= (fov_deg / 2) |
| |
|
| | def visible(observer: Agent, target: Agent, grid: List[List[int]]) -> bool: |
| | return within_fov(observer, target.x, target.y, FOV_DEG) and bresenham_los(grid, observer.x, observer.y, target.x, target.y) |
| |
|
| | |
| | |
| | |
| | def turn_left(a: Agent) -> None: |
| | a.ori = (a.ori - 1) % 4 |
| |
|
| | def turn_right(a: Agent) -> None: |
| | a.ori = (a.ori + 1) % 4 |
| |
|
| | def move_forward(state: WorldState, a: Agent) -> str: |
| | dx, dy = DIRS[a.ori] |
| | nx, ny = a.x + dx, a.y + dy |
| | if not in_bounds(nx, ny): |
| | return "blocked: bounds" |
| | if is_blocking(state.grid[ny][nx]): |
| | return "blocked: wall" |
| | if state.grid[ny][nx] == DOOR: |
| | state.grid[ny][nx] = EMPTY |
| | state.event_log.append(f"t={state.step}: {a.name} opened a door.") |
| | a.x, a.y = nx, ny |
| |
|
| | if state.grid[ny][nx] == TELE: |
| | teles = [(x, y) for y in range(GRID_H) for x in range(GRID_W) if state.grid[y][x] == TELE] |
| | if len(teles) >= 2: |
| | teles_sorted = sorted(teles) |
| | idx = teles_sorted.index((nx, ny)) |
| | dest = teles_sorted[(idx + 1) % len(teles_sorted)] |
| | a.x, a.y = dest |
| | state.event_log.append(f"t={state.step}: {a.name} teleported.") |
| | return "moved: teleported" |
| | return "moved" |
| |
|
| | def apply_action(state: WorldState, agent_name: str, action: str) -> str: |
| | a = state.agents[agent_name] |
| | if action == "L": |
| | turn_left(a) |
| | return "turned left" |
| | if action == "R": |
| | turn_right(a) |
| | return "turned right" |
| | if action == "F": |
| | return move_forward(state, a) |
| | return "noop" |
| |
|
| | |
| | |
| | |
| | def raycast_view(state: WorldState, observer: Agent) -> np.ndarray: |
| | img = np.zeros((VIEW_H, VIEW_W, 3), dtype=np.uint8) |
| | img[:, :] = SKY |
| |
|
| | for y in range(VIEW_H // 2, VIEW_H): |
| | t = (y - VIEW_H // 2) / (VIEW_H // 2 + 1e-6) |
| | col = (1 - t) * FLOOR_NEAR + t * FLOOR_FAR |
| | img[y, :] = col.astype(np.uint8) |
| |
|
| | fov = math.radians(FOV_DEG) |
| | half_fov = fov / 2 |
| |
|
| | for rx in range(RAY_W): |
| | cam_x = (2 * rx / (RAY_W - 1)) - 1 |
| | ray_ang = math.radians(ORI_DEG[observer.ori]) + cam_x * half_fov |
| |
|
| | ox, oy = observer.x + 0.5, observer.y + 0.5 |
| | sin_a = math.sin(ray_ang) |
| | cos_a = math.cos(ray_ang) |
| |
|
| | depth = 0.0 |
| | hit = None |
| | side = 0 |
| |
|
| | while depth < MAX_DEPTH: |
| | depth += 0.05 |
| | tx = int(ox + cos_a * depth) |
| | ty = int(oy + sin_a * depth) |
| | if not in_bounds(tx, ty): |
| | break |
| | tile = state.grid[ty][tx] |
| | if tile == WALL: |
| | hit = "wall" |
| | side = 1 if abs(cos_a) > abs(sin_a) else 0 |
| | break |
| | if tile == DOOR: |
| | hit = "door" |
| | break |
| |
|
| | if hit is None: |
| | continue |
| |
|
| | depth *= math.cos(ray_ang - math.radians(ORI_DEG[observer.ori])) |
| | depth = max(depth, 0.001) |
| |
|
| | proj_h = int((VIEW_H * 0.9) / depth) |
| | y0 = max(0, VIEW_H // 2 - proj_h // 2) |
| | y1 = min(VIEW_H - 1, VIEW_H // 2 + proj_h // 2) |
| |
|
| | if hit == "door": |
| | col = DOOR_COL.copy() |
| | else: |
| | col = WALL_BASE.copy() if side == 0 else WALL_SIDE.copy() |
| |
|
| | dim = max(0.25, 1.0 - (depth / MAX_DEPTH)) |
| | col = (col * dim).astype(np.uint8) |
| |
|
| | x0 = int(rx * (VIEW_W / RAY_W)) |
| | x1 = int((rx + 1) * (VIEW_W / RAY_W)) |
| | img[y0:y1, x0:x1] = col |
| |
|
| | |
| | for nm, other in state.agents.items(): |
| | if nm == observer.name: |
| | continue |
| | if visible(observer, other, state.grid): |
| | dx = other.x - observer.x |
| | dy = other.y - observer.y |
| | ang = (math.degrees(math.atan2(dy, dx)) % 360) |
| | facing = ORI_DEG[observer.ori] |
| | diff = (ang - facing + 540) % 360 - 180 |
| | sx = int((diff / (FOV_DEG / 2)) * (VIEW_W / 2) + (VIEW_W / 2)) |
| | dist = math.sqrt(dx * dx + dy * dy) |
| | h = int((VIEW_H * 0.65) / max(dist, 0.75)) |
| | w = max(10, h // 3) |
| | y_mid = VIEW_H // 2 |
| | y0 = max(0, y_mid - h // 2) |
| | y1 = min(VIEW_H - 1, y_mid + h // 2) |
| | x0 = max(0, sx - w // 2) |
| | x1 = min(VIEW_W - 1, sx + w // 2) |
| | col = AGENT_COLORS.get(nm, (255, 200, 120)) |
| | img[y0:y1, x0:x1] = np.array(col, dtype=np.uint8) |
| |
|
| | if state.overlay: |
| | cx, cy = VIEW_W // 2, VIEW_H // 2 |
| | img[cy - 1:cy + 2, cx - 10:cx + 10] = np.array([120, 190, 255], dtype=np.uint8) |
| | img[cy - 10:cy + 10, cx - 1:cx + 2] = np.array([120, 190, 255], dtype=np.uint8) |
| |
|
| | return img |
| |
|
| | def render_topdown(grid: np.ndarray, agents: Dict[str, Agent], title: str, show_agents: bool = True) -> Image.Image: |
| | w = grid.shape[1] * TILE |
| | h = grid.shape[0] * TILE |
| | im = Image.new("RGB", (w, h + 28), (10, 12, 18)) |
| | draw = ImageDraw.Draw(im) |
| |
|
| | for y in range(grid.shape[0]): |
| | for x in range(grid.shape[1]): |
| | t = int(grid[y, x]) |
| | if t == -1: |
| | col = (18, 20, 32) |
| | elif t == EMPTY: |
| | col = (26, 30, 44) |
| | elif t == WALL: |
| | col = (190, 190, 210) |
| | elif t == FOOD: |
| | col = (255, 210, 120) |
| | elif t == NOISE: |
| | col = (255, 120, 220) |
| | elif t == DOOR: |
| | col = (140, 210, 255) |
| | elif t == TELE: |
| | col = (120, 190, 255) |
| | else: |
| | col = (80, 80, 90) |
| |
|
| | x0, y0 = x * TILE, y * TILE + 28 |
| | draw.rectangle([x0, y0, x0 + TILE - 1, y0 + TILE - 1], fill=col) |
| |
|
| | for x in range(grid.shape[1] + 1): |
| | xx = x * TILE |
| | draw.line([xx, 28, xx, h + 28], fill=(12, 14, 22)) |
| | for y in range(grid.shape[0] + 1): |
| | yy = y * TILE + 28 |
| | draw.line([0, yy, w, yy], fill=(12, 14, 22)) |
| |
|
| | if show_agents: |
| | for nm, a in agents.items(): |
| | cx = a.x * TILE + TILE // 2 |
| | cy = a.y * TILE + 28 + TILE // 2 |
| | col = AGENT_COLORS.get(nm, (220, 220, 220)) |
| | r = TILE // 3 |
| | draw.ellipse([cx - r, cy - r, cx + r, cy + r], fill=col) |
| | dx, dy = DIRS[a.ori] |
| | draw.line([cx, cy, cx + dx * r, cy + dy * r], fill=(10, 10, 10), width=3) |
| |
|
| | draw.rectangle([0, 0, w, 28], fill=(14, 16, 26)) |
| | draw.text((8, 6), title, fill=(230, 230, 240)) |
| | return im |
| |
|
| | |
| | |
| | |
| | def update_belief_for_agent(state: WorldState, belief: np.ndarray, agent: Agent) -> None: |
| | belief[agent.y, agent.x] = state.grid[agent.y][agent.x] |
| | base = math.radians(ORI_DEG[agent.ori]) |
| | half = math.radians(FOV_DEG / 2) |
| | rays = 33 if agent.name != "Scout" else 45 |
| |
|
| | for i in range(rays): |
| | t = i / (rays - 1) |
| | ang = base + (t * 2 - 1) * half |
| | sin_a, cos_a = math.sin(ang), math.cos(ang) |
| | ox, oy = agent.x + 0.5, agent.y + 0.5 |
| | depth = 0.0 |
| | while depth < MAX_DEPTH: |
| | depth += 0.2 |
| | tx = int(ox + cos_a * depth) |
| | ty = int(oy + sin_a * depth) |
| | if not in_bounds(tx, ty): |
| | break |
| | belief[ty, tx] = state.grid[ty][tx] |
| | if state.grid[ty][tx] == WALL: |
| | break |
| |
|
| | |
| | |
| | |
| | def bfs_distance(grid: List[List[int]], sx: int, sy: int, gx: int, gy: int) -> Optional[int]: |
| | if (sx, sy) == (gx, gy): |
| | return 0 |
| | q = [(sx, sy)] |
| | dist = { (sx, sy): 0 } |
| | head = 0 |
| | while head < len(q): |
| | x, y = q[head]; head += 1 |
| | for dx, dy in DIRS: |
| | nx, ny = x + dx, y + dy |
| | if not in_bounds(nx, ny): |
| | continue |
| | if grid[ny][nx] == WALL: |
| | continue |
| | if (nx, ny) in dist: |
| | continue |
| | dist[(nx, ny)] = dist[(x, y)] + 1 |
| | if (nx, ny) == (gx, gy): |
| | return dist[(nx, ny)] |
| | q.append((nx, ny)) |
| | return None |
| |
|
| | |
| | |
| | |
| | def obs_key(state: WorldState, who: str) -> str: |
| | pred = state.agents["Predator"] |
| | prey = state.agents["Prey"] |
| | a = state.agents[who] |
| | |
| | dx = prey.x - pred.x |
| | dy = prey.y - pred.y |
| | dx_bin = int(np.clip(dx, -6, 6)) |
| | dy_bin = int(np.clip(dy, -6, 6)) |
| | vis = 1 if visible(pred, prey, state.grid) else 0 |
| | |
| | if who == "Predator": |
| | return f"P|{pred.x},{pred.y},{pred.ori}|d{dx_bin},{dy_bin}|v{vis}" |
| | if who == "Prey": |
| | |
| | vis2 = 1 if visible(prey, pred, state.grid) else 0 |
| | ddx = pred.x - prey.x |
| | ddy = pred.y - prey.y |
| | ddx_bin = int(np.clip(ddx, -6, 6)) |
| | ddy_bin = int(np.clip(ddy, -6, 6)) |
| | return f"R|{prey.x},{prey.y},{prey.ori}|d{ddx_bin},{ddy_bin}|v{vis2}|e{int(prey.energy//25)}" |
| | |
| | return f"S|{a.x},{a.y},{a.ori}" |
| |
|
| | def q_get(q: Dict[str, List[float]], key: str) -> List[float]: |
| | if key not in q: |
| | q[key] = [0.0, 0.0, 0.0] |
| | return q[key] |
| |
|
| | def epsilon_greedy(qvals: List[float], eps: float, r: np.random.Generator) -> int: |
| | if r.random() < eps: |
| | return int(r.integers(0, len(qvals))) |
| | return int(np.argmax(qvals)) |
| |
|
| | def q_update(q: Dict[str, List[float]], key: str, a_idx: int, reward: float, next_key: str, alpha: float, gamma: float) -> Tuple[float, float, float]: |
| | qv = q_get(q, key) |
| | nq = q_get(q, next_key) |
| | old = qv[a_idx] |
| | target = reward + gamma * float(np.max(nq)) |
| | new = old + alpha * (target - old) |
| | qv[a_idx] = new |
| | return old, target, new |
| |
|
| | |
| | |
| | |
| | def heuristic_pred_action(state: WorldState) -> str: |
| | pred = state.agents["Predator"] |
| | prey = state.agents["Prey"] |
| | if visible(pred, prey, state.grid): |
| | dx = prey.x - pred.x |
| | dy = prey.y - pred.y |
| | ang = (math.degrees(math.atan2(dy, dx)) % 360) |
| | facing = ORI_DEG[pred.ori] |
| | diff = (ang - facing + 540) % 360 - 180 |
| | if diff < -10: |
| | return "L" |
| | if diff > 10: |
| | return "R" |
| | return "F" |
| | r = rng_for(state.seed, state.step, stream=11) |
| | return r.choice(ACTIONS) |
| |
|
| | def heuristic_prey_action(state: WorldState) -> str: |
| | prey = state.agents["Prey"] |
| | pred = state.agents["Predator"] |
| | if visible(prey, pred, state.grid): |
| | dx = pred.x - prey.x |
| | dy = pred.y - prey.y |
| | ang = (math.degrees(math.atan2(dy, dx)) % 360) |
| | facing = ORI_DEG[prey.ori] |
| | diff = (ang - facing + 540) % 360 - 180 |
| | diff_away = ((diff + 180) + 540) % 360 - 180 |
| | if diff_away < -10: |
| | return "L" |
| | if diff_away > 10: |
| | return "R" |
| | return "F" |
| | r = rng_for(state.seed, state.step, stream=12) |
| | return r.choice(ACTIONS) |
| |
|
| | def heuristic_scout_action(state: WorldState) -> str: |
| | r = rng_for(state.seed, state.step, stream=13) |
| | return r.choice(ACTIONS) |
| |
|
| | |
| | |
| | |
| | def pred_reward(state_prev: WorldState, state_now: WorldState) -> float: |
| | cfg = state_now.cfg |
| | pred0 = state_prev.agents["Predator"] |
| | prey0 = state_prev.agents["Prey"] |
| | pred1 = state_now.agents["Predator"] |
| | prey1 = state_now.agents["Prey"] |
| | d0 = abs(pred0.x - prey0.x) + abs(pred0.y - prey0.y) |
| | d1 = abs(pred1.x - prey1.x) + abs(pred1.y - prey1.y) |
| | r = cfg.pred_step_penalty + cfg.pred_dist_coeff * (d0 - d1) |
| | if state_now.caught: |
| | r += cfg.pred_catch_reward |
| | return float(r) |
| |
|
| | def prey_reward(state_prev: WorldState, state_now: WorldState, ate_food: bool) -> float: |
| | cfg = state_now.cfg |
| | r = cfg.prey_step_penalty + cfg.prey_survive_reward |
| | if ate_food: |
| | r += cfg.prey_food_reward |
| | if state_now.caught: |
| | r += cfg.prey_caught_penalty |
| | return float(r) |
| |
|
| | |
| | |
| | |
| | TRACE_MAX = 400 |
| |
|
| | def clone_shallow(state: WorldState) -> WorldState: |
| | |
| | return WorldState( |
| | seed=state.seed, |
| | step=state.step, |
| | grid=[row[:] for row in state.grid], |
| | agents={k: Agent(**asdict(v)) for k, v in state.agents.items()}, |
| | controlled=state.controlled, |
| | pov=state.pov, |
| | overlay=state.overlay, |
| | caught=state.caught, |
| | branches=dict(state.branches), |
| | event_log=list(state.event_log), |
| | trace_log=list(state.trace_log), |
| | cfg=state.cfg, |
| | q_pred=state.q_pred, |
| | q_prey=state.q_prey, |
| | metrics=state.metrics, |
| | ) |
| |
|
| | def check_catch(state: WorldState) -> None: |
| | pred = state.agents["Predator"] |
| | prey = state.agents["Prey"] |
| | if pred.x == prey.x and pred.y == prey.y: |
| | state.caught = True |
| | state.event_log.append(f"t={state.step}: CAUGHT.") |
| |
|
| | def consume_food(state: WorldState) -> bool: |
| | prey = state.agents["Prey"] |
| | if state.grid[prey.y][prey.x] == FOOD: |
| | prey.energy = min(200, prey.energy + 35) |
| | state.grid[prey.y][prey.x] = EMPTY |
| | state.event_log.append(f"t={state.step}: Prey ate food (+energy).") |
| | return True |
| | return False |
| |
|
| | def choose_action(state: WorldState, who: str, stream: int) -> Tuple[str, str, Optional[Tuple[str,int]]]: |
| | """ |
| | Returns (action, reason, q_info) |
| | q_info: (obs_key, action_index) if chosen by Q, else None |
| | """ |
| | cfg = state.cfg |
| | r = rng_for(state.seed, state.step, stream=stream) |
| |
|
| | if who == "Predator" and cfg.use_q_pred: |
| | k = obs_key(state, "Predator") |
| | qv = q_get(state.q_pred, k) |
| | a_idx = epsilon_greedy(qv, state.metrics.epsilon, r) |
| | return ACTIONS[a_idx], f"Q(pred) eps={state.metrics.epsilon:.3f} q={np.round(qv,3).tolist()}", (k, a_idx) |
| |
|
| | if who == "Prey" and cfg.use_q_prey: |
| | k = obs_key(state, "Prey") |
| | qv = q_get(state.q_prey, k) |
| | a_idx = epsilon_greedy(qv, state.metrics.epsilon, r) |
| | return ACTIONS[a_idx], f"Q(prey) eps={state.metrics.epsilon:.3f} q={np.round(qv,3).tolist()}", (k, a_idx) |
| |
|
| | |
| | if who == "Predator": |
| | a = heuristic_pred_action(state) |
| | return a, "heuristic(pred)", None |
| | if who == "Prey": |
| | a = heuristic_prey_action(state) |
| | return a, "heuristic(prey)", None |
| | a = heuristic_scout_action(state) |
| | return a, "heuristic(scout)", None |
| |
|
| | def tick(state: WorldState, manual_action: Optional[str] = None) -> None: |
| | if state.caught: |
| | return |
| |
|
| | prev = clone_shallow(state) |
| |
|
| | |
| | pred = state.agents["Predator"] |
| | prey = state.agents["Prey"] |
| | opt_dist = bfs_distance(state.grid, pred.x, pred.y, prey.x, prey.y) |
| | if opt_dist is None: |
| | opt_dist = 999 |
| |
|
| | |
| | chosen = {} |
| | reasons = {} |
| | qinfo = {} |
| |
|
| | |
| | if manual_action: |
| | chosen[state.controlled] = manual_action |
| | reasons[state.controlled] = "manual" |
| | qinfo[state.controlled] = None |
| |
|
| | |
| | for who in ["Predator", "Prey", "Scout"]: |
| | if who in chosen: |
| | continue |
| | act, reason, q_i = choose_action(state, who, stream={"Predator":21,"Prey":22,"Scout":23}[who]) |
| | chosen[who] = act |
| | reasons[who] = reason |
| | qinfo[who] = q_i |
| |
|
| | |
| | outcomes = {} |
| | for who in ["Predator", "Prey", "Scout"]: |
| | outcomes[who] = apply_action(state, who, chosen[who]) |
| |
|
| | ate = consume_food(state) |
| | check_catch(state) |
| |
|
| | |
| | pred_r = pred_reward(prev, state) |
| | prey_r = prey_reward(prev, state, ate_food=ate) |
| |
|
| | q_lines = [] |
| | if qinfo["Predator"] is not None: |
| | k, a_idx = qinfo["Predator"] |
| | nk = obs_key(state, "Predator") |
| | old, target, new = q_update(state.q_pred, k, a_idx, pred_r, nk, state.cfg.alpha, state.cfg.gamma) |
| | q_lines.append(f"Qpred: {k} a={ACTIONS[a_idx]} old={old:.3f} tgt={target:.3f} new={new:.3f}") |
| |
|
| | if qinfo["Prey"] is not None: |
| | k, a_idx = qinfo["Prey"] |
| | nk = obs_key(state, "Prey") |
| | old, target, new = q_update(state.q_prey, k, a_idx, prey_r, nk, state.cfg.alpha, state.cfg.gamma) |
| | q_lines.append(f"Qprey: {k} a={ACTIONS[a_idx]} old={old:.3f} tgt={target:.3f} new={new:.3f}") |
| |
|
| | |
| | dist_now = manhattan(state.agents["Predator"], state.agents["Prey"]) |
| | eff = (opt_dist / max(1, dist_now)) if dist_now > 0 else 1.0 |
| | trace = ( |
| | f"t={state.step} optDist~{opt_dist} distNow={dist_now} " |
| | f"| Pred:{chosen['Predator']} ({outcomes['Predator']}) [{reasons['Predator']}] r={pred_r:+.3f} " |
| | f"| Prey:{chosen['Prey']} ({outcomes['Prey']}) [{reasons['Prey']}] r={prey_r:+.3f} " |
| | f"| Scout:{chosen['Scout']} ({outcomes['Scout']}) [{reasons['Scout']}] " |
| | f"| ateFood={ate} caught={state.caught}" |
| | ) |
| | if q_lines: |
| | trace += " | " + " ; ".join(q_lines) |
| |
|
| | state.trace_log.append(trace) |
| | if len(state.trace_log) > TRACE_MAX: |
| | state.trace_log = state.trace_log[-TRACE_MAX:] |
| |
|
| | state.step += 1 |
| |
|
| | |
| | |
| | |
| | def reset_episode(state: WorldState, seed: Optional[int] = None) -> None: |
| | |
| | if seed is None: |
| | seed = state.seed |
| | fresh = init_state(seed) |
| | fresh.cfg = state.cfg |
| | fresh.q_pred = state.q_pred |
| | fresh.q_prey = state.q_prey |
| | fresh.metrics = state.metrics |
| | fresh.metrics.epsilon = state.metrics.epsilon |
| | state.seed = fresh.seed |
| | state.step = 0 |
| | state.grid = fresh.grid |
| | state.agents = fresh.agents |
| | state.controlled = fresh.controlled |
| | state.pov = fresh.pov |
| | state.overlay = fresh.overlay |
| | state.caught = False |
| | state.branches = fresh.branches |
| | state.event_log = ["Episode reset."] |
| | state.trace_log = [] |
| |
|
| | def run_episode(state: WorldState, max_steps: int) -> Tuple[bool, int, float]: |
| | |
| | start_pred = state.agents["Predator"] |
| | start_prey = state.agents["Prey"] |
| | opt = bfs_distance(state.grid, start_pred.x, start_pred.y, start_prey.x, start_prey.y) |
| | if opt is None: |
| | opt = 999 |
| | steps = 0 |
| | while steps < max_steps and not state.caught: |
| | tick(state, manual_action=None) |
| | steps += 1 |
| | caught = state.caught |
| | eff = float(opt / max(1, steps)) if opt < 999 else 0.0 |
| | return caught, steps, eff |
| |
|
| | def train(state: WorldState, episodes: int, max_steps: int) -> None: |
| | m = state.metrics |
| | cfg = state.cfg |
| | catches = 0 |
| | total_steps_catch = 0 |
| | total_eff = 0.0 |
| |
|
| | for ep in range(episodes): |
| | |
| | ep_seed = (state.seed * 1_000_003 + (m.episodes + ep) * 97_531) & 0xFFFFFFFF |
| | reset_episode(state, seed=int(ep_seed)) |
| |
|
| | caught, steps, eff = run_episode(state, max_steps=max_steps) |
| | total_eff += eff |
| |
|
| | if caught: |
| | catches += 1 |
| | total_steps_catch += steps |
| |
|
| | |
| | m.epsilon = max(cfg.epsilon_min, m.epsilon * cfg.epsilon_decay) |
| |
|
| | |
| | m.episodes += episodes |
| | m.catches += catches |
| | m.last_episode_steps = steps |
| | m.last_episode_eff = eff |
| | if catches > 0: |
| | |
| | avg_steps = total_steps_catch / catches |
| | m.avg_steps_to_catch = ( |
| | 0.85 * m.avg_steps_to_catch + 0.15 * avg_steps |
| | if m.avg_steps_to_catch > 0 else avg_steps |
| | ) |
| | avg_eff = total_eff / max(1, episodes) |
| | m.avg_path_efficiency = ( |
| | 0.85 * m.avg_path_efficiency + 0.15 * avg_eff |
| | if m.avg_path_efficiency > 0 else avg_eff |
| | ) |
| |
|
| | state.event_log.append( |
| | f"Training: +{episodes} eps | catches={catches}/{episodes} | " |
| | f"avgStepsToCatch~{m.avg_steps_to_catch:.2f} | avgEff~{m.avg_path_efficiency:.2f} | eps={m.epsilon:.3f}" |
| | ) |
| |
|
| | |
| | |
| | |
| | MAX_HISTORY = 1200 |
| |
|
| | def snapshot_of(state: WorldState) -> Snapshot: |
| | return Snapshot( |
| | step=state.step, |
| | agents={k: asdict(v) for k, v in state.agents.items()}, |
| | grid=[row[:] for row in state.grid], |
| | caught=state.caught, |
| | event_log_tail=state.event_log[-20:], |
| | trace_tail=state.trace_log[-40:], |
| | ) |
| |
|
| | def restore_into(state: WorldState, snap: Snapshot) -> None: |
| | state.step = snap.step |
| | state.grid = [row[:] for row in snap.grid] |
| | for k, d in snap.agents.items(): |
| | state.agents[k] = Agent(**d) |
| | state.caught = snap.caught |
| | state.event_log.append(f"Jumped to snapshot t={snap.step}.") |
| |
|
| | |
| | |
| | |
| | def export_run(state: WorldState, history: List[Snapshot]) -> str: |
| | payload = { |
| | "seed": state.seed, |
| | "controlled": state.controlled, |
| | "pov": state.pov, |
| | "overlay": state.overlay, |
| | "cfg": asdict(state.cfg), |
| | "metrics": asdict(state.metrics), |
| | "q_pred": state.q_pred, |
| | "q_prey": state.q_prey, |
| | "history": [asdict(s) for s in history], |
| | "grid": state.grid, |
| | } |
| | return json.dumps(payload, indent=2) |
| |
|
| | def import_run(txt: str) -> Tuple[WorldState, List[Snapshot], Dict[str, np.ndarray], int]: |
| | data = json.loads(txt) |
| | st = init_state(int(data.get("seed", 1337))) |
| | st.controlled = data.get("controlled", st.controlled) |
| | st.pov = data.get("pov", st.pov) |
| | st.overlay = bool(data.get("overlay", False)) |
| | st.grid = data.get("grid", st.grid) |
| |
|
| | st.cfg = TrainConfig(**data.get("cfg", asdict(st.cfg))) |
| | st.metrics = Metrics(**data.get("metrics", asdict(st.metrics))) |
| |
|
| | st.q_pred = data.get("q_pred", {}) |
| | st.q_prey = data.get("q_prey", {}) |
| |
|
| | hist = [Snapshot(**s) for s in data.get("history", [])] |
| | bel = init_belief() |
| | r_idx = max(0, len(hist) - 1) |
| |
|
| | if hist: |
| | restore_into(st, hist[-1]) |
| | st.event_log.append("Imported run.") |
| | return st, hist, bel, r_idx |
| |
|
| | |
| | |
| | |
| | def build_views(state: WorldState, beliefs: Dict[str, np.ndarray]) -> Tuple[np.ndarray, Image.Image, Image.Image, Image.Image, str, str, str]: |
| | for nm, a in state.agents.items(): |
| | update_belief_for_agent(state, beliefs[nm], a) |
| |
|
| | pov = raycast_view(state, state.agents[state.pov]) |
| | truth_np = np.array(state.grid, dtype=np.int16) |
| | truth_img = render_topdown(truth_np, state.agents, f"Truth Map — t={state.step} seed={state.seed}", show_agents=True) |
| |
|
| | ctrl = state.controlled |
| | other = "Prey" if ctrl == "Predator" else "Predator" |
| | b_ctrl = render_topdown(beliefs[ctrl], state.agents, f"{ctrl} Belief", show_agents=True) |
| | b_other = render_topdown(beliefs[other], state.agents, f"{other} Belief", show_agents=True) |
| |
|
| | m = state.metrics |
| | pred = state.agents["Predator"] |
| | prey = state.agents["Prey"] |
| | scout = state.agents["Scout"] |
| |
|
| | status = ( |
| | f"Controlled={state.controlled} | POV={state.pov} | caught={state.caught} | eps={m.epsilon:.3f}\n" |
| | f"Episodes={m.episodes} | catches={m.catches} | avgStepsToCatch~{m.avg_steps_to_catch:.2f} | avgEff~{m.avg_path_efficiency:.2f}\n" |
| | f"Pred({pred.x},{pred.y}) o={pred.ori} | Prey({prey.x},{prey.y}) o={prey.ori} e={prey.energy} | Scout({scout.x},{scout.y}) o={scout.ori}" |
| | ) |
| | events = "\n".join(state.event_log[-18:]) |
| | trace = "\n".join(state.trace_log[-18:]) |
| | return pov, truth_img, b_ctrl, b_other, status, events, trace |
| |
|
| | def grid_click_to_tile(evt: gr.SelectData, selected_tile: int, state: WorldState) -> WorldState: |
| | x_px, y_px = evt.index |
| | y_px -= 28 |
| | if y_px < 0: |
| | return state |
| | gx = int(x_px // TILE) |
| | gy = int(y_px // TILE) |
| | if not in_bounds(gx, gy): |
| | return state |
| | if gx == 0 or gy == 0 or gx == GRID_W - 1 or gy == GRID_H - 1: |
| | return state |
| | state.grid[gy][gx] = selected_tile |
| | state.event_log.append(f"t={state.step}: Tile ({gx},{gy}) -> {TILE_NAMES.get(selected_tile)}") |
| | return state |
| |
|
| | |
| | |
| | |
| | with gr.Blocks(title="Agent POV") as demo: |
| | gr.Markdown( |
| | "## Agent-POV by ZEN AI Co.\n" |
| | "Track every interaction, train policies, and audit why outcomes happened.\n" |
| | "No timers (compatibility). Use Tick/Run/Train for controlled experiments." |
| | ) |
| |
|
| | st = gr.State(init_state(1337)) |
| | history = gr.State([snapshot_of(init_state(1337))]) |
| | beliefs = gr.State(init_belief()) |
| | rewind_idx = gr.State(0) |
| |
|
| | with gr.Row(): |
| | pov_img = gr.Image(label="POV (Pseudo-3D)", type="numpy", width=VIEW_W, height=VIEW_H) |
| | with gr.Column(): |
| | status = gr.Textbox(label="Status + Metrics", lines=4) |
| | events = gr.Textbox(label="Event Log", lines=10) |
| | trace = gr.Textbox(label="Step Trace (why it happened)", lines=10) |
| |
|
| | with gr.Row(): |
| | truth = gr.Image(label="Truth Map (click to edit tiles)", type="pil") |
| | belief_a = gr.Image(label="Belief (Controlled)", type="pil") |
| | belief_b = gr.Image(label="Belief (Other)", type="pil") |
| |
|
| | with gr.Row(): |
| | with gr.Column(scale=2): |
| | gr.Markdown("### Manual Controls") |
| | with gr.Row(): |
| | btn_L = gr.Button("L") |
| | btn_F = gr.Button("F") |
| | btn_R = gr.Button("R") |
| | with gr.Row(): |
| | btn_tick = gr.Button("Tick") |
| | run_steps = gr.Number(value=25, label="Run N steps", precision=0) |
| | btn_run = gr.Button("Run") |
| | with gr.Row(): |
| | btn_toggle_control = gr.Button("Toggle Controlled") |
| | btn_toggle_pov = gr.Button("Toggle POV") |
| | overlay = gr.Checkbox(False, label="Overlay reticle") |
| |
|
| | tile_pick = gr.Radio( |
| | choices=[(TILE_NAMES[k], k) for k in [EMPTY, WALL, FOOD, NOISE, DOOR, TELE]], |
| | value=WALL, |
| | label="Paint tile type" |
| | ) |
| |
|
| | with gr.Column(scale=3): |
| | gr.Markdown("### Training Controls (Q-learning)") |
| | use_q_pred = gr.Checkbox(True, label="Use Q-learning: Predator") |
| | use_q_prey = gr.Checkbox(True, label="Use Q-learning: Prey") |
| | alpha = gr.Slider(0.01, 0.5, value=0.15, step=0.01, label="alpha (learn rate)") |
| | gamma = gr.Slider(0.5, 0.99, value=0.95, step=0.01, label="gamma (discount)") |
| | eps = gr.Slider(0.0, 0.5, value=0.10, step=0.01, label="epsilon (exploration)") |
| | eps_decay = gr.Slider(0.90, 0.999, value=0.995, step=0.001, label="epsilon decay") |
| | eps_min = gr.Slider(0.0, 0.2, value=0.02, step=0.01, label="epsilon min") |
| |
|
| | episodes = gr.Number(value=50, label="Train episodes", precision=0) |
| | max_steps = gr.Number(value=250, label="Max steps per episode", precision=0) |
| | btn_train = gr.Button("Train") |
| |
|
| | btn_reset = gr.Button("Reset Episode") |
| | btn_reset_all = gr.Button("Reset ALL (wipe Q + metrics)") |
| |
|
| | with gr.Row(): |
| | with gr.Column(): |
| | rewind = gr.Slider(0, 0, value=0, step=1, label="Rewind (history index)") |
| | btn_jump = gr.Button("Jump") |
| | with gr.Column(): |
| | export_box = gr.Textbox(label="Export JSON", lines=10) |
| | btn_export = gr.Button("Export") |
| | with gr.Column(): |
| | import_box = gr.Textbox(label="Import JSON", lines=10) |
| | btn_import = gr.Button("Import") |
| |
|
| | def refresh(state: WorldState, hist: List[Snapshot], bel: Dict[str, np.ndarray], r: int): |
| | r_max = max(0, len(hist) - 1) |
| | r = max(0, min(int(r), r_max)) |
| | pov, tr, ba, bb, stxt, etxt, ttxt = build_views(state, bel) |
| | return ( |
| | pov, tr, ba, bb, |
| | stxt, etxt, ttxt, |
| | gr.update(maximum=r_max, value=r), |
| | r |
| | ) |
| |
|
| | def push_hist(state: WorldState, hist: List[Snapshot]) -> List[Snapshot]: |
| | hist.append(snapshot_of(state)) |
| | if len(hist) > MAX_HISTORY: |
| | hist.pop(0) |
| | return hist |
| |
|
| | def set_cfg(state: WorldState, uq_pred: bool, uq_prey: bool, a: float, g: float, e: float, ed: float, emin: float): |
| | state.cfg.use_q_pred = bool(uq_pred) |
| | state.cfg.use_q_prey = bool(uq_prey) |
| | state.cfg.alpha = float(a) |
| | state.cfg.gamma = float(g) |
| | state.metrics.epsilon = float(e) |
| | state.cfg.epsilon_decay = float(ed) |
| | state.cfg.epsilon_min = float(emin) |
| | return state |
| |
|
| | def do_manual(state, hist, bel, r, act): |
| | tick(state, manual_action=act) |
| | hist = push_hist(state, hist) |
| | r = len(hist) - 1 |
| | out = refresh(state, hist, bel, r) |
| | return out + (state, hist, bel, r) |
| |
|
| | def do_tick(state, hist, bel, r): |
| | tick(state, manual_action=None) |
| | hist = push_hist(state, hist) |
| | r = len(hist) - 1 |
| | out = refresh(state, hist, bel, r) |
| | return out + (state, hist, bel, r) |
| |
|
| | def do_run(state, hist, bel, r, n): |
| | n = max(1, int(n)) |
| | for _ in range(n): |
| | if state.caught: |
| | break |
| | tick(state, manual_action=None) |
| | hist = push_hist(state, hist) |
| | r = len(hist) - 1 |
| | out = refresh(state, hist, bel, r) |
| | return out + (state, hist, bel, r) |
| |
|
| | def toggle_control(state, hist, bel, r): |
| | order = ["Predator", "Prey", "Scout"] |
| | i = order.index(state.controlled) |
| | state.controlled = order[(i + 1) % len(order)] |
| | state.event_log.append(f"Controlled -> {state.controlled}") |
| | hist = push_hist(state, hist) |
| | r = len(hist) - 1 |
| | out = refresh(state, hist, bel, r) |
| | return out + (state, hist, bel, r) |
| |
|
| | def toggle_pov(state, hist, bel, r): |
| | order = ["Predator", "Prey", "Scout"] |
| | i = order.index(state.pov) |
| | state.pov = order[(i + 1) % len(order)] |
| | state.event_log.append(f"POV -> {state.pov}") |
| | hist = push_hist(state, hist) |
| | r = len(hist) - 1 |
| | out = refresh(state, hist, bel, r) |
| | return out + (state, hist, bel, r) |
| |
|
| | def set_overlay(state, hist, bel, r, ov): |
| | state.overlay = bool(ov) |
| | out = refresh(state, hist, bel, r) |
| | return out + (state, hist, bel, r) |
| |
|
| | def click_truth(tile, state, hist, bel, r, evt: gr.SelectData): |
| | state = grid_click_to_tile(evt, int(tile), state) |
| | hist = push_hist(state, hist) |
| | r = len(hist) - 1 |
| | out = refresh(state, hist, bel, r) |
| | return out + (state, hist, bel, r) |
| |
|
| | def jump(state, hist, bel, r, idx): |
| | if not hist: |
| | out = refresh(state, hist, bel, r) |
| | return out + (state, hist, bel, r) |
| | idx = max(0, min(int(idx), len(hist) - 1)) |
| | restore_into(state, hist[idx]) |
| | r = idx |
| | out = refresh(state, hist, bel, r) |
| | return out + (state, hist, bel, r) |
| |
|
| | def reset_ep(state, hist, bel, r): |
| | reset_episode(state, seed=state.seed) |
| | hist = [snapshot_of(state)] |
| | r = 0 |
| | bel = init_belief() |
| | out = refresh(state, hist, bel, r) |
| | return out + (state, hist, bel, r) |
| |
|
| | def reset_all(state, hist, bel, r): |
| | seed = state.seed |
| | state = init_state(seed) |
| | hist = [snapshot_of(state)] |
| | bel = init_belief() |
| | r = 0 |
| | out = refresh(state, hist, bel, r) |
| | return out + (state, hist, bel, r) |
| |
|
| | def do_train(state, hist, bel, r, |
| | uq_pred, uq_prey, a, g, e, ed, emin, |
| | eps_count, max_s): |
| | state = set_cfg(state, uq_pred, uq_prey, a, g, e, ed, emin) |
| | train(state, episodes=max(1, int(eps_count)), max_steps=max(10, int(max_s))) |
| | |
| | reset_episode(state, seed=state.seed) |
| | hist = [snapshot_of(state)] |
| | bel = init_belief() |
| | r = 0 |
| | out = refresh(state, hist, bel, r) |
| | return out + (state, hist, bel, r) |
| |
|
| | def export_fn(state, hist): |
| | return export_run(state, hist) |
| |
|
| | def import_fn(txt): |
| | state, hist, bel, r = import_run(txt) |
| | pov, tr, ba, bb, stxt, etxt, ttxt = build_views(state, bel) |
| | r_max = max(0, len(hist) - 1) |
| | return ( |
| | pov, tr, ba, bb, stxt, etxt, ttxt, |
| | gr.update(maximum=r_max, value=r), |
| | state, hist, bel, r |
| | ) |
| |
|
| | |
| | btn_L.click(lambda s,h,b,r: do_manual(s,h,b,r,"L"), |
| | inputs=[st, history, beliefs, rewind_idx], |
| | outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
| | queue=True) |
| |
|
| | btn_F.click(lambda s,h,b,r: do_manual(s,h,b,r,"F"), |
| | inputs=[st, history, beliefs, rewind_idx], |
| | outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
| | queue=True) |
| |
|
| | btn_R.click(lambda s,h,b,r: do_manual(s,h,b,r,"R"), |
| | inputs=[st, history, beliefs, rewind_idx], |
| | outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
| | queue=True) |
| |
|
| | btn_tick.click(do_tick, |
| | inputs=[st, history, beliefs, rewind_idx], |
| | outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
| | queue=True) |
| |
|
| | btn_run.click(do_run, |
| | inputs=[st, history, beliefs, rewind_idx, run_steps], |
| | outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
| | queue=True) |
| |
|
| | btn_toggle_control.click(toggle_control, |
| | inputs=[st, history, beliefs, rewind_idx], |
| | outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
| | queue=True) |
| |
|
| | btn_toggle_pov.click(toggle_pov, |
| | inputs=[st, history, beliefs, rewind_idx], |
| | outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
| | queue=True) |
| |
|
| | overlay.change(set_overlay, |
| | inputs=[st, history, beliefs, rewind_idx, overlay], |
| | outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
| | queue=True) |
| |
|
| | truth.select(click_truth, |
| | inputs=[tile_pick, st, history, beliefs, rewind_idx], |
| | outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
| | queue=True) |
| |
|
| | btn_jump.click(jump, |
| | inputs=[st, history, beliefs, rewind_idx, rewind], |
| | outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
| | queue=True) |
| |
|
| | btn_reset.click(reset_ep, |
| | inputs=[st, history, beliefs, rewind_idx], |
| | outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
| | queue=True) |
| |
|
| | btn_reset_all.click(reset_all, |
| | inputs=[st, history, beliefs, rewind_idx], |
| | outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
| | queue=True) |
| |
|
| | btn_train.click(do_train, |
| | inputs=[st, history, beliefs, rewind_idx, |
| | use_q_pred, use_q_prey, alpha, gamma, eps, eps_decay, eps_min, |
| | episodes, max_steps], |
| | outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx, st, history, beliefs, rewind_idx], |
| | queue=True) |
| |
|
| | btn_export.click(export_fn, inputs=[st, history], outputs=[export_box], queue=True) |
| |
|
| | btn_import.click(import_fn, |
| | inputs=[import_box], |
| | outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, st, history, beliefs, rewind_idx], |
| | queue=True) |
| |
|
| | demo.load(refresh, |
| | inputs=[st, history, beliefs, rewind_idx], |
| | outputs=[pov_img, truth, belief_a, belief_b, status, events, trace, rewind, rewind_idx], |
| | queue=True) |
| |
|
| | demo.queue().launch() |
| |
|