|
import os |
|
from dataclasses import dataclass |
|
from typing import Any |
|
|
|
import numpy as np |
|
from graphviz import Digraph |
|
|
|
|
|
def generate_random_actions_discrete(num_actions: int, action_space_size: int, num_of_sampled_actions: int, |
|
reshape=False): |
|
""" |
|
Overview: |
|
Generate a list of random actions. |
|
Arguments: |
|
- num_actions (:obj:`int`): The number of actions to generate. |
|
- action_space_size (:obj:`int`): The size of the action space. |
|
- num_of_sampled_actions (:obj:`int`): The number of sampled actions. |
|
- reshape (:obj:`bool`): Whether to reshape the actions. |
|
Returns: |
|
A list of random actions. |
|
""" |
|
actions = [ |
|
np.random.randint(0, action_space_size, num_of_sampled_actions).reshape(-1) |
|
for _ in range(num_actions) |
|
] |
|
|
|
|
|
if num_of_sampled_actions == 1: |
|
actions = [action[0] for action in actions] |
|
|
|
|
|
if reshape and num_of_sampled_actions > 1: |
|
actions = [action.reshape(num_of_sampled_actions, 1) for action in actions] |
|
|
|
return actions |
|
|
|
|
|
@dataclass |
|
class BufferedData: |
|
data: Any |
|
index: str |
|
meta: dict |
|
|
|
|
|
def get_augmented_data(board_size, play_data): |
|
""" |
|
Overview: |
|
augment the data set by rotation and flipping |
|
Arguments: |
|
play_data: [(state, mcts_prob, winner_z), ..., ...] |
|
""" |
|
extend_data = [] |
|
for data in play_data: |
|
state = data['state'] |
|
mcts_prob = data['mcts_prob'] |
|
winner = data['winner'] |
|
for i in [1, 2, 3, 4]: |
|
|
|
equi_state = np.array([np.rot90(s, i) for s in state]) |
|
equi_mcts_prob = np.rot90(np.flipud(mcts_prob.reshape(board_size, board_size)), i) |
|
extend_data.append( |
|
{ |
|
'state': equi_state, |
|
'mcts_prob': np.flipud(equi_mcts_prob).flatten(), |
|
'winner': winner |
|
} |
|
) |
|
|
|
equi_state = np.array([np.fliplr(s) for s in equi_state]) |
|
equi_mcts_prob = np.fliplr(equi_mcts_prob) |
|
extend_data.append( |
|
{ |
|
'state': equi_state, |
|
'mcts_prob': np.flipud(equi_mcts_prob).flatten(), |
|
'winner': winner |
|
} |
|
) |
|
return extend_data |
|
|
|
|
|
def prepare_observation(observation_list, model_type='conv'): |
|
""" |
|
Overview: |
|
Prepare the observations to satisfy the input format of model. |
|
if model_type='conv': |
|
[B, S, W, H, C] -> [B, S x C, W, H] |
|
where B is batch size, S is stack num, W is width, H is height, and C is the number of channels |
|
if model_type='mlp': |
|
[B, S, O] -> [B, S x O] |
|
where B is batch size, S is stack num, O is obs shape. |
|
Arguments: |
|
- observation_list (:obj:`List`): list of observations. |
|
- model_type (:obj:`str`): type of the model. (default is 'conv') |
|
""" |
|
assert model_type in ['conv', 'mlp'] |
|
observation_array = np.array(observation_list) |
|
|
|
if model_type == 'conv': |
|
|
|
if len(observation_array.shape) == 3: |
|
|
|
|
|
|
|
|
|
observation_array = observation_array.reshape( |
|
observation_array.shape[0], observation_array.shape[1], observation_array.shape[2], 1 |
|
) |
|
|
|
elif len(observation_array.shape) == 5: |
|
|
|
|
|
|
|
|
|
|
|
observation_array = np.transpose(observation_array, (0, 1, 4, 2, 3)) |
|
|
|
shape = observation_array.shape |
|
|
|
|
|
observation_array = observation_array.reshape((shape[0], -1, shape[-2], shape[-1])) |
|
|
|
elif model_type == 'mlp': |
|
|
|
|
|
|
|
|
|
observation_array = observation_array.reshape(observation_array.shape[0], -1) |
|
|
|
|
|
return observation_array |
|
|
|
|
|
def obtain_tree_topology(root, to_play=-1): |
|
node_stack = [] |
|
edge_topology_list = [] |
|
node_topology_list = [] |
|
node_id_list = [] |
|
node_stack.append(root) |
|
while len(node_stack) > 0: |
|
node = node_stack[-1] |
|
node_stack.pop() |
|
node_dict = {} |
|
node_dict['node_id'] = node.simulation_index |
|
node_dict['visit_count'] = node.visit_count |
|
node_dict['policy_prior'] = node.prior |
|
node_dict['value'] = node.value |
|
node_topology_list.append(node_dict) |
|
|
|
node_id_list.append(node.simulation_index) |
|
for a in node.legal_actions: |
|
child = node.get_child(a) |
|
if child.expanded: |
|
child.parent_simulation_index = node.simulation_index |
|
edge_dict = {} |
|
edge_dict['parent_id'] = node.simulation_index |
|
edge_dict['child_id'] = child.simulation_index |
|
edge_topology_list.append(edge_dict) |
|
node_stack.append(child) |
|
return edge_topology_list, node_id_list, node_topology_list |
|
|
|
|
|
def plot_simulation_graph(env_root, current_step, graph_directory=None): |
|
edge_topology_list, node_id_list, node_topology_list = obtain_tree_topology(env_root) |
|
dot = Digraph(comment='this is direction') |
|
for node_topology in node_topology_list: |
|
node_name = str(node_topology['node_id']) |
|
label = f"node_id: {node_topology['node_id']}, \n visit_count: {node_topology['visit_count']}, \n policy_prior: {round(node_topology['policy_prior'], 4)}, \n value: {round(node_topology['value'], 4)}" |
|
dot.node(node_name, label=label) |
|
for edge_topology in edge_topology_list: |
|
parent_id = str(edge_topology['parent_id']) |
|
child_id = str(edge_topology['child_id']) |
|
label = parent_id + '-' + child_id |
|
dot.edge(parent_id, child_id, label=label) |
|
if graph_directory is None: |
|
graph_directory = './data_visualize/' |
|
if not os.path.exists(graph_directory): |
|
os.makedirs(graph_directory) |
|
graph_path = graph_directory + 'simulation_visualize_' + str(current_step) + 'step.gv' |
|
dot.format = 'png' |
|
dot.render(graph_path, view=False) |
|
|