sgoodfriend's picture
PPO playing LunarLander-v2 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
479abbd
raw
history blame contribute delete
No virus
5.55 kB
import copy
import random
from collections import deque
from typing import Any, Deque, Dict, List, Optional
import numpy as np
from rl_algo_impls.runner.config import Config
from rl_algo_impls.shared.policy.policy import Policy
from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
from rl_algo_impls.wrappers.vectorable_wrapper import (
VecEnvObs,
VecEnvStepReturn,
VecotarableWrapper,
)
class SelfPlayWrapper(VecotarableWrapper):
next_obs: VecEnvObs
next_action_masks: Optional[np.ndarray]
def __init__(
self,
env,
config: Config,
num_old_policies: int = 0,
save_steps: int = 20_000,
swap_steps: int = 10_000,
window: int = 10,
swap_window_size: int = 2,
selfplay_bots: Optional[Dict[str, Any]] = None,
bot_always_player_2: bool = False,
) -> None:
super().__init__(env)
assert num_old_policies % 2 == 0, f"num_old_policies must be even"
assert (
num_old_policies % swap_window_size == 0
), f"num_old_policies must be a multiple of swap_window_size"
self.config = config
self.num_old_policies = num_old_policies
self.save_steps = save_steps
self.swap_steps = swap_steps
self.swap_window_size = swap_window_size
self.selfplay_bots = selfplay_bots
self.bot_always_player_2 = bot_always_player_2
self.policies: Deque[Policy] = deque(maxlen=window)
self.policy_assignments: List[Optional[Policy]] = [None] * env.num_envs
self.steps_since_swap = np.zeros(env.num_envs)
self.selfplay_policies: Dict[str, Policy] = {}
self.num_envs = env.num_envs - num_old_policies
if self.selfplay_bots:
self.num_envs -= sum(self.selfplay_bots.values())
self.initialize_selfplay_bots()
def get_action_mask(self) -> Optional[np.ndarray]:
return self.env.get_action_mask()[self.learner_indexes()]
def learner_indexes(self) -> List[int]:
return [p is None for p in self.policy_assignments]
def checkpoint_policy(self, copied_policy: Policy) -> None:
copied_policy.train(False)
self.policies.append(copied_policy)
if all(p is None for p in self.policy_assignments[: 2 * self.num_old_policies]):
for i in range(self.num_old_policies):
# Switch between player 1 and 2
self.policy_assignments[
2 * i + (i % 2 if not self.bot_always_player_2 else 1)
] = copied_policy
def swap_policy(self, idx: int, swap_window_size: int = 1) -> None:
policy = random.choice(self.policies)
idx = idx // 2 * 2
for j in range(swap_window_size * 2):
if self.policy_assignments[idx + j]:
self.policy_assignments[idx + j] = policy
self.steps_since_swap[idx : idx + swap_window_size * 2] = np.zeros(
swap_window_size * 2
)
def initialize_selfplay_bots(self) -> None:
if not self.selfplay_bots:
return
from rl_algo_impls.runner.running_utils import get_device, make_policy
env = self.env # Type: ignore
device = get_device(self.config, env)
start_idx = 2 * self.num_old_policies
for model_path, n in self.selfplay_bots.items():
policy = make_policy(
self.config.algo,
env,
device,
load_path=model_path,
**self.config.policy_hyperparams,
).eval()
self.selfplay_policies["model_path"] = policy
for idx in range(start_idx, start_idx + 2 * n, 2):
bot_idx = (
(idx + 1) if self.bot_always_player_2 else (idx + idx // 2 % 2)
)
self.policy_assignments[bot_idx] = policy
start_idx += 2 * n
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
env = self.env # type: ignore
all_actions = np.zeros((env.num_envs,) + actions.shape[1:], dtype=actions.dtype)
orig_learner_indexes = self.learner_indexes()
all_actions[orig_learner_indexes] = actions
for policy in set(p for p in self.policy_assignments if p):
policy_indexes = [policy == p for p in self.policy_assignments]
if any(policy_indexes):
all_actions[policy_indexes] = policy.act(
self.next_obs[policy_indexes],
deterministic=False,
action_masks=self.next_action_masks[policy_indexes]
if self.next_action_masks is not None
else None,
)
self.next_obs, rew, done, info = env.step(all_actions)
self.next_action_masks = self.env.get_action_mask()
rew = rew[orig_learner_indexes]
info = [i for i, b in zip(info, orig_learner_indexes) if b]
self.steps_since_swap += 1
for idx in range(0, self.num_old_policies * 2, 2 * self.swap_window_size):
if self.steps_since_swap[idx] > self.swap_steps:
self.swap_policy(idx, self.swap_window_size)
new_learner_indexes = self.learner_indexes()
return self.next_obs[new_learner_indexes], rew, done[new_learner_indexes], info
def reset(self) -> VecEnvObs:
self.next_obs = super().reset()
self.next_action_masks = self.env.get_action_mask()
return self.next_obs[self.learner_indexes()]