|
""" |
|
Overview: |
|
Adapt the connect4 environment in PettingZoo (https://github.com/Farama-Foundation/PettingZoo) to the BaseEnv interface. |
|
Connect Four is a 2-player turn based game, where players must connect four of their tokens vertically, horizontally or diagonally. |
|
The players drop their respective token in a column of a standing grid, where each token will fall until it reaches the bottom of the column or reaches an existing token. |
|
Players cannot place a token in a full column, and the game ends when either a player has made a sequence of 4 tokens, or when all 7 columns have been filled. |
|
Mode: |
|
- ``self_play_mode``: In ``self_play_mode``, two players take turns to play. This mode is used in AlphaZero for data generating. |
|
- ``play_with_bot_mode``: In this mode, the environment has a bot inside, which take the role of player 2. So the player may play against the bot. |
|
Bot: |
|
- MCTSBot: A bot which take action through a Monte Carlo Tree Search, which has a high performance. |
|
- RuleBot: A bot which take action according to some simple settings, which has a moderate performance. Note: Currently the RuleBot can only exclude actions that would lead to losing the game within three moves. |
|
Note: Currently the RuleBot can only exclude actions that would lead to losing the game within three moves. One possible improvement is to further enhance the bot's long-term planning capabilities. |
|
Observation Space: |
|
The observation in the Connect4 environment is a dictionary with five elements, which contains key information about the current state. |
|
- observation (:obj:`array`): An array that represents information about the current state, with a shape of (3, 6, 7). |
|
The length of the first dimension is 3, which stores three two-dimensional game boards with shapes (6, 7). |
|
These boards represent the positions occupied by the current player, the positions occupied by the opponent player, and the identity of the current player, respectively. |
|
- action_mask (:obj:`array`): A mask for the actions, indicating which actions are executable. It is a one-dimensional array of length 7, corresponding to columns 1 to 7 of the game board. |
|
It has a value of 1 for the columns where a move can be made, and a value of 0 for other positions. |
|
- board (:obj:`array`): A visual representation of the current game board, represented as a 6x7 array, in which the positions where player 1 and player 2 have placed their tokens are marked with values 1 and 2, respectively. |
|
- current_player_index (:obj:`int`): The index of the current player, with player 1 having an index of 0 and player 2 having an index of 1. |
|
- to_play (:obj:`int`): The player who needs to take an action in the current state, with a value of 1 or 2. |
|
Action Space: |
|
A set of integers from 0 to 6 (inclusive), where the action represents which column a token should be dropped in. |
|
Reward Space: |
|
For the ``self_play_mode``, a reward of 1 is returned at the time step when the game terminates, and a reward of 0 is returned at all other time steps. |
|
For the ``play_with_bot_mode``, at the time step when the game terminates, if the bot wins, the reward is -1; if the agent wins, the reward is 1; and in all other cases, the reward is 0. |
|
""" |
|
|
|
import copy |
|
import os |
|
import sys |
|
from typing import List, Any, Tuple, Optional |
|
|
|
import imageio |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pygame |
|
from ding.envs import BaseEnv, BaseEnvTimestep |
|
from ding.utils import ENV_REGISTRY |
|
from ditk import logging |
|
from easydict import EasyDict |
|
from gymnasium import spaces |
|
|
|
from zoo.board_games.connect4.envs.rule_bot import Connect4RuleBot |
|
from zoo.board_games.mcts_bot import MCTSBot |
|
|
|
|
|
@ENV_REGISTRY.register('connect4') |
|
class Connect4Env(BaseEnv): |
|
config = dict( |
|
|
|
env_name="Connect4", |
|
|
|
battle_mode='self_play_mode', |
|
|
|
battle_mode_in_simulation_env='self_play_mode', |
|
|
|
|
|
render_mode=None, |
|
|
|
replay_path=None, |
|
|
|
bot_action_type='rule', |
|
|
|
agent_vs_human=False, |
|
|
|
prob_random_agent=0, |
|
|
|
prob_expert_agent=0, |
|
|
|
prob_random_action_in_bot=0., |
|
|
|
screen_scaling=9, |
|
|
|
channel_last=False, |
|
|
|
scale=False, |
|
|
|
stop_value=2, |
|
) |
|
|
|
@classmethod |
|
def default_config(cls: type) -> EasyDict: |
|
cfg = EasyDict(copy.deepcopy(cls.config)) |
|
cfg.cfg_type = cls.__name__ + 'Dict' |
|
return cfg |
|
|
|
def __init__(self, cfg: dict = None) -> None: |
|
|
|
self.cfg = cfg |
|
|
|
|
|
self.channel_last = cfg.channel_last |
|
self.scale = cfg.scale |
|
|
|
|
|
self.screen_scaling = cfg.screen_scaling |
|
|
|
self.render_mode = cfg.render_mode |
|
self.replay_name_suffix = "test" |
|
self.replay_path = cfg.replay_path |
|
self.replay_format = 'gif' |
|
self.screen = None |
|
self.frames = [] |
|
|
|
|
|
|
|
self.battle_mode = cfg.battle_mode |
|
assert self.battle_mode in ['self_play_mode', 'play_with_bot_mode', 'eval_mode'] |
|
|
|
self.battle_mode_in_simulation_env = 'self_play_mode' |
|
|
|
|
|
self.agent_vs_human = cfg.agent_vs_human |
|
|
|
|
|
self.prob_random_agent = cfg.prob_random_agent |
|
self.prob_expert_agent = cfg.prob_expert_agent |
|
assert (self.prob_random_agent >= 0 and self.prob_expert_agent == 0) or ( |
|
self.prob_random_agent == 0 and self.prob_expert_agent >= 0), \ |
|
f'self.prob_random_agent:{self.prob_random_agent}, self.prob_expert_agent:{self.prob_expert_agent}' |
|
|
|
|
|
self.board = [0] * (6 * 7) |
|
|
|
self.players = [1, 2] |
|
self._current_player = 1 |
|
self._env = self |
|
|
|
|
|
|
|
self.bot_action_type = cfg.bot_action_type |
|
self.prob_random_action_in_bot = cfg.prob_random_action_in_bot |
|
if self.bot_action_type == 'mcts': |
|
cfg_temp = EasyDict(cfg.copy()) |
|
cfg_temp.save_replay = False |
|
cfg_temp.bot_action_type = None |
|
env_mcts = Connect4Env(EasyDict(cfg_temp)) |
|
self.mcts_bot = MCTSBot(env_mcts, 'mcts_player', 50) |
|
elif self.bot_action_type == 'rule': |
|
self.rule_bot = Connect4RuleBot(self, self._current_player) |
|
|
|
|
|
if self.render_mode is not None: |
|
self.render(self.render_mode) |
|
|
|
def _player_step(self, action: int, flag: int) -> BaseEnvTimestep: |
|
""" |
|
Overview: |
|
A function that implements the transition of the environment's state. \ |
|
After taking an action in the environment, the function transitions the environment to the next state \ |
|
and returns the relevant information for the next time step. |
|
Arguments: |
|
- action (:obj:`int`): A value from 0 to 6 indicating the position to move on the connect4 board. |
|
- flag (:obj:`str`): A marker indicating the source of an action, for debugging convenience. |
|
Returns: |
|
- timestep (:obj:`BaseEnvTimestep`): A namedtuple that records the observation and obtained reward after taking the action, \ |
|
whether the game is terminated, and some other information. |
|
""" |
|
if action in self.legal_actions: |
|
piece = self.players.index(self._current_player) + 1 |
|
for i in list(filter(lambda x: x % 7 == action, list(range(41, -1, -1)))): |
|
if self.board[i] == 0: |
|
self.board[i] = piece |
|
break |
|
else: |
|
print(np.array(self.board).reshape(6, 7)) |
|
logging.warning( |
|
f"You input illegal action: {action}, the legal_actions are {self.legal_actions}. " |
|
f"flag is {flag}." |
|
f"Now we randomly choice a action from self.legal_actions." |
|
) |
|
action = self.random_action() |
|
print("the random action is", action) |
|
piece = self.players.index(self._current_player) + 1 |
|
for i in list(filter(lambda x: x % 7 == action, list(range(41, -1, -1)))): |
|
if self.board[i] == 0: |
|
self.board[i] = piece |
|
break |
|
|
|
|
|
done, winner = self.get_done_winner() |
|
if not winner == -1: |
|
reward = np.array(1).astype(np.float32) |
|
else: |
|
reward = np.array(0).astype(np.float32) |
|
|
|
info = {} |
|
|
|
self._current_player = self.next_player |
|
|
|
obs = self.observe() |
|
|
|
|
|
if self.render_mode is not None: |
|
self.render(self.render_mode) |
|
if done: |
|
info['eval_episode_return'] = reward |
|
if self.render_mode == 'image_savefile_mode': |
|
self.save_render_output(replay_name_suffix=self.replay_name_suffix, replay_path=self.replay_path, |
|
format=self.replay_format) |
|
|
|
return BaseEnvTimestep(obs, reward, done, info) |
|
|
|
def step(self, action: int) -> BaseEnvTimestep: |
|
""" |
|
Overview: |
|
The step function of the environment. It receives an action from the player and returns the state of the environment after performing that action. \ |
|
In ``self_play_mode``, this function only call ``_player_step()`` once since the agent play with it self and play the role of both two players 1 and 2.\ |
|
In ``play_with_bot_mode``, this function first use the recieved ``action`` to call the ``_player_step()`` and then use the action from bot to call it again.\ |
|
Then return the result of taking these two actions sequentially in the environment.\ |
|
In ``eval_mode``, this function also call ``_player_step()`` twice, and the second action is from human action or from the bot. |
|
Arguments: |
|
- action (:obj:`int`): A value from 0 to 6 indicating the position to move on the connect4 board. |
|
Returns: |
|
- timestep (:obj:`BaseEnvTimestep`): A namedtuple that records the observation and obtained reward after taking the action, \ |
|
whether the game is terminated, and some other information. |
|
""" |
|
if self.battle_mode == 'self_play_mode': |
|
|
|
if self.prob_random_agent > 0: |
|
if np.random.rand() < self.prob_random_agent: |
|
action = self.random_action() |
|
elif self.prob_expert_agent > 0: |
|
if np.random.rand() < self.prob_expert_agent: |
|
action = self.bot_action() |
|
|
|
flag = "agent" |
|
timestep = self._player_step(action, flag) |
|
|
|
if timestep.done: |
|
|
|
timestep.info['eval_episode_return'] = -timestep.reward if timestep.obs[ |
|
'to_play'] == 1 else timestep.reward |
|
|
|
return timestep |
|
|
|
elif self.battle_mode == 'play_with_bot_mode': |
|
|
|
flag = "bot_agent" |
|
timestep_player1 = self._player_step(action, flag) |
|
|
|
if timestep_player1.done: |
|
|
|
|
|
timestep_player1.obs['to_play'] = -1 |
|
|
|
return timestep_player1 |
|
|
|
|
|
bot_action = self.bot_action() |
|
flag = "bot_bot" |
|
timestep_player2 = self._player_step(bot_action, flag) |
|
|
|
timestep_player2.info['eval_episode_return'] = -timestep_player2.reward |
|
timestep_player2 = timestep_player2._replace(reward=-timestep_player2.reward) |
|
|
|
timestep = timestep_player2 |
|
|
|
|
|
timestep.obs['to_play'] = -1 |
|
|
|
return timestep |
|
|
|
elif self.battle_mode == 'eval_mode': |
|
|
|
flag = "eval_agent" |
|
timestep_player1 = self._player_step(action, flag) |
|
|
|
if timestep_player1.done: |
|
|
|
|
|
timestep_player1.obs['to_play'] = -1 |
|
|
|
return timestep_player1 |
|
|
|
|
|
if self.agent_vs_human: |
|
bot_action = self.human_to_action() |
|
else: |
|
bot_action = self.bot_action() |
|
|
|
flag = "eval_bot" |
|
timestep_player2 = self._player_step(bot_action, flag) |
|
|
|
|
|
timestep_player2.info['eval_episode_return'] = -timestep_player2.reward |
|
timestep_player2 = timestep_player2._replace(reward=-timestep_player2.reward) |
|
|
|
timestep = timestep_player2 |
|
|
|
|
|
timestep.obs['to_play'] = -1 |
|
|
|
return timestep |
|
|
|
def reset(self, start_player_index: int = 0, init_state: Optional[np.ndarray] = None, |
|
replay_name_suffix: Optional[str] = None) -> dict: |
|
""" |
|
Overview: |
|
Env reset and custom state start by init_state. |
|
Arguments: |
|
- start_player_index(:obj:`int`): players = [1,2], player_index = [0,1] |
|
- init_state(:obj:`array`): custom start state. |
|
""" |
|
if replay_name_suffix is not None: |
|
self.replay_name_suffix = replay_name_suffix |
|
if init_state is None: |
|
self.board = [0] * (6 * 7) |
|
else: |
|
self.board = init_state |
|
self.players = [1, 2] |
|
self.start_player_index = start_player_index |
|
self._current_player = self.players[self.start_player_index] |
|
|
|
self._action_space = spaces.Discrete(7) |
|
self._reward_space = spaces.Box(low=-1, high=1, shape=(1,), dtype=np.float32) |
|
self._observation_space = spaces.Dict( |
|
{ |
|
"observation": spaces.Box(low=0, high=1, shape=(3, 6, 7), dtype=np.int8), |
|
"action_mask": spaces.Box(low=0, high=1, shape=(7,), dtype=np.int8), |
|
"board": spaces.Box(low=0, high=2, shape=(6, 7), dtype=np.int8), |
|
"current_player_index": spaces.Discrete(2), |
|
"to_play": spaces.Discrete(2), |
|
} |
|
) |
|
|
|
obs = self.observe() |
|
return obs |
|
|
|
def current_state(self) -> Tuple[np.ndarray, np.ndarray]: |
|
""" |
|
Overview: |
|
Obtain the state from the view of current player.\ |
|
self.board is nd-array, 0 indicates that no stones is placed here,\ |
|
1 indicates that player 1's stone is placed here, 2 indicates player 2's stone is placed here. |
|
Returns: |
|
- current_state (:obj:`array`): |
|
the 0 dim means which positions is occupied by ``self.current_player``,\ |
|
the 1 dim indicates which positions are occupied by ``self.next_player``,\ |
|
the 2 dim indicates which player is the to_play player, 1 means player 1, 2 means player 2. |
|
""" |
|
board_vals = np.array(self.board).reshape(6, 7) |
|
board_curr_player = np.where(board_vals == self.current_player, 1, 0) |
|
board_opponent_player = np.where(board_vals == self.next_player, 1, 0) |
|
board_to_play = np.full((6, 7), self.current_player) |
|
raw_obs = np.array([board_curr_player, board_opponent_player, board_to_play], dtype=np.float32) |
|
if self.scale: |
|
scale_obs = copy.deepcopy(raw_obs / 2) |
|
else: |
|
scale_obs = copy.deepcopy(raw_obs) |
|
if self.channel_last: |
|
|
|
|
|
return np.transpose(raw_obs, [1, 2, 0]), np.transpose(scale_obs, [1, 2, 0]) |
|
else: |
|
|
|
return raw_obs, scale_obs |
|
|
|
def observe(self) -> dict: |
|
legal_moves = self.legal_actions |
|
|
|
action_mask = np.zeros(7, "int8") |
|
for i in legal_moves: |
|
action_mask[i] = 1 |
|
|
|
if self.battle_mode == 'play_with_bot_mode' or self.battle_mode == 'eval_mode': |
|
return {"observation": self.current_state()[1], |
|
"action_mask": action_mask, |
|
"board": copy.deepcopy(self.board), |
|
"current_player_index": self.players.index(self._current_player), |
|
"to_play": -1 |
|
} |
|
elif self.battle_mode == 'self_play_mode': |
|
return {"observation": self.current_state()[1], |
|
"action_mask": action_mask, |
|
"board": copy.deepcopy(self.board), |
|
"current_player_index": self.players.index(self._current_player), |
|
"to_play": self._current_player |
|
} |
|
|
|
@property |
|
def legal_actions(self) -> List[int]: |
|
return [i for i in range(7) if self.board[i] == 0] |
|
|
|
def render(self, mode: str = None) -> None: |
|
""" |
|
Overview: |
|
Renders the Connect Four game environment. |
|
Arguments: |
|
- mode (:obj:`str`): The rendering mode. Options are None, 'state_realtime_mode', 'image_realtime_mode' or 'image_savefile_mode'. |
|
When set to None, the game state is not rendered. |
|
In 'state_realtime_mode', the game state is illustrated in a text-based format directly in the console. |
|
The 'image_realtime_mode' displays the game as an RGB image in real-time. |
|
With 'image_savefile_mode', the game is rendered as an RGB image but not displayed in real-time. Instead, the image is saved to a designated file. |
|
Please note that the default rendering mode is set to None. |
|
""" |
|
|
|
if mode == "state_realtime_mode": |
|
print(np.array(self.board).reshape(6, 7)) |
|
return |
|
else: |
|
|
|
screen_width = 99 * self.screen_scaling |
|
screen_height = 86 / 99 * screen_width |
|
pygame.init() |
|
self.screen = pygame.Surface((screen_width, screen_height)) |
|
|
|
|
|
tile_size = (screen_width * (91 / 99)) / 7 |
|
|
|
red_chip = self.get_image(os.path.join("img", "C4RedPiece.png")) |
|
red_chip = pygame.transform.scale( |
|
red_chip, (int(tile_size * (9 / 13)), int(tile_size * (9 / 13))) |
|
) |
|
|
|
black_chip = self.get_image(os.path.join("img", "C4BlackPiece.png")) |
|
black_chip = pygame.transform.scale( |
|
black_chip, (int(tile_size * (9 / 13)), int(tile_size * (9 / 13))) |
|
) |
|
|
|
board_img = self.get_image(os.path.join("img", "Connect4Board.png")) |
|
board_img = pygame.transform.scale( |
|
board_img, ((int(screen_width)), int(screen_height)) |
|
) |
|
|
|
self.screen.blit(board_img, (0, 0)) |
|
|
|
|
|
for i in range(0, 42): |
|
if self.board[i] == 1: |
|
self.screen.blit( |
|
red_chip, |
|
( |
|
(i % 7) * (tile_size) + (tile_size * (6 / 13)), |
|
int(i / 7) * (tile_size) + (tile_size * (6 / 13)), |
|
), |
|
) |
|
elif self.board[i] == 2: |
|
self.screen.blit( |
|
black_chip, |
|
( |
|
(i % 7) * (tile_size) + (tile_size * (6 / 13)), |
|
int(i / 7) * (tile_size) + (tile_size * (6 / 13)), |
|
), |
|
) |
|
if mode == "image_realtime_mode": |
|
surface_array = pygame.surfarray.pixels3d(self.screen) |
|
surface_array = np.transpose(surface_array, (1, 0, 2)) |
|
plt.imshow(surface_array) |
|
plt.draw() |
|
plt.pause(0.001) |
|
elif mode == "image_savefile_mode": |
|
|
|
observation = np.array(pygame.surfarray.pixels3d(self.screen)) |
|
self.frames.append(np.transpose(observation, axes=(1, 0, 2))) |
|
|
|
self.screen = None |
|
|
|
return None |
|
|
|
def save_render_output(self, replay_name_suffix: str = '', replay_path: str = None, format: str = 'gif') -> None: |
|
""" |
|
Overview: |
|
Save the rendered frames as an output file. |
|
Arguments: |
|
- replay_name_suffix (:obj:`str`): The suffix to be added to the replay filename. |
|
- replay_path (:obj:`str`): The path to save the replay file. If None, the default filename will be used. |
|
- format (:obj:`str`): The format of the output file. Options are 'gif' or 'mp4'. |
|
""" |
|
|
|
if replay_path is None: |
|
filename = f'connect4_{replay_name_suffix}.{format}' |
|
else: |
|
if not os.path.exists(replay_path): |
|
os.makedirs(replay_path) |
|
filename = replay_path + f'/connect4_{replay_name_suffix}.{format}' |
|
|
|
if format == 'gif': |
|
|
|
imageio.mimsave(filename, self.frames, 'GIF', duration=0.1) |
|
elif format == 'mp4': |
|
|
|
imageio.mimsave(filename, self.frames, fps=30, codec='mpeg4') |
|
|
|
else: |
|
raise ValueError("Unsupported format: {}".format(format)) |
|
logging.info("Saved output to {}".format(filename)) |
|
self.frames = [] |
|
|
|
def get_done_winner(self) -> Tuple[bool, int]: |
|
""" |
|
Overview: |
|
Check if the game is done and find the winner. |
|
Returns: |
|
- outputs (:obj:`Tuple`): Tuple containing 'done' and 'winner', |
|
- if player 1 win, 'done' = True, 'winner' = 1 |
|
- if player 2 win, 'done' = True, 'winner' = 2 |
|
- if draw, 'done' = True, 'winner' = -1 |
|
- if game is not over, 'done' = False,'winner' = -1 |
|
""" |
|
board = copy.deepcopy(np.array(self.board)).reshape(6, 7) |
|
for piece in [1, 2]: |
|
|
|
column_count = 7 |
|
row_count = 6 |
|
|
|
for c in range(column_count - 3): |
|
for r in range(row_count): |
|
if ( |
|
board[r][c] == piece |
|
and board[r][c + 1] == piece |
|
and board[r][c + 2] == piece |
|
and board[r][c + 3] == piece |
|
): |
|
return True, piece |
|
|
|
|
|
for c in range(column_count): |
|
for r in range(row_count - 3): |
|
if ( |
|
board[r][c] == piece |
|
and board[r + 1][c] == piece |
|
and board[r + 2][c] == piece |
|
and board[r + 3][c] == piece |
|
): |
|
return True, piece |
|
|
|
|
|
for c in range(column_count - 3): |
|
for r in range(row_count - 3): |
|
if ( |
|
board[r][c] == piece |
|
and board[r + 1][c + 1] == piece |
|
and board[r + 2][c + 2] == piece |
|
and board[r + 3][c + 3] == piece |
|
): |
|
return True, piece |
|
|
|
|
|
for c in range(column_count - 3): |
|
for r in range(3, row_count): |
|
if ( |
|
board[r][c] == piece |
|
and board[r - 1][c + 1] == piece |
|
and board[r - 2][c + 2] == piece |
|
and board[r - 3][c + 3] == piece |
|
): |
|
return True, piece |
|
|
|
if all(x in [1, 2] for x in self.board): |
|
return True, -1 |
|
|
|
return False, -1 |
|
|
|
def get_done_reward(self) -> Tuple[bool, int]: |
|
""" |
|
Overview: |
|
Check if the game is over and what is the reward in the perspective of player 1.\ |
|
Return 'done' and 'reward'. |
|
Returns: |
|
- outputs (:obj:`Tuple`): Tuple containing 'done' and 'reward', |
|
- if player 1 win, 'done' = True, 'reward' = 1 |
|
- if player 2 win, 'done' = True, 'reward' = -1 |
|
- if draw, 'done' = True, 'reward' = 0 |
|
- if game is not over, 'done' = False,'reward' = None |
|
""" |
|
done, winner = self.get_done_winner() |
|
if winner == 1: |
|
reward = 1 |
|
elif winner == 2: |
|
reward = -1 |
|
elif winner == -1 and done: |
|
reward = 0 |
|
elif winner == -1 and not done: |
|
|
|
reward = None |
|
return done, reward |
|
|
|
def random_action(self) -> int: |
|
action_list = self.legal_actions |
|
return np.random.choice(action_list) |
|
|
|
def bot_action(self) -> int: |
|
if np.random.rand() < self.prob_random_action_in_bot: |
|
return self.random_action() |
|
else: |
|
if self.bot_action_type == 'rule': |
|
return self.rule_bot.get_rule_bot_action(self.board, self._current_player) |
|
elif self.bot_action_type == 'mcts': |
|
return self.mcts_bot.get_actions(self.board, player_index=self.current_player_index) |
|
|
|
def action_to_string(self, action: int) -> str: |
|
""" |
|
Overview: |
|
Convert an action number to a string representing the action. |
|
Arguments: |
|
- action: an integer from the action space. |
|
Returns: |
|
- String representing the action. |
|
""" |
|
return f"Play column {action + 1}" |
|
|
|
def human_to_action(self) -> int: |
|
""" |
|
Overview: |
|
For multiplayer games, ask the user for a legal action \ |
|
and return the corresponding action number. |
|
Returns: |
|
An integer from the action space. |
|
""" |
|
print(np.array(self.board).reshape(6, 7)) |
|
while True: |
|
try: |
|
column = int( |
|
input( |
|
f"Enter the column to play for the player {self.current_player}: " |
|
) |
|
) |
|
action = column - 1 |
|
if action in self.legal_actions: |
|
break |
|
else: |
|
print("Wrong input, try again") |
|
except KeyboardInterrupt: |
|
print("exit") |
|
sys.exit(0) |
|
except Exception as e: |
|
print("Wrong input, try again") |
|
return action |
|
|
|
def seed(self, seed: int, dynamic_seed: bool = True) -> None: |
|
self._seed = seed |
|
self._dynamic_seed = dynamic_seed |
|
np.random.seed(self._seed) |
|
|
|
def __repr__(self) -> str: |
|
return "LightZero Connect4 Env" |
|
|
|
@property |
|
def current_player(self) -> int: |
|
return self._current_player |
|
|
|
@property |
|
def current_player_index(self) -> int: |
|
""" |
|
Overview: |
|
current_player_index = 0, current_player = 1 \ |
|
current_player_index = 1, current_player = 2 |
|
""" |
|
return 0 if self._current_player == 1 else 1 |
|
|
|
@property |
|
def next_player(self) -> int: |
|
return self.players[0] if self._current_player == self.players[1] else self.players[1] |
|
|
|
@property |
|
def observation_space(self) -> spaces.Space: |
|
return self._observation_space |
|
|
|
@property |
|
def action_space(self) -> spaces.Space: |
|
return self._action_space |
|
|
|
@property |
|
def reward_space(self) -> spaces.Space: |
|
return self._reward_space |
|
|
|
def simulate_action(self, action: int) -> Any: |
|
""" |
|
Overview: |
|
execute action and get next_simulator_env. used in AlphaZero. |
|
Arguments: |
|
- action: an integer from the action space. |
|
Returns: |
|
- next_simulator_env: next simulator env after execute action. |
|
""" |
|
if action not in self.legal_actions: |
|
raise ValueError("action {0} on board {1} is not legal".format(action, self.board)) |
|
new_board = copy.deepcopy(self.board) |
|
piece = self.players.index(self._current_player) + 1 |
|
for i in list(filter(lambda x: x % 7 == action, list(range(41, -1, -1)))): |
|
if new_board[i] == 0: |
|
new_board[i] = piece |
|
break |
|
if self.start_player_index == 0: |
|
start_player_index = 1 |
|
else: |
|
start_player_index = 0 |
|
next_simulator_env = copy.deepcopy(self) |
|
next_simulator_env.reset(start_player_index, init_state=new_board) |
|
return next_simulator_env |
|
|
|
@staticmethod |
|
def create_collector_env_cfg(cfg: dict) -> List[dict]: |
|
collector_env_num = cfg.pop('collector_env_num') |
|
cfg = copy.deepcopy(cfg) |
|
return [cfg for _ in range(collector_env_num)] |
|
|
|
@staticmethod |
|
def create_evaluator_env_cfg(cfg: dict) -> List[dict]: |
|
evaluator_env_num = cfg.pop('evaluator_env_num') |
|
cfg = copy.deepcopy(cfg) |
|
|
|
|
|
cfg.battle_mode = 'eval_mode' |
|
return [cfg for _ in range(evaluator_env_num)] |
|
|
|
def close(self) -> None: |
|
pass |
|
|
|
def get_image(self, path: str) -> Any: |
|
from os import path as os_path |
|
import pygame |
|
|
|
cwd = os_path.dirname(__file__) |
|
image = pygame.image.load(cwd + "/" + path) |
|
sfc = pygame.Surface(image.get_size(), flags=pygame.SRCALPHA) |
|
sfc.blit(image, (0, 0)) |
|
return sfc |
|
|