PPO playing procgen-bigfish-easy from https://github.com/sgoodfriend/rl-algo-impls/tree/21ee1ab96a186676e5ed2f8c3185902f7c7bca7a
a678795
| import numpy as np | |
| from gym.wrappers.monitoring.video_recorder import VideoRecorder | |
| from stable_baselines3.common.vec_env.base_vec_env import ( | |
| VecEnv, | |
| VecEnvStepReturn, | |
| VecEnvWrapper, | |
| VecEnvObs, | |
| ) | |
| from typing import Optional | |
| class VecSingleImageWrapper(VecEnvWrapper): | |
| def __init__(self, venv: VecEnv): | |
| super().__init__(venv) | |
| def render(self, mode: str = "human") -> Optional[np.ndarray]: | |
| images = self.venv.get_images() | |
| return images[0] | |
| def step_wait(self) -> VecEnvStepReturn: | |
| return self.venv.step_wait() | |
| def reset(self) -> VecEnvObs: | |
| return self.venv.reset() | |
| class VecSingleRecorder(VecEnvWrapper): | |
| def __init__( | |
| self, | |
| venv: VecEnv, | |
| video_path_prefix: str, | |
| video_step_interval: int = 1_000_000, | |
| max_video_length: int = 3600, | |
| ): | |
| super().__init__(venv) | |
| self.single_image_wrapper = VecSingleImageWrapper(venv) | |
| self.video_path_prefix = video_path_prefix | |
| self.video_step_interval = video_step_interval | |
| self.max_video_length = max_video_length | |
| self.total_steps = 0 | |
| self.next_record_video_step = 0 | |
| self.video_recorder = None | |
| self.recorded_frames = 0 | |
| def step_wait(self) -> VecEnvStepReturn: | |
| obs, rew, dones, infos = self.venv.step_wait() | |
| self.total_steps += self.venv.num_envs | |
| # Using first env to record episodes | |
| if self.video_recorder: | |
| self.video_recorder.capture_frame() | |
| self.recorded_frames += 1 | |
| if dones[0] and infos[0].get("episode"): | |
| episode_info = { | |
| k: v.item() if hasattr(v, "item") else v | |
| for k, v in infos[0]["episode"].items() | |
| } | |
| self.video_recorder.metadata["episode"] = episode_info | |
| if dones[0] or self.recorded_frames > self.max_video_length: | |
| self._close_video_recorder() | |
| elif dones[0] and self.total_steps >= self.next_record_video_step: | |
| self._start_video_recorder() | |
| return obs, rew, dones, infos | |
| def reset(self) -> VecEnvObs: | |
| obs = self.venv.reset() | |
| if self.video_recorder: | |
| self._close_video_recorder() | |
| elif ( | |
| not self.video_recorder and self.total_steps >= self.next_record_video_step | |
| ): | |
| self._start_video_recorder() | |
| return obs | |
| def _start_video_recorder(self) -> None: | |
| self._close_video_recorder() | |
| video_path = f"{self.video_path_prefix}-{self.next_record_video_step}" | |
| self.video_recorder = VideoRecorder( | |
| self.single_image_wrapper, | |
| base_path=video_path, | |
| metadata={"step": self.total_steps}, | |
| ) | |
| self.video_recorder.capture_frame() | |
| self.recorded_frames = 1 | |
| self.next_record_video_step += self.video_step_interval | |
| def _close_video_recorder(self) -> None: | |
| if self.video_recorder: | |
| self.video_recorder.close() | |
| self.video_recorder = None | |