a2c-MountainCar-v0 / rl_algo_impls /wrappers /episode_stats_writer.py
sgoodfriend's picture
A2C playing MountainCar-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c
9a64f31
raw
history blame contribute delete
No virus
2.83 kB
from collections import deque
from typing import Any, Dict, List, Optional
import numpy as np
from torch.utils.tensorboard.writer import SummaryWriter
from rl_algo_impls.shared.stats import Episode, EpisodesStats
from rl_algo_impls.wrappers.vectorable_wrapper import (
VecEnvObs,
VecEnvStepReturn,
VecotarableWrapper,
)
class EpisodeStatsWriter(VecotarableWrapper):
def __init__(
self,
env,
tb_writer: SummaryWriter,
training: bool = True,
rolling_length=100,
additional_keys_to_log: Optional[List[str]] = None,
):
super().__init__(env)
self.training = training
self.tb_writer = tb_writer
self.rolling_length = rolling_length
self.episodes = deque(maxlen=rolling_length)
self.total_steps = 0
self.episode_cnt = 0
self.last_episode_cnt_print = 0
self.additional_keys_to_log = (
additional_keys_to_log if additional_keys_to_log is not None else []
)
def step(self, actions: np.ndarray) -> VecEnvStepReturn:
obs, rews, dones, infos = self.env.step(actions)
self._record_stats(infos)
return obs, rews, dones, infos
# Support for stable_baselines3.common.vec_env.VecEnvWrapper
def step_wait(self) -> VecEnvStepReturn:
obs, rews, dones, infos = self.env.step_wait()
self._record_stats(infos)
return obs, rews, dones, infos
def _record_stats(self, infos: List[Dict[str, Any]]) -> None:
self.total_steps += getattr(self.env, "num_envs", 1)
step_episodes = []
for info in infos:
ep_info = info.get("episode")
if ep_info:
additional_info = {k: info[k] for k in self.additional_keys_to_log}
episode = Episode(ep_info["r"], ep_info["l"], info=additional_info)
step_episodes.append(episode)
self.episodes.append(episode)
if step_episodes:
tag = "train" if self.training else "eval"
step_stats = EpisodesStats(step_episodes, simple=True)
step_stats.write_to_tensorboard(self.tb_writer, tag, self.total_steps)
rolling_stats = EpisodesStats(self.episodes)
rolling_stats.write_to_tensorboard(
self.tb_writer, f"{tag}_rolling", self.total_steps
)
self.episode_cnt += len(step_episodes)
if self.episode_cnt >= self.last_episode_cnt_print + self.rolling_length:
print(
f"Episode: {self.episode_cnt} | "
f"Steps: {self.total_steps} | "
f"{rolling_stats}"
)
self.last_episode_cnt_print += self.rolling_length
def reset(self) -> VecEnvObs:
return self.env.reset()