File size: 5,546 Bytes
1cde088 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import copy
import random
from collections import deque
from typing import Any, Deque, Dict, List, Optional
import numpy as np
from rl_algo_impls.runner.config import Config
from rl_algo_impls.shared.policy.policy import Policy
from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
from rl_algo_impls.wrappers.vectorable_wrapper import (
VecEnvObs,
VecEnvStepReturn,
VecotarableWrapper,
)
class SelfPlayWrapper(VecotarableWrapper):
next_obs: VecEnvObs
next_action_masks: Optional[np.ndarray]
def __init__(
self,
env,
config: Config,
num_old_policies: int = 0,
save_steps: int = 20_000,
swap_steps: int = 10_000,
window: int = 10,
swap_window_size: int = 2,
selfplay_bots: Optional[Dict[str, Any]] = None,
bot_always_player_2: bool = False,
) -> None:
super().__init__(env)
assert num_old_policies % 2 == 0, f"num_old_policies must be even"
assert (
num_old_policies % swap_window_size == 0
), f"num_old_policies must be a multiple of swap_window_size"
self.config = config
self.num_old_policies = num_old_policies
self.save_steps = save_steps
self.swap_steps = swap_steps
self.swap_window_size = swap_window_size
self.selfplay_bots = selfplay_bots
self.bot_always_player_2 = bot_always_player_2
self.policies: Deque[Policy] = deque(maxlen=window)
self.policy_assignments: List[Optional[Policy]] = [None] * env.num_envs
self.steps_since_swap = np.zeros(env.num_envs)
self.selfplay_policies: Dict[str, Policy] = {}
self.num_envs = env.num_envs - num_old_policies
if self.selfplay_bots:
self.num_envs -= sum(self.selfplay_bots.values())
self.initialize_selfplay_bots()
def get_action_mask(self) -> Optional[np.ndarray]:
return self.env.get_action_mask()[self.learner_indexes()]
def learner_indexes(self) -> List[int]:
return [p is None for p in self.policy_assignments]
def checkpoint_policy(self, copied_policy: Policy) -> None:
copied_policy.train(False)
self.policies.append(copied_policy)
if all(p is None for p in self.policy_assignments[: 2 * self.num_old_policies]):
for i in range(self.num_old_policies):
# Switch between player 1 and 2
self.policy_assignments[
2 * i + (i % 2 if not self.bot_always_player_2 else 1)
] = copied_policy
def swap_policy(self, idx: int, swap_window_size: int = 1) -> None:
policy = random.choice(self.policies)
idx = idx // 2 * 2
for j in range(swap_window_size * 2):
if self.policy_assignments[idx + j]:
self.policy_assignments[idx + j] = policy
self.steps_since_swap[idx : idx + swap_window_size * 2] = np.zeros(
swap_window_size * 2
)
def initialize_selfplay_bots(self) -> None:
if not self.selfplay_bots:
return
from rl_algo_impls.runner.running_utils import get_device, make_policy
env = self.env # Type: ignore
device = get_device(self.config, env)
start_idx = 2 * self.num_old_policies
for model_path, n in self.selfplay_bots.items():
policy = make_policy(
self.config.algo,
env,
device,
load_path=model_path,
**self.config.policy_hyperparams,
).eval()
self.selfplay_policies["model_path"] = policy
for idx in range(start_idx, start_idx + 2 * n, 2):
bot_idx = (
(idx + 1) if self.bot_always_player_2 else (idx + idx // 2 % 2)
)
self.policy_assignments[bot_idx] = policy
start_idx += 2 * n
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
env = self.env # type: ignore
all_actions = np.zeros((env.num_envs,) + actions.shape[1:], dtype=actions.dtype)
orig_learner_indexes = self.learner_indexes()
all_actions[orig_learner_indexes] = actions
for policy in set(p for p in self.policy_assignments if p):
policy_indexes = [policy == p for p in self.policy_assignments]
if any(policy_indexes):
all_actions[policy_indexes] = policy.act(
self.next_obs[policy_indexes],
deterministic=False,
action_masks=self.next_action_masks[policy_indexes]
if self.next_action_masks is not None
else None,
)
self.next_obs, rew, done, info = env.step(all_actions)
self.next_action_masks = self.env.get_action_mask()
rew = rew[orig_learner_indexes]
info = [i for i, b in zip(info, orig_learner_indexes) if b]
self.steps_since_swap += 1
for idx in range(0, self.num_old_policies * 2, 2 * self.swap_window_size):
if self.steps_since_swap[idx] > self.swap_steps:
self.swap_policy(idx, self.swap_window_size)
new_learner_indexes = self.learner_indexes()
return self.next_obs[new_learner_indexes], rew, done[new_learner_indexes], info
def reset(self) -> VecEnvObs:
self.next_obs = super().reset()
self.next_action_masks = self.env.get_action_mask()
return self.next_obs[self.learner_indexes()]
|