sgoodfriend's picture
DQN playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c
e588054
from collections import deque
from typing import Any, Dict, List, Optional
import numpy as np
from torch.utils.tensorboard.writer import SummaryWriter
from rl_algo_impls.shared.stats import Episode, EpisodesStats
from rl_algo_impls.wrappers.vectorable_wrapper import (
VecEnvObs,
VecEnvStepReturn,
VecotarableWrapper,
)
class EpisodeStatsWriter(VecotarableWrapper):
def __init__(
self,
env,
tb_writer: SummaryWriter,
training: bool = True,
rolling_length=100,
additional_keys_to_log: Optional[List[str]] = None,
):
super().__init__(env)
self.training = training
self.tb_writer = tb_writer
self.rolling_length = rolling_length
self.episodes = deque(maxlen=rolling_length)
self.total_steps = 0
self.episode_cnt = 0
self.last_episode_cnt_print = 0
self.additional_keys_to_log = (
additional_keys_to_log if additional_keys_to_log is not None else []
)
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
obs, rews, dones, infos = self.env.step(actions)
self._record_stats(infos)
return obs, rews, dones, infos
# Support for stable_baselines3.common.vec_env.VecEnvWrapper
def step_wait(self) -> VecEnvStepReturn:
obs, rews, dones, infos = self.env.step_wait()
self._record_stats(infos)
return obs, rews, dones, infos
def _record_stats(self, infos: List[Dict[str, Any]]) -> None:
self.total_steps += getattr(self.env, "num_envs", 1)
step_episodes = []
for info in infos:
ep_info = info.get("episode")
if ep_info:
additional_info = {k: info[k] for k in self.additional_keys_to_log}
episode = Episode(ep_info["r"], ep_info["l"], info=additional_info)
step_episodes.append(episode)
self.episodes.append(episode)
if step_episodes:
tag = "train" if self.training else "eval"
step_stats = EpisodesStats(step_episodes, simple=True)
step_stats.write_to_tensorboard(self.tb_writer, tag, self.total_steps)
rolling_stats = EpisodesStats(self.episodes)
rolling_stats.write_to_tensorboard(
self.tb_writer, f"{tag}_rolling", self.total_steps
)
self.episode_cnt += len(step_episodes)
if self.episode_cnt >= self.last_episode_cnt_print + self.rolling_length:
print(
f"Episode: {self.episode_cnt} | "
f"Steps: {self.total_steps} | "
f"{rolling_stats}"
)
self.last_episode_cnt_print += self.rolling_length
def reset(self) -> VecEnvObs:
return self.env.reset()