File size: 2,526 Bytes
b05d1d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from dataclasses import astuple
from typing import Optional
import gym
import numpy as np
from torch.utils.tensorboard.writer import SummaryWriter
from rl_algo_impls.runner.config import Config, EnvHyperparams
from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
from rl_algo_impls.wrappers.hwc_to_chw_observation import HwcToChwObservation
from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
def make_procgen_env(
config: Config,
hparams: EnvHyperparams,
training: bool = True,
render: bool = False,
normalize_load_path: Optional[str] = None,
tb_writer: Optional[SummaryWriter] = None,
) -> VecEnv:
from gym3 import ExtractDictObWrapper, ViewerWrapper
from procgen.env import ProcgenGym3Env, ToBaselinesVecEnv
(
_, # env_type
n_envs,
_, # frame_stack
make_kwargs,
_, # no_reward_timeout_steps
_, # no_reward_fire_steps
_, # vec_env_class
normalize,
normalize_kwargs,
rolling_length,
_, # train_record_video
_, # video_step_interval
_, # initial_steps_to_truncate
_, # clip_atari_rewards
_, # normalize_type
_, # mask_actions
_, # bots
) = astuple(hparams)
seed = config.seed(training=training)
make_kwargs = make_kwargs or {}
make_kwargs["render_mode"] = "rgb_array"
if seed is not None:
make_kwargs["rand_seed"] = seed
envs = ProcgenGym3Env(n_envs, config.env_id, **make_kwargs)
envs = ExtractDictObWrapper(envs, key="rgb")
if render:
envs = ViewerWrapper(envs, info_key="rgb")
envs = ToBaselinesVecEnv(envs)
envs = IsVectorEnv(envs)
# TODO: Handle Grayscale and/or FrameStack
envs = HwcToChwObservation(envs)
envs = gym.wrappers.RecordEpisodeStatistics(envs)
if seed is not None:
envs.action_space.seed(seed)
envs.observation_space.seed(seed)
if training:
assert tb_writer
envs = EpisodeStatsWriter(
envs, tb_writer, training=training, rolling_length=rolling_length
)
if normalize and training:
normalize_kwargs = normalize_kwargs or {}
envs = gym.wrappers.NormalizeReward(envs)
clip_obs = normalize_kwargs.get("clip_reward", 10.0)
envs = gym.wrappers.TransformReward(
envs, lambda r: np.clip(r, -clip_obs, clip_obs)
)
return envs # type: ignore
|