sgoodfriend's picture
PPO playing QbertNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
a861765
from dataclasses import astuple
from typing import Optional
import gym
import numpy as np
from torch.utils.tensorboard.writer import SummaryWriter
from rl_algo_impls.runner.config import Config, EnvHyperparams
from rl_algo_impls.wrappers.action_mask_wrapper import MicrortsMaskWrapper
from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
from rl_algo_impls.wrappers.hwc_to_chw_observation import HwcToChwObservation
from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv
from rl_algo_impls.wrappers.microrts_stats_recorder import MicrortsStatsRecorder
from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
def make_microrts_env(
config: Config,
hparams: EnvHyperparams,
training: bool = True,
render: bool = False,
normalize_load_path: Optional[str] = None,
tb_writer: Optional[SummaryWriter] = None,
) -> VecEnv:
import gym_microrts
from gym_microrts import microrts_ai
from rl_algo_impls.shared.vec_env.microrts_compat import (
MicroRTSGridModeSharedMemVecEnvCompat,
MicroRTSGridModeVecEnvCompat,
)
(
_, # env_type
n_envs,
_, # frame_stack
make_kwargs,
_, # no_reward_timeout_steps
_, # no_reward_fire_steps
_, # vec_env_class
_, # normalize
_, # normalize_kwargs,
rolling_length,
_, # train_record_video
_, # video_step_interval
_, # initial_steps_to_truncate
_, # clip_atari_rewards
_, # normalize_type
_, # mask_actions
bots,
self_play_kwargs,
selfplay_bots,
) = astuple(hparams)
seed = config.seed(training=training)
make_kwargs = make_kwargs or {}
self_play_kwargs = self_play_kwargs or {}
if "num_selfplay_envs" not in make_kwargs:
make_kwargs["num_selfplay_envs"] = 0
if "num_bot_envs" not in make_kwargs:
num_selfplay_envs = make_kwargs["num_selfplay_envs"]
if num_selfplay_envs:
num_bot_envs = (
n_envs
- make_kwargs["num_selfplay_envs"]
+ self_play_kwargs.get("num_old_policies", 0)
+ (len(selfplay_bots) if selfplay_bots else 0)
)
else:
num_bot_envs = n_envs
make_kwargs["num_bot_envs"] = num_bot_envs
if "reward_weight" in make_kwargs:
# Reward Weights:
# WinLossRewardFunction
# ResourceGatherRewardFunction
# ProduceWorkerRewardFunction
# ProduceBuildingRewardFunction
# AttackRewardFunction
# ProduceCombatUnitRewardFunction
make_kwargs["reward_weight"] = np.array(make_kwargs["reward_weight"])
if bots:
ai2s = []
for ai_name, n in bots.items():
for _ in range(n):
if len(ai2s) >= make_kwargs["num_bot_envs"]:
break
ai = getattr(microrts_ai, ai_name)
assert ai, f"{ai_name} not in microrts_ai"
ai2s.append(ai)
else:
ai2s = [microrts_ai.randomAI for _ in range(make_kwargs["num_bot_envs"])]
make_kwargs["ai2s"] = ai2s
if len(make_kwargs.get("map_paths", [])) < 2:
EnvClass = MicroRTSGridModeSharedMemVecEnvCompat
else:
EnvClass = MicroRTSGridModeVecEnvCompat
envs = EnvClass(**make_kwargs)
envs = HwcToChwObservation(envs)
envs = IsVectorEnv(envs)
envs = MicrortsMaskWrapper(envs)
if self_play_kwargs:
if selfplay_bots:
self_play_kwargs["selfplay_bots"] = selfplay_bots
envs = SelfPlayWrapper(envs, config, **self_play_kwargs)
if seed is not None:
envs.action_space.seed(seed)
envs.observation_space.seed(seed)
envs = gym.wrappers.RecordEpisodeStatistics(envs)
envs = MicrortsStatsRecorder(envs, config.algo_hyperparams.get("gamma", 0.99), bots)
if training:
assert tb_writer
envs = EpisodeStatsWriter(
envs,
tb_writer,
training=training,
rolling_length=rolling_length,
additional_keys_to_log=config.additional_keys_to_log,
)
return envs