File size: 618 Bytes
be9c115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import gym
import gym.spaces.dict
import numpy as np

from gym import ObservationWrapper


class GetRgbObservation(ObservationWrapper):
    def __init__(self, env) -> None:
        super().__init__(env)
        assert isinstance(env.observation_space, gym.spaces.dict.Dict)

        self.observation_space = env.observation_space["rgb"]  # type: ignore
        if getattr(env, "is_vector_env"):
            self.single_observation_space = env.single_observation_space["rgb"]  # type: ignore

    def observation(self, observation: gym.spaces.dict.Dict) -> np.ndarray:
        return observation["rgb"]  # type: ignore