from dataclasses import astuple from typing import Optional import gym import numpy as np from torch.utils.tensorboard.writer import SummaryWriter from rl_algo_impls.runner.config import Config, EnvHyperparams from rl_algo_impls.wrappers.action_mask_wrapper import MicrortsMaskWrapper from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter from rl_algo_impls.wrappers.hwc_to_chw_observation import HwcToChwObservation from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv from rl_algo_impls.wrappers.microrts_stats_recorder import MicrortsStatsRecorder from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv def make_microrts_env( config: Config, hparams: EnvHyperparams, training: bool = True, render: bool = False, normalize_load_path: Optional[str] = None, tb_writer: Optional[SummaryWriter] = None, ) -> VecEnv: import gym_microrts from gym_microrts import microrts_ai from rl_algo_impls.shared.vec_env.microrts_compat import ( MicroRTSGridModeVecEnvCompat, ) ( _, # env_type n_envs, _, # frame_stack make_kwargs, _, # no_reward_timeout_steps _, # no_reward_fire_steps _, # vec_env_class _, # normalize _, # normalize_kwargs, rolling_length, _, # train_record_video _, # video_step_interval _, # initial_steps_to_truncate _, # clip_atari_rewards _, # normalize_type _, # mask_actions bots, ) = astuple(hparams) seed = config.seed(training=training) make_kwargs = make_kwargs or {} if "num_selfplay_envs" not in make_kwargs: make_kwargs["num_selfplay_envs"] = 0 if "num_bot_envs" not in make_kwargs: make_kwargs["num_bot_envs"] = n_envs - make_kwargs["num_selfplay_envs"] if "reward_weight" in make_kwargs: make_kwargs["reward_weight"] = np.array(make_kwargs["reward_weight"]) if bots: ai2s = [] for ai_name, n in bots.items(): for _ in range(n): if len(ai2s) >= make_kwargs["num_bot_envs"]: break ai = getattr(microrts_ai, ai_name) assert ai, f"{ai_name} not in microrts_ai" ai2s.append(ai) else: ai2s = [microrts_ai.randomAI for _ in make_kwargs["num_bot_envs"]] make_kwargs["ai2s"] = ai2s envs = MicroRTSGridModeVecEnvCompat(**make_kwargs) envs = HwcToChwObservation(envs) envs = IsVectorEnv(envs) envs = MicrortsMaskWrapper(envs) if seed is not None: envs.action_space.seed(seed) envs.observation_space.seed(seed) envs = gym.wrappers.RecordEpisodeStatistics(envs) envs = MicrortsStatsRecorder(envs, config.algo_hyperparams.get("gamma", 0.99)) if training: assert tb_writer envs = EpisodeStatsWriter( envs, tb_writer, training=training, rolling_length=rolling_length, additional_keys_to_log=config.additional_keys_to_log, ) return envs