PPO playing BreakoutNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0
20d9758
| import gym | |
| import numpy as np | |
| from rl_algo_impls.wrappers.vectorable_wrapper import VecotarableWrapper | |
| class VideoCompatWrapper(VecotarableWrapper): | |
| def __init__(self, env: gym.Env) -> None: | |
| super().__init__(env) | |
| def render(self, mode="human", **kwargs): | |
| r = super().render(mode=mode, **kwargs) | |
| if mode == "rgb_array" and isinstance(r, np.ndarray) and r.dtype != np.uint8: | |
| r = r.astype(np.uint8) | |
| return r | |