File size: 18,368 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 |
"""
Overview:
This code implements an MCTSbot that uses MCTS to make decisions.
The MCTSnode is an abstract base class that specifies the basic methods that a Monte Carlo Tree node should have.
The TwoPlayersMCTSnode class inherits from this base class and implements the specific methods.
MCTS implements the search function, which takes in a root node and performs a search to obtain the optimal action.
MCTSbot integrates the above functions and can create a root node based on the current game environment,
and then calls MCTS to perform a search and make a decision.
For more details, you can refer to: https://github.com/int8/monte-carlo-tree-search.
"""
import time
from abc import ABC, abstractmethod
from collections import defaultdict
import numpy as np
import copy
class MCTSNode(ABC):
"""
Overview:
This is an abstract base class that outlines the fundamental methods for a Monte Carlo Tree node.
Each specific method must be implemented in the subclasses for specific use-cases.
"""
def __init__(self, env, parent=None):
"""
Arguments:
- env (:obj:`BaseEnv`): The game environment of the current node.
The properties of this object contain information about the current game environment.
For instance, in a game of tictactoe:
- env.board: A (3,3) array representing the game board, e.g.,
[[0,2,0],
[1,1,0],
[2,0,0]]
Here, 0 denotes an unplayed position, 1 represents a position occupied by player 1, and 2 indicates a position taken by player 2.
- env.players: A list [1,2] representing the two players, player 1 and player 2 respectively.
- env._current_player: Denotes the player who is to make a move in the current turn, which is alterating in each turn not only in the reset phase.
The methods of this object implement functionalities such as game state transitions and retrieving game results.
- parent (:obj:`MCTSNode`): The parent node of the current node. The parent node is primarily used for backpropagation during the Monte Carlo Tree Search.
For the root node, this parent returns None as it does not have a parent node.
"""
self.env = env
self.parent = parent
self.children = []
self.expanded_actions = []
self.best_action = -1
@property
@abstractmethod
def legal_actions(self):
pass
@property
@abstractmethod
def value(self):
pass
@property
@abstractmethod
def visit_count(self):
pass
@abstractmethod
def expand(self):
pass
@abstractmethod
def is_terminal_node(self):
pass
@abstractmethod
def rollout(self):
pass
@abstractmethod
def backpropagate(self, reward):
pass
def is_fully_expanded(self):
"""
Overview:
This method checks if the node is fully expanded.
A node is considered fully expanded when all of its child nodes have been selected at least once.
Whenever a new child node is selected for the first time, a corresponding action is removed from the list of legal actions.
Once the list of legal actions is depleted, it signifies that all child nodes have been selected,
thereby indicating that the parent node is fully expanded.
"""
return len(self.legal_actions) == 0
def best_child(self, c_param=1.4):
"""
Overview:
This function finds the best child node which has the highest UCB (Upper Confidence Bound) score.
The UCB formula is:
{UCT}(v_i, v) = \frac{Q(v_i)}{N(v_i)} + c \sqrt{\frac{\log(N(v))}{N(v_i)}}
- Q(v_i) is the estimated value of the child node v_i.
- N(v_i) is a counter of how many times the child node v_i has been on the backpropagation path.
- N(v) is a counter of how many times the parent (current) node v has been on the backpropagation path.
- c is a parameter which balances exploration and exploitation.
Arguments:
- c_param (:obj:`float`): a parameter which controls the balance between exploration and exploitation. Default value is 1.4.
Returns:
- node (:obj:`MCTSnode`)The child node which has the highest UCB score.
"""
# Calculate the ucb score for every child node in the list.
choices_weights = [(child_node.value / child_node.visit_count) + c_param * np.sqrt(
(2 * np.log(self.visit_count) / child_node.visit_count)) for child_node in self.children]
# Save the best action based on the highest UCB score.
self.best_action = self.expanded_actions[np.argmax(choices_weights)]
# Choose the child node which has the highest ucb score.
return self.children[np.argmax(choices_weights)]
def rollout_policy(self, possible_actions):
"""
Overview:
This method implements the rollout policy for a node during the Monte Carlo Tree Search.
The rollout policy is used to determine the action taken during the simulation phase of the MCTS.
In this case, the policy is to randomly choose an action from the list of possible actions.
Arguments:
- possible_actions(:obj:`list`): A list of all possible actions that can be taken from the current state.
Return:
- action(:obj:`int`): A randomly chosen action from the list of possible actions.
"""
return possible_actions[np.random.randint(len(possible_actions))]
class TwoPlayersMCTSNode(MCTSNode):
"""
Overview:
This subclass inherits from the abstract base class and implements the specific methods required for a two players' Monte Carlo Tree node.
"""
def __init__(self, env, parent=None):
"""
Overview:
This function initializes a new instance of the class. It sets the parent node, environment, and initializes the number of visits, results, and legal actions.
Arguments:
- env (:obj:`BaseEnv`): the environment object which contains information about the current game state.
- parent (:obj:`MCTSNode`): the parent node of this node. If None, then this node is the root node.
"""
super().__init__(env, parent)
self._number_of_visits = 0.
# A default dictionary which sets the value to 0 for undefined keys.
self._results = defaultdict(int)
self._legal_actions = None
# Get all legal actions in current state from the environment object.
@property
def legal_actions(self):
if self._legal_actions is None:
self._legal_actions = copy.deepcopy(self.env.legal_actions)
return self._legal_actions
@property
def value(self):
"""
Overview:
This property represents the estimated value (Q-value) of the current node.
self._results[1] represents the number of wins for player 1.
self._results[-1] represents the number of wins for player 2.
The Q-value is calculated depends on which player is the current player at the parent node,
and is computed as the difference between the wins of the current player and the opponent.
If the parent's current player is player 1, Q-value is the difference of player 1's wins and player 2's wins.
If the parent's current player is player 2, Q-value is the difference of player 2's wins and player 1's wins.
For example, if self._results[1] = 10 (player 1's wins) and self._results[-1] = 5 (player 2's wins):
- If the parent's current player is player 1, then Q-value = 10 - 5 = 5.
- If the parent's current player is player 2, then Q-value = 5 - 10 = -5.
This way, a higher Q-value for a node indicates a higher win rate for the parent's current player.
"""
# Determine the number of wins and losses based on the current player at the parent node.
wins, loses = (self._results[1], self._results[-1]) if self.parent.env.current_player == 1 else (
self._results[-1], self._results[1])
# Calculate and return the Q-value as the difference between wins and losses.
return wins - loses
@property
def visit_count(self):
"""
Overview:
This property represents the number of times the node has been visited during the search.
"""
return self._number_of_visits
def expand(self):
"""
Overview:
This method expands the current node by creating a new child node.
It pops an action from the list of legal actions, simulates the action to get the next game state,
and creates a new child node with that state. The new child node is then added to the list of children nodes.
Returns:
- node(:obj:`TwoPlayersMCTSNode`): The child node object that has been created.
"""
# Choose an untried action from the list of legal actions and pop it out. Only untried actions are left in the list.
action = self.legal_actions.pop()
# The simulate_action() function returns a new environment which resets the board and the current player flag.
next_simulator_env = self.env.simulate_action(action)
# Create a new node object for the child node and append it to the children list.
child_node = TwoPlayersMCTSNode(next_simulator_env, parent=self)
self.children.append(child_node)
# Add the action that has been tried to the expanded_actions list.
self.expanded_actions.append(action)
# Return the child node object.
return child_node
def is_terminal_node(self):
"""
Overview:
This function checks whether the current node is a terminal node.
It uses the game environment's get_done_reward method to check if the game has ended.
Returns:
- A bool flag representing whether the game is over.
"""
# The get_done_reward() returns a tuple (done, reward).
# reward = ±1 when player 1 wins/loses the game.
# reward = 0 when it is a tie.
# reward = None when current node is not a teminal node.
# done is the bool flag representing whether the game is over.
return self.env.get_done_reward()[0]
def rollout(self):
"""
Overview:
This method performs a rollout (simulation) from the current node.
It repeatedly selects an action based on the rollout policy and simulates the action until the game ends.
The method then returns the reward of the game's final state.
Returns:
-reward (:obj:`int`): reward = ±1 when player 1 wins/loses the game, reward = 0 when it is a tie, reward = None when current node is not a teminal node.
"""
# print('simulation begin')
current_rollout_env = self.env
# print(current_rollout_env.board)
while not current_rollout_env.get_done_reward()[0]:
possible_actions = current_rollout_env.legal_actions
action = self.rollout_policy(possible_actions)
current_rollout_env = current_rollout_env.simulate_action(action)
# print('\n')
# print(current_rollout_env.board)
# print('simulation end \n')
return current_rollout_env.get_done_reward()[1]
def backpropagate(self, result):
"""
Overview:
This method performs backpropagation from the current node.
It increments the number of times the node has been visited and the number of wins for the result.
If the current node has a parent, the method recursively backpropagates the result to the parent.
"""
self._number_of_visits += 1.
# result is the index of the self._results list.
# result = ±1 when player 1 wins/loses the game.
self._results[result] += 1.
if self.parent:
self.parent.backpropagate(result)
class MCTS(object):
"""
Overview:
This class implements Monte Carlo Tree Search from the root node, whose environment is the real environment of the game at the current moment.
After the tree search and rollout simulation, every child node of the root node has a UCB value.
Then the decision for the root node is to choose the child node with the highest UCB value.
"""
def __init__(self, node):
"""
Overview:
This function initializes a new instance of the MCTS class with the given root node object.
Parameters:
- node (:obj:`TwoPlayersMCTSNode`): The root node object for the MCTS.
"""
self.root = node
def best_action(self, simulations_number=None, total_simulation_seconds=None, best_action_type="UCB"):
"""
Overview:
This function simulates the game by constantly selecting the best child node and backpropagating the result.
Arguments:
- simulations_number (:obj:`int`): The number of simulations performed to get the best action.
- total_simulation_seconds (:obj:`float`): The amount of time the algorithm has to run. Specified in seconds.
- best_action_type (:obj:`str`): The type of best action selection to use. Either "UCB" or "most visited".
Returns:
- node(:obj:`TwoPlayersMCTSNode`): The best children node object, which contains the best action to take.
"""
# The search cost is determined by either the maximum number of simulations or the longest simulation time.
# If no simulations number is provided, run simulations for the specified time.
if simulations_number is None:
assert (total_simulation_seconds is not None)
end_time = time.time() + total_simulation_seconds
while True:
# Get the leaf node.
leaf_node = self._tree_policy()
# Rollout from the leaf node.
reward = leaf_node.rollout()
# Backpropagate from the leaf node to the root node.
leaf_node.backpropagate(reward)
if time.time() > end_time:
break
# If a simulation number is provided, run the specified number of simulations.
else:
for i in range(0, simulations_number):
# print('****simlulation-{}****'.format(i))
# Get the leaf node.
leaf_node = self._tree_policy()
# Rollout from the leaf node.
reward = leaf_node.rollout()
# print('reward={}\n'.format(reward))
# Backpropagate from the leaf node to the root node.
leaf_node.backpropagate(reward)
# To select the best child go for exploitation only.
if best_action_type == "UCB":
return self.root.best_child(c_param=0.)
else:
children_visit_counts = [child_node.visit_count for child_node in self.root.children]
self.root.best_action = self.root.expanded_actions[np.argmax(children_visit_counts)]
return self.root.children[np.argmax(children_visit_counts)]
#
def _tree_policy(self):
"""
Overview:
This function implements the tree search from the root node to the leaf node, which is either visited for the first time or is the terminal node.
At each step, if the current node is not fully expanded, it expands.
If it is fully expanded, it moves to the best child according to the tree policy.
Returns:
- node(:obj:`TwoPlayersMCTSNode`): The leaf node object that has been reached by the tree search.
"""
current_node = self.root
while not current_node.is_terminal_node():
if not current_node.is_fully_expanded():
# choose a child node which has not been visited before
return current_node.expand()
else:
current_node = current_node.best_child()
return current_node
class MCTSBot:
"""
Overview:
A robot which can use MCTS to make decision, choose an action to take.
"""
def __init__(self, env, bot_name, num_simulation=50):
"""
Overview:
This function initializes a new instance of the MCTSBot class.
Arguments:
- env (:obj:`BaseEnv`): The environment object for the game.
- bot_name (:obj:`str`): The name of the MCTS Bot.
- num_simulation (:obj:`int`): The number of simulations to perform during the MCTS.
"""
self.name = bot_name
self.num_simulation = num_simulation
self.simulator_env = env
def get_actions(self, state, player_index, best_action_type="UCB"):
"""
Overview:
This function gets the actions that the MCTS Bot will take.
The environment is reset to the given state.
Then, MCTS is performed with the specified number of simulations to find the best action.
Arguments:
- state (:obj:`list`): The current game state.
- player_index (:obj:`int`): The index of the current player.
- best_action_type (:obj:`str`): The type of best action selection to use. Either "UCB" or "most visited".
Returns:
- action (:obj:`int`): The best action that the MCTS Bot will take.
"""
# Every time before make a decision, reset the environment to the current environment of the game.
self.simulator_env.reset(start_player_index=player_index, init_state=state)
root = TwoPlayersMCTSNode(self.simulator_env)
# Do the MCTS to find the best action to take.
mcts = MCTS(root)
mcts.best_action(self.num_simulation, best_action_type=best_action_type)
return root.best_action
|