sgoodfriend's picture
DQN playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
3fd02ed
import copy
import random
from collections import deque
from typing import Any, Deque, Dict, List, Optional
import numpy as np
from rl_algo_impls.runner.config import Config
from rl_algo_impls.shared.policy.policy import Policy
from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
from rl_algo_impls.wrappers.vectorable_wrapper import (
VecEnvObs,
VecEnvStepReturn,
VecotarableWrapper,
)
class SelfPlayWrapper(VecotarableWrapper):
next_obs: VecEnvObs
next_action_masks: Optional[np.ndarray]
def __init__(
self,
env,
config: Config,
num_old_policies: int = 0,
save_steps: int = 20_000,
swap_steps: int = 10_000,
window: int = 10,
swap_window_size: int = 2,
selfplay_bots: Optional[Dict[str, Any]] = None,
bot_always_player_2: bool = False,
) -> None:
super().__init__(env)
assert num_old_policies % 2 == 0, f"num_old_policies must be even"
assert (
num_old_policies % swap_window_size == 0
), f"num_old_policies must be a multiple of swap_window_size"
self.config = config
self.num_old_policies = num_old_policies
self.save_steps = save_steps
self.swap_steps = swap_steps
self.swap_window_size = swap_window_size
self.selfplay_bots = selfplay_bots
self.bot_always_player_2 = bot_always_player_2
self.policies: Deque[Policy] = deque(maxlen=window)
self.policy_assignments: List[Optional[Policy]] = [None] * env.num_envs
self.steps_since_swap = np.zeros(env.num_envs)
self.selfplay_policies: Dict[str, Policy] = {}
self.num_envs = env.num_envs - num_old_policies
if self.selfplay_bots:
self.num_envs -= sum(self.selfplay_bots.values())
self.initialize_selfplay_bots()
def get_action_mask(self) -> Optional[np.ndarray]:
return self.env.get_action_mask()[self.learner_indexes()]
def learner_indexes(self) -> List[int]:
return [p is None for p in self.policy_assignments]
def checkpoint_policy(self, copied_policy: Policy) -> None:
copied_policy.train(False)
self.policies.append(copied_policy)
if all(p is None for p in self.policy_assignments[: 2 * self.num_old_policies]):
for i in range(self.num_old_policies):
# Switch between player 1 and 2
self.policy_assignments[
2 * i + (i % 2 if not self.bot_always_player_2 else 1)
] = copied_policy
def swap_policy(self, idx: int, swap_window_size: int = 1) -> None:
policy = random.choice(self.policies)
idx = idx // 2 * 2
for j in range(swap_window_size * 2):
if self.policy_assignments[idx + j]:
self.policy_assignments[idx + j] = policy
self.steps_since_swap[idx : idx + swap_window_size * 2] = np.zeros(
swap_window_size * 2
)
def initialize_selfplay_bots(self) -> None:
if not self.selfplay_bots:
return
from rl_algo_impls.runner.running_utils import get_device, make_policy
env = self.env # Type: ignore
device = get_device(self.config, env)
start_idx = 2 * self.num_old_policies
for model_path, n in self.selfplay_bots.items():
policy = make_policy(
self.config.algo,
env,
device,
load_path=model_path,
**self.config.policy_hyperparams,
).eval()
self.selfplay_policies["model_path"] = policy
for idx in range(start_idx, start_idx + 2 * n, 2):
bot_idx = (
(idx + 1) if self.bot_always_player_2 else (idx + idx // 2 % 2)
)
self.policy_assignments[bot_idx] = policy
start_idx += 2 * n
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
env = self.env # type: ignore
all_actions = np.zeros((env.num_envs,) + actions.shape[1:], dtype=actions.dtype)
orig_learner_indexes = self.learner_indexes()
all_actions[orig_learner_indexes] = actions
for policy in set(p for p in self.policy_assignments if p):
policy_indexes = [policy == p for p in self.policy_assignments]
if any(policy_indexes):
all_actions[policy_indexes] = policy.act(
self.next_obs[policy_indexes],
deterministic=False,
action_masks=self.next_action_masks[policy_indexes]
if self.next_action_masks is not None
else None,
)
self.next_obs, rew, done, info = env.step(all_actions)
self.next_action_masks = self.env.get_action_mask()
rew = rew[orig_learner_indexes]
info = [i for i, b in zip(info, orig_learner_indexes) if b]
self.steps_since_swap += 1
for idx in range(0, self.num_old_policies * 2, 2 * self.swap_window_size):
if self.steps_since_swap[idx] > self.swap_steps:
self.swap_policy(idx, self.swap_window_size)
new_learner_indexes = self.learner_indexes()
return self.next_obs[new_learner_indexes], rew, done[new_learner_indexes], info
def reset(self) -> VecEnvObs:
self.next_obs = super().reset()
self.next_action_masks = self.env.get_action_mask()
return self.next_obs[self.learner_indexes()]