Kaushik Rajan commited on
Commit
5e02c64
·
1 Parent(s): 06c8d18

Phase 2.2: Implement game environments - TicTacToe and Kuhn Poker with Gymnasium interface, utilities, and tests

Browse files
src/games/__init__.py CHANGED
@@ -5,8 +5,12 @@ This module contains implementations of zero-sum games used for
5
  self-play training, including Kuhn Poker and TicTacToe.
6
  """
7
 
8
- from .kuhn_poker import KuhnPokerEnv
9
- from .tictactoe import TicTacToeEnv
10
- from .base_game import BaseGameEnv
11
 
12
- __all__ = ["KuhnPokerEnv", "TicTacToeEnv", "BaseGameEnv"]
 
 
 
 
 
 
5
  self-play training, including Kuhn Poker and TicTacToe.
6
  """
7
 
8
+ from .tictactoe import TicTacToeEnv, create_tictactoe_env
9
+ from .kuhn_poker import KuhnPokerEnv, create_kuhn_poker_env
 
10
 
11
+ __all__ = [
12
+ "TicTacToeEnv",
13
+ "KuhnPokerEnv",
14
+ "create_tictactoe_env",
15
+ "create_kuhn_poker_env"
16
+ ]
src/games/game_utils.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Game utility functions for SPIRAL training.
3
+
4
+ This module contains helper functions for game environments,
5
+ including multi-turn logic and game state management.
6
+ """
7
+
8
+ import gymnasium as gym
9
+ from typing import Dict, Any, Type, Union
10
+ import numpy as np
11
+
12
+ from .tictactoe import TicTacToeEnv
13
+ from .kuhn_poker import KuhnPokerEnv
14
+
15
+
16
+ # Game registry
17
+ GAMES_REGISTRY: Dict[str, Type[gym.Env]] = {
18
+ "tictactoe": TicTacToeEnv,
19
+ "kuhn_poker": KuhnPokerEnv,
20
+ }
21
+
22
+
23
+ def create_game_env(game_name: str, **kwargs) -> gym.Env:
24
+ """
25
+ Create a game environment by name.
26
+
27
+ Args:
28
+ game_name: Name of the game ("tictactoe", "kuhn_poker")
29
+ **kwargs: Additional arguments for the environment
30
+
31
+ Returns:
32
+ Game environment instance
33
+
34
+ Raises:
35
+ ValueError: If game_name is not recognized
36
+ """
37
+ if game_name not in GAMES_REGISTRY:
38
+ available_games = list(GAMES_REGISTRY.keys())
39
+ raise ValueError(f"Unknown game: {game_name}. Available games: {available_games}")
40
+
41
+ game_class = GAMES_REGISTRY[game_name]
42
+ return game_class(**kwargs)
43
+
44
+
45
+ def get_game_info(game_name: str) -> Dict[str, Any]:
46
+ """
47
+ Get information about a game environment.
48
+
49
+ Args:
50
+ game_name: Name of the game
51
+
52
+ Returns:
53
+ Dictionary with game information
54
+ """
55
+ env = create_game_env(game_name)
56
+
57
+ info = {
58
+ "name": game_name,
59
+ "action_space": env.action_space,
60
+ "observation_space": env.observation_space,
61
+ "max_episode_steps": getattr(env, "_max_episode_steps", None),
62
+ "render_modes": env.metadata.get("render_modes", []),
63
+ }
64
+
65
+ # Add game-specific information
66
+ if game_name == "tictactoe":
67
+ info.update({
68
+ "description": "3x3 TicTacToe game with alternating turns",
69
+ "players": 2,
70
+ "zero_sum": True,
71
+ "perfect_information": True,
72
+ })
73
+ elif game_name == "kuhn_poker":
74
+ info.update({
75
+ "description": "Simplified poker with 3 cards (J, Q, K)",
76
+ "players": 2,
77
+ "zero_sum": True,
78
+ "perfect_information": False,
79
+ })
80
+
81
+ env.close()
82
+ return info
83
+
84
+
85
+ def get_available_games() -> list:
86
+ """Get list of available game names."""
87
+ return list(GAMES_REGISTRY.keys())
88
+
89
+
90
+ def is_game_over(env: gym.Env) -> bool:
91
+ """
92
+ Check if the game is over.
93
+
94
+ Args:
95
+ env: Game environment
96
+
97
+ Returns:
98
+ True if game is over, False otherwise
99
+ """
100
+ if hasattr(env, 'game_over'):
101
+ return env.game_over
102
+ return False
103
+
104
+
105
+ def get_valid_actions(env: gym.Env) -> list:
106
+ """
107
+ Get valid actions for the current state.
108
+
109
+ Args:
110
+ env: Game environment
111
+
112
+ Returns:
113
+ List of valid actions
114
+ """
115
+ if hasattr(env, '_get_valid_actions'):
116
+ return env._get_valid_actions()
117
+ elif hasattr(env, 'get_valid_actions'):
118
+ return env.get_valid_actions()
119
+ else:
120
+ # Fallback: assume all actions are valid
121
+ return list(range(env.action_space.n))
122
+
123
+
124
+ def get_action_mask(env: gym.Env) -> np.ndarray:
125
+ """
126
+ Get action mask for the current state.
127
+
128
+ Args:
129
+ env: Game environment
130
+
131
+ Returns:
132
+ Boolean mask where True indicates valid actions
133
+ """
134
+ if hasattr(env, 'get_action_mask'):
135
+ return env.get_action_mask()
136
+ else:
137
+ # Fallback: create mask from valid actions
138
+ valid_actions = get_valid_actions(env)
139
+ mask = np.zeros(env.action_space.n, dtype=bool)
140
+ for action in valid_actions:
141
+ mask[action] = True
142
+ return mask
143
+
144
+
145
+ def play_random_game(game_name: str, render: bool = False, seed: int = None) -> Dict[str, Any]:
146
+ """
147
+ Play a random game to completion.
148
+
149
+ Args:
150
+ game_name: Name of the game to play
151
+ render: Whether to render the game
152
+ seed: Random seed for reproducibility
153
+
154
+ Returns:
155
+ Dictionary with game results
156
+ """
157
+ env = create_game_env(game_name, render_mode="human" if render else None)
158
+
159
+ if seed is not None:
160
+ env.reset(seed=seed)
161
+ else:
162
+ env.reset()
163
+
164
+ if render:
165
+ env.render()
166
+
167
+ total_reward = 0
168
+ step_count = 0
169
+ actions_taken = []
170
+
171
+ while not is_game_over(env):
172
+ valid_actions = get_valid_actions(env)
173
+ action = np.random.choice(valid_actions)
174
+
175
+ obs, reward, terminated, truncated, info = env.step(action)
176
+ actions_taken.append(action)
177
+ total_reward += reward
178
+ step_count += 1
179
+
180
+ if render:
181
+ print(f"Step {step_count}: Action {action}, Reward: {reward}")
182
+ env.render()
183
+
184
+ if terminated or truncated:
185
+ break
186
+
187
+ results = {
188
+ "game_name": game_name,
189
+ "total_reward": total_reward,
190
+ "step_count": step_count,
191
+ "actions_taken": actions_taken,
192
+ "winner": getattr(env, 'winner', None),
193
+ "final_info": info
194
+ }
195
+
196
+ env.close()
197
+ return results
198
+
199
+
200
+ if __name__ == "__main__":
201
+ # Test the utilities
202
+ print("Available games:", get_available_games())
203
+
204
+ for game_name in get_available_games():
205
+ print(f"\n{game_name.upper()} Info:")
206
+ info = get_game_info(game_name)
207
+ for key, value in info.items():
208
+ print(f" {key}: {value}")
209
+
210
+ # Play a random game
211
+ print("\nPlaying random TicTacToe game:")
212
+ result = play_random_game("tictactoe", render=True, seed=42)
src/games/kuhn_poker.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Kuhn Poker Game Environment
3
+
4
+ A simple Kuhn Poker implementation using Gymnasium for SPIRAL training.
5
+ Kuhn Poker is a simplified poker variant with 3 cards (J, Q, K).
6
+ """
7
+
8
+ import gymnasium as gym
9
+ import numpy as np
10
+ from gymnasium import spaces
11
+ from typing import Tuple, Dict, Any, Optional, List
12
+ import random
13
+
14
+
15
+ class KuhnPokerEnv(gym.Env):
16
+ """
17
+ Kuhn Poker environment for SPIRAL training.
18
+
19
+ Rules:
20
+ - 3 cards: Jack (0), Queen (1), King (2)
21
+ - Each player gets 1 card
22
+ - Each player antes 1 chip
23
+ - Player 1 acts first: Check or Bet
24
+ - Player 2 then acts: Check, Call, or Fold
25
+ - If both check, high card wins
26
+ - If one bets and other calls, high card wins
27
+ - If one bets and other folds, bettor wins
28
+
29
+ Action space: [Check/Call=0, Bet=1, Fold=2]
30
+ Observation space: [player_card, opponent_action, betting_round]
31
+ """
32
+
33
+ metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 1}
34
+
35
+ # Card values: Jack=0, Queen=1, King=2
36
+ JACK, QUEEN, KING = 0, 1, 2
37
+ CARDS = [JACK, QUEEN, KING]
38
+ CARD_NAMES = ["J", "Q", "K"]
39
+
40
+ # Actions
41
+ CHECK_CALL, BET, FOLD = 0, 1, 2
42
+ ACTION_NAMES = ["Check/Call", "Bet", "Fold"]
43
+
44
+ def __init__(self, render_mode: Optional[str] = None):
45
+ super().__init__()
46
+
47
+ # Observation: [player_card, opponent_last_action, betting_round, pot_size]
48
+ self.observation_space = spaces.Box(
49
+ low=0, high=10, shape=(4,), dtype=np.int8
50
+ )
51
+
52
+ # Actions: Check/Call, Bet, Fold
53
+ self.action_space = spaces.Discrete(3)
54
+
55
+ self.render_mode = render_mode
56
+ self.reset()
57
+
58
+ def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]:
59
+ """Reset the game to initial state."""
60
+ super().reset(seed=seed)
61
+
62
+ # Deal cards
63
+ cards = self.CARDS.copy()
64
+ random.shuffle(cards)
65
+ self.player1_card = cards[0]
66
+ self.player2_card = cards[1]
67
+
68
+ # Game state
69
+ self.current_player = 1 # Player 1 starts
70
+ self.pot = 2 # Each player antes 1
71
+ self.player1_bet = 1 # Ante
72
+ self.player2_bet = 1 # Ante
73
+ self.game_over = False
74
+ self.winner = None
75
+ self.betting_round = 0
76
+ self.actions_history = []
77
+
78
+ observation = self._get_observation()
79
+ info = self._get_info()
80
+
81
+ return observation, info
82
+
83
+ def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
84
+ """
85
+ Execute one step in the environment.
86
+
87
+ Args:
88
+ action: 0=Check/Call, 1=Bet, 2=Fold
89
+
90
+ Returns:
91
+ observation, reward, terminated, truncated, info
92
+ """
93
+ if self.game_over:
94
+ raise ValueError("Game is already over. Call reset() to start new game.")
95
+
96
+ # Record action
97
+ self.actions_history.append((self.current_player, action))
98
+
99
+ # Process action
100
+ if action == self.FOLD:
101
+ # Current player folds, opponent wins
102
+ self.game_over = True
103
+ self.winner = 2 if self.current_player == 1 else 1
104
+ reward = self._calculate_reward()
105
+
106
+ elif action == self.BET:
107
+ # Current player bets
108
+ if self.current_player == 1:
109
+ self.player1_bet += 1
110
+ self.pot += 1
111
+ else:
112
+ self.player2_bet += 1
113
+ self.pot += 1
114
+
115
+ # Check if this ends the betting round
116
+ if self.betting_round == 0:
117
+ # First bet, opponent gets to act
118
+ self.current_player = 2
119
+ self.betting_round = 1
120
+ reward = 0.0
121
+ else:
122
+ # Second bet (raise), go to showdown
123
+ self.game_over = True
124
+ self.winner = self._determine_winner_by_cards()
125
+ reward = self._calculate_reward()
126
+
127
+ else: # CHECK_CALL
128
+ if self.betting_round == 0:
129
+ # First action is check
130
+ if self.current_player == 1:
131
+ # Player 1 checks, player 2 acts
132
+ self.current_player = 2
133
+ self.betting_round = 1
134
+ reward = 0.0
135
+ else:
136
+ # Player 2 checks after player 1 checked, showdown
137
+ self.game_over = True
138
+ self.winner = self._determine_winner_by_cards()
139
+ reward = self._calculate_reward()
140
+ else:
141
+ # This is a call
142
+ if self.current_player == 2:
143
+ # Player 2 calls player 1's bet
144
+ self.player2_bet = self.player1_bet
145
+ self.pot = self.player1_bet + self.player2_bet
146
+ self.game_over = True
147
+ self.winner = self._determine_winner_by_cards()
148
+ reward = self._calculate_reward()
149
+ else:
150
+ # Player 1 calls player 2's bet
151
+ self.player1_bet = self.player2_bet
152
+ self.pot = self.player1_bet + self.player2_bet
153
+ self.game_over = True
154
+ self.winner = self._determine_winner_by_cards()
155
+ reward = self._calculate_reward()
156
+
157
+ observation = self._get_observation()
158
+ info = self._get_info()
159
+
160
+ return observation, reward, self.game_over, False, info
161
+
162
+ def _get_observation(self) -> np.ndarray:
163
+ """Get current observation for the current player."""
164
+ # Get current player's card
165
+ player_card = self.player1_card if self.current_player == 1 else self.player2_card
166
+
167
+ # Get opponent's last action (if any)
168
+ opponent_last_action = -1
169
+ if self.actions_history:
170
+ for player, action in reversed(self.actions_history):
171
+ if player != self.current_player:
172
+ opponent_last_action = action
173
+ break
174
+
175
+ # Observation: [player_card, opponent_last_action, betting_round, pot_size]
176
+ observation = np.array([
177
+ player_card,
178
+ opponent_last_action + 1, # -1 becomes 0, 0 becomes 1, etc.
179
+ self.betting_round,
180
+ self.pot
181
+ ], dtype=np.int8)
182
+
183
+ return observation
184
+
185
+ def _get_info(self) -> Dict[str, Any]:
186
+ """Get additional info about the game state."""
187
+ return {
188
+ "current_player": self.current_player,
189
+ "game_over": self.game_over,
190
+ "winner": self.winner,
191
+ "player1_card": self.player1_card,
192
+ "player2_card": self.player2_card,
193
+ "pot": self.pot,
194
+ "betting_round": self.betting_round,
195
+ "actions_history": self.actions_history.copy(),
196
+ "valid_actions": self._get_valid_actions()
197
+ }
198
+
199
+ def _get_valid_actions(self) -> List[int]:
200
+ """Get list of valid actions."""
201
+ if self.game_over:
202
+ return []
203
+
204
+ # All actions are always valid in Kuhn Poker
205
+ return [self.CHECK_CALL, self.BET, self.FOLD]
206
+
207
+ def _determine_winner_by_cards(self) -> int:
208
+ """Determine winner by comparing cards."""
209
+ if self.player1_card > self.player2_card:
210
+ return 1
211
+ else:
212
+ return 2
213
+
214
+ def _calculate_reward(self) -> float:
215
+ """Calculate reward for the current player."""
216
+ if not self.game_over:
217
+ return 0.0
218
+
219
+ if self.winner == self.current_player:
220
+ # Won - get the pot minus what you put in
221
+ if self.current_player == 1:
222
+ return float(self.pot - self.player1_bet)
223
+ else:
224
+ return float(self.pot - self.player2_bet)
225
+ else:
226
+ # Lost - lose what you put in
227
+ if self.current_player == 1:
228
+ return float(-self.player1_bet)
229
+ else:
230
+ return float(-self.player2_bet)
231
+
232
+ def render(self) -> Optional[np.ndarray]:
233
+ """Render the game state."""
234
+ if self.render_mode == "human":
235
+ self._render_human()
236
+ elif self.render_mode == "rgb_array":
237
+ return self._render_rgb_array()
238
+
239
+ def _render_human(self):
240
+ """Print the game state to console."""
241
+ print("\n" + "="*40)
242
+ print("KUHN POKER")
243
+ print("="*40)
244
+ print(f"Player 1 Card: {self.CARD_NAMES[self.player1_card]}")
245
+ print(f"Player 2 Card: {self.CARD_NAMES[self.player2_card]}")
246
+ print(f"Pot: {self.pot}")
247
+ print(f"Current Player: {self.current_player}")
248
+ print(f"Betting Round: {self.betting_round}")
249
+
250
+ if self.actions_history:
251
+ print("Actions:")
252
+ for player, action in self.actions_history:
253
+ print(f" Player {player}: {self.ACTION_NAMES[action]}")
254
+
255
+ if self.game_over:
256
+ print(f"Game Over! Winner: Player {self.winner}")
257
+ print("="*40)
258
+
259
+ def _render_rgb_array(self) -> np.ndarray:
260
+ """Render as RGB array for visualization."""
261
+ # Simple RGB representation (placeholder)
262
+ rgb = np.zeros((100, 100, 3), dtype=np.uint8)
263
+
264
+ # Color based on current player's card
265
+ if self.current_player == 1:
266
+ card_value = self.player1_card
267
+ else:
268
+ card_value = self.player2_card
269
+
270
+ # Different colors for different cards
271
+ if card_value == self.JACK:
272
+ rgb[:, :] = [255, 0, 0] # Red for Jack
273
+ elif card_value == self.QUEEN:
274
+ rgb[:, :] = [0, 255, 0] # Green for Queen
275
+ else: # King
276
+ rgb[:, :] = [0, 0, 255] # Blue for King
277
+
278
+ return rgb
279
+
280
+ def get_action_mask(self) -> np.ndarray:
281
+ """Get mask of valid actions (1 for valid, 0 for invalid)."""
282
+ mask = np.zeros(3, dtype=np.int8)
283
+ for action in self._get_valid_actions():
284
+ mask[action] = 1
285
+ return mask
286
+
287
+
288
+ def create_kuhn_poker_env() -> KuhnPokerEnv:
289
+ """Factory function to create a Kuhn Poker environment."""
290
+ return KuhnPokerEnv()
291
+
292
+
293
+ if __name__ == "__main__":
294
+ # Test the environment
295
+ env = KuhnPokerEnv(render_mode="human")
296
+
297
+ # Play a simple game
298
+ obs, info = env.reset()
299
+ print("Initial state:")
300
+ env.render()
301
+
302
+ # Simulate some moves
303
+ while not env.game_over:
304
+ valid_actions = env._get_valid_actions()
305
+ action = random.choice(valid_actions)
306
+
307
+ obs, reward, terminated, truncated, info = env.step(action)
308
+ print(f"\nPlayer {env.current_player if not env.game_over else 'Previous'} action: {env.ACTION_NAMES[action]}")
309
+ print(f"Reward: {reward}")
310
+ env.render()
311
+
312
+ if terminated:
313
+ print(f"Game terminated! Final reward: {reward}")
314
+ break
src/games/tictactoe.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TicTacToe Game Environment
3
+
4
+ A simple TicTacToe implementation using Gymnasium for SPIRAL training.
5
+ """
6
+
7
+ import gymnasium as gym
8
+ import numpy as np
9
+ from gymnasium import spaces
10
+ from typing import Tuple, Dict, Any, Optional
11
+
12
+
13
+ class TicTacToeEnv(gym.Env):
14
+ """
15
+ TicTacToe environment for SPIRAL training.
16
+
17
+ - 3x3 grid
18
+ - Players alternate turns (1 and -1)
19
+ - Action space: 9 positions (0-8)
20
+ - Observation space: 3x3 grid with values {-1, 0, 1}
21
+ - Reward: +1 for win, -1 for loss, 0 for draw/ongoing
22
+ """
23
+
24
+ metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 1}
25
+
26
+ def __init__(self, render_mode: Optional[str] = None):
27
+ super().__init__()
28
+
29
+ # 3x3 grid, each cell can be -1 (player 2), 0 (empty), or 1 (player 1)
30
+ self.observation_space = spaces.Box(
31
+ low=-1, high=1, shape=(3, 3), dtype=np.int8
32
+ )
33
+
34
+ # 9 possible actions (positions 0-8)
35
+ self.action_space = spaces.Discrete(9)
36
+
37
+ self.render_mode = render_mode
38
+ self.reset()
39
+
40
+ def reset(self, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]:
41
+ """Reset the game to initial state."""
42
+ super().reset(seed=seed)
43
+
44
+ # Initialize empty board
45
+ self.board = np.zeros((3, 3), dtype=np.int8)
46
+ self.current_player = 1 # Player 1 starts
47
+ self.game_over = False
48
+ self.winner = None
49
+ self.move_count = 0
50
+
51
+ observation = self._get_observation()
52
+ info = self._get_info()
53
+
54
+ return observation, info
55
+
56
+ def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
57
+ """
58
+ Execute one step in the environment.
59
+
60
+ Args:
61
+ action: Position to place mark (0-8)
62
+
63
+ Returns:
64
+ observation, reward, terminated, truncated, info
65
+ """
66
+ if self.game_over:
67
+ raise ValueError("Game is already over. Call reset() to start new game.")
68
+
69
+ # Convert action to row, col
70
+ row, col = divmod(action, 3)
71
+
72
+ # Check if move is valid
73
+ if self.board[row, col] != 0:
74
+ # Invalid move - penalize and end game
75
+ reward = -1.0
76
+ terminated = True
77
+ self.game_over = True
78
+ info = self._get_info()
79
+ info["invalid_move"] = True
80
+ return self._get_observation(), reward, terminated, False, info
81
+
82
+ # Make the move
83
+ self.board[row, col] = self.current_player
84
+ self.move_count += 1
85
+
86
+ # Check for win
87
+ winner = self._check_winner()
88
+ if winner is not None:
89
+ self.game_over = True
90
+ self.winner = winner
91
+ reward = 1.0 if winner == self.current_player else -1.0
92
+ terminated = True
93
+ elif self.move_count >= 9:
94
+ # Draw
95
+ self.game_over = True
96
+ reward = 0.0
97
+ terminated = True
98
+ else:
99
+ # Game continues
100
+ reward = 0.0
101
+ terminated = False
102
+ self.current_player *= -1 # Switch player
103
+
104
+ observation = self._get_observation()
105
+ info = self._get_info()
106
+
107
+ return observation, reward, terminated, False, info
108
+
109
+ def _get_observation(self) -> np.ndarray:
110
+ """Get current board state."""
111
+ return self.board.copy()
112
+
113
+ def _get_info(self) -> Dict[str, Any]:
114
+ """Get additional info about the game state."""
115
+ return {
116
+ "current_player": self.current_player,
117
+ "game_over": self.game_over,
118
+ "winner": self.winner,
119
+ "move_count": self.move_count,
120
+ "valid_actions": self._get_valid_actions()
121
+ }
122
+
123
+ def _get_valid_actions(self) -> list:
124
+ """Get list of valid actions (empty positions)."""
125
+ valid_actions = []
126
+ for i in range(9):
127
+ row, col = divmod(i, 3)
128
+ if self.board[row, col] == 0:
129
+ valid_actions.append(i)
130
+ return valid_actions
131
+
132
+ def _check_winner(self) -> Optional[int]:
133
+ """
134
+ Check if there's a winner.
135
+
136
+ Returns:
137
+ 1 if player 1 wins, -1 if player 2 wins, None if no winner
138
+ """
139
+ # Check rows
140
+ for row in range(3):
141
+ if abs(self.board[row, :].sum()) == 3:
142
+ return self.board[row, 0]
143
+
144
+ # Check columns
145
+ for col in range(3):
146
+ if abs(self.board[:, col].sum()) == 3:
147
+ return self.board[0, col]
148
+
149
+ # Check diagonals
150
+ if abs(self.board.diagonal().sum()) == 3:
151
+ return self.board[0, 0]
152
+
153
+ if abs(np.fliplr(self.board).diagonal().sum()) == 3:
154
+ return self.board[0, 2]
155
+
156
+ return None
157
+
158
+ def render(self) -> Optional[np.ndarray]:
159
+ """Render the game state."""
160
+ if self.render_mode == "human":
161
+ self._render_human()
162
+ elif self.render_mode == "rgb_array":
163
+ return self._render_rgb_array()
164
+
165
+ def _render_human(self):
166
+ """Print the board to console."""
167
+ print("\n" + "="*13)
168
+ for row in range(3):
169
+ print("|", end="")
170
+ for col in range(3):
171
+ cell = self.board[row, col]
172
+ if cell == 1:
173
+ print(" X ", end="|")
174
+ elif cell == -1:
175
+ print(" O ", end="|")
176
+ else:
177
+ print(f" {row*3 + col} ", end="|")
178
+ print()
179
+ print("="*13)
180
+
181
+ if self.game_over:
182
+ if self.winner is not None:
183
+ winner_symbol = "X" if self.winner == 1 else "O"
184
+ print(f"Game Over! Winner: {winner_symbol}")
185
+ else:
186
+ print("Game Over! It's a draw!")
187
+
188
+ def _render_rgb_array(self) -> np.ndarray:
189
+ """Render as RGB array for visualization."""
190
+ # Simple RGB representation
191
+ rgb = np.zeros((3, 3, 3), dtype=np.uint8)
192
+
193
+ # Player 1 (X) = Red, Player 2 (O) = Blue, Empty = White
194
+ for row in range(3):
195
+ for col in range(3):
196
+ if self.board[row, col] == 1:
197
+ rgb[row, col] = [255, 0, 0] # Red
198
+ elif self.board[row, col] == -1:
199
+ rgb[row, col] = [0, 0, 255] # Blue
200
+ else:
201
+ rgb[row, col] = [255, 255, 255] # White
202
+
203
+ return rgb
204
+
205
+ def get_action_mask(self) -> np.ndarray:
206
+ """Get mask of valid actions (1 for valid, 0 for invalid)."""
207
+ mask = np.zeros(9, dtype=np.int8)
208
+ for action in self._get_valid_actions():
209
+ mask[action] = 1
210
+ return mask
211
+
212
+
213
+ def create_tictactoe_env() -> TicTacToeEnv:
214
+ """Factory function to create a TicTacToe environment."""
215
+ return TicTacToeEnv()
216
+
217
+
218
+ if __name__ == "__main__":
219
+ # Test the environment
220
+ env = TicTacToeEnv(render_mode="human")
221
+
222
+ # Play a simple game
223
+ obs, info = env.reset()
224
+ print("Initial state:")
225
+ env.render()
226
+
227
+ # Make some moves
228
+ moves = [0, 4, 1, 3, 2] # X wins
229
+ for move in moves:
230
+ if not env.game_over:
231
+ obs, reward, terminated, truncated, info = env.step(move)
232
+ print(f"\nMove: {move}, Reward: {reward}")
233
+ env.render()
234
+
235
+ if terminated:
236
+ print(f"Game terminated! Final reward: {reward}")
237
+ break
test_games.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Test script for game environments.
4
+ """
5
+
6
+ import sys
7
+ import os
8
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
9
+
10
+ from games import TicTacToeEnv, KuhnPokerEnv, create_tictactoe_env, create_kuhn_poker_env
11
+ from games.game_utils import get_available_games, get_game_info, play_random_game
12
+
13
+ def test_tictactoe():
14
+ """Test TicTacToe environment."""
15
+ print("Testing TicTacToe...")
16
+ env = create_tictactoe_env()
17
+ obs, info = env.reset()
18
+ print(f"Initial observation shape: {obs.shape}")
19
+ print(f"Action space: {env.action_space}")
20
+ print(f"Observation space: {env.observation_space}")
21
+
22
+ # Test a few moves
23
+ action = 0
24
+ obs, reward, terminated, truncated, info = env.step(action)
25
+ print(f"After move {action}: reward={reward}, terminated={terminated}")
26
+
27
+ env.close()
28
+ print("TicTacToe test passed!\n")
29
+
30
+
31
+ def test_kuhn_poker():
32
+ """Test Kuhn Poker environment."""
33
+ print("Testing Kuhn Poker...")
34
+ env = create_kuhn_poker_env()
35
+ obs, info = env.reset()
36
+ print(f"Initial observation: {obs}")
37
+ print(f"Action space: {env.action_space}")
38
+ print(f"Observation space: {env.observation_space}")
39
+
40
+ # Test a move
41
+ action = 0 # Check/Call
42
+ obs, reward, terminated, truncated, info = env.step(action)
43
+ print(f"After action {action}: reward={reward}, terminated={terminated}")
44
+
45
+ env.close()
46
+ print("Kuhn Poker test passed!\n")
47
+
48
+
49
+ def test_game_utils():
50
+ """Test game utility functions."""
51
+ print("Testing game utilities...")
52
+
53
+ # Test available games
54
+ games = get_available_games()
55
+ print(f"Available games: {games}")
56
+
57
+ # Test game info
58
+ for game_name in games:
59
+ info = get_game_info(game_name)
60
+ print(f"{game_name} info: {info['description']}")
61
+
62
+ print("Game utilities test passed!\n")
63
+
64
+
65
+ def main():
66
+ """Run all tests."""
67
+ print("Running game environment tests...\n")
68
+
69
+ try:
70
+ test_tictactoe()
71
+ test_kuhn_poker()
72
+ test_game_utils()
73
+ print("All tests passed! ✅")
74
+ except Exception as e:
75
+ print(f"Test failed: {e}")
76
+ return 1
77
+
78
+ return 0
79
+
80
+
81
+ if __name__ == "__main__":
82
+ sys.exit(main())