zjowowen's picture
init space
079c32c
raw
history blame
6.88 kB
import os
from dataclasses import dataclass
from typing import Any
import numpy as np
from graphviz import Digraph
def generate_random_actions_discrete(num_actions: int, action_space_size: int, num_of_sampled_actions: int,
reshape=False):
"""
Overview:
Generate a list of random actions.
Arguments:
- num_actions (:obj:`int`): The number of actions to generate.
- action_space_size (:obj:`int`): The size of the action space.
- num_of_sampled_actions (:obj:`int`): The number of sampled actions.
- reshape (:obj:`bool`): Whether to reshape the actions.
Returns:
A list of random actions.
"""
actions = [
np.random.randint(0, action_space_size, num_of_sampled_actions).reshape(-1)
for _ in range(num_actions)
]
# If num_of_sampled_actions == 1, flatten the actions to a list of numbers
if num_of_sampled_actions == 1:
actions = [action[0] for action in actions]
# Reshape actions if needed
if reshape and num_of_sampled_actions > 1:
actions = [action.reshape(num_of_sampled_actions, 1) for action in actions]
return actions
@dataclass
class BufferedData:
data: Any
index: str
meta: dict
def get_augmented_data(board_size, play_data):
"""
Overview:
augment the data set by rotation and flipping
Arguments:
play_data: [(state, mcts_prob, winner_z), ..., ...]
"""
extend_data = []
for data in play_data:
state = data['state']
mcts_prob = data['mcts_prob']
winner = data['winner']
for i in [1, 2, 3, 4]:
# rotate counterclockwise
equi_state = np.array([np.rot90(s, i) for s in state])
equi_mcts_prob = np.rot90(np.flipud(mcts_prob.reshape(board_size, board_size)), i)
extend_data.append(
{
'state': equi_state,
'mcts_prob': np.flipud(equi_mcts_prob).flatten(),
'winner': winner
}
)
# flip horizontally
equi_state = np.array([np.fliplr(s) for s in equi_state])
equi_mcts_prob = np.fliplr(equi_mcts_prob)
extend_data.append(
{
'state': equi_state,
'mcts_prob': np.flipud(equi_mcts_prob).flatten(),
'winner': winner
}
)
return extend_data
def prepare_observation(observation_list, model_type='conv'):
"""
Overview:
Prepare the observations to satisfy the input format of model.
if model_type='conv':
[B, S, W, H, C] -> [B, S x C, W, H]
where B is batch size, S is stack num, W is width, H is height, and C is the number of channels
if model_type='mlp':
[B, S, O] -> [B, S x O]
where B is batch size, S is stack num, O is obs shape.
Arguments:
- observation_list (:obj:`List`): list of observations.
- model_type (:obj:`str`): type of the model. (default is 'conv')
"""
assert model_type in ['conv', 'mlp']
observation_array = np.array(observation_list)
if model_type == 'conv':
# for 3-dimensional image obs
if len(observation_array.shape) == 3:
# for vector obs input, e.g. classical control and box2d environments
# to be compatible with LightZero model/policy,
# observation_array: [B, S, O], where O is original obs shape
# [B, S, O] -> [B, S, O, 1]
observation_array = observation_array.reshape(
observation_array.shape[0], observation_array.shape[1], observation_array.shape[2], 1
)
elif len(observation_array.shape) == 5:
# image obs input, e.g. atari environments
# observation_array: [B, S, W, H, C]
# 1, 4, 8, 1, 1 -> 1, 4, 1, 8, 1
# [B, S, W, H, C] -> [B, S, C, W, H]
observation_array = np.transpose(observation_array, (0, 1, 4, 2, 3))
shape = observation_array.shape
# 1, 4, 1, 8, 1 -> 1, 4*1, 8, 1
# [B, S, C, W, H] -> [B, S*C, W, H]
observation_array = observation_array.reshape((shape[0], -1, shape[-2], shape[-1]))
elif model_type == 'mlp':
# for 1-dimensional vector obs
# observation_array: [B, S, O], where O is original obs shape
# [B, S, O] -> [B, S*O]
# print(observation_array.shape)
observation_array = observation_array.reshape(observation_array.shape[0], -1)
# print(observation_array.shape)
return observation_array
def obtain_tree_topology(root, to_play=-1):
node_stack = []
edge_topology_list = []
node_topology_list = []
node_id_list = []
node_stack.append(root)
while len(node_stack) > 0:
node = node_stack[-1]
node_stack.pop()
node_dict = {}
node_dict['node_id'] = node.simulation_index
node_dict['visit_count'] = node.visit_count
node_dict['policy_prior'] = node.prior
node_dict['value'] = node.value
node_topology_list.append(node_dict)
node_id_list.append(node.simulation_index)
for a in node.legal_actions:
child = node.get_child(a)
if child.expanded:
child.parent_simulation_index = node.simulation_index
edge_dict = {}
edge_dict['parent_id'] = node.simulation_index
edge_dict['child_id'] = child.simulation_index
edge_topology_list.append(edge_dict)
node_stack.append(child)
return edge_topology_list, node_id_list, node_topology_list
def plot_simulation_graph(env_root, current_step, graph_directory=None):
edge_topology_list, node_id_list, node_topology_list = obtain_tree_topology(env_root)
dot = Digraph(comment='this is direction')
for node_topology in node_topology_list:
node_name = str(node_topology['node_id'])
label = f"node_id: {node_topology['node_id']}, \n visit_count: {node_topology['visit_count']}, \n policy_prior: {round(node_topology['policy_prior'], 4)}, \n value: {round(node_topology['value'], 4)}"
dot.node(node_name, label=label)
for edge_topology in edge_topology_list:
parent_id = str(edge_topology['parent_id'])
child_id = str(edge_topology['child_id'])
label = parent_id + '-' + child_id
dot.edge(parent_id, child_id, label=label)
if graph_directory is None:
graph_directory = './data_visualize/'
if not os.path.exists(graph_directory):
os.makedirs(graph_directory)
graph_path = graph_directory + 'simulation_visualize_' + str(current_step) + 'step.gv'
dot.format = 'png'
dot.render(graph_path, view=False)