File size: 28,564 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
import copy
from collections import namedtuple
from typing import List, Dict, Tuple

import numpy as np
import torch.distributions
import torch.nn.functional as F
import torch.optim as optim
from ding.policy.base_policy import Policy
from ding.torch_utils import to_device
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate
from easydict import EasyDict

from lzero.policy import configure_optimizers
from lzero.policy.utils import pad_and_get_lengths, compute_entropy


@POLICY_REGISTRY.register('sampled_alphazero')
class SampledAlphaZeroPolicy(Policy):
    """
    Overview:
        The policy class for Sampled AlphaZero.
    """

    # The default_config for AlphaZero policy.
    config = dict(
        # (str) The type of policy, as the key of the policy registry.
        type='alphazero',
        # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled AlphaZero)
        # this variable is used in ``collector``.
        sampled_algo=False,
        normalize_prob_of_sampled_actions=False,
        policy_loss_type='cross_entropy',  # options={'cross_entropy', 'KL'}
        # (bool) Whether to use torch.compile method to speed up our model, which required torch>=2.0.
        torch_compile=False,
        # (bool) Whether to use TF32 for our model.
        tensor_float_32=False,
        model=dict(
            # (tuple) The stacked obs shape.
            observation_shape=(3, 6, 6),
            # (int) The number of res blocks in AlphaZero model.
            num_res_blocks=1,
            # (int) The number of channels of hidden states in AlphaZero model.
            num_channels=32,
        ),
        # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation.
        mcts_ctree=True,
        # (bool) Whether to use cuda for network.
        cuda=False,
        # (int) How many updates(iterations) to train after collector's one collection.
        # Bigger "update_per_collect" means bigger off-policy.
        # collect data -> update policy-> collect data -> ...
        # For different env, we have different episode_length,
        # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor.
        # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically.
        update_per_collect=None,
        # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None.
        model_update_ratio=0.1,
        # (int) Minibatch size for one gradient descent.
        batch_size=256,
        # (str) Optimizer for training policy network. ['SGD', 'Adam', 'AdamW']
        optim_type='SGD',
        # (float) Learning rate for training policy network. Initial lr for manually decay schedule.
        learning_rate=0.2,
        # (float) Weight decay for training policy network.
        weight_decay=1e-4,
        # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction).
        momentum=0.9,
        # (float) The maximum constraint value of gradient norm clipping.
        grad_clip_value=10,
        # (float) The weight of value loss.
        value_weight=1.0,
        # (int) The number of environments used in collecting data.
        collector_env_num=8,
        # (int) The number of environments used in evaluating policy.
        evaluator_env_num=3,
        # (bool) Whether to use piecewise constant learning rate decay.
        # i.e. lr: 0.2 -> 0.02 -> 0.002
        lr_piecewise_constant_decay=True,
        # (int) The number of final training iterations to control lr decay, which is only used for manually decay.
        threshold_training_steps_for_final_lr=int(5e5),
        # (bool) Whether to use manually temperature decay.
        # i.e. temperature: 1 -> 0.5 -> 0.25
        manual_temperature_decay=False,
        # (int) The number of final training iterations to control temperature, which is only used for manually decay.
        threshold_training_steps_for_final_temperature=int(1e5),
        # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration.
        # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False.
        fixed_temperature_value=0.25,
        mcts=dict(
            # (int) The number of simulations to perform at each move.
            num_simulations=50,
            # (int) The maximum number of moves to make in a game.
            max_moves=512,  # for chess and shogi, 722 for Go.
            # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of the search tree.
            root_dirichlet_alpha=0.3,
            # (float) The noise weight at the root node of the search tree.
            root_noise_weight=0.25,
            # (int) The base constant used in the PUCT formula for balancing exploration and exploitation during tree search.
            pb_c_base=19652,
            # (float) The initialization constant used in the PUCT formula for balancing exploration and exploitation during tree search.
            pb_c_init=1.25,
            #
            legal_actions=None,
            # (int) The action space size.
            action_space_size=9,
            # (int) The number of sampled actions for each state.
            num_of_sampled_actions=2,
            #
            continuous_action_space=False,
        ),
        other=dict(replay_buffer=dict(
            replay_buffer_size=int(1e6),
            save_episode=False,
        )),
    )

    def default_model(self) -> Tuple[str, List[str]]:
        """
        Overview:
            Return this algorithm default model setting for demonstration.
        Returns:
            - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry.
            - import_names (:obj:`List[str]`): The model class path list used in this algorithm.
        """
        return 'AlphaZeroModel', ['lzero.model.alphazero_model']

    def _init_learn(self) -> None:
        assert self._cfg.optim_type in ['SGD', 'Adam', 'AdamW'], self._cfg.optim_type
        if self._cfg.optim_type == 'SGD':
            self._optimizer = optim.SGD(
                self._model.parameters(),
                lr=self._cfg.learning_rate,
                momentum=self._cfg.momentum,
                weight_decay=self._cfg.weight_decay,
            )
        elif self._cfg.optim_type == 'Adam':
            self._optimizer = optim.Adam(
                self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay
            )
        elif self._cfg.optim_type == 'AdamW':
            self._optimizer = configure_optimizers(
                model=self._model,
                weight_decay=self._cfg.weight_decay,
                learning_rate=self._cfg.learning_rate,
                device_type=self._cfg.device
            )

        if self._cfg.lr_piecewise_constant_decay:
            from torch.optim.lr_scheduler import LambdaLR
            max_step = self._cfg.threshold_training_steps_for_final_lr
            # NOTE: the 1, 0.1, 0.01 is the decay rate, not the lr.
            # lr_lambda = lambda step: 1 if step < max_step * 0.5 else (0.1 if step < max_step else 0.01)  # noqa
            lr_lambda = lambda step: 1 if step < max_step * 0.33 else (0.1 if step < max_step * 0.66 else 0.01)  # noqa

            self.lr_scheduler = LambdaLR(self._optimizer, lr_lambda=lr_lambda)

        # Algorithm config
        self._value_weight = self._cfg.value_weight
        self._entropy_weight = self._cfg.entropy_weight
        # Main and target models
        self._learn_model = self._model

        # TODO(pu): test the effect of torch 2.0
        if self._cfg.torch_compile:
            self._learn_model = torch.compile(self._learn_model)

    def _forward_learn(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, float]:
        for input_dict in inputs:
            # Check and remove 'katago_game_state' from 'obs' if it exists
            if 'katago_game_state' in input_dict['obs']:
                del input_dict['obs']['katago_game_state']

            # Check and remove 'katago_game_state' from 'next_obs' if it exists
            if 'katago_game_state' in input_dict['next_obs']:
                del input_dict['next_obs']['katago_game_state']

        # list of dict -> dict of list
        # only for env with variable legal actions
        inputs = pad_and_get_lengths(inputs, self._cfg.mcts.num_of_sampled_actions)
        inputs = default_collate(inputs)
        valid_action_length = inputs['action_length']

        if self._cuda:
            inputs = to_device(inputs, self._device)
        self._learn_model.train()

        state_batch = inputs['obs']['observation']
        mcts_visit_count_probs = inputs['probs']
        reward = inputs['reward']
        root_sampled_actions = inputs['root_sampled_actions']

        if len(root_sampled_actions.shape) == 1:
            print(f"root_sampled_actions.shape: {root_sampled_actions.shape}")
        state_batch = state_batch.to(device=self._device, dtype=torch.float)
        mcts_visit_count_probs = mcts_visit_count_probs.to(device=self._device, dtype=torch.float)
        reward = reward.to(device=self._device, dtype=torch.float)

        policy_probs, values = self._learn_model.compute_policy_value(state_batch)
        policy_log_probs = torch.log(policy_probs)

        # calculate policy entropy, for monitoring only
        entropy = compute_entropy(policy_probs)
        entropy_loss = -entropy

        # ==============================================================
        # policy loss
        # ==============================================================
        policy_loss = self._calculate_policy_loss_disc(policy_probs, mcts_visit_count_probs, root_sampled_actions,
                                                       valid_action_length)

        # ==============================================================
        # value loss
        # ==============================================================
        value_loss = F.mse_loss(values.view(-1), reward)

        total_loss = self._value_weight * value_loss + policy_loss + self._entropy_weight * entropy_loss
        self._optimizer.zero_grad()
        total_loss.backward()

        total_grad_norm_before_clip = torch.nn.utils.clip_grad_norm_(
            list(self._model.parameters()),
            max_norm=self._cfg.grad_clip_value,
        )
        self._optimizer.step()
        if self._cfg.lr_piecewise_constant_decay is True:
            self.lr_scheduler.step()

        # =============
        # after update
        # =============
        return {
            'cur_lr': self._optimizer.param_groups[0]['lr'],
            'total_loss': total_loss.item(),
            'policy_loss': policy_loss.item(),
            'value_loss': value_loss.item(),
            'entropy_loss': entropy_loss.item(),
            'total_grad_norm_before_clip': total_grad_norm_before_clip.item(),
            'collect_mcts_temperature': self.collect_mcts_temperature,
        }

    def _calculate_policy_loss_disc(
            self, policy_probs: torch.Tensor, target_policy: torch.Tensor,
            target_sampled_actions: torch.Tensor, valid_action_lengths: torch.Tensor
    ) -> torch.Tensor:

        # For each batch and each sampled action, get the corresponding probability
        # from policy_probs and target_policy, and put it into sampled_policy_probs and
        # sampled_target_policy at the same position.
        sampled_policy_probs = policy_probs.gather(1, target_sampled_actions)
        sampled_target_policy = target_policy.gather(1, target_sampled_actions)

        # Create a mask for valid actions
        max_length = target_sampled_actions.size(1)
        mask = torch.arange(max_length).expand(len(valid_action_lengths), max_length) < valid_action_lengths.unsqueeze(
            1)
        mask = mask.to(device=self._device)

        # Apply the mask to sampled_policy_probs and sampled_target_policy
        sampled_policy_probs = sampled_policy_probs * mask.float()
        sampled_target_policy = sampled_target_policy * mask.float()

        # Normalize sampled_policy_probs and sampled_target_policy
        sampled_policy_probs = sampled_policy_probs / (sampled_policy_probs.sum(dim=1, keepdim=True) + 1e-6)
        sampled_target_policy = sampled_target_policy / (sampled_target_policy.sum(dim=1, keepdim=True) + 1e-6)

        # after normalization, the sum of each row should be 1, but the prob corresponding to valid action becomes a small non-zero value
        # Use torch.where to prevent gradients for invalid actions
        sampled_policy_probs = torch.where(mask, sampled_policy_probs, torch.zeros_like(sampled_policy_probs))
        sampled_target_policy = torch.where(mask, sampled_target_policy, torch.zeros_like(sampled_target_policy))

        if self._cfg.policy_loss_type == 'KL':
            # Calculate the KL divergence between sampled_policy_probs and sampled_target_policy
            # The KL divergence between 2 probability distributions P and Q is defined as:
            # KL(P || Q) = sum(P(i) * log(P(i) / Q(i)))
            # We use the PyTorch function kl_div to calculate it.
            loss = torch.nn.functional.kl_div(
                sampled_policy_probs.log(), sampled_target_policy, reduction='none'
            )

            loss = torch.nan_to_num(loss)

            # Apply the mask to the loss
            loss = loss * mask.float()
            # Calculate the mean loss over the batch
            loss = loss.sum() / mask.sum()

        elif self._cfg.policy_loss_type == 'cross_entropy':
            # Calculate the cross entropy loss between sampled_policy_probs and sampled_target_policy
            # The cross entropy between 2 probability distributions P and Q is defined as:
            # H(P, Q) = -sum(P(i) * log(Q(i)))
            # We use the PyTorch function cross_entropy to calculate it.
            loss = torch.nn.functional.cross_entropy(
                sampled_policy_probs, torch.argmax(sampled_target_policy, dim=1), reduction='none'
            )

            # 使用 nan_to_num 将 loss 中的 nan 值设置为0
            loss = torch.nan_to_num(loss)

            # Apply the mask to the loss
            loss = loss * mask.float()
            # Calculate the mean loss over the batch
            loss = loss.sum() / mask.sum()

        else:
            raise ValueError(f"Invalid policy_loss_type: {self._cfg.policy_loss_type}")

        return loss

    def _init_collect(self) -> None:
        """
        Overview:
            Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils.
        """
        self._get_simulation_env()

        self._collect_model = self._model
        if self._cfg.mcts_ctree:
            import sys
            sys.path.append('./LightZero/lzero/mcts/ctree/ctree_alphazero/build')
            import mcts_alphazero
            self._collect_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves, self._cfg.mcts.num_simulations,
                                                     self._cfg.mcts.pb_c_base,
                                                     self._cfg.mcts.pb_c_init, self._cfg.mcts.root_dirichlet_alpha,
                                                     self._cfg.mcts.root_noise_weight, self.simulate_env)
        else:
            if self._cfg.sampled_algo:
                from lzero.mcts.ptree.ptree_az_sampled import MCTS
            else:
                from lzero.mcts.ptree.ptree_az import MCTS
            self._collect_mcts = MCTS(self._cfg.mcts, self.simulate_env)

        self.collect_mcts_temperature = 1

    @torch.no_grad()
    def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch.Tensor]:

        """
        Overview:
            The forward function for collecting data in collect mode. Use real env to execute MCTS search.
        Arguments:
            - obs (:obj:`Dict`): The dict of obs, the key is env_id and the value is the \
                corresponding obs in this timestep.
            - temperature (:obj:`float`): The temperature for MCTS search.
        Returns:
            - output (:obj:`Dict[str, torch.Tensor]`): The dict of output, the key is env_id and the value is the \
                the corresponding policy output in this timestep, including action, probs and so on.
        """
        self.collect_mcts_temperature = temperature
        ready_env_id = list(obs.keys())
        init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}
        try:
            katago_game_state = {env_id: obs[env_id]['katago_game_state'] for env_id in ready_env_id}
        except Exception as e:
            katago_game_state = {env_id: None for env_id in ready_env_id}

        start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id}
        output = {}
        self._policy_model = self._collect_model
        for env_id in ready_env_id:
            # print('[collect] start_player_index={}'.format(start_player_index[env_id]))
            # print('[collect] init_state=\n{}'.format(init_state[env_id]))
            state_config_for_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
                                                       init_state=init_state[env_id],
                                                       katago_policy_init=True,
                                                       katago_game_state=katago_game_state[env_id]))

            action, mcts_visit_count_probs = self._collect_mcts.get_next_action(
                state_config_for_env_reset,
                self._policy_value_func,
                self.collect_mcts_temperature,
                True,
            )

            # if np.array_equal(self._collect_mcts.get_sampled_actions(), np.array([2, 2, 3])):
            #     print('debug')
            output[env_id] = {
                'action': action,
                'probs': mcts_visit_count_probs,
                'root_sampled_actions': self._collect_mcts.get_sampled_actions(),
            }

        return output

    def _init_eval(self) -> None:
        """
        Overview:
            Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils.
        """
        self._get_simulation_env()
        if self._cfg.mcts_ctree:
            import sys
            sys.path.append('./LightZero/lzero/mcts/ctree/ctree_alphazero/build')
            import mcts_alphazero
            # TODO(pu): how to set proper num_simulations for evaluation
            self._eval_mcts = mcts_alphazero.MCTS(self._cfg.mcts.max_moves,
                                                  min(800, self._cfg.mcts.num_simulations * 4),
                                                  self._cfg.mcts.pb_c_base,
                                                  self._cfg.mcts.pb_c_init, self._cfg.mcts.root_dirichlet_alpha,
                                                  self._cfg.mcts.root_noise_weight, self.simulate_env)
        else:
            if self._cfg.sampled_algo:
                from lzero.mcts.ptree.ptree_az_sampled import MCTS
            else:
                from lzero.mcts.ptree.ptree_az import MCTS
            mcts_eval_config = copy.deepcopy(self._cfg.mcts)
            # TODO(pu): how to set proper num_simulations for evaluation
            mcts_eval_config.num_simulations = min(800, mcts_eval_config.num_simulations * 4)
            self._eval_mcts = MCTS(mcts_eval_config, self.simulate_env)

        self._eval_model = self._model

    def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:

        """
        Overview:
            The forward function for evaluating the current policy in eval mode, similar to ``self._forward_collect``.
        Arguments:
            - obs (:obj:`Dict`): The dict of obs, the key is env_id and the value is the \
                corresponding obs in this timestep.
        Returns:
            - output (:obj:`Dict[str, torch.Tensor]`): The dict of output, the key is env_id and the value is the \
                the corresponding policy output in this timestep, including action, probs and so on.
        """
        ready_env_id = list(obs.keys())
        init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}
        try:
            katago_game_state = {env_id: obs[env_id]['katago_game_state'] for env_id in ready_env_id}
        except Exception as e:
            katago_game_state = {env_id: None for env_id in ready_env_id}

        start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id}
        output = {}
        self._policy_model = self._eval_model
        for env_id in ready_env_id:
            # print('[eval] start_player_index={}'.format(start_player_index[env_id]))
            # print('[eval] init_state=\n {}'.format(init_state[env_id]))

            state_config_for_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
                                                       init_state=init_state[env_id],
                                                       katago_policy_init=False,
                                                       katago_game_state=katago_game_state[env_id]))

            # try:
            action, mcts_visit_count_probs = self._eval_mcts.get_next_action(state_config_for_env_reset,
                                                                             self._policy_value_func,
                                                                             1.0, False)
            # except Exception as e:
            #     print(f"Exception occurred: {e}")
            #     print(f"Is self._policy_value_func callable? {callable(self._policy_value_func)}")
            #     raise  # re-raise the exception
            # print("="*20)
            # print(action, mcts_visit_count_probs)
            # print("="*20)
            output[env_id] = {
                'action': action,
                'probs': mcts_visit_count_probs,
            }
        return output

    def _get_simulation_env(self):
        assert self._cfg.simulation_env_name in ['tictactoe', 'gomoku', 'go'], self._cfg.simulation_env_name
        assert self._cfg.simulation_env_config_type in ['play_with_bot', 'self_play', 'league',
                                                        'sampled_play_with_bot'], self._cfg.simulation_env_config_type
        if self._cfg.simulation_env_name == 'tictactoe':
            from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv
            if self._cfg.simulation_env_config_type == 'play_with_bot':
                from zoo.board_games.tictactoe.config.tictactoe_alphazero_bot_mode_config import \
                    tictactoe_alphazero_config
            elif self._cfg.simulation_env_config_type == 'self_play':
                from zoo.board_games.tictactoe.config.tictactoe_alphazero_sp_mode_config import \
                    tictactoe_alphazero_config
            elif self._cfg.simulation_env_config_type == 'league':
                from zoo.board_games.tictactoe.config.tictactoe_alphazero_league_config import \
                    tictactoe_alphazero_config
            elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot':
                from zoo.board_games.tictactoe.config.tictactoe_sampled_alphazero_bot_mode_config import \
                    tictactoe_sampled_alphazero_config as tictactoe_alphazero_config

            self.simulate_env = TicTacToeEnv(tictactoe_alphazero_config.env)

        elif self._cfg.simulation_env_name == 'gomoku':
            from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv
            if self._cfg.simulation_env_config_type == 'play_with_bot':
                from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import gomoku_alphazero_config
            elif self._cfg.simulation_env_config_type == 'self_play':
                from zoo.board_games.gomoku.config.gomoku_alphazero_sp_mode_config import gomoku_alphazero_config
            elif self._cfg.simulation_env_config_type == 'league':
                from zoo.board_games.gomoku.config.gomoku_alphazero_league_config import gomoku_alphazero_config
            elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot':
                from zoo.board_games.gomoku.config.gomoku_sampled_alphazero_bot_mode_config import \
                    gomoku_sampled_alphazero_config as gomoku_alphazero_config

            self.simulate_env = GomokuEnv(gomoku_alphazero_config.env)
        elif self._cfg.simulation_env_name == 'go':
            from zoo.board_games.go.envs.go_env import GoEnv
            if self._cfg.simulation_env_config_type == 'play_with_bot':
                from zoo.board_games.go.config.go_alphazero_bot_mode_config import go_alphazero_config
            elif self._cfg.simulation_env_config_type == 'self_play':
                from zoo.board_games.go.config.go_alphazero_sp_mode_config import go_alphazero_config
            elif self._cfg.simulation_env_config_type == 'league':
                from zoo.board_games.go.config.go_alphazero_league_config import go_alphazero_config
            elif self._cfg.simulation_env_config_type == 'sampled_play_with_bot':
                from zoo.board_games.go.config.go_sampled_alphazero_bot_mode_config import \
                    go_sampled_alphazero_config as go_alphazero_config

            self.simulate_env = GoEnv(go_alphazero_config.env)

    @torch.no_grad()
    def _policy_value_func(self, environment: 'Environment') -> Tuple[Dict[int, np.ndarray], float]:
        # Retrieve the legal actions in the current environment
        legal_actions = environment.legal_actions

        # Retrieve the current state and its scale from the environment
        current_state, state_scale = environment.current_state()

        # Convert the state scale to a PyTorch FloatTensor, adding a dimension to match the model's input requirements
        state_scale_tensor = torch.from_numpy(state_scale).to(
            device=self._device, dtype=torch.float
        ).unsqueeze(0)

        # Compute action probabilities and state value for the current state using the policy model, without gradient computation
        with torch.no_grad():
            action_probabilities, state_value = self._policy_model.compute_policy_value(state_scale_tensor)

        # Extract the probabilities of the legal actions from the action probabilities, and convert the result to a numpy array
        legal_action_probabilities = dict(
            zip(legal_actions, action_probabilities.squeeze(0)[legal_actions].detach().cpu().numpy()))

        # Return probabilities of the legal actions and the state value
        return legal_action_probabilities, state_value.item()

    def _monitor_vars_learn(self) -> List[str]:
        """
        Overview:
            Register the variables to be monitored in learn mode. The registered variables will be logged in
            tensorboard according to the return value ``_forward_learn``.
        """
        return super()._monitor_vars_learn() + [
            'cur_lr', 'total_loss', 'policy_loss', 'value_loss', 'entropy_loss', 'total_grad_norm_before_clip',
            'collect_mcts_temperature'
        ]

    def _process_transition(self, obs: Dict, model_output: Dict[str, torch.Tensor], timestep: namedtuple) -> Dict:
        """
        Overview:
            Generate the dict type transition (one timestep) data from policy learning.
        """
        if 'katago_game_state' in obs.keys():
            del obs['katago_game_state']
        # if 'katago_game_state' in timestep.obs.keys():
        #     del timestep.obs['katago_game_state']
        # Note: used in _foward_collect  in alphazero_collector now

        return {
            'obs': obs,
            'next_obs': timestep.obs,
            'action': model_output['action'],
            'root_sampled_actions': model_output['root_sampled_actions'],
            'probs': model_output['probs'],
            'reward': timestep.reward,
            'done': timestep.done,
        }

    def _get_train_sample(self, data):
        # be compatible with DI-engine Policy class
        pass