#!/usr/bin/env python3 """ Play any Atari game using a Vision-Language Model via the Hugging Face Router API. The script: 1. Starts an Atari environment (Docker) for the selected game 2. Sends recent screen frames to a vision-language model 3. Parses the model's integer response into an Atari action id 4. Reports a minimal summary Notes: - Frames are sent raw (no overlays, cropping, or resizing) - The model receives the legal action ids each step and must return one integer Usage: export API_KEY=your_hf_token_here python examples/atari_pong_inference.py --game breakout --model Qwen/Qwen3-VL-8B-Instruct:novita """ import os import re import base64 import gradio as gr from collections import deque from io import BytesIO from typing import Deque, List, Optional import numpy as np from PIL import Image from openai import OpenAI from envs.atari_env import AtariEnv, AtariAction # API Configuration # For HuggingFace: Use HF_TOKEN and set API_BASE_URL API_BASE_URL = "https://router.huggingface.co/v1" # Hugging Face Router endpoint API_KEY = os.getenv("API_KEY") # Required for Hugging Face ATARI_ENV_BASE_URL = os.getenv("ATARI_ENV_BASE_URL") # Optional: connect to a remote Atari env # Vision-Language Model (Hugging Face Router compatible) MODEL = "Qwen/Qwen3-VL-8B-Instruct:novita" # Configuration TEMPERATURE = 0.7 MAX_STEPS_PER_GAME = 10000 MAX_TOKENS = 16 VERBOSE = True FRAME_HISTORY_LENGTH = 4 DISPLAY_SCALE = 3 # Scale factor for enlarging frames sent to UI MODEL_SCALE = 3 # Scale factor for enlarging frames sent to the model # Generic game prompt for the vision model VISION_PROMPT = ( "You are playing an Atari-style game. You will be given recent frames " "and the list of legal action ids for the current step. " "Respond with a single integer that is exactly one of the legal action ids. " "Do not include any words or punctuation — only the integer." ) ACTIONS_LOOKUP = { 0: "NOOP", 1: "FIRE", 2: "UP", 3: "RIGHT", 4: "LEFT", 5: "DOWN", 6: "UPRIGHT", 7: "UPLEFT", 8: "DOWNRIGHT", 9: "DOWNLEFT", 10: "UPFIRE", 11: "RIGHTFIRE", 12: "LEFTFIRE", 13: "DOWNFIRE", 14: "UPRIGHTFIRE", 15: "UPLEFTFIRE", 16: "DOWNRIGHTFIRE", 17: "DOWNLEFTFIRE", } def screen_to_base64(screen: List[int], screen_shape: List[int]) -> str: """Convert flattened screen array to base64 encoded PNG image (no processing).""" screen_array = np.array(screen, dtype=np.uint8).reshape(screen_shape) image = Image.fromarray(screen_array) # Enlarge image for model input if configured try: if MODEL_SCALE and MODEL_SCALE > 1: image = image.resize((image.width * MODEL_SCALE, image.height * MODEL_SCALE), Image.NEAREST) except Exception: pass buffer = BytesIO() image.save(buffer, format='PNG') buffer.seek(0) return base64.b64encode(buffer.read()).decode('utf-8') def screen_to_numpy(screen: List[int], screen_shape: List[int]) -> np.ndarray: """Convert flattened screen to a larger RGB numpy array for gr.Image display.""" arr = np.array(screen, dtype=np.uint8).reshape(screen_shape) if len(screen_shape) == 3: img = Image.fromarray(arr, mode='RGB') else: img = Image.fromarray(arr, mode='L') # Enlarge with nearest-neighbor to preserve pixel edges try: img = img.resize((img.width * DISPLAY_SCALE, img.height * DISPLAY_SCALE), Image.NEAREST) except Exception: pass if img.mode != 'RGB': img = img.convert('RGB') return np.array(img) def content_text(text: str) -> dict: return {"type": "text", "text": text} def content_image_b64(b64_png: str) -> dict: return {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{b64_png}"}} def build_messages(prompt: str, frame_history_b64: Deque[str], current_b64: str, legal_actions: List[int]) -> List[dict]: messages: List[dict] = [ {"role": "system", "content": [content_text(prompt)]} ] if len(frame_history_b64) > 1: total = len(frame_history_b64) messages.extend([ { "role": "user", "content": [ content_text(f"Frame -{total - idx}"), content_image_b64(_img), ], } for idx, _img in enumerate(list(frame_history_b64)[:-1]) ]) messages.append({ "role": "user", "content": [content_text("Current frame:"), content_image_b64(current_b64)], }) # Include mapping of action ids to human-readable names for the model action_pairs = ", ".join([f"{aid}:{ACTIONS_LOOKUP.get(aid, 'UNK')}" for aid in legal_actions]) messages.append({ "role": "user", "content": [content_text(f"Legal actions (id:name): {action_pairs}. Respond with exactly one INTEGER id.")], }) return messages class GameSession: """Holds environment/model state and advances one step per tick.""" def __init__(self, game: str, model_name: str, prompt_text: str): if not API_KEY: raise RuntimeError("Missing API_KEY for HF Router") self.client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) self.env: Optional[AtariEnv] = None self.model_name = model_name self.game = game self.prompt = (prompt_text or "").strip() or VISION_PROMPT self.frame_history_base64: Deque[str] = deque(maxlen=FRAME_HISTORY_LENGTH) self.total_reward = 0.0 self.steps = 0 self.done = False # Start environment self.env = AtariEnv(base_url=f"https://burtenshaw-{game}.hf.space") result = self.env.reset() self.obs = result.observation self.log_message = f"Game: {self.game} started" def close(self): if self.env is not None: try: self.env.close() finally: self.env = None self.done = True def next_frame(self) -> Optional[np.ndarray]: # Snapshot env reference to avoid race if another thread closes it mid-tick env = self.env if self.done or env is None: return None if self.steps >= MAX_STEPS_PER_GAME: self.close() return None # Prepare images image_data = screen_to_base64(self.obs.screen, self.obs.screen_shape) if FRAME_HISTORY_LENGTH > 0: self.frame_history_base64.append(image_data) # Build messages (deduplicated helpers) messages = build_messages(self.prompt, self.frame_history_base64, image_data, self.obs.legal_actions) # Query model try: completion = self.client.chat.completions.create( model=self.model_name, messages=messages, temperature=TEMPERATURE, max_tokens=MAX_TOKENS, ) response_text = completion.choices[0].message.content or "" action_id = parse_action(response_text, self.obs.legal_actions) except Exception: action_id = 0 if 0 in self.obs.legal_actions else self.obs.legal_actions[0] # Step env (guard against races with stop/close) try: result = env.step(AtariAction(action_id=action_id)) except AttributeError: # env likely closed concurrently self.close() return None except Exception: # Network/server error - stop session gracefully self.close() return None self.obs = result.observation self.total_reward += result.reward or 0.0 self.steps += 1 if result.done: self.done = True self.close() action_name = ACTIONS_LOOKUP.get(action_id, str(action_id)) self.log_message += f"\nAction: {action_name} ({action_id}) Reward: {result.reward}" return screen_to_numpy(self.obs.screen, self.obs.screen_shape) def parse_action(text: str, legal_actions: List[int]) -> int: """ Parse action from model output. Handles chain-of-thought format by taking the LAST valid number found. Args: text: Model's text response (may include reasoning) legal_actions: List of valid action IDs Returns: Selected action ID (defaults to NOOP if parsing fails) """ # Look for single digit numbers in the response numbers = re.findall(r'\b\d+\b', text) # Check from the end (last number is likely the final action after reasoning) for num_str in reversed(numbers): action_id = int(num_str) if action_id in legal_actions: return action_id # Default to NOOP if available, otherwise first legal action return 0 if 0 in legal_actions else legal_actions[0] # Legacy CLI loop removed; Gradio's Image.every drives stepping via GameSession.next_frame def start_session(game: str, model_name: str, prompt_text: str) -> Optional[GameSession]: try: return GameSession(game=game, model_name=model_name, prompt_text=prompt_text) except Exception as e: raise gr.Error(str(e)) def stop_session(session: Optional[GameSession]) -> Optional[GameSession]: if isinstance(session, GameSession): session.close() return None def frame_tick(session: Optional[GameSession]) -> Optional[np.ndarray]: if not isinstance(session, GameSession): return None frame = session.next_frame() if frame is None: # Auto-stop when done session.close() return None return frame def log_tick(session: Optional[GameSession]) -> str: if not isinstance(session, GameSession): return "" return session.log_message def launch_gradio_app(): games = [ "pong", "breakout", "pacman", ] models = [ "Qwen/Qwen3-VL-8B-Instruct", "Qwen/Qwen3-VL-72B-A14B-Instruct", "Qwen/Qwen3-VL-235B-A22B-Instruct", ] with gr.Blocks() as demo: gr.Markdown(""" ### Atari Vision-Language Control - Select a game and model, edit the prompt, then click Start. - Frames are streamed directly from the environment without modification. - There are a limited number of environment spaces via `"https://burtenshaw-{game}.hf.space"` - Duplicate the space and change environment variables if you want to use a different game. """) session_state = gr.State() with gr.Row(): with gr.Column(): game_dd = gr.Dropdown(choices=games, value="pong", label="Game") model_dd = gr.Dropdown(choices=models, value=models[0], label="Model") prompt_tb = gr.Textbox(label="Prompt", value=VISION_PROMPT, lines=6) with gr.Row(): start_btn = gr.Button("Start", variant="primary") stop_btn = gr.Button("Stop") with gr.Column(): out_image = gr.Image(label="Game Stream", type="numpy", value=frame_tick, inputs=[session_state], every=0.1, height=480, width=640) out_text = gr.Textbox(label="Game Logs", value=log_tick, inputs=[session_state], lines=10, every=0.5) # Controls start_btn.click(start_session, inputs=[game_dd, model_dd, prompt_tb], outputs=[session_state]) stop_btn.click(stop_session, inputs=[session_state], outputs=[session_state]) demo.queue() demo.launch() if __name__ == "__main__": launch_gradio_app()