File size: 5,048 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import pytest
import torch
from easydict import EasyDict
from lzero.policy import inverse_scalar_transform


class MuZeroModelFake(torch.nn.Module):
    """
    Overview:
        Fake MuZero model just for test EfficientZeroMCTSPtree.
    Interfaces:
        __init__, initial_inference, recurrent_inference
    """

    def __init__(self, action_num):
        super().__init__()
        self.action_num = action_num

    def initial_inference(self, observation):
        encoded_state = observation
        batch_size = encoded_state.shape[0]

        value = torch.zeros(size=(batch_size, 601))
        value_prefix = [0. for _ in range(batch_size)]
        # policy_logits = torch.zeros(size=(batch_size, self.action_num))
        policy_logits = 0.1 * torch.ones(size=(batch_size, self.action_num))

        latent_state = torch.zeros(size=(batch_size, 12, 3, 3))
        reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16)))

        output = {
            'value': value,
            'value_prefix': value_prefix,
            'policy_logits': policy_logits,
            'latent_state': latent_state,
            'reward_hidden_state': reward_hidden_state_state
        }

        return EasyDict(output)

    def recurrent_inference(self, hidden_states, reward_hidden_states, actions):
        batch_size = hidden_states.shape[0]
        latent_state = torch.zeros(size=(batch_size, 12, 3, 3))
        reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16)))
        value = torch.zeros(size=(batch_size, 601))
        value_prefix = torch.zeros(size=(batch_size, 601))
        policy_logits = 0.1 * torch.ones(size=(batch_size, self.action_num))
        # policy_logits = torch.zeros(size=(batch_size, self.action_num))

        output = {
            'value': value,
            'value_prefix': value_prefix,
            'policy_logits': policy_logits,
            'latent_state': latent_state,
            'reward_hidden_state': reward_hidden_state_state
        }

        return EasyDict(output)


@pytest.mark.unittest
def test_mcts():
    import numpy as np
    from lzero.mcts.tree_search.mcts_ctree_sampled import SampledEfficientZeroMCTSCtree as MCTSCtree

    policy_config = EasyDict(
        dict(
            lstm_horizon_len=5,
            num_of_sampled_actions=6,
            num_simulations=100,
            batch_size=5,
            pb_c_base=1,
            pb_c_init=1,
            discount_factor=0.9,
            root_dirichlet_alpha=0.3,
            root_noise_weight=0.2,
            dirichlet_alpha=0.3,
            exploration_fraction=1,
            device='cpu',
            value_delta_max=0,
            model=dict(
                continuous_action_space=True,
                support_scale=300,
                action_space_size=2,
                categorical_distribution=True,
            ),
        )
    )

    batch_size = env_nums = policy_config.batch_size
    model = MuZeroModelFake(action_num=policy_config.model.action_space_size * 2)
    stack_obs = torch.zeros(
        size=(
            batch_size,
            policy_config.model.action_space_size * 2,
        ), dtype=torch.float
    )

    network_output = model.initial_inference(stack_obs.float())

    latent_state_roots = network_output['latent_state']
    reward_hidden_state_state = network_output['reward_hidden_state']
    pred_values_pool = network_output['value']
    value_prefix_pool = network_output['value_prefix']
    policy_logits_pool = network_output['policy_logits']

    # network output process
    pred_values_pool = inverse_scalar_transform(pred_values_pool,
                                                policy_config.model.support_scale).detach().cpu().numpy()
    latent_state_roots = latent_state_roots.detach().cpu().numpy()
    reward_hidden_state_state = (
        reward_hidden_state_state[0].detach().cpu().numpy(), reward_hidden_state_state[1].detach().cpu().numpy()
    )
    policy_logits_pool = policy_logits_pool.detach().cpu().numpy().tolist()

    legal_actions_list = [[-1 for i in range(5)] for _ in range(env_nums)]
    roots = MCTSCtree.roots(
        env_nums,
        legal_actions_list,
        policy_config.model.action_space_size,
        policy_config.num_of_sampled_actions,
        continuous_action_space=True
    )

    noises = [
        np.random.dirichlet([policy_config.root_dirichlet_alpha] * policy_config.num_of_sampled_actions
                            ).astype(np.float32).tolist() for _ in range(env_nums)
    ]
    to_play_batch = [int(np.random.randint(1, 2, 1)) for _ in range(env_nums)]
    roots.prepare(policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play_batch)

    MCTSCtree(policy_config).search(roots, model, latent_state_roots, reward_hidden_state_state, to_play_batch)
    roots_distributions = roots.get_distributions()
    assert np.array(roots_distributions).shape == (batch_size, policy_config.num_of_sampled_actions)