File size: 380 Bytes
be9c115
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
import gym
import numpy as np


class VideoCompatWrapper(gym.Wrapper):
    def __init__(self, env: gym.Env) -> None:
        super().__init__(env)

    def render(self, mode="human", **kwargs):
        r = super().render(mode=mode, **kwargs)
        if mode == "rgb_array" and isinstance(r, np.ndarray) and r.dtype != np.uint8:
            r = r.astype(np.uint8)
        return r