|
import torch |
|
from easydict import EasyDict |
|
from lzero.policy import inverse_scalar_transform, select_action |
|
import numpy as np |
|
import random |
|
|
|
from lzero.mcts.tree_search.mcts_ptree import EfficientZeroMCTSPtree as MCTSPtree |
|
from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree |
|
import time |
|
|
|
|
|
class MuZeroModelFake(torch.nn.Module): |
|
""" |
|
Overview: |
|
Fake MuZero model just for test EfficientZeroMCTSPtree. |
|
Interfaces: |
|
__init__, initial_inference, recurrent_inference |
|
""" |
|
|
|
def __init__(self, action_num): |
|
super().__init__() |
|
self.action_num = action_num |
|
|
|
def initial_inference(self, observation): |
|
encoded_state = observation |
|
batch_size = encoded_state.shape[0] |
|
|
|
value = torch.zeros(size=(batch_size, 601)) |
|
value_prefix = [0. for _ in range(batch_size)] |
|
policy_logits = torch.zeros(size=(batch_size, self.action_num)) |
|
latent_state = torch.zeros(size=(batch_size, 12, 3, 3)) |
|
reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16))) |
|
|
|
output = { |
|
'searched_value': value, |
|
'value_prefix': value_prefix, |
|
'policy_logits': policy_logits, |
|
'latent_state': latent_state, |
|
'reward_hidden_state': reward_hidden_state_state |
|
} |
|
|
|
return EasyDict(output) |
|
|
|
def recurrent_inference(self, hidden_states, reward_hidden_states, actions): |
|
batch_size = hidden_states.shape[0] |
|
latent_state = torch.zeros(size=(batch_size, 12, 3, 3)) |
|
reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16))) |
|
value = torch.zeros(size=(batch_size, 601)) |
|
value_prefix = torch.zeros(size=(batch_size, 601)) |
|
policy_logits = torch.zeros(size=(batch_size, self.action_num)) |
|
|
|
output = { |
|
'searched_value': value, |
|
'value_prefix': value_prefix, |
|
'policy_logits': policy_logits, |
|
'latent_state': latent_state, |
|
'reward_hidden_state': reward_hidden_state_state |
|
} |
|
|
|
return EasyDict(output) |
|
|
|
|
|
def ptree_func(policy_config, num_simulations): |
|
""" |
|
Overview: |
|
Search on the tree of the Python implementation and record the time spent at different stages. |
|
Arguments: |
|
- policy_config: config of game. |
|
- num_simulations: Number of simulations. |
|
Returns: |
|
- build_time: Type builds take time. |
|
- prepare_time: time for prepare. |
|
- search_time. |
|
- total_time. |
|
""" |
|
batch_size = env_nums = policy_config.batch_size |
|
action_space_size = policy_config.action_space_size |
|
|
|
build_time = [] |
|
prepare_time = [] |
|
search_time = [] |
|
total_time = [] |
|
|
|
for n_s in num_simulations: |
|
t0 = time.time() |
|
model = MuZeroModelFake(action_num=action_space_size) |
|
stack_obs = torch.zeros( |
|
size=( |
|
batch_size, |
|
n_s, |
|
), dtype=torch.float |
|
) |
|
|
|
policy_config.num_simulations = n_s |
|
network_output = model.initial_inference(stack_obs.float()) |
|
|
|
latent_state_roots = network_output['latent_state'] |
|
reward_hidden_state_state = network_output['reward_hidden_state'] |
|
pred_values_pool = network_output['value'] |
|
value_prefix_pool = network_output['value_prefix'] |
|
policy_logits_pool = network_output['policy_logits'] |
|
|
|
|
|
pred_values_pool = inverse_scalar_transform(pred_values_pool, |
|
policy_config.model.support_scale).detach().cpu().numpy() |
|
latent_state_roots = latent_state_roots.detach().cpu().numpy() |
|
reward_hidden_state_state = ( |
|
reward_hidden_state_state[0].detach().cpu().numpy(), reward_hidden_state_state[1].detach().cpu().numpy() |
|
) |
|
policy_logits_pool = policy_logits_pool.detach().cpu().numpy().tolist() |
|
|
|
action_mask = [[random.randint(0, 1) for _ in range(action_space_size)] for _ in range(env_nums)] |
|
assert len(action_mask) == batch_size |
|
assert len(action_mask[0]) == action_space_size |
|
|
|
action_num = [int(np.array(action_mask[i]).sum()) for i in range(env_nums)] |
|
legal_actions_list = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(env_nums)] |
|
to_play = [np.random.randint(1, 3) for i in range(env_nums)] |
|
assert len(to_play) == batch_size |
|
|
|
for i in range(env_nums): |
|
assert action_num[i] == len(legal_actions_list[i]) |
|
t1 = time.time() |
|
roots = MCTSPtree.roots(env_nums, legal_actions_list) |
|
build_time.append(time.time() - t1) |
|
noises = [ |
|
np.random.dirichlet([policy_config.root_dirichlet_alpha] * int(sum(action_mask[j])) |
|
).astype(np.float32).tolist() for j in range(env_nums) |
|
] |
|
t1 = time.time() |
|
roots.prepare(policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) |
|
prepare_time.append(time.time() - t1) |
|
t1 = time.time() |
|
MCTSPtree(policy_config).search(roots, model, latent_state_roots, reward_hidden_state_state, to_play) |
|
search_time.append(time.time() - t1) |
|
total_time.append(time.time() - t0) |
|
roots_distributions = roots.get_distributions() |
|
roots_values = roots.get_values() |
|
assert len(roots_values) == env_nums |
|
assert len(roots_values) == env_nums |
|
for i in range(env_nums): |
|
assert len(roots_distributions[i]) == action_num[i] |
|
|
|
temperature = [1 for _ in range(env_nums)] |
|
for i in range(env_nums): |
|
distributions = roots_distributions[i] |
|
action_index, visit_count_distribution_entropy = select_action( |
|
distributions, temperature=temperature[i], deterministic=False |
|
) |
|
action = np.where(np.array(action_mask[i]) == 1.0)[0][action_index] |
|
assert action_index < action_num[i] |
|
assert action == legal_actions_list[i][action_index] |
|
print('\n action_index={}, legal_action={}, action={}'.format(action_index, legal_actions_list[i], action)) |
|
return build_time, prepare_time, search_time, total_time |
|
|
|
|
|
def ctree_func(policy_config, num_simulations): |
|
""" |
|
Overview: |
|
Search on the tree of the C++ implementation and record the time spent at different stages. |
|
Arguments: |
|
- policy_config: config of game. |
|
- num_simulations: Number of simulations. |
|
Returns: |
|
- build_time: Type builds take time. |
|
- prepare_time: time for prepare. |
|
- search_time. |
|
- total_time. |
|
""" |
|
batch_size = env_nums = policy_config.batch_size |
|
action_space_size = policy_config.action_space_size |
|
|
|
build_time = [] |
|
prepare_time = [] |
|
search_time = [] |
|
total_time = [] |
|
|
|
for n_s in num_simulations: |
|
t0 = time.time() |
|
model = MuZeroModelFake(action_num=action_space_size) |
|
stack_obs = torch.zeros( |
|
size=( |
|
batch_size, |
|
n_s, |
|
), dtype=torch.float |
|
) |
|
policy_config.num_simulations = n_s |
|
|
|
network_output = model.initial_inference(stack_obs.float()) |
|
|
|
latent_state_roots = network_output['latent_state'] |
|
reward_hidden_state_state = network_output['reward_hidden_state'] |
|
pred_values_pool = network_output['value'] |
|
value_prefix_pool = network_output['value_prefix'] |
|
policy_logits_pool = network_output['policy_logits'] |
|
|
|
|
|
pred_values_pool = inverse_scalar_transform(pred_values_pool, |
|
policy_config.model.support_scale).detach().cpu().numpy() |
|
latent_state_roots = latent_state_roots.detach().cpu().numpy() |
|
reward_hidden_state_state = ( |
|
reward_hidden_state_state[0].detach().cpu().numpy(), reward_hidden_state_state[1].detach().cpu().numpy() |
|
) |
|
policy_logits_pool = policy_logits_pool.detach().cpu().numpy().tolist() |
|
|
|
action_mask = [[random.randint(0, 1) for _ in range(action_space_size)] for _ in range(env_nums)] |
|
assert len(action_mask) == batch_size |
|
assert len(action_mask[0]) == action_space_size |
|
|
|
action_num = [int(np.array(action_mask[i]).sum()) for i in range(env_nums)] |
|
legal_actions_list = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(env_nums)] |
|
to_play = [np.random.randint(1, 3) for i in range(env_nums)] |
|
assert len(to_play) == batch_size |
|
|
|
for i in range(env_nums): |
|
assert action_num[i] == len(legal_actions_list[i]) |
|
|
|
t1 = time.time() |
|
roots = MCTSCtree.roots(env_nums, legal_actions_list) |
|
build_time.append(time.time() - t1) |
|
noises = [ |
|
np.random.dirichlet([policy_config.root_dirichlet_alpha] * int(sum(action_mask[j])) |
|
).astype(np.float32).tolist() for j in range(env_nums) |
|
] |
|
t1 = time.time() |
|
roots.prepare(policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play) |
|
prepare_time.append(time.time() - t1) |
|
t1 = time.time() |
|
MCTSCtree(policy_config).search(roots, model, latent_state_roots, reward_hidden_state_state, to_play) |
|
search_time.append(time.time() - t1) |
|
total_time.append(time.time() - t0) |
|
roots_distributions = roots.get_distributions() |
|
roots_values = roots.get_values() |
|
assert len(roots_values) == env_nums |
|
assert len(roots_values) == env_nums |
|
for i in range(env_nums): |
|
assert len(roots_distributions[i]) == action_num[i] |
|
|
|
temperature = [1 for _ in range(env_nums)] |
|
for i in range(env_nums): |
|
distributions = roots_distributions[i] |
|
action_index, visit_count_distribution_entropy = select_action( |
|
distributions, temperature=temperature[i], deterministic=False |
|
) |
|
action = np.where(np.array(action_mask[i]) == 1.0)[0][action_index] |
|
assert action_index < action_num[i] |
|
assert action == legal_actions_list[i][action_index] |
|
print('\n action_index={}, legal_action={}, action={}'.format(action_index, legal_actions_list[i], action)) |
|
return build_time, prepare_time, search_time, total_time |
|
|
|
|
|
def plot(ctree_time, ptree_time, iters, label): |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from matplotlib import pyplot |
|
plt.style.use('seaborn-whitegrid') |
|
palette = pyplot.get_cmap('Set1') |
|
font1 = { |
|
'family': 'Times New Roman', |
|
'weight': 'normal', |
|
'size': 18, |
|
} |
|
|
|
plt.figure(figsize=(20, 10)) |
|
|
|
color = palette(0) |
|
avg = np.mean(ctree_time, axis=0) |
|
std = np.std(ctree_time, axis=0) |
|
r1 = list(map(lambda x: x[0] - x[1], zip(avg, std))) |
|
r2 = list(map(lambda x: x[0] + x[1], zip(avg, std))) |
|
plt.plot(iters, avg, color=color, label="ctree", linewidth=3.0) |
|
plt.fill_between(iters, r1, r2, color=color, alpha=0.2) |
|
|
|
|
|
ptree_time = np.array(ptree_time) |
|
color = palette(1) |
|
avg = np.mean(ptree_time, axis=0) |
|
std = np.std(ptree_time, axis=0) |
|
r1 = list(map(lambda x: x[0] - x[1], zip(avg, std))) |
|
r2 = list(map(lambda x: x[0] + x[1], zip(avg, std))) |
|
plt.plot(iters, avg, color=color, label="ptree", linewidth=3.0) |
|
plt.fill_between(iters, r1, r2, color=color, alpha=0.2) |
|
|
|
plt.legend(loc='lower right', prop=font1) |
|
plt.title('{}'.format(label)) |
|
plt.xlabel('simulations', fontsize=22) |
|
plt.ylabel('time', fontsize=22) |
|
plt.savefig('{}-time.png'.format(label)) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
policy_config = EasyDict( |
|
dict( |
|
lstm_horizon_len=5, |
|
model=dict( |
|
support_scale=300, |
|
categorical_distribution=True, |
|
), |
|
action_space_size=100, |
|
num_simulations=100, |
|
batch_size=512, |
|
pb_c_base=1, |
|
pb_c_init=1, |
|
discount_factor=0.9, |
|
root_dirichlet_alpha=0.3, |
|
root_noise_weight=0.2, |
|
dirichlet_alpha=0.3, |
|
exploration_fraction=1, |
|
device='cpu', |
|
value_delta_max=0.01, |
|
) |
|
) |
|
|
|
ACTION_SPCAE_SIZE = [16, 50] |
|
BATCH_SIZE = [8, 64, 512] |
|
NUM_SIMULATIONS = [i for i in range(20, 200, 20)] |
|
|
|
|
|
|
|
|
|
|
|
for action_space_size in ACTION_SPCAE_SIZE: |
|
for batch_size in BATCH_SIZE: |
|
policy_config.batch_size = batch_size |
|
policy_config.action_space_size = action_space_size |
|
ctree_build_time = [] |
|
ctree_prepare_time = [] |
|
ctree_search_time = [] |
|
ptree_build_time = [] |
|
ptree_prepare_time = [] |
|
ptree_search_time = [] |
|
ctree_total_time = [] |
|
ptree_total_time = [] |
|
num_simulations = NUM_SIMULATIONS |
|
for i in range(3): |
|
build_time, prepare_time, search_time, total_time = ctree_func( |
|
policy_config, num_simulations=num_simulations |
|
) |
|
ctree_build_time.append(build_time) |
|
ctree_prepare_time.append(prepare_time) |
|
ctree_search_time.append(search_time) |
|
ctree_total_time.append(total_time) |
|
|
|
for i in range(3): |
|
build_time, prepare_time, search_time, total_time = ptree_func( |
|
policy_config, num_simulations=num_simulations |
|
) |
|
ptree_build_time.append(build_time) |
|
ptree_prepare_time.append(prepare_time) |
|
ptree_search_time.append(search_time) |
|
ptree_total_time.append(total_time) |
|
label = 'action_space_size_{}_batch_size_{}'.format(action_space_size, batch_size) |
|
plot(ctree_build_time, ptree_build_time, iters=num_simulations, label=label + '_bulid_time') |
|
plot(ctree_prepare_time, ptree_prepare_time, iters=num_simulations, label=label + '_prepare_time') |
|
plot(ctree_search_time, ptree_search_time, iters=num_simulations, label=label + '_search_time') |
|
plot(ctree_total_time, ptree_total_time, iters=num_simulations, label=label + '_total_time') |
|
|