Kaushik Rajan
Phase 2.2: Implement game environments - TicTacToe and Kuhn Poker with Gymnasium interface, utilities, and tests
5e02c64
raw
history blame
7.56 kB
"""
TicTacToe Game Environment
A simple TicTacToe implementation using Gymnasium for SPIRAL training.
"""
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from typing import Tuple, Dict, Any, Optional
class TicTacToeEnv(gym.Env):
"""
TicTacToe environment for SPIRAL training.
- 3x3 grid
- Players alternate turns (1 and -1)
- Action space: 9 positions (0-8)
- Observation space: 3x3 grid with values {-1, 0, 1}
- Reward: +1 for win, -1 for loss, 0 for draw/ongoing
"""
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 1}
def __init__(self, render_mode: Optional[str] = None):
super().__init__()
# 3x3 grid, each cell can be -1 (player 2), 0 (empty), or 1 (player 1)
self.observation_space = spaces.Box(
low=-1, high=1, shape=(3, 3), dtype=np.int8
)
# 9 possible actions (positions 0-8)
self.action_space = spaces.Discrete(9)
self.render_mode = render_mode
self.reset()
def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]:
"""Reset the game to initial state."""
super().reset(seed=seed)
# Initialize empty board
self.board = np.zeros((3, 3), dtype=np.int8)
self.current_player = 1 # Player 1 starts
self.game_over = False
self.winner = None
self.move_count = 0
observation = self._get_observation()
info = self._get_info()
return observation, info
def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
"""
Execute one step in the environment.
Args:
action: Position to place mark (0-8)
Returns:
observation, reward, terminated, truncated, info
"""
if self.game_over:
raise ValueError("Game is already over. Call reset() to start new game.")
# Convert action to row, col
row, col = divmod(action, 3)
# Check if move is valid
if self.board[row, col] != 0:
# Invalid move - penalize and end game
reward = -1.0
terminated = True
self.game_over = True
info = self._get_info()
info["invalid_move"] = True
return self._get_observation(), reward, terminated, False, info
# Make the move
self.board[row, col] = self.current_player
self.move_count += 1
# Check for win
winner = self._check_winner()
if winner is not None:
self.game_over = True
self.winner = winner
reward = 1.0 if winner == self.current_player else -1.0
terminated = True
elif self.move_count >= 9:
# Draw
self.game_over = True
reward = 0.0
terminated = True
else:
# Game continues
reward = 0.0
terminated = False
self.current_player *= -1 # Switch player
observation = self._get_observation()
info = self._get_info()
return observation, reward, terminated, False, info
def _get_observation(self) -> np.ndarray:
"""Get current board state."""
return self.board.copy()
def _get_info(self) -> Dict[str, Any]:
"""Get additional info about the game state."""
return {
"current_player": self.current_player,
"game_over": self.game_over,
"winner": self.winner,
"move_count": self.move_count,
"valid_actions": self._get_valid_actions()
}
def _get_valid_actions(self) -> list:
"""Get list of valid actions (empty positions)."""
valid_actions = []
for i in range(9):
row, col = divmod(i, 3)
if self.board[row, col] == 0:
valid_actions.append(i)
return valid_actions
def _check_winner(self) -> Optional[int]:
"""
Check if there's a winner.
Returns:
1 if player 1 wins, -1 if player 2 wins, None if no winner
"""
# Check rows
for row in range(3):
if abs(self.board[row, :].sum()) == 3:
return self.board[row, 0]
# Check columns
for col in range(3):
if abs(self.board[:, col].sum()) == 3:
return self.board[0, col]
# Check diagonals
if abs(self.board.diagonal().sum()) == 3:
return self.board[0, 0]
if abs(np.fliplr(self.board).diagonal().sum()) == 3:
return self.board[0, 2]
return None
def render(self) -> Optional[np.ndarray]:
"""Render the game state."""
if self.render_mode == "human":
self._render_human()
elif self.render_mode == "rgb_array":
return self._render_rgb_array()
def _render_human(self):
"""Print the board to console."""
print("\n" + "="*13)
for row in range(3):
print("|", end="")
for col in range(3):
cell = self.board[row, col]
if cell == 1:
print(" X ", end="|")
elif cell == -1:
print(" O ", end="|")
else:
print(f" {row*3 + col} ", end="|")
print()
print("="*13)
if self.game_over:
if self.winner is not None:
winner_symbol = "X" if self.winner == 1 else "O"
print(f"Game Over! Winner: {winner_symbol}")
else:
print("Game Over! It's a draw!")
def _render_rgb_array(self) -> np.ndarray:
"""Render as RGB array for visualization."""
# Simple RGB representation
rgb = np.zeros((3, 3, 3), dtype=np.uint8)
# Player 1 (X) = Red, Player 2 (O) = Blue, Empty = White
for row in range(3):
for col in range(3):
if self.board[row, col] == 1:
rgb[row, col] = [255, 0, 0] # Red
elif self.board[row, col] == -1:
rgb[row, col] = [0, 0, 255] # Blue
else:
rgb[row, col] = [255, 255, 255] # White
return rgb
def get_action_mask(self) -> np.ndarray:
"""Get mask of valid actions (1 for valid, 0 for invalid)."""
mask = np.zeros(9, dtype=np.int8)
for action in self._get_valid_actions():
mask[action] = 1
return mask
def create_tictactoe_env() -> TicTacToeEnv:
"""Factory function to create a TicTacToe environment."""
return TicTacToeEnv()
if __name__ == "__main__":
# Test the environment
env = TicTacToeEnv(render_mode="human")
# Play a simple game
obs, info = env.reset()
print("Initial state:")
env.render()
# Make some moves
moves = [0, 4, 1, 3, 2] # X wins
for move in moves:
if not env.game_over:
obs, reward, terminated, truncated, info = env.step(move)
print(f"\nMove: {move}, Reward: {reward}")
env.render()
if terminated:
print(f"Game terminated! Final reward: {reward}")
break