sgoodfriend's picture
A2C playing BreakoutNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
233c511
import itertools
import os
import shutil
from time import perf_counter
from typing import Dict, List, Optional, Union
import numpy as np
from torch.utils.tensorboard.writer import SummaryWriter
from rl_algo_impls.shared.callbacks import Callback
from rl_algo_impls.shared.policy.policy import Policy
from rl_algo_impls.shared.stats import Episode, EpisodeAccumulator, EpisodesStats
from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker
from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
class EvaluateAccumulator(EpisodeAccumulator):
def __init__(
self,
num_envs: int,
goal_episodes: int,
print_returns: bool = True,
ignore_first_episode: bool = False,
additional_keys_to_log: Optional[List[str]] = None,
):
super().__init__(num_envs)
self.completed_episodes_by_env_idx = [[] for _ in range(num_envs)]
self.goal_episodes_per_env = int(np.ceil(goal_episodes / num_envs))
self.print_returns = print_returns
if ignore_first_episode:
first_done = set()
def should_record_done(idx: int) -> bool:
has_done_first_episode = idx in first_done
first_done.add(idx)
return has_done_first_episode
self.should_record_done = should_record_done
else:
self.should_record_done = lambda idx: True
self.additional_keys_to_log = additional_keys_to_log
def on_done(self, ep_idx: int, episode: Episode, info: Dict) -> None:
if self.additional_keys_to_log:
episode.info = {k: info[k] for k in self.additional_keys_to_log}
if (
self.should_record_done(ep_idx)
and len(self.completed_episodes_by_env_idx[ep_idx])
>= self.goal_episodes_per_env
):
return
self.completed_episodes_by_env_idx[ep_idx].append(episode)
if self.print_returns:
print(
f"Episode {len(self)} | "
f"Score {episode.score} | "
f"Length {episode.length}"
)
def __len__(self) -> int:
return sum(len(ce) for ce in self.completed_episodes_by_env_idx)
@property
def episodes(self) -> List[Episode]:
return list(itertools.chain(*self.completed_episodes_by_env_idx))
def is_done(self) -> bool:
return all(
len(ce) == self.goal_episodes_per_env
for ce in self.completed_episodes_by_env_idx
)
def evaluate(
env: VecEnv,
policy: Policy,
n_episodes: int,
render: bool = False,
deterministic: bool = True,
print_returns: bool = True,
ignore_first_episode: bool = False,
additional_keys_to_log: Optional[List[str]] = None,
score_function: str = "mean-std",
) -> EpisodesStats:
policy.sync_normalization(env)
policy.eval()
episodes = EvaluateAccumulator(
env.num_envs,
n_episodes,
print_returns,
ignore_first_episode,
additional_keys_to_log=additional_keys_to_log,
)
obs = env.reset()
get_action_mask = getattr(env, "get_action_mask", None)
while not episodes.is_done():
act = policy.act(
obs,
deterministic=deterministic,
action_masks=get_action_mask() if get_action_mask else None,
)
obs, rew, done, info = env.step(act)
episodes.step(rew, done, info)
if render:
env.render()
stats = EpisodesStats(
episodes.episodes,
score_function=score_function,
)
if print_returns:
print(stats)
return stats
class EvalCallback(Callback):
def __init__(
self,
policy: Policy,
env: VecEnv,
tb_writer: SummaryWriter,
best_model_path: Optional[str] = None,
step_freq: Union[int, float] = 50_000,
n_episodes: int = 10,
save_best: bool = True,
deterministic: bool = True,
record_best_videos: bool = True,
video_env: Optional[VecEnv] = None,
best_video_dir: Optional[str] = None,
max_video_length: int = 3600,
ignore_first_episode: bool = False,
additional_keys_to_log: Optional[List[str]] = None,
score_function: str = "mean-std",
wandb_enabled: bool = False,
) -> None:
super().__init__()
self.policy = policy
self.env = env
self.tb_writer = tb_writer
self.best_model_path = best_model_path
self.step_freq = int(step_freq)
self.n_episodes = n_episodes
self.save_best = save_best
self.deterministic = deterministic
self.stats: List[EpisodesStats] = []
self.best = None
self.record_best_videos = record_best_videos
assert video_env or not record_best_videos
self.video_env = video_env
assert best_video_dir or not record_best_videos
self.best_video_dir = best_video_dir
if best_video_dir:
os.makedirs(best_video_dir, exist_ok=True)
self.max_video_length = max_video_length
self.best_video_base_path = None
self.ignore_first_episode = ignore_first_episode
self.additional_keys_to_log = additional_keys_to_log
self.score_function = score_function
self.wandb_enabled = wandb_enabled
def on_step(self, timesteps_elapsed: int = 1) -> bool:
super().on_step(timesteps_elapsed)
if self.timesteps_elapsed // self.step_freq >= len(self.stats):
self.evaluate()
return True
def evaluate(
self, n_episodes: Optional[int] = None, print_returns: Optional[bool] = None
) -> EpisodesStats:
start_time = perf_counter()
eval_stat = evaluate(
self.env,
self.policy,
n_episodes or self.n_episodes,
deterministic=self.deterministic,
print_returns=print_returns or False,
ignore_first_episode=self.ignore_first_episode,
additional_keys_to_log=self.additional_keys_to_log,
score_function=self.score_function,
)
end_time = perf_counter()
self.tb_writer.add_scalar(
"eval/steps_per_second",
eval_stat.length.sum() / (end_time - start_time),
self.timesteps_elapsed,
)
self.policy.train(True)
print(f"Eval Timesteps: {self.timesteps_elapsed} | {eval_stat}")
self.stats.append(eval_stat)
if not self.best or eval_stat >= self.best:
strictly_better = not self.best or eval_stat > self.best
self.best = eval_stat
if self.save_best:
assert self.best_model_path
self.policy.save(self.best_model_path)
print("Saved best model")
if self.wandb_enabled:
import wandb
best_model_name = os.path.split(self.best_model_path)[-1]
shutil.make_archive(
os.path.join(wandb.run.dir, best_model_name), # type: ignore
"zip",
self.best_model_path,
)
self.best.write_to_tensorboard(
self.tb_writer, "best_eval", self.timesteps_elapsed
)
if strictly_better and self.record_best_videos:
assert self.video_env and self.best_video_dir
self.best_video_base_path = os.path.join(
self.best_video_dir, str(self.timesteps_elapsed)
)
video_wrapped = VecEpisodeRecorder(
self.video_env,
self.best_video_base_path,
max_video_length=self.max_video_length,
)
video_stats = evaluate(
video_wrapped,
self.policy,
1,
deterministic=self.deterministic,
print_returns=False,
score_function=self.score_function,
)
print(f"Saved best video: {video_stats}")
eval_stat.write_to_tensorboard(self.tb_writer, "eval", self.timesteps_elapsed)
return eval_stat