File size: 1,473 Bytes
e81ed1e
 
32c4441
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
from gym import Env, Space, Wrapper
from stable_baselines3.common.vec_env import VecEnv as SB3VecEnv

VecEnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]]
VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]]


class VecotarableWrapper(Wrapper):
    def __init__(self, env: Env) -> None:
        super().__init__(env)
        self.num_envs = getattr(env, "num_envs", 1)
        self.is_vector_env = getattr(env, "is_vector_env", False)
        self.single_observation_space = single_observation_space(env)
        self.single_action_space = single_action_space(env)

    def step(self, action) -> VecEnvStepReturn:
        return self.env.step(action)

    def reset(self) -> VecEnvObs:
        return self.env.reset()


VecEnv = Union[VecotarableWrapper, SB3VecEnv]


def single_observation_space(env: Union[VecEnv, Env]) -> Space:
    return getattr(env, "single_observation_space", env.observation_space)


def single_action_space(env: Union[VecEnv, Env]) -> Space:
    return getattr(env, "single_action_space", env.action_space)


W = TypeVar("W", bound=Wrapper)


def find_wrapper(env: VecEnv, wrapper_class: Type[W]) -> Optional[W]:
    current = env
    while current and current != current.unwrapped:
        if isinstance(current, wrapper_class):
            return current
        current = getattr(current, "env")
    return None