File size: 1,015 Bytes
e9e96b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import numpy as np
from gym.vector.sync_vector_env import SyncVectorEnv
from stable_baselines3.common.vec_env.base_vec_env import tile_images
from typing import Optional
from rl_algo_impls.wrappers.vectorable_wrapper import (
VecotarableWrapper,
)
class SyncVectorEnvRenderCompat(VecotarableWrapper):
def __init__(self, env) -> None:
super().__init__(env)
def render(self, mode: str = "human") -> Optional[np.ndarray]:
base_env = self.env.unwrapped
if isinstance(base_env, SyncVectorEnv):
imgs = [env.render(mode="rgb_array") for env in base_env.envs]
bigimg = tile_images(imgs)
if mode == "human":
import cv2
cv2.imshow("vecenv", bigimg[:, :, ::-1])
cv2.waitKey(1)
elif mode == "rgb_array":
return bigimg
else:
raise NotImplemented(f"Render mode {mode} is not supported")
else:
return self.env.render(mode=mode)
|