sgoodfriend's picture
VPG playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0
88739bd
import numpy as np
from gym import Env, Space, Wrapper
from stable_baselines3.common.vec_env import VecEnv as SB3VecEnv
from typing import Dict, List, Optional, Type, TypeVar, Tuple, Union
VecEnvObs = Union[np.ndarray, Dict[str, np.ndarray], Tuple[np.ndarray, ...]]
VecEnvStepReturn = Tuple[VecEnvObs, np.ndarray, np.ndarray, List[Dict]]
class VecotarableWrapper(Wrapper):
def __init__(self, env: Env) -> None:
super().__init__(env)
self.num_envs = getattr(env, "num_envs", 1)
self.is_vector_env = getattr(env, "is_vector_env", False)
self.single_observation_space = single_observation_space(env)
self.single_action_space = single_action_space(env)
def step(self, action) -> VecEnvStepReturn:
return self.env.step(action)
def reset(self) -> VecEnvObs:
return self.env.reset()
VecEnv = Union[VecotarableWrapper, SB3VecEnv]
def single_observation_space(env: Union[VecEnv, Env]) -> Space:
return getattr(env, "single_observation_space", env.observation_space)
def single_action_space(env: Union[VecEnv, Env]) -> Space:
return getattr(env, "single_action_space", env.action_space)
W = TypeVar("W", bound=Wrapper)
def find_wrapper(env: VecEnv, wrapper_class: Type[W]) -> Optional[W]:
current = env
while current and current != current.unwrapped:
if isinstance(current, wrapper_class):
return current
current = getattr(current, "env")
return None