File size: 5,048 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import pytest
import torch
from easydict import EasyDict
from lzero.policy import inverse_scalar_transform
class MuZeroModelFake(torch.nn.Module):
"""
Overview:
Fake MuZero model just for test EfficientZeroMCTSPtree.
Interfaces:
__init__, initial_inference, recurrent_inference
"""
def __init__(self, action_num):
super().__init__()
self.action_num = action_num
def initial_inference(self, observation):
encoded_state = observation
batch_size = encoded_state.shape[0]
value = torch.zeros(size=(batch_size, 601))
value_prefix = [0. for _ in range(batch_size)]
# policy_logits = torch.zeros(size=(batch_size, self.action_num))
policy_logits = 0.1 * torch.ones(size=(batch_size, self.action_num))
latent_state = torch.zeros(size=(batch_size, 12, 3, 3))
reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16)))
output = {
'value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
'reward_hidden_state': reward_hidden_state_state
}
return EasyDict(output)
def recurrent_inference(self, hidden_states, reward_hidden_states, actions):
batch_size = hidden_states.shape[0]
latent_state = torch.zeros(size=(batch_size, 12, 3, 3))
reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16)))
value = torch.zeros(size=(batch_size, 601))
value_prefix = torch.zeros(size=(batch_size, 601))
policy_logits = 0.1 * torch.ones(size=(batch_size, self.action_num))
# policy_logits = torch.zeros(size=(batch_size, self.action_num))
output = {
'value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
'reward_hidden_state': reward_hidden_state_state
}
return EasyDict(output)
@pytest.mark.unittest
def test_mcts():
import numpy as np
from lzero.mcts.tree_search.mcts_ctree_sampled import SampledEfficientZeroMCTSCtree as MCTSCtree
policy_config = EasyDict(
dict(
lstm_horizon_len=5,
num_of_sampled_actions=6,
num_simulations=100,
batch_size=5,
pb_c_base=1,
pb_c_init=1,
discount_factor=0.9,
root_dirichlet_alpha=0.3,
root_noise_weight=0.2,
dirichlet_alpha=0.3,
exploration_fraction=1,
device='cpu',
value_delta_max=0,
model=dict(
continuous_action_space=True,
support_scale=300,
action_space_size=2,
categorical_distribution=True,
),
)
)
batch_size = env_nums = policy_config.batch_size
model = MuZeroModelFake(action_num=policy_config.model.action_space_size * 2)
stack_obs = torch.zeros(
size=(
batch_size,
policy_config.model.action_space_size * 2,
), dtype=torch.float
)
network_output = model.initial_inference(stack_obs.float())
latent_state_roots = network_output['latent_state']
reward_hidden_state_state = network_output['reward_hidden_state']
pred_values_pool = network_output['value']
value_prefix_pool = network_output['value_prefix']
policy_logits_pool = network_output['policy_logits']
# network output process
pred_values_pool = inverse_scalar_transform(pred_values_pool,
policy_config.model.support_scale).detach().cpu().numpy()
latent_state_roots = latent_state_roots.detach().cpu().numpy()
reward_hidden_state_state = (
reward_hidden_state_state[0].detach().cpu().numpy(), reward_hidden_state_state[1].detach().cpu().numpy()
)
policy_logits_pool = policy_logits_pool.detach().cpu().numpy().tolist()
legal_actions_list = [[-1 for i in range(5)] for _ in range(env_nums)]
roots = MCTSCtree.roots(
env_nums,
legal_actions_list,
policy_config.model.action_space_size,
policy_config.num_of_sampled_actions,
continuous_action_space=True
)
noises = [
np.random.dirichlet([policy_config.root_dirichlet_alpha] * policy_config.num_of_sampled_actions
).astype(np.float32).tolist() for _ in range(env_nums)
]
to_play_batch = [int(np.random.randint(1, 2, 1)) for _ in range(env_nums)]
roots.prepare(policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play_batch)
MCTSCtree(policy_config).search(roots, model, latent_state_roots, reward_hidden_state_state, to_play_batch)
roots_distributions = roots.get_distributions()
assert np.array(roots_distributions).shape == (batch_size, policy_config.num_of_sampled_actions)
|