File size: 7,721 Bytes
db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe 05b94c0 db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 7c70ebe db8a108 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import os
from dataclasses import astuple
from typing import Callable, Optional
import gym
from gym.vector.async_vector_env import AsyncVectorEnv
from gym.vector.sync_vector_env import SyncVectorEnv
from gym.wrappers.frame_stack import FrameStack
from gym.wrappers.gray_scale_observation import GrayScaleObservation
from gym.wrappers.resize_observation import ResizeObservation
from stable_baselines3.common.atari_wrappers import MaxAndSkipEnv, NoopResetEnv
from stable_baselines3.common.vec_env.dummy_vec_env import DummyVecEnv
from stable_baselines3.common.vec_env.subproc_vec_env import SubprocVecEnv
from stable_baselines3.common.vec_env.vec_normalize import VecNormalize
from torch.utils.tensorboard.writer import SummaryWriter
from rl_algo_impls.runner.config import Config, EnvHyperparams
from rl_algo_impls.shared.policy.policy import VEC_NORMALIZE_FILENAME
from rl_algo_impls.shared.vec_env.utils import (
import_for_env_id,
is_atari,
is_bullet_env,
is_car_racing,
is_gym_procgen,
is_microrts,
)
from rl_algo_impls.wrappers.action_mask_wrapper import SingleActionMaskWrapper
from rl_algo_impls.wrappers.atari_wrappers import (
ClipRewardEnv,
EpisodicLifeEnv,
FireOnLifeStarttEnv,
)
from rl_algo_impls.wrappers.episode_record_video import EpisodeRecordVideo
from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
from rl_algo_impls.wrappers.hwc_to_chw_observation import HwcToChwObservation
from rl_algo_impls.wrappers.initial_step_truncate_wrapper import (
InitialStepTruncateWrapper,
)
from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv
from rl_algo_impls.wrappers.no_reward_timeout import NoRewardTimeout
from rl_algo_impls.wrappers.noop_env_seed import NoopEnvSeed
from rl_algo_impls.wrappers.normalize import NormalizeObservation, NormalizeReward
from rl_algo_impls.wrappers.sync_vector_env_render_compat import (
SyncVectorEnvRenderCompat,
)
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
from rl_algo_impls.wrappers.video_compat_wrapper import VideoCompatWrapper
def make_vec_env(
config: Config,
hparams: EnvHyperparams,
training: bool = True,
render: bool = False,
normalize_load_path: Optional[str] = None,
tb_writer: Optional[SummaryWriter] = None,
) -> VecEnv:
(
env_type,
n_envs,
frame_stack,
make_kwargs,
no_reward_timeout_steps,
no_reward_fire_steps,
vec_env_class,
normalize,
normalize_kwargs,
rolling_length,
train_record_video,
video_step_interval,
initial_steps_to_truncate,
clip_atari_rewards,
normalize_type,
mask_actions,
_, # bots
_, # self_play_kwargs
_, # selfplay_bots
) = astuple(hparams)
import_for_env_id(config.env_id)
seed = config.seed(training=training)
make_kwargs = make_kwargs.copy() if make_kwargs is not None else {}
if is_bullet_env(config) and render:
make_kwargs["render"] = True
if is_car_racing(config):
make_kwargs["verbose"] = 0
if is_gym_procgen(config) and not render:
make_kwargs["render_mode"] = "rgb_array"
def make(idx: int) -> Callable[[], gym.Env]:
def _make() -> gym.Env:
env = gym.make(config.env_id, **make_kwargs)
env = gym.wrappers.RecordEpisodeStatistics(env)
env = VideoCompatWrapper(env)
if training and train_record_video and idx == 0:
env = EpisodeRecordVideo(
env,
config.video_prefix,
step_increment=n_envs,
video_step_interval=int(video_step_interval),
)
if training and initial_steps_to_truncate:
env = InitialStepTruncateWrapper(
env, idx * initial_steps_to_truncate // n_envs
)
if is_atari(config): # type: ignore
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
env = EpisodicLifeEnv(env, training=training)
action_meanings = env.unwrapped.get_action_meanings()
if "FIRE" in action_meanings: # type: ignore
env = FireOnLifeStarttEnv(env, action_meanings.index("FIRE"))
if clip_atari_rewards:
env = ClipRewardEnv(env, training=training)
env = ResizeObservation(env, (84, 84))
env = GrayScaleObservation(env, keep_dim=False)
env = FrameStack(env, frame_stack)
elif is_car_racing(config):
env = ResizeObservation(env, (64, 64))
env = GrayScaleObservation(env, keep_dim=False)
env = FrameStack(env, frame_stack)
elif is_gym_procgen(config):
# env = GrayScaleObservation(env, keep_dim=False)
env = NoopEnvSeed(env)
env = HwcToChwObservation(env)
if frame_stack > 1:
env = FrameStack(env, frame_stack)
elif is_microrts(config):
env = HwcToChwObservation(env)
if no_reward_timeout_steps:
env = NoRewardTimeout(
env, no_reward_timeout_steps, n_fire_steps=no_reward_fire_steps
)
if seed is not None:
env.seed(seed + idx)
env.action_space.seed(seed + idx)
env.observation_space.seed(seed + idx)
return env
return _make
if env_type == "sb3vec":
VecEnvClass = {"sync": DummyVecEnv, "async": SubprocVecEnv}[vec_env_class]
elif env_type == "gymvec":
VecEnvClass = {"sync": SyncVectorEnv, "async": AsyncVectorEnv}[vec_env_class]
else:
raise ValueError(f"env_type {env_type} unsupported")
envs = VecEnvClass([make(i) for i in range(n_envs)])
if env_type == "gymvec" and vec_env_class == "sync":
envs = SyncVectorEnvRenderCompat(envs)
if env_type == "sb3vec":
envs = IsVectorEnv(envs)
if mask_actions:
envs = SingleActionMaskWrapper(envs)
if training:
assert tb_writer
envs = EpisodeStatsWriter(
envs, tb_writer, training=training, rolling_length=rolling_length
)
if normalize:
if normalize_type is None:
normalize_type = "sb3" if env_type == "sb3vec" else "gymlike"
normalize_kwargs = normalize_kwargs or {}
if normalize_type == "sb3":
if normalize_load_path:
envs = VecNormalize.load(
os.path.join(normalize_load_path, VEC_NORMALIZE_FILENAME),
envs, # type: ignore
)
else:
envs = VecNormalize(
envs, # type: ignore
training=training,
**normalize_kwargs,
)
if not training:
envs.norm_reward = False
elif normalize_type == "gymlike":
if normalize_kwargs.get("norm_obs", True):
envs = NormalizeObservation(
envs, training=training, clip=normalize_kwargs.get("clip_obs", 10.0)
)
if training and normalize_kwargs.get("norm_reward", True):
envs = NormalizeReward(
envs,
training=training,
clip=normalize_kwargs.get("clip_reward", 10.0),
)
else:
raise ValueError(
f"normalize_type {normalize_type} not supported (sb3 or gymlike)"
)
return envs
|