Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| """ | |
| World Model Demo - Interactive AI Planning Visualization | |
| Educational demonstration of model-based reinforcement learning concepts | |
| """ | |
| import gradio as gr | |
| import random | |
| import time | |
| # ============================================================================ | |
| # World Model Core Classes | |
| # ============================================================================ | |
| class GridWorld: | |
| """Simple grid environment for world model demonstration""" | |
| def __init__(self, size=6): | |
| self.size = size | |
| self.reset() | |
| def reset(self): | |
| self.agent_pos = [0, 0] | |
| self.goal_pos = [self.size - 1, self.size - 1] | |
| self.obstacles = self._generate_obstacles() | |
| self.steps = 0 | |
| return self._get_state() | |
| def _generate_obstacles(self): | |
| obstacles = set() | |
| num_obstacles = self.size - 1 | |
| attempts = 0 | |
| while len(obstacles) < num_obstacles and attempts < 100: | |
| x, y = random.randint(0, self.size-1), random.randint(0, self.size-1) | |
| if [x, y] != self.agent_pos and [x, y] != self.goal_pos: | |
| # Don't block the only path | |
| if not (x == 0 and y == 1) and not (x == 1 and y == 0): | |
| obstacles.add((x, y)) | |
| attempts += 1 | |
| return obstacles | |
| def _get_state(self): | |
| return { | |
| 'agent': self.agent_pos.copy(), | |
| 'goal': self.goal_pos, | |
| 'obstacles': list(self.obstacles), | |
| 'size': self.size, | |
| 'steps': self.steps | |
| } | |
| def step(self, action): | |
| moves = {'up': (0, -1), 'down': (0, 1), 'left': (-1, 0), 'right': (1, 0)} | |
| dx, dy = moves.get(action, (0, 0)) | |
| new_x = max(0, min(self.size - 1, self.agent_pos[0] + dx)) | |
| new_y = max(0, min(self.size - 1, self.agent_pos[1] + dy)) | |
| if (new_x, new_y) not in self.obstacles: | |
| self.agent_pos = [new_x, new_y] | |
| self.steps += 1 | |
| done = self.agent_pos == self.goal_pos | |
| return self._get_state(), done | |
| def copy(self): | |
| new_world = GridWorld(self.size) | |
| new_world.agent_pos = self.agent_pos.copy() | |
| new_world.goal_pos = self.goal_pos.copy() | |
| new_world.obstacles = self.obstacles.copy() | |
| new_world.steps = self.steps | |
| return new_world | |
| class WorldModelAgent: | |
| """Agent that uses a world model to plan ahead""" | |
| def __init__(self): | |
| self.imagination_steps = [] | |
| self.best_path = [] | |
| self.action_values = {} | |
| def imagine_action(self, world, action): | |
| """Use world model to predict outcome without actually taking action""" | |
| imagined_world = world.copy() | |
| imagined_state, done = imagined_world.step(action) | |
| return imagined_state, done, imagined_world | |
| def evaluate_position(self, pos, goal): | |
| """Simple heuristic: negative manhattan distance to goal""" | |
| return -(abs(pos[0] - goal[0]) + abs(pos[1] - goal[1])) | |
| def plan(self, world, depth=3): | |
| """ | |
| Plan ahead by imagining future states. | |
| This is what makes world models special - we can "think" before acting. | |
| """ | |
| self.imagination_steps = [] | |
| self.action_values = {} | |
| actions = ['up', 'down', 'left', 'right'] | |
| for action in actions: | |
| # Imagine taking this action | |
| imagined_state, done, imagined_world = self.imagine_action(world, action) | |
| # Record what we imagined | |
| self.imagination_steps.append({ | |
| 'action': action, | |
| 'predicted_pos': imagined_state['agent'].copy(), | |
| 'depth': 1 | |
| }) | |
| if done: | |
| # Found goal! | |
| self.action_values[action] = 100 | |
| continue | |
| # Look deeper - imagine further into the future | |
| value = self.evaluate_position(imagined_state['agent'], imagined_state['goal']) | |
| # Plan 2 steps ahead | |
| best_future_value = -999 | |
| for next_action in actions: | |
| future_state, future_done, _ = self.imagine_action(imagined_world, next_action) | |
| self.imagination_steps.append({ | |
| 'action': f"{action}→{next_action}", | |
| 'predicted_pos': future_state['agent'].copy(), | |
| 'depth': 2 | |
| }) | |
| if future_done: | |
| best_future_value = 100 | |
| break | |
| future_value = self.evaluate_position(future_state['agent'], future_state['goal']) | |
| best_future_value = max(best_future_value, future_value) | |
| self.action_values[action] = value + 0.9 * best_future_value | |
| # Return best action | |
| best_action = max(self.action_values, key=self.action_values.get) | |
| return best_action, self.action_values, self.imagination_steps | |
| # ============================================================================ | |
| # Visualization | |
| # ============================================================================ | |
| def render_grid(state, phase="observe", imagined_positions=None, highlight_action=None): | |
| """Render the grid as HTML""" | |
| agent = state['agent'] | |
| goal = state['goal'] | |
| obstacles = set(tuple(o) if isinstance(o, list) else o for o in state['obstacles']) | |
| size = state['size'] | |
| phase_info = { | |
| 'observe': ('🔍 OBSERVE', '#3b82f6', 'Perceiving current state...'), | |
| 'imagine': ('💭 IMAGINE', '#f59e0b', 'Simulating possible futures...'), | |
| 'evaluate': ('⚖️ EVALUATE', '#8b5cf6', 'Scoring each path...'), | |
| 'act': ('⚡ ACT', '#10b981', 'Executing best action!'), | |
| } | |
| phase_name, phase_color, phase_desc = phase_info.get(phase, ('', '#6b7280', '')) | |
| html = f''' | |
| <div style="text-align: center; font-family: system-ui, sans-serif;"> | |
| <div style="display: inline-block; background: linear-gradient(135deg, #1e293b 0%, #0f172a 100%); | |
| padding: 24px; border-radius: 16px; box-shadow: 0 8px 32px rgba(0,0,0,0.4);"> | |
| <div style="margin-bottom: 8px; color: {phase_color}; font-weight: bold; font-size: 22px; | |
| text-shadow: 0 0 20px {phase_color}40;"> | |
| {phase_name} | |
| </div> | |
| <div style="margin-bottom: 16px; color: #94a3b8; font-size: 14px;"> | |
| {phase_desc} | |
| </div> | |
| <table style="border-collapse: collapse; margin: auto; border-radius: 8px; overflow: hidden;"> | |
| ''' | |
| # Convert imagined positions to set for easy lookup | |
| imagined_set = set() | |
| if imagined_positions: | |
| for pos in imagined_positions: | |
| imagined_set.add(tuple(pos)) | |
| for y in range(size): | |
| html += '<tr>' | |
| for x in range(size): | |
| bg = '#334155' | |
| content = '' | |
| border = '2px solid #475569' | |
| opacity = '1' | |
| if (x, y) in obstacles: | |
| bg = '#991b1b' | |
| content = '🧱' | |
| elif [x, y] == goal: | |
| bg = '#166534' | |
| content = '⭐' | |
| elif [x, y] == agent: | |
| bg = '#1d4ed8' | |
| content = '🤖' | |
| elif (x, y) in imagined_set: | |
| # Show imagined positions as ghost agents | |
| bg = '#475569' | |
| content = '👻' | |
| border = f'2px dashed {phase_color}' | |
| html += f''' | |
| <td style="width: 50px; height: 50px; background: {bg}; | |
| border: {border}; text-align: center; font-size: 24px; | |
| transition: all 0.3s ease;"> | |
| {content} | |
| </td> | |
| ''' | |
| html += '</tr>' | |
| html += ''' | |
| </table> | |
| <div style="margin-top: 16px; color: #64748b; font-size: 13px;"> | |
| 🤖 Agent | ⭐ Goal | 🧱 Wall | 👻 Imagined Position | |
| </div> | |
| </div> | |
| </div> | |
| ''' | |
| return html | |
| def render_thinking(action_values, imagination_steps, best_action): | |
| """Render the agent's thinking process""" | |
| if not action_values: | |
| return "<div style='color: #64748b; text-align: center; padding: 20px;'>Click 'Think & Move' to see the agent plan!</div>" | |
| html = ''' | |
| <div style="font-family: system-ui, sans-serif; padding: 16px; background: #1e293b; border-radius: 12px;"> | |
| <h3 style="color: #f59e0b; margin-top: 0;">🧠 Agent's Reasoning</h3> | |
| <p style="color: #94a3b8; font-size: 14px;">The agent imagined taking each action and predicted the outcomes:</p> | |
| <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 12px; margin-top: 12px;"> | |
| ''' | |
| action_symbols = {'up': '⬆️', 'down': '⬇️', 'left': '⬅️', 'right': '➡️'} | |
| for action, value in sorted(action_values.items(), key=lambda x: -x[1]): | |
| is_best = action == best_action | |
| border_color = '#10b981' if is_best else '#475569' | |
| bg = '#064e3b' if is_best else '#334155' | |
| label = ' ✓ BEST' if is_best else '' | |
| html += f''' | |
| <div style="background: {bg}; border: 2px solid {border_color}; border-radius: 8px; padding: 12px; text-align: center;"> | |
| <div style="font-size: 24px;">{action_symbols.get(action, '?')}</div> | |
| <div style="color: #e2e8f0; font-weight: bold; margin-top: 4px;">{action.upper()}{label}</div> | |
| <div style="color: #94a3b8; font-size: 13px;">Score: {value:.1f}</div> | |
| </div> | |
| ''' | |
| html += ''' | |
| </div> | |
| <div style="margin-top: 16px; padding: 12px; background: #0f172a; border-radius: 8px;"> | |
| <div style="color: #10b981; font-weight: bold;">💡 Why this works:</div> | |
| <div style="color: #94a3b8; font-size: 13px; margin-top: 8px;"> | |
| The agent <b>imagined</b> each possible action, <b>predicted</b> where it would end up, | |
| and <b>evaluated</b> how close that gets to the goal. It can even imagine 2 steps ahead! | |
| <br><br> | |
| This is different from trial-and-error learning — the agent "thinks" before acting. | |
| </div> | |
| </div> | |
| </div> | |
| ''' | |
| return html | |
| # ============================================================================ | |
| # Global State | |
| # ============================================================================ | |
| world = GridWorld(6) | |
| agent = WorldModelAgent() | |
| current_state = world.reset() | |
| def reset_game(): | |
| global world, agent, current_state | |
| world = GridWorld(6) | |
| agent = WorldModelAgent() | |
| current_state = world.reset() | |
| grid_html = render_grid(current_state, phase="observe") | |
| thinking_html = "<div style='color: #64748b; text-align: center; padding: 20px;'>Click <b>'Think & Move'</b> to watch the agent plan!</div>" | |
| status = "🔄 New environment! Click 'Think & Move' to see the world model in action." | |
| return grid_html, thinking_html, status | |
| def think_and_move(): | |
| """Main function: Agent thinks using world model, then acts""" | |
| global current_state, world, agent | |
| # Check if already at goal | |
| if current_state['agent'] == current_state['goal']: | |
| return reset_game() | |
| # Phase 1: Observe (already done - we have current_state) | |
| # Phase 2: Imagine & Evaluate - Plan using world model | |
| best_action, action_values, imagination_steps = agent.plan(world) | |
| # Get imagined positions for visualization | |
| imagined_positions = [step['predicted_pos'] for step in imagination_steps if step['depth'] == 1] | |
| # Show imagination phase | |
| grid_html = render_grid(current_state, phase="imagine", imagined_positions=imagined_positions) | |
| thinking_html = render_thinking(action_values, imagination_steps, best_action) | |
| # Phase 3: Act - Execute the best action | |
| current_state, done = world.step(best_action) | |
| # Update grid to show result | |
| grid_html = render_grid(current_state, phase="act" if not done else "observe") | |
| if done: | |
| status = f"🎉 Goal reached in {current_state['steps']} steps! Click 'Reset' for a new puzzle." | |
| else: | |
| status = f"Step {current_state['steps']}: Chose {best_action.upper()} (score: {action_values[best_action]:.1f})" | |
| return grid_html, thinking_html, status | |
| def manual_move(action): | |
| """Let user move manually to compare with agent""" | |
| global current_state, world | |
| if current_state['agent'] == current_state['goal']: | |
| return reset_game() | |
| current_state, done = world.step(action) | |
| grid_html = render_grid(current_state, phase="observe") | |
| thinking_html = "<div style='color: #64748b; text-align: center; padding: 20px;'>You moved manually. Click 'Think & Move' to see how the agent would plan!</div>" | |
| if done: | |
| status = f"🎉 You reached the goal in {current_state['steps']} steps!" | |
| else: | |
| status = f"You moved {action}. Steps: {current_state['steps']}" | |
| return grid_html, thinking_html, status | |
| # ============================================================================ | |
| # Gradio Interface | |
| # ============================================================================ | |
| with gr.Blocks(title="World Model Demo", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🧠 World Model Demo | |
| **Watch an AI agent "think" before it acts!** | |
| Unlike reactive AI that just responds to inputs, this agent uses a **world model** to: | |
| 1. **Imagine** what would happen if it took each action | |
| 2. **Evaluate** which imagined future is best | |
| 3. **Act** based on its mental simulation | |
| 👉 **Click "Think & Move"** to watch the agent plan its path to the ⭐ goal! | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| grid_display = gr.HTML() | |
| status_display = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(scale=2): | |
| thinking_display = gr.HTML() | |
| gr.Markdown("### 🎮 Controls") | |
| think_btn = gr.Button("🧠 Think & Move", variant="primary", size="lg") | |
| reset_btn = gr.Button("🔄 Reset", variant="secondary") | |
| gr.Markdown("---") | |
| gr.Markdown("**Manual controls** (to compare with agent):") | |
| with gr.Row(): | |
| up_btn = gr.Button("⬆️") | |
| with gr.Row(): | |
| left_btn = gr.Button("⬅️") | |
| down_btn = gr.Button("⬇️") | |
| right_btn = gr.Button("➡️") | |
| with gr.Accordion("📖 What makes this different from ChatGPT/Claude?", open=False): | |
| gr.Markdown(""" | |
| | Aspect | Language Model (GPT, Claude) | World Model (This Demo) | | |
| |--------|------------------------------|-------------------------| | |
| | **Predicts** | Next *word* in text | Next *state* given action | | |
| | **"Thinking"** | Generates plausible text | Simulates physical outcomes | | |
| | **Planning** | Implicit (chain-of-thought) | Explicit (tree search) | | |
| **The key insight:** This agent can "imagine" taking actions and see the results | |
| *before* committing to them in the real world. It's like planning your route | |
| on a map before driving. | |
| **Real examples:** MuZero (mastered Chess/Go without knowing rules), | |
| Dreamer (robot control), IRIS (Atari games) | |
| """) | |
| with gr.Accordion("🔬 Why does this matter for AI Safety?", open=False): | |
| gr.Markdown(""" | |
| World models are important for AI safety because: | |
| - **Predictability**: We can inspect what futures the agent is considering | |
| - **Interpretability**: The agent's "reasoning" is explicit, not hidden | |
| - **Control**: We can verify the agent isn't planning harmful actions | |
| - **Corrigibility**: Planning agents can incorporate "avoid irreversible actions" | |
| Understanding how AI systems model the world helps us build systems we can trust. | |
| """) | |
| # Connect buttons | |
| think_btn.click(think_and_move, outputs=[grid_display, thinking_display, status_display]) | |
| reset_btn.click(reset_game, outputs=[grid_display, thinking_display, status_display]) | |
| up_btn.click(lambda: manual_move("up"), outputs=[grid_display, thinking_display, status_display]) | |
| down_btn.click(lambda: manual_move("down"), outputs=[grid_display, thinking_display, status_display]) | |
| left_btn.click(lambda: manual_move("left"), outputs=[grid_display, thinking_display, status_display]) | |
| right_btn.click(lambda: manual_move("right"), outputs=[grid_display, thinking_display, status_display]) | |
| # Initialize | |
| demo.load(reset_game, outputs=[grid_display, thinking_display, status_display]) | |
| if __name__ == "__main__": | |
| demo.launch() | |