a2c-Walker2DBulletEnv-v0 / wrappers /sync_vector_env_render_compat.py
sgoodfriend's picture
A2C playing Walker2DBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/0760ef7d52b17f30219a27c18ba52c8895025ae3
3c98d8b
raw
history blame
1 kB
import numpy as np
from gym.vector.sync_vector_env import SyncVectorEnv
from stable_baselines3.common.vec_env.base_vec_env import tile_images
from typing import Optional
from wrappers.vectorable_wrapper import (
VecotarableWrapper,
)
class SyncVectorEnvRenderCompat(VecotarableWrapper):
def __init__(self, env) -> None:
super().__init__(env)
def render(self, mode: str = "human") -> Optional[np.ndarray]:
base_env = self.env.unwrapped
if isinstance(base_env, SyncVectorEnv):
imgs = [env.render(mode="rgb_array") for env in base_env.envs]
bigimg = tile_images(imgs)
if mode == "human":
import cv2
cv2.imshow("vecenv", bigimg[:, :, ::-1])
cv2.waitKey(1)
elif mode == "rgb_array":
return bigimg
else:
raise NotImplemented(f"Render mode {mode} is not supported")
else:
return self.env.render(mode=mode)