sgoodfriend's picture
DQN playing PongNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b
3fd02ed
raw
history blame contribute delete
No virus
4.25 kB
from dataclasses import astuple
from typing import Optional
import gym
import numpy as np
from torch.utils.tensorboard.writer import SummaryWriter
from rl_algo_impls.runner.config import Config, EnvHyperparams
from rl_algo_impls.wrappers.action_mask_wrapper import MicrortsMaskWrapper
from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter
from rl_algo_impls.wrappers.hwc_to_chw_observation import HwcToChwObservation
from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv
from rl_algo_impls.wrappers.microrts_stats_recorder import MicrortsStatsRecorder
from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv
def make_microrts_env(
config: Config,
hparams: EnvHyperparams,
training: bool = True,
render: bool = False,
normalize_load_path: Optional[str] = None,
tb_writer: Optional[SummaryWriter] = None,
) -> VecEnv:
import gym_microrts
from gym_microrts import microrts_ai
from rl_algo_impls.shared.vec_env.microrts_compat import (
MicroRTSGridModeSharedMemVecEnvCompat,
MicroRTSGridModeVecEnvCompat,
)
(
_, # env_type
n_envs,
_, # frame_stack
make_kwargs,
_, # no_reward_timeout_steps
_, # no_reward_fire_steps
_, # vec_env_class
_, # normalize
_, # normalize_kwargs,
rolling_length,
_, # train_record_video
_, # video_step_interval
_, # initial_steps_to_truncate
_, # clip_atari_rewards
_, # normalize_type
_, # mask_actions
bots,
self_play_kwargs,
selfplay_bots,
) = astuple(hparams)
seed = config.seed(training=training)
make_kwargs = make_kwargs or {}
self_play_kwargs = self_play_kwargs or {}
if "num_selfplay_envs" not in make_kwargs:
make_kwargs["num_selfplay_envs"] = 0
if "num_bot_envs" not in make_kwargs:
num_selfplay_envs = make_kwargs["num_selfplay_envs"]
if num_selfplay_envs:
num_bot_envs = (
n_envs
- make_kwargs["num_selfplay_envs"]
+ self_play_kwargs.get("num_old_policies", 0)
+ (len(selfplay_bots) if selfplay_bots else 0)
)
else:
num_bot_envs = n_envs
make_kwargs["num_bot_envs"] = num_bot_envs
if "reward_weight" in make_kwargs:
# Reward Weights:
# WinLossRewardFunction
# ResourceGatherRewardFunction
# ProduceWorkerRewardFunction
# ProduceBuildingRewardFunction
# AttackRewardFunction
# ProduceCombatUnitRewardFunction
make_kwargs["reward_weight"] = np.array(make_kwargs["reward_weight"])
if bots:
ai2s = []
for ai_name, n in bots.items():
for _ in range(n):
if len(ai2s) >= make_kwargs["num_bot_envs"]:
break
ai = getattr(microrts_ai, ai_name)
assert ai, f"{ai_name} not in microrts_ai"
ai2s.append(ai)
else:
ai2s = [microrts_ai.randomAI for _ in range(make_kwargs["num_bot_envs"])]
make_kwargs["ai2s"] = ai2s
if len(make_kwargs.get("map_paths", [])) < 2:
EnvClass = MicroRTSGridModeSharedMemVecEnvCompat
else:
EnvClass = MicroRTSGridModeVecEnvCompat
envs = EnvClass(**make_kwargs)
envs = HwcToChwObservation(envs)
envs = IsVectorEnv(envs)
envs = MicrortsMaskWrapper(envs)
if self_play_kwargs:
if selfplay_bots:
self_play_kwargs["selfplay_bots"] = selfplay_bots
envs = SelfPlayWrapper(envs, config, **self_play_kwargs)
if seed is not None:
envs.action_space.seed(seed)
envs.observation_space.seed(seed)
envs = gym.wrappers.RecordEpisodeStatistics(envs)
envs = MicrortsStatsRecorder(envs, config.algo_hyperparams.get("gamma", 0.99), bots)
if training:
assert tb_writer
envs = EpisodeStatsWriter(
envs,
tb_writer,
training=training,
rolling_length=rolling_length,
additional_keys_to_log=config.additional_keys_to_log,
)
return envs