DQN playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0
2a30b4a
| import numpy as np | |
| from gym.vector.sync_vector_env import SyncVectorEnv | |
| from stable_baselines3.common.vec_env.base_vec_env import tile_images | |
| from typing import Optional | |
| from rl_algo_impls.wrappers.vectorable_wrapper import ( | |
| VecotarableWrapper, | |
| ) | |
| class SyncVectorEnvRenderCompat(VecotarableWrapper): | |
| def __init__(self, env) -> None: | |
| super().__init__(env) | |
| def render(self, mode: str = "human") -> Optional[np.ndarray]: | |
| base_env = self.env.unwrapped | |
| if isinstance(base_env, SyncVectorEnv): | |
| imgs = [env.render(mode="rgb_array") for env in base_env.envs] | |
| bigimg = tile_images(imgs) | |
| if mode == "human": | |
| import cv2 | |
| cv2.imshow("vecenv", bigimg[:, :, ::-1]) | |
| cv2.waitKey(1) | |
| elif mode == "rgb_array": | |
| return bigimg | |
| else: | |
| raise NotImplemented(f"Render mode {mode} is not supported") | |
| else: | |
| return self.env.render(mode=mode) | |