dqn-CartPole-v1 / rl_algo_impls /wrappers /sync_vector_env_render_compat.py
sgoodfriend's picture
DQN playing CartPole-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0
e491716
import numpy as np
from gym.vector.sync_vector_env import SyncVectorEnv
from stable_baselines3.common.vec_env.base_vec_env import tile_images
from typing import Optional
from rl_algo_impls.wrappers.vectorable_wrapper import (
VecotarableWrapper,
)
class SyncVectorEnvRenderCompat(VecotarableWrapper):
def __init__(self, env) -> None:
super().__init__(env)
def render(self, mode: str = "human") -> Optional[np.ndarray]:
base_env = self.env.unwrapped
if isinstance(base_env, SyncVectorEnv):
imgs = [env.render(mode="rgb_array") for env in base_env.envs]
bigimg = tile_images(imgs)
if mode == "human":
import cv2
cv2.imshow("vecenv", bigimg[:, :, ::-1])
cv2.waitKey(1)
elif mode == "rgb_array":
return bigimg
else:
raise NotImplemented(f"Render mode {mode} is not supported")
else:
return self.env.render(mode=mode)