diff --git a/README.md b/README.md index 28e86d387f67c0d375a7d0305357c6bd285f212e..ef7e51198c551e67656d29cb63a72d3944572221 100644 --- a/README.md +++ b/README.md @@ -23,17 +23,17 @@ model-index: This is a trained model of a **VPG** agent playing **HalfCheetahBulletEnv-v0** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo. -All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/7lx79bf0. +All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/ysd5gj7p. ## Training Results -This model was trained from 3 trainings of **VPG** agents using different initial seeds. These agents were trained by checking out [0511de3](https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std). +This model was trained from 3 trainings of **VPG** agents using different initial seeds. These agents were trained by checking out [983cb75](https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std). | algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url | |:-------|:------------------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------| -| vpg | HalfCheetahBulletEnv-v0 | 1 | 1783.45 | 39.4267 | 10 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/0nbnf5de) | -| vpg | HalfCheetahBulletEnv-v0 | 2 | 1356.92 | 61.5334 | 10 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/1b644sa4) | -| vpg | HalfCheetahBulletEnv-v0 | 3 | 1012.63 | 1125.66 | 10 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/gpb5qxmi) | +| vpg | HalfCheetahBulletEnv-v0 | 1 | 1783.45 | 39.4267 | 10 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/bth9z4or) | +| vpg | HalfCheetahBulletEnv-v0 | 2 | 1356.92 | 61.5334 | 10 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/aick6t7d) | +| vpg | HalfCheetahBulletEnv-v0 | 3 | 1012.63 | 1125.66 | 10 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/6vlakoui) | ### Prerequisites: Weights & Biases (WandB) @@ -53,10 +53,10 @@ login`. Note: While the model state dictionary and hyperaparameters are saved, the latest implementation could be sufficiently different to not be able to reproduce similar results. You might need to checkout the commit the agent was trained on: -[0511de3](https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c). +[983cb75](https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b). ``` # Downloads the model, sets hyperparameters, and runs agent for 3 episodes -python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/0nbnf5de +python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/bth9z4or ``` Setup hasn't been completely worked out yet, so you might be best served by using Google @@ -68,7 +68,7 @@ notebook. ## Training If you want the highest chance to reproduce these results, you'll want to checkout the -commit the agent was trained on: [0511de3](https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c). While +commit the agent was trained on: [983cb75](https://github.com/sgoodfriend/rl-algo-impls/tree/983cb75e43e51cf4ef57f177194ab9a4a1a8808b). While training is deterministic, different hardware will give different results. ``` @@ -83,7 +83,7 @@ notebook. ## Benchmarking (with Lambda Labs instance) -This and other models from https://api.wandb.ai/links/sgoodfriend/7lx79bf0 were generated by running a script on a Lambda +This and other models from https://api.wandb.ai/links/sgoodfriend/ysd5gj7p were generated by running a script on a Lambda Labs instance. In a Lambda Labs instance terminal: ``` git clone git@github.com:sgoodfriend/rl-algo-impls.git @@ -120,7 +120,8 @@ env: HalfCheetahBulletEnv-v0 env_hyperparams: normalize: true env_id: null -eval_params: {} +eval_hyperparams: {} +microrts_reward_decay_callback: false n_timesteps: 2000000 policy_hyperparams: hidden_sizes: @@ -132,9 +133,9 @@ wandb_entity: null wandb_group: null wandb_project_name: rl-algo-impls-benchmarks wandb_tags: -- benchmark_0511de3 -- host_152-67-249-42 +- benchmark_983cb75 +- host_129-159-43-75 - branch_main -- v0.0.8 +- v0.0.9 ``` diff --git a/environment.yml b/environment.yml index a5f2efb65d96b38ee9bdba44fb9b5dcd26857a72..79a1d26cd61bcc17c4ba8eff22712d8e1c278145 100644 --- a/environment.yml +++ b/environment.yml @@ -4,7 +4,7 @@ channels: - conda-forge - nodefaults dependencies: - - python>=3.8, <3.11 + - python>=3.8, <3.10 - mamba - pip - pytorch diff --git a/pyproject.toml b/pyproject.toml index dcfbed2c67b57c57c58f907284937021a4d716d7..fc2d58188564db3f68eaff9dd8479802ff4d3399 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "rl_algo_impls" -version = "0.0.8" +version = "0.0.9" description = "Implementations of reinforcement learning algorithms" authors = [ {name = "Scott Goodfriend", email = "goodfriend.scott@gmail.com"}, @@ -56,14 +56,17 @@ procgen = [ "glfw >= 1.12.0, < 1.13", "procgen; platform_machine=='x86_64'", ] -microrts-old = [ +microrts-ppo = [ "numpy < 1.24.0", # Support for gym-microrts < 0.6.0 "gym-microrts == 0.2.0", # Match ppo-implementation-details ] -microrts = [ +microrts-paper = [ "numpy < 1.24.0", # Support for gym-microrts < 0.6.0 "gym-microrts == 0.3.2", ] +microrts = [ + "gym-microrts", +] jupyter = [ "jupyter", "notebook" diff --git a/replay.meta.json b/replay.meta.json index 29f16aa8ccc8771c8c19006934fe34a228ffec46..fa6e522c665745c83881756c212f3184dcde6be2 100644 --- a/replay.meta.json +++ b/replay.meta.json @@ -1 +1 @@ -{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "320x240", "-pix_fmt", "rgb24", "-framerate", "60", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "60", "/tmp/tmpee7w32ac/vpg-HalfCheetahBulletEnv-v0/replay.mp4"]}, "episode": {"r": 1776.6221923828125, "l": 1000, "t": 28.162648}} \ No newline at end of file +{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 4.2.7-0ubuntu0.1 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with gcc 9 (Ubuntu 9.4.0-1ubuntu1~20.04.1)\\nconfiguration: --prefix=/usr --extra-version=0ubuntu0.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-avresample --disable-filter=resample --enable-avisynth --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librsvg --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwavpack --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-nvenc --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared\\nlibavutil 56. 31.100 / 56. 31.100\\nlibavcodec 58. 54.100 / 58. 54.100\\nlibavformat 58. 29.100 / 58. 29.100\\nlibavdevice 58. 8.100 / 58. 8.100\\nlibavfilter 7. 57.100 / 7. 57.100\\nlibavresample 4. 0. 0 / 4. 0. 0\\nlibswscale 5. 5.100 / 5. 5.100\\nlibswresample 3. 5.100 / 3. 5.100\\nlibpostproc 55. 5.100 / 55. 5.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "320x240", "-pix_fmt", "rgb24", "-framerate", "60", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "60", "/tmp/tmpdbf1h42i/vpg-HalfCheetahBulletEnv-v0/replay.mp4"]}, "episodes": [{"r": 1776.6221923828125, "l": 1000, "t": 28.189788}]} \ No newline at end of file diff --git a/rl_algo_impls/a2c/a2c.py b/rl_algo_impls/a2c/a2c.py index 18e77c5845c9fd6149611ed9bc15cc451b1e3cd9..6a1bb842d0035ebdccb23ce66f5e3a9b53cd3010 100644 --- a/rl_algo_impls/a2c/a2c.py +++ b/rl_algo_impls/a2c/a2c.py @@ -1,23 +1,23 @@ import logging +from time import perf_counter +from typing import List, Optional, TypeVar + import numpy as np import torch import torch.nn as nn import torch.nn.functional as F - -from time import perf_counter from torch.utils.tensorboard.writer import SummaryWriter -from typing import Optional, TypeVar from rl_algo_impls.shared.algorithm import Algorithm -from rl_algo_impls.shared.callbacks.callback import Callback +from rl_algo_impls.shared.callbacks import Callback from rl_algo_impls.shared.gae import compute_advantages -from rl_algo_impls.shared.policy.on_policy import ActorCritic +from rl_algo_impls.shared.policy.actor_critic import ActorCritic from rl_algo_impls.shared.schedule import schedule, update_learning_rate from rl_algo_impls.shared.stats import log_scalars from rl_algo_impls.wrappers.vectorable_wrapper import ( VecEnv, - single_observation_space, single_action_space, + single_observation_space, ) A2CSelf = TypeVar("A2CSelf", bound="A2C") @@ -70,7 +70,7 @@ class A2C(Algorithm): def learn( self: A2CSelf, train_timesteps: int, - callback: Optional[Callback] = None, + callbacks: Optional[List[Callback]] = None, total_timesteps: Optional[int] = None, start_timesteps: int = 0, ) -> A2CSelf: @@ -193,8 +193,10 @@ class A2C(Algorithm): timesteps_elapsed, ) - if callback: - if not callback.on_step(timesteps_elapsed=rollout_steps): + if callbacks: + if not all( + c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks + ): logging.info( f"Callback terminated training at {timesteps_elapsed} timesteps" ) diff --git a/rl_algo_impls/a2c/optimize.py b/rl_algo_impls/a2c/optimize.py index cd3cb807f8de22634dab26b1d525a484b32ae7d5..0a65f09359261de042a581e197c8dc763b5814c8 100644 --- a/rl_algo_impls/a2c/optimize.py +++ b/rl_algo_impls/a2c/optimize.py @@ -1,10 +1,10 @@ -import optuna - from copy import deepcopy -from rl_algo_impls.runner.config import Config, Hyperparams, EnvHyperparams -from rl_algo_impls.shared.vec_env import make_eval_env +import optuna + +from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams from rl_algo_impls.shared.policy.optimize_on_policy import sample_on_policy_hyperparams +from rl_algo_impls.shared.vec_env import make_eval_env from rl_algo_impls.tuning.optimize_env import sample_env_hyperparams @@ -16,7 +16,11 @@ def sample_params( hyperparams = deepcopy(base_hyperparams) base_env_hyperparams = EnvHyperparams(**hyperparams.env_hyperparams) - env = make_eval_env(base_config, base_env_hyperparams, override_n_envs=1) + env = make_eval_env( + base_config, + base_env_hyperparams, + override_hparams={"n_envs": 1}, + ) # env_hyperparams env_hyperparams = sample_env_hyperparams(trial, hyperparams.env_hyperparams, env) diff --git a/rl_algo_impls/dqn/dqn.py b/rl_algo_impls/dqn/dqn.py index 57cd3e074444352d003d6f60ca95a5add79467b4..1dafcda9f01661d61ea20ed392cdeebeff589c34 100644 --- a/rl_algo_impls/dqn/dqn.py +++ b/rl_algo_impls/dqn/dqn.py @@ -1,18 +1,19 @@ import copy -import numpy as np +import logging import random +from collections import deque +from typing import List, NamedTuple, Optional, TypeVar + +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F - -from collections import deque from torch.optim import Adam from torch.utils.tensorboard.writer import SummaryWriter -from typing import NamedTuple, Optional, TypeVar from rl_algo_impls.dqn.policy import DQNPolicy from rl_algo_impls.shared.algorithm import Algorithm -from rl_algo_impls.shared.callbacks.callback import Callback +from rl_algo_impls.shared.callbacks import Callback from rl_algo_impls.shared.schedule import linear_schedule from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs @@ -118,7 +119,7 @@ class DQN(Algorithm): self.max_grad_norm = max_grad_norm def learn( - self: DQNSelf, total_timesteps: int, callback: Optional[Callback] = None + self: DQNSelf, total_timesteps: int, callbacks: Optional[List[Callback]] = None ) -> DQNSelf: self.policy.train(True) obs = self.env.reset() @@ -140,8 +141,14 @@ class DQN(Algorithm): if steps_since_target_update >= self.target_update_interval: self._update_target() steps_since_target_update = 0 - if callback: - callback.on_step(timesteps_elapsed=rollout_steps) + if callbacks: + if not all( + c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks + ): + logging.info( + f"Callback terminated training at {timesteps_elapsed} timesteps" + ) + break return self def train(self) -> None: diff --git a/rl_algo_impls/dqn/q_net.py b/rl_algo_impls/dqn/q_net.py index 4b2f556483673c5a428c6820f156cc63fce6a3f6..a32036b196e548d179e1941a84ef798ea9d2c0df 100644 --- a/rl_algo_impls/dqn/q_net.py +++ b/rl_algo_impls/dqn/q_net.py @@ -6,7 +6,7 @@ import torch.nn as nn from gym.spaces import Discrete from rl_algo_impls.shared.encoder import Encoder -from rl_algo_impls.shared.module.module import mlp +from rl_algo_impls.shared.module.utils import mlp class QNetwork(nn.Module): diff --git a/rl_algo_impls/huggingface_publish.py b/rl_algo_impls/huggingface_publish.py index c89a4eecde5d7043c43477312b9dc743a591f126..00f6a23ec6a68b5f0014ffc9d0f367f8deeeac5a 100644 --- a/rl_algo_impls/huggingface_publish.py +++ b/rl_algo_impls/huggingface_publish.py @@ -3,24 +3,23 @@ import os os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" import argparse -import requests import shutil import subprocess import tempfile -import wandb -import wandb.apis.public - from typing import List, Optional +import requests +import wandb.apis.public from huggingface_hub.hf_api import HfApi, upload_folder from huggingface_hub.repocard import metadata_save from pyvirtualdisplay.display import Display +import wandb from rl_algo_impls.publish.markdown_format import EvalTableData, model_card_text from rl_algo_impls.runner.config import EnvHyperparams from rl_algo_impls.runner.evaluate import EvalArgs, evaluate_model -from rl_algo_impls.shared.vec_env import make_eval_env from rl_algo_impls.shared.callbacks.eval_callback import evaluate +from rl_algo_impls.shared.vec_env import make_eval_env from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder @@ -134,7 +133,7 @@ def publish( make_eval_env( config, EnvHyperparams(**config.env_hyperparams), - override_n_envs=1, + override_hparams={"n_envs": 1}, normalize_load_path=model_path, ), os.path.join(repo_dir_path, "replay"), @@ -144,7 +143,7 @@ def publish( video_env, policy, 1, - deterministic=config.eval_params.get("deterministic", True), + deterministic=config.eval_hyperparams.get("deterministic", True), ) api = HfApi() diff --git a/rl_algo_impls/hyperparams/a2c.yml b/rl_algo_impls/hyperparams/a2c.yml index a15f29300f71cfb7cf21419d444fc7df15c76092..33170bbf41ad1a7575ffd67cb78cd494a2aa19f7 100644 --- a/rl_algo_impls/hyperparams/a2c.yml +++ b/rl_algo_impls/hyperparams/a2c.yml @@ -101,31 +101,32 @@ HopperBulletEnv-v0: CarRacing-v0: n_timesteps: !!float 4e6 env_hyperparams: - n_envs: 16 + n_envs: 4 frame_stack: 4 normalize: true normalize_kwargs: norm_obs: false norm_reward: true policy_hyperparams: - use_sde: false - log_std_init: -1.3502584927786276 + use_sde: true + log_std_init: -4.839609092563 init_layers_orthogonal: true activation_fn: tanh share_features_extractor: false cnn_flatten_dim: 256 hidden_sizes: [256] algo_hyperparams: - n_steps: 16 - learning_rate: 0.000025630993245026736 - learning_rate_decay: linear - gamma: 0.99957617037542 - gae_lambda: 0.949455676599436 - ent_coef: !!float 1.707983205298309e-7 - vf_coef: 0.10428178193833336 - max_grad_norm: 0.5406643389792273 - normalize_advantage: true + n_steps: 64 + learning_rate: 0.000018971962220405576 + gamma: 0.9942776405534832 + gae_lambda: 0.9549244758833236 + ent_coef: 0.0000015666550584860516 + ent_coef_decay: linear + vf_coef: 0.12164696385898476 + max_grad_norm: 2.2574480552177127 + normalize_advantage: false use_rms_prop: false + sde_sample_freq: 16 _atari: &atari-defaults n_timesteps: !!float 1e7 diff --git a/rl_algo_impls/hyperparams/dqn.yml b/rl_algo_impls/hyperparams/dqn.yml index 4274deaf842e186873808dfb1f3da8da4d36440e..fbe8ab54cbc29d2d698081e6d5fca8563db7b085 100644 --- a/rl_algo_impls/hyperparams/dqn.yml +++ b/rl_algo_impls/hyperparams/dqn.yml @@ -15,7 +15,7 @@ CartPole-v1: &cartpole-defaults gradient_steps: 128 exploration_fraction: 0.16 exploration_final_eps: 0.04 - eval_params: + eval_hyperparams: step_freq: !!float 1e4 CartPole-v0: @@ -76,7 +76,7 @@ LunarLander-v2: exploration_fraction: 0.12 exploration_final_eps: 0.1 max_grad_norm: 0.5 - eval_params: + eval_hyperparams: step_freq: 25_000 _atari: &atari-defaults @@ -97,7 +97,7 @@ _atari: &atari-defaults gradient_steps: 2 exploration_fraction: 0.1 exploration_final_eps: 0.01 - eval_params: + eval_hyperparams: deterministic: false PongNoFrameskip-v4: diff --git a/rl_algo_impls/hyperparams/ppo.yml b/rl_algo_impls/hyperparams/ppo.yml index ec533e646a89540a958f8e57225f942ba6db875a..9d7f34943b7ea24d800b5c8f158724851e955003 100644 --- a/rl_algo_impls/hyperparams/ppo.yml +++ b/rl_algo_impls/hyperparams/ppo.yml @@ -13,7 +13,7 @@ CartPole-v1: &cartpole-defaults learning_rate_decay: linear clip_range: 0.2 clip_range_decay: linear - eval_params: + eval_hyperparams: step_freq: !!float 2.5e4 CartPole-v0: @@ -52,7 +52,7 @@ MountainCarContinuous-v0: gae_lambda: 0.9 max_grad_norm: 5 vf_coef: 0.19 - eval_params: + eval_hyperparams: step_freq: 5000 Acrobot-v1: @@ -162,7 +162,7 @@ _atari: &atari-defaults clip_range_decay: linear vf_coef: 0.5 ent_coef: 0.01 - eval_params: + eval_hyperparams: deterministic: false _norm-rewards-atari: &norm-rewards-atari-default @@ -228,7 +228,7 @@ _microrts: µrts-defaults clip_range_decay: none clip_range_vf: 0.1 ppo2_vf_coef_halving: true - eval_params: + eval_hyperparams: µrts-eval-defaults deterministic: false # Good idea because MultiCategorical mode isn't great _no-mask-microrts: &no-mask-microrts-defaults @@ -252,15 +252,15 @@ MicrortsRandomEnemyShapedReward3-v1-NoMask: _microrts_ai: µrts-ai-defaults <<: *microrts-defaults n_timesteps: !!float 100e6 - additional_keys_to_log: ["microrts_stats"] + additional_keys_to_log: ["microrts_stats", "microrts_results"] env_hyperparams: µrts-ai-env-defaults n_envs: 24 env_type: microrts - make_kwargs: + make_kwargs: µrts-ai-env-make-kwargs-defaults num_selfplay_envs: 0 - max_steps: 2000 + max_steps: 4000 render_theme: 2 - map_path: maps/16x16/basesWorkers16x16.xml + map_paths: [maps/16x16/basesWorkers16x16.xml] reward_weight: [10.0, 1.0, 1.0, 0.2, 1.0, 4.0] policy_hyperparams: µrts-ai-policy-defaults <<: *microrts-policy-defaults @@ -278,6 +278,15 @@ _microrts_ai: µrts-ai-defaults max_grad_norm: 0.5 clip_range: 0.1 clip_range_vf: 0.1 + eval_hyperparams: µrts-ai-eval-defaults + <<: *microrts-eval-defaults + score_function: mean + max_video_length: 4000 + env_overrides: µrts-ai-eval-env-overrides + make_kwargs: + <<: *microrts-ai-env-make-kwargs-defaults + max_steps: 4000 + reward_weight: [1.0, 0, 0, 0, 0, 0] MicrortsAttackPassiveEnemySparseReward-v3: <<: *microrts-ai-defaults @@ -305,6 +314,18 @@ enc-dec-MicrortsDefeatRandomEnemySparseReward-v3: actor_head_style: gridnet_decoder v_hidden_sizes: [128] +unet-MicrortsDefeatRandomEnemySparseReward-v3: + <<: *microrts-random-ai-defaults + # device: cpu + policy_hyperparams: + <<: *microrts-ai-policy-defaults + actor_head_style: unet + v_hidden_sizes: [256, 128] + algo_hyperparams: + <<: *microrts-ai-algo-defaults + learning_rate: !!float 2.5e-4 + learning_rate_decay: spike + MicrortsDefeatCoacAIShaped-v3: µrts-coacai-defaults <<: *microrts-ai-defaults env_id: MicrortsDefeatCoacAIShaped-v3 # Workaround to keep model name simple @@ -313,6 +334,27 @@ MicrortsDefeatCoacAIShaped-v3: µrts-coacai-defaults <<: *microrts-ai-env-defaults bots: coacAI: 24 + eval_hyperparams: µrts-coacai-eval-defaults + <<: *microrts-ai-eval-defaults + step_freq: !!float 1e6 + n_episodes: 26 + env_overrides: µrts-coacai-eval-env-overrides + <<: *microrts-ai-eval-env-overrides + n_envs: 26 + bots: + coacAI: 2 + randomBiasedAI: 2 + randomAI: 2 + passiveAI: 2 + workerRushAI: 2 + lightRushAI: 2 + naiveMCTSAI: 2 + mixedBot: 2 + rojo: 2 + izanagi: 2 + tiamat: 2 + droplet: 2 + guidedRojoA3N: 2 MicrortsDefeatCoacAIShaped-v3-diverseBots: µrts-diverse-defaults <<: *microrts-coacai-defaults @@ -325,6 +367,7 @@ MicrortsDefeatCoacAIShaped-v3-diverseBots: µrts-diverse-defaults workerRushAI: 2 enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots: + µrts-env-dec-diverse-defaults <<: *microrts-diverse-defaults policy_hyperparams: <<: *microrts-ai-policy-defaults @@ -332,6 +375,76 @@ enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots: actor_head_style: gridnet_decoder v_hidden_sizes: [128] +debug-enc-dec-MicrortsDefeatCoacAIShaped-v3-diverseBots: + <<: *microrts-env-dec-diverse-defaults + n_timesteps: !!float 1e6 + +unet-MicrortsDefeatCoacAIShaped-v3-diverseBots: µrts-unet-defaults + <<: *microrts-diverse-defaults + policy_hyperparams: + <<: *microrts-ai-policy-defaults + actor_head_style: unet + v_hidden_sizes: [256, 128] + algo_hyperparams: µrts-unet-algo-defaults + <<: *microrts-ai-algo-defaults + learning_rate: !!float 2.5e-4 + learning_rate_decay: spike + +Microrts-selfplay-unet: µrts-selfplay-defaults + <<: *microrts-unet-defaults + env_hyperparams: µrts-selfplay-env-defaults + <<: *microrts-ai-env-defaults + make_kwargs: µrts-selfplay-env-make-kwargs-defaults + <<: *microrts-ai-env-make-kwargs-defaults + num_selfplay_envs: 36 + self_play_kwargs: + num_old_policies: 12 + save_steps: 300000 + swap_steps: 6000 + swap_window_size: 4 + window: 33 + eval_hyperparams: µrts-selfplay-eval-defaults + <<: *microrts-coacai-eval-defaults + env_overrides: µrts-selfplay-eval-env-overrides + <<: *microrts-coacai-eval-env-overrides + self_play_kwargs: {} + +Microrts-selfplay-unet-winloss: µrts-selfplay-winloss-defaults + <<: *microrts-selfplay-defaults + env_hyperparams: + <<: *microrts-selfplay-env-defaults + make_kwargs: + <<: *microrts-selfplay-env-make-kwargs-defaults + reward_weight: [1.0, 0, 0, 0, 0, 0] + algo_hyperparams: µrts-selfplay-winloss-algo-defaults + <<: *microrts-unet-algo-defaults + gamma: 0.999 + +Microrts-selfplay-unet-decay: µrts-selfplay-decay-defaults + <<: *microrts-selfplay-defaults + microrts_reward_decay_callback: true + algo_hyperparams: + <<: *microrts-unet-algo-defaults + gamma_end: 0.999 + +Microrts-selfplay-unet-debug: µrts-selfplay-debug-defaults + <<: *microrts-selfplay-decay-defaults + eval_hyperparams: + <<: *microrts-selfplay-eval-defaults + step_freq: !!float 1e5 + env_overrides: + <<: *microrts-selfplay-eval-env-overrides + n_envs: 24 + bots: + coacAI: 12 + randomBiasedAI: 4 + workerRushAI: 4 + lightRushAI: 4 + +Microrts-selfplay-unet-debug-mps: + <<: *microrts-selfplay-debug-defaults + device: mps + HalfCheetahBulletEnv-v0: &pybullet-defaults n_timesteps: !!float 2e6 env_hyperparams: &pybullet-env-defaults @@ -418,7 +531,7 @@ _procgen: &procgen-defaults learning_rate: !!float 5e-4 # learning_rate_decay: linear vf_coef: 0.5 - eval_params: &procgen-eval-defaults + eval_hyperparams: &procgen-eval-defaults ignore_first_episode: true # deterministic: false step_freq: !!float 1e5 @@ -466,7 +579,7 @@ _procgen-hard: &procgen-hard-defaults batch_size: 8192 clip_range_decay: linear learning_rate_decay: linear - eval_params: + eval_hyperparams: <<: *procgen-eval-defaults step_freq: !!float 5e5 diff --git a/rl_algo_impls/hyperparams/vpg.yml b/rl_algo_impls/hyperparams/vpg.yml index 0193dc1ae5e791edf2f72bb40cd4a8d143c29b79..e24f970cc904810c1516a519f3f0a8172e30ef45 100644 --- a/rl_algo_impls/hyperparams/vpg.yml +++ b/rl_algo_impls/hyperparams/vpg.yml @@ -7,7 +7,7 @@ CartPole-v1: &cartpole-defaults gae_lambda: 1 val_lr: 0.01 train_v_iters: 80 - eval_params: + eval_hyperparams: step_freq: !!float 2.5e4 CartPole-v0: @@ -52,7 +52,7 @@ MountainCarContinuous-v0: val_lr: !!float 1e-3 train_v_iters: 80 max_grad_norm: 5 - eval_params: + eval_hyperparams: step_freq: 5000 Acrobot-v1: @@ -78,7 +78,7 @@ LunarLander-v2: val_lr: 0.0001 train_v_iters: 80 max_grad_norm: 0.5 - eval_params: + eval_hyperparams: deterministic: false BipedalWalker-v3: @@ -96,7 +96,7 @@ BipedalWalker-v3: val_lr: !!float 1e-4 train_v_iters: 80 max_grad_norm: 0.5 - eval_params: + eval_hyperparams: deterministic: false CarRacing-v0: @@ -169,7 +169,7 @@ FrozenLake-v1: val_lr: 0.01 train_v_iters: 80 max_grad_norm: 0.5 - eval_params: + eval_hyperparams: step_freq: !!float 5e4 n_episodes: 10 save_best: true @@ -193,5 +193,5 @@ _atari: &atari-defaults train_v_iters: 80 max_grad_norm: 0.5 ent_coef: 0.01 - eval_params: + eval_hyperparams: deterministic: false diff --git a/rl_algo_impls/optimize.py b/rl_algo_impls/optimize.py index 6ea2a57b9c8f36405161c726f8cb1f582313a48f..1c3bbcb55fdb6f583ee323b8150316faa1468349 100644 --- a/rl_algo_impls/optimize.py +++ b/rl_algo_impls/optimize.py @@ -2,37 +2,44 @@ import dataclasses import gc import inspect import logging +import os +from dataclasses import asdict, dataclass +from typing import Callable, List, NamedTuple, Optional, Sequence, Union + import numpy as np import optuna -import os import torch -import wandb - -from dataclasses import asdict, dataclass from optuna.pruners import HyperbandPruner from optuna.samplers import TPESampler from optuna.visualization import plot_optimization_history, plot_param_importances from torch.utils.tensorboard.writer import SummaryWriter -from typing import Callable, List, NamedTuple, Optional, Sequence, Union +import wandb from rl_algo_impls.a2c.optimize import sample_params as a2c_sample_params from rl_algo_impls.runner.config import Config, EnvHyperparams, RunArgs -from rl_algo_impls.shared.vec_env import make_env, make_eval_env from rl_algo_impls.runner.running_utils import ( + ALGOS, base_parser, - load_hyperparams, - set_seeds, get_device, - make_policy, - ALGOS, hparam_dict, + load_hyperparams, + make_policy, + set_seeds, +) +from rl_algo_impls.shared.callbacks import Callback +from rl_algo_impls.shared.callbacks.microrts_reward_decay_callback import ( + MicrortsRewardDecayCallback, ) from rl_algo_impls.shared.callbacks.optimize_callback import ( Evaluation, OptimizeCallback, evaluation, ) +from rl_algo_impls.shared.callbacks.self_play_callback import SelfPlayCallback from rl_algo_impls.shared.stats import EpisodesStats +from rl_algo_impls.shared.vec_env import make_env, make_eval_env +from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper +from rl_algo_impls.wrappers.vectorable_wrapper import find_wrapper @dataclass @@ -195,29 +202,38 @@ def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) - config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer ) device = get_device(config, env) - policy = make_policy(args.algo, env, device, **config.policy_hyperparams) + policy_factory = lambda: make_policy( + args.algo, env, device, **config.policy_hyperparams + ) + policy = policy_factory() algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams) eval_env = make_eval_env( config, EnvHyperparams(**config.env_hyperparams), - override_n_envs=study_args.n_eval_envs, + override_hparams={"n_envs": study_args.n_eval_envs}, ) - callback = OptimizeCallback( + optimize_callback = OptimizeCallback( policy, eval_env, trial, tb_writer, step_freq=config.n_timesteps // study_args.n_evaluations, n_episodes=study_args.n_eval_episodes, - deterministic=config.eval_params.get("deterministic", True), + deterministic=config.eval_hyperparams.get("deterministic", True), ) + callbacks: List[Callback] = [optimize_callback] + if config.hyperparams.microrts_reward_decay_callback: + callbacks.append(MicrortsRewardDecayCallback(config, env)) + selfPlayWrapper = find_wrapper(env, SelfPlayWrapper) + if selfPlayWrapper: + callbacks.append(SelfPlayCallback(policy, policy_factory, selfPlayWrapper)) try: - algo.learn(config.n_timesteps, callback=callback) + algo.learn(config.n_timesteps, callbacks=callbacks) - if not callback.is_pruned: - callback.evaluate() - if not callback.is_pruned: + if not optimize_callback.is_pruned: + optimize_callback.evaluate() + if not optimize_callback.is_pruned: policy.save(config.model_dir_path(best=False)) eval_stat: EpisodesStats = callback.last_eval_stat # type: ignore @@ -230,8 +246,8 @@ def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) - "hparam/last_result": eval_stat.score.mean - eval_stat.score.std, "hparam/train_mean": train_stat.score.mean, "hparam/train_result": train_stat.score.mean - train_stat.score.std, - "hparam/score": callback.last_score, - "hparam/is_pruned": callback.is_pruned, + "hparam/score": optimize_callback.last_score, + "hparam/is_pruned": optimize_callback.is_pruned, }, None, config.run_name(), @@ -239,13 +255,15 @@ def simple_optimize(trial: optuna.Trial, args: RunArgs, study_args: StudyArgs) - tb_writer.close() if wandb_enabled: - wandb.run.summary["state"] = "Pruned" if callback.is_pruned else "Complete" + wandb.run.summary["state"] = ( # type: ignore + "Pruned" if optimize_callback.is_pruned else "Complete" + ) wandb.finish(quiet=True) - if callback.is_pruned: + if optimize_callback.is_pruned: raise optuna.exceptions.TrialPruned() - return callback.last_score + return optimize_callback.last_score except AssertionError as e: logging.warning(e) return np.nan @@ -299,7 +317,10 @@ def stepwise_optimize( tb_writer=tb_writer, ) device = get_device(config, env) - policy = make_policy(arg.algo, env, device, **config.policy_hyperparams) + policy_factory = lambda: make_policy( + arg.algo, env, device, **config.policy_hyperparams + ) + policy = policy_factory() if i > 0: policy.load(config.model_dir_path()) algo = ALGOS[arg.algo]( @@ -310,7 +331,7 @@ def stepwise_optimize( config, EnvHyperparams(**config.env_hyperparams), normalize_load_path=config.model_dir_path() if i > 0 else None, - override_n_envs=study_args.n_eval_envs, + override_hparams={"n_envs": study_args.n_eval_envs}, ) start_timesteps = int(i * config.n_timesteps / study_args.n_evaluations) @@ -319,10 +340,22 @@ def stepwise_optimize( - start_timesteps ) + callbacks = [] + if config.hyperparams.microrts_reward_decay_callback: + callbacks.append( + MicrortsRewardDecayCallback( + config, env, start_timesteps=start_timesteps + ) + ) + selfPlayWrapper = find_wrapper(env, SelfPlayWrapper) + if selfPlayWrapper: + callbacks.append( + SelfPlayCallback(policy, policy_factory, selfPlayWrapper) + ) try: algo.learn( train_timesteps, - callback=None, + callbacks=callbacks, total_timesteps=config.n_timesteps, start_timesteps=start_timesteps, ) @@ -333,7 +366,7 @@ def stepwise_optimize( eval_env, tb_writer, study_args.n_eval_episodes, - config.eval_params.get("deterministic", True), + config.eval_hyperparams.get("deterministic", True), start_timesteps + train_timesteps, ) ) @@ -379,7 +412,7 @@ def stepwise_optimize( def wandb_finish(state: str) -> None: - wandb.run.summary["state"] = state + wandb.run.summary["state"] = state # type: ignore wandb.finish(quiet=True) diff --git a/rl_algo_impls/ppo/ppo.py b/rl_algo_impls/ppo/ppo.py index cfa5975c52b725b3b5cc8046e6dac4a17ac844af..686300b35e4590ef6b634a8afefbf60664dc2d60 100644 --- a/rl_algo_impls/ppo/ppo.py +++ b/rl_algo_impls/ppo/ppo.py @@ -10,12 +10,16 @@ from torch.optim import Adam from torch.utils.tensorboard.writer import SummaryWriter from rl_algo_impls.shared.algorithm import Algorithm -from rl_algo_impls.shared.callbacks.callback import Callback +from rl_algo_impls.shared.callbacks import Callback from rl_algo_impls.shared.gae import compute_advantages -from rl_algo_impls.shared.policy.on_policy import ActorCritic -from rl_algo_impls.shared.schedule import schedule, update_learning_rate +from rl_algo_impls.shared.policy.actor_critic import ActorCritic +from rl_algo_impls.shared.schedule import ( + constant_schedule, + linear_schedule, + schedule, + update_learning_rate, +) from rl_algo_impls.shared.stats import log_scalars -from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker from rl_algo_impls.wrappers.vectorable_wrapper import ( VecEnv, single_action_space, @@ -102,12 +106,17 @@ class PPO(Algorithm): sde_sample_freq: int = -1, update_advantage_between_epochs: bool = True, update_returns_between_epochs: bool = False, + gamma_end: Optional[float] = None, ) -> None: super().__init__(policy, env, device, tb_writer) self.policy = policy - self.action_masker = find_action_masker(env) + self.get_action_mask = getattr(env, "get_action_mask", None) - self.gamma = gamma + self.gamma_schedule = ( + linear_schedule(gamma, gamma_end) + if gamma_end is not None + else constant_schedule(gamma) + ) self.gae_lambda = gae_lambda self.optimizer = Adam(self.policy.parameters(), lr=learning_rate, eps=1e-7) self.lr_schedule = schedule(learning_rate_decay, learning_rate) @@ -138,7 +147,7 @@ class PPO(Algorithm): def learn( self: PPOSelf, train_timesteps: int, - callback: Optional[Callback] = None, + callbacks: Optional[List[Callback]] = None, total_timesteps: Optional[int] = None, start_timesteps: int = 0, ) -> PPOSelf: @@ -153,15 +162,13 @@ class PPO(Algorithm): act_shape = self.policy.action_shape next_obs = self.env.reset() - next_action_masks = ( - self.action_masker.action_masks() if self.action_masker else None - ) - next_episode_starts = np.full(step_dim, True, dtype=np.bool8) + next_action_masks = self.get_action_mask() if self.get_action_mask else None + next_episode_starts = np.full(step_dim, True, dtype=np.bool_) obs = np.zeros(epoch_dim + obs_space.shape, dtype=obs_space.dtype) # type: ignore actions = np.zeros(epoch_dim + act_shape, dtype=act_space.dtype) # type: ignore rewards = np.zeros(epoch_dim, dtype=np.float32) - episode_starts = np.zeros(epoch_dim, dtype=np.bool8) + episode_starts = np.zeros(epoch_dim, dtype=np.bool_) values = np.zeros(epoch_dim, dtype=np.float32) logprobs = np.zeros(epoch_dim, dtype=np.float32) action_masks = ( @@ -181,10 +188,12 @@ class PPO(Algorithm): learning_rate = self.lr_schedule(progress) update_learning_rate(self.optimizer, learning_rate) pi_clip = self.clip_range_schedule(progress) + gamma = self.gamma_schedule(progress) chart_scalars = { "learning_rate": self.optimizer.param_groups[0]["lr"], "ent_coef": ent_coef, "pi_clip": pi_clip, + "gamma": gamma, } if self.clip_range_vf_schedule: v_clip = self.clip_range_vf_schedule(progress) @@ -215,7 +224,7 @@ class PPO(Algorithm): clamped_action ) next_action_masks = ( - self.action_masker.action_masks() if self.action_masker else None + self.get_action_mask() if self.get_action_mask else None ) self.policy.train() @@ -251,7 +260,7 @@ class PPO(Algorithm): next_episode_starts, next_obs, self.policy, - self.gamma, + gamma, self.gae_lambda, ) b_advantages = torch.tensor(advantages.reshape(-1)).to(self.device) @@ -364,8 +373,10 @@ class PPO(Algorithm): timesteps_elapsed, ) - if callback: - if not callback.on_step(timesteps_elapsed=rollout_steps): + if callbacks: + if not all( + c.on_step(timesteps_elapsed=rollout_steps) for c in callbacks + ): logging.info( f"Callback terminated training at {timesteps_elapsed} timesteps" ) diff --git a/rl_algo_impls/runner/config.py b/rl_algo_impls/runner/config.py index d92758eaae6929e319038444bf10846ba001ad95..53b60d5b79e8d52daaa1cb82879fea1de5188630 100644 --- a/rl_algo_impls/runner/config.py +++ b/rl_algo_impls/runner/config.py @@ -51,6 +51,8 @@ class EnvHyperparams: normalize_type: Optional[str] = None mask_actions: bool = False bots: Optional[Dict[str, int]] = None + self_play_kwargs: Optional[Dict[str, Any]] = None + selfplay_bots: Optional[Dict[str, int]] = None HyperparamsSelf = TypeVar("HyperparamsSelf", bound="Hyperparams") @@ -63,9 +65,10 @@ class Hyperparams: env_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict) policy_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict) algo_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict) - eval_params: Dict[str, Any] = dataclasses.field(default_factory=dict) + eval_hyperparams: Dict[str, Any] = dataclasses.field(default_factory=dict) env_id: Optional[str] = None additional_keys_to_log: List[str] = dataclasses.field(default_factory=list) + microrts_reward_decay_callback: bool = False @classmethod def from_dict_with_extra_fields( @@ -110,8 +113,14 @@ class Config: return self.hyperparams.algo_hyperparams @property - def eval_params(self) -> Dict[str, Any]: - return self.hyperparams.eval_params + def eval_hyperparams(self) -> Dict[str, Any]: + return self.hyperparams.eval_hyperparams + + def eval_callback_params(self) -> Dict[str, Any]: + eval_hyperparams = self.eval_hyperparams.copy() + if "env_overrides" in eval_hyperparams: + del eval_hyperparams["env_overrides"] + return eval_hyperparams @property def algo(self) -> str: diff --git a/rl_algo_impls/runner/evaluate.py b/rl_algo_impls/runner/evaluate.py index 41eb34bb183bc451be44a14ce770a32be196d51e..70e5d4207ea051147aa1017597b8fc1abc903715 100644 --- a/rl_algo_impls/runner/evaluate.py +++ b/rl_algo_impls/runner/evaluate.py @@ -1,20 +1,19 @@ import os import shutil - from dataclasses import dataclass from typing import NamedTuple, Optional -from rl_algo_impls.shared.vec_env import make_eval_env from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs from rl_algo_impls.runner.running_utils import ( - load_hyperparams, - set_seeds, get_device, + load_hyperparams, make_policy, + set_seeds, ) from rl_algo_impls.shared.callbacks.eval_callback import evaluate from rl_algo_impls.shared.policy.policy import Policy from rl_algo_impls.shared.stats import EpisodesStats +from rl_algo_impls.shared.vec_env import make_eval_env @dataclass @@ -71,7 +70,7 @@ def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation: env = make_eval_env( config, EnvHyperparams(**config.env_hyperparams), - override_n_envs=args.n_envs, + override_hparams={"n_envs": args.n_envs} if args.n_envs else None, render=args.render, normalize_load_path=model_path, ) @@ -87,7 +86,7 @@ def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation: deterministic = ( args.deterministic_eval if args.deterministic_eval is not None - else config.eval_params.get("deterministic", True) + else config.eval_hyperparams.get("deterministic", True) ) return Evaluation( policy, diff --git a/rl_algo_impls/runner/running_utils.py b/rl_algo_impls/runner/running_utils.py index 9a872448708c6435c223589b0ad94f2ba35f8c29..6509ab474e1b2c0a9ad34a872f03967d3ef40f79 100644 --- a/rl_algo_impls/runner/running_utils.py +++ b/rl_algo_impls/runner/running_utils.py @@ -22,7 +22,7 @@ from rl_algo_impls.ppo.ppo import PPO from rl_algo_impls.runner.config import Config, Hyperparams from rl_algo_impls.shared.algorithm import Algorithm from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback -from rl_algo_impls.shared.policy.on_policy import ActorCritic +from rl_algo_impls.shared.policy.actor_critic import ActorCritic from rl_algo_impls.shared.policy.policy import Policy from rl_algo_impls.shared.vec_env.utils import import_for_env_id, is_microrts from rl_algo_impls.vpg.policy import VPGActorCritic @@ -97,29 +97,21 @@ def get_device(config: Config, env: VecEnv) -> torch.device: # cuda by default if device == "auto": device = "cuda" - # Apple MPS is a second choice (sometimes) - if device == "cuda" and not torch.cuda.is_available(): - device = "mps" - # If no MPS, fallback to cpu - if device == "mps" and not torch.backends.mps.is_available(): - device = "cpu" - # Simple environments like Discreet and 1-D Boxes might also be better - # served with the CPU. - if device == "mps": - obs_space = single_observation_space(env) - if isinstance(obs_space, Discrete): + # Apple MPS is a second choice (sometimes) + if device == "cuda" and not torch.cuda.is_available(): + device = "mps" + # If no MPS, fallback to cpu + if device == "mps" and not torch.backends.mps.is_available(): device = "cpu" - elif isinstance(obs_space, Box) and len(obs_space.shape) == 1: - device = "cpu" - if is_microrts(config): - try: - from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv - - # Models that move more than one unit at a time should use mps - if not isinstance(env.unwrapped, MicroRTSGridModeVecEnv): - device = "cpu" - except ModuleNotFoundError: - # Likely on gym_microrts v0.0.2 to match ppo-implementation-details + # Simple environments like Discreet and 1-D Boxes might also be better + # served with the CPU. + if device == "mps": + obs_space = single_observation_space(env) + if isinstance(obs_space, Discrete): + device = "cpu" + elif isinstance(obs_space, Box) and len(obs_space.shape) == 1: + device = "cpu" + if is_microrts(config): device = "cpu" print(f"Device: {device}") return torch.device(device) diff --git a/rl_algo_impls/runner/selfplay_evaluate.py b/rl_algo_impls/runner/selfplay_evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..5aa43f680e233e41345c265c4fb8ad9d60407751 --- /dev/null +++ b/rl_algo_impls/runner/selfplay_evaluate.py @@ -0,0 +1,142 @@ +import copy +import dataclasses +import os +import shutil +from dataclasses import dataclass +from typing import List, NamedTuple, Optional + +import numpy as np + +import wandb +from rl_algo_impls.runner.config import Config, EnvHyperparams, Hyperparams, RunArgs +from rl_algo_impls.runner.evaluate import Evaluation +from rl_algo_impls.runner.running_utils import ( + get_device, + load_hyperparams, + make_policy, + set_seeds, +) +from rl_algo_impls.shared.callbacks.eval_callback import evaluate +from rl_algo_impls.shared.vec_env import make_eval_env +from rl_algo_impls.wrappers.vec_episode_recorder import VecEpisodeRecorder + + +@dataclass +class SelfplayEvalArgs(RunArgs): + # Either wandb_run_paths or model_file_paths must have 2 elements in it. + wandb_run_paths: List[str] = dataclasses.field(default_factory=list) + model_file_paths: List[str] = dataclasses.field(default_factory=list) + render: bool = False + best: bool = True + n_envs: int = 1 + n_episodes: int = 1 + deterministic_eval: Optional[bool] = None + no_print_returns: bool = False + video_path: Optional[str] = None + + +def selfplay_evaluate(args: SelfplayEvalArgs, root_dir: str) -> Evaluation: + if args.wandb_run_paths: + api = wandb.Api() + args, config, player_1_model_path = load_player( + api, args.wandb_run_paths[0], args, root_dir + ) + _, _, player_2_model_path = load_player( + api, args.wandb_run_paths[1], args, root_dir + ) + elif args.model_file_paths: + hyperparams = load_hyperparams(args.algo, args.env) + + config = Config(args, hyperparams, root_dir) + player_1_model_path, player_2_model_path = args.model_file_paths + else: + raise ValueError("Must specify 2 wandb_run_paths or 2 model_file_paths") + + print(args) + + set_seeds(args.seed, args.use_deterministic_algorithms) + + env_make_kwargs = ( + config.eval_hyperparams.get("env_overrides", {}).get("make_kwargs", {}).copy() + ) + env_make_kwargs["num_selfplay_envs"] = args.n_envs * 2 + env = make_eval_env( + config, + EnvHyperparams(**config.env_hyperparams), + override_hparams={ + "n_envs": args.n_envs, + "selfplay_bots": { + player_2_model_path: args.n_envs, + }, + "self_play_kwargs": { + "num_old_policies": 0, + "save_steps": np.inf, + "swap_steps": np.inf, + "bot_always_player_2": True, + }, + "bots": None, + "make_kwargs": env_make_kwargs, + }, + render=args.render, + normalize_load_path=player_1_model_path, + ) + if args.video_path: + env = VecEpisodeRecorder( + env, args.video_path, max_video_length=18000, num_episodes=args.n_episodes + ) + device = get_device(config, env) + policy = make_policy( + args.algo, + env, + device, + load_path=player_1_model_path, + **config.policy_hyperparams, + ).eval() + + deterministic = ( + args.deterministic_eval + if args.deterministic_eval is not None + else config.eval_hyperparams.get("deterministic", True) + ) + return Evaluation( + policy, + evaluate( + env, + policy, + args.n_episodes, + render=args.render, + deterministic=deterministic, + print_returns=not args.no_print_returns, + ), + config, + ) + + +class PlayerData(NamedTuple): + args: SelfplayEvalArgs + config: Config + model_path: str + + +def load_player( + api: wandb.Api, run_path: str, args: SelfplayEvalArgs, root_dir: str +) -> PlayerData: + args = copy.copy(args) + + run = api.run(run_path) + params = run.config + args.algo = params["algo"] + args.env = params["env"] + args.seed = params.get("seed", None) + args.use_deterministic_algorithms = params.get("use_deterministic_algorithms", True) + config = Config(args, Hyperparams.from_dict_with_extra_fields(params), root_dir) + model_path = config.model_dir_path(best=args.best, downloaded=True) + + model_archive_name = config.model_dir_name(best=args.best, extension=".zip") + run.file(model_archive_name).download() + if os.path.isdir(model_path): + shutil.rmtree(model_path) + shutil.unpack_archive(model_archive_name, model_path) + os.remove(model_archive_name) + + return PlayerData(args, config, model_path) diff --git a/rl_algo_impls/runner/train.py b/rl_algo_impls/runner/train.py index eb7b94bc24c18855656f1caeb723565d7a84764e..15408596ef14fb582592422219979b68ff768d17 100644 --- a/rl_algo_impls/runner/train.py +++ b/rl_algo_impls/runner/train.py @@ -1,12 +1,17 @@ # Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html) import os +from rl_algo_impls.shared.callbacks import Callback +from rl_algo_impls.shared.callbacks.self_play_callback import SelfPlayCallback +from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper +from rl_algo_impls.wrappers.vectorable_wrapper import find_wrapper + os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" import dataclasses import shutil from dataclasses import asdict, dataclass -from typing import Any, Dict, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence import yaml from torch.utils.tensorboard.writer import SummaryWriter @@ -23,6 +28,9 @@ from rl_algo_impls.runner.running_utils import ( set_seeds, ) from rl_algo_impls.shared.callbacks.eval_callback import EvalCallback +from rl_algo_impls.shared.callbacks.microrts_reward_decay_callback import ( + MicrortsRewardDecayCallback, +) from rl_algo_impls.shared.stats import EpisodesStats from rl_algo_impls.shared.vec_env import make_env, make_eval_env @@ -41,7 +49,7 @@ def train(args: TrainArgs): print(hyperparams) config = Config(args, hyperparams, os.getcwd()) - wandb_enabled = args.wandb_project_name + wandb_enabled = bool(args.wandb_project_name) if wandb_enabled: wandb.tensorboard.patch( root_logdir=config.tensorboard_summary_path, pytorch=True @@ -66,14 +74,17 @@ def train(args: TrainArgs): config, EnvHyperparams(**config.env_hyperparams), tb_writer=tb_writer ) device = get_device(config, env) - policy = make_policy(args.algo, env, device, **config.policy_hyperparams) + policy_factory = lambda: make_policy( + args.algo, env, device, **config.policy_hyperparams + ) + policy = policy_factory() algo = ALGOS[args.algo](policy, env, device, tb_writer, **config.algo_hyperparams) num_parameters = policy.num_parameters() num_trainable_parameters = policy.num_trainable_parameters() if wandb_enabled: - wandb.run.summary["num_parameters"] = num_parameters - wandb.run.summary["num_trainable_parameters"] = num_trainable_parameters + wandb.run.summary["num_parameters"] = num_parameters # type: ignore + wandb.run.summary["num_trainable_parameters"] = num_trainable_parameters # type: ignore else: print( f"num_parameters = {num_parameters} ; " @@ -81,40 +92,49 @@ def train(args: TrainArgs): ) eval_env = make_eval_env(config, EnvHyperparams(**config.env_hyperparams)) - record_best_videos = config.eval_params.get("record_best_videos", True) - callback = EvalCallback( + record_best_videos = config.eval_hyperparams.get("record_best_videos", True) + eval_callback = EvalCallback( policy, eval_env, tb_writer, best_model_path=config.model_dir_path(best=True), - **config.eval_params, + **config.eval_callback_params(), video_env=make_eval_env( - config, EnvHyperparams(**config.env_hyperparams), override_n_envs=1 + config, + EnvHyperparams(**config.env_hyperparams), + override_hparams={"n_envs": 1}, ) if record_best_videos else None, best_video_dir=config.best_videos_dir, additional_keys_to_log=config.additional_keys_to_log, + wandb_enabled=wandb_enabled, ) - algo.learn(config.n_timesteps, callback=callback) + callbacks: List[Callback] = [eval_callback] + if config.hyperparams.microrts_reward_decay_callback: + callbacks.append(MicrortsRewardDecayCallback(config, env)) + selfPlayWrapper = find_wrapper(env, SelfPlayWrapper) + if selfPlayWrapper: + callbacks.append(SelfPlayCallback(policy, policy_factory, selfPlayWrapper)) + algo.learn(config.n_timesteps, callbacks=callbacks) policy.save(config.model_dir_path(best=False)) - eval_stats = callback.evaluate(n_episodes=10, print_returns=True) + eval_stats = eval_callback.evaluate(n_episodes=10, print_returns=True) - plot_eval_callback(callback, tb_writer, config.run_name()) + plot_eval_callback(eval_callback, tb_writer, config.run_name()) log_dict: Dict[str, Any] = { "eval": eval_stats._asdict(), } - if callback.best: - log_dict["best_eval"] = callback.best._asdict() + if eval_callback.best: + log_dict["best_eval"] = eval_callback.best._asdict() log_dict.update(asdict(hyperparams)) log_dict.update(vars(args)) with open(config.logs_path, "a") as f: yaml.dump({config.run_name(): log_dict}, f) - best_eval_stats: EpisodesStats = callback.best # type: ignore + best_eval_stats: EpisodesStats = eval_callback.best # type: ignore tb_writer.add_hparams( hparam_dict(hyperparams, vars(args)), { @@ -132,13 +152,8 @@ def train(args: TrainArgs): if wandb_enabled: shutil.make_archive( - os.path.join(wandb.run.dir, config.model_dir_name()), + os.path.join(wandb.run.dir, config.model_dir_name()), # type: ignore "zip", config.model_dir_path(), ) - shutil.make_archive( - os.path.join(wandb.run.dir, config.model_dir_name(best=True)), - "zip", - config.model_dir_path(best=True), - ) wandb.finish() diff --git a/rl_algo_impls/selfplay_enjoy.py b/rl_algo_impls/selfplay_enjoy.py new file mode 100644 index 0000000000000000000000000000000000000000..5e8dc4c006861dae5c359075a01f07e4074a77f9 --- /dev/null +++ b/rl_algo_impls/selfplay_enjoy.py @@ -0,0 +1,53 @@ +# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html) +import os + +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" + +from rl_algo_impls.runner.running_utils import base_parser +from rl_algo_impls.runner.selfplay_evaluate import SelfplayEvalArgs, selfplay_evaluate + + +def selfplay_enjoy() -> None: + parser = base_parser(multiple=False) + parser.add_argument( + "--wandb-run-paths", + type=str, + nargs="*", + help="WandB run paths to load players from. Must be 0 or 2", + ) + parser.add_argument( + "--model-file-paths", + type=str, + help="File paths to load players from. Must be 0 or 2", + ) + parser.add_argument("--render", action="store_true") + parser.add_argument("--n-envs", default=1, type=int) + parser.add_argument("--n-episodes", default=1, type=int) + parser.add_argument("--deterministic-eval", default=None, type=bool) + parser.add_argument( + "--no-print-returns", action="store_true", help="Limit printing" + ) + parser.add_argument( + "--video-path", type=str, help="Path to save video of all plays" + ) + # parser.set_defaults( + # algo=["ppo"], + # env=["Microrts-selfplay-unet-decay"], + # n_episodes=10, + # model_file_paths=[ + # "downloaded_models/ppo-Microrts-selfplay-unet-decay-S3-best", + # "downloaded_models/ppo-Microrts-selfplay-unet-decay-S2-best", + # ], + # video_path="/Users/sgoodfriend/Desktop/decay3-vs-decay2", + # ) + args = parser.parse_args() + args.algo = args.algo[0] + args.env = args.env[0] + args.seed = args.seed[0] + args = SelfplayEvalArgs(**vars(args)) + + selfplay_evaluate(args, os.getcwd()) + + +if __name__ == "__main__": + selfplay_enjoy() diff --git a/rl_algo_impls/shared/actor/__init__.py b/rl_algo_impls/shared/actor/__init__.py index f4b59b46ff97bdca05c14145e79b8e361c47eeec..bc4bc09fc5756f4d16baaa85f2cb3fcef94cae83 100644 --- a/rl_algo_impls/shared/actor/__init__.py +++ b/rl_algo_impls/shared/actor/__init__.py @@ -1,2 +1,2 @@ -from rl_algo_impls.shared.actor.actor import Actor, PiForward +from rl_algo_impls.shared.actor.actor import Actor, PiForward, pi_forward from rl_algo_impls.shared.actor.make_actor import actor_head diff --git a/rl_algo_impls/shared/actor/actor.py b/rl_algo_impls/shared/actor/actor.py index 2da077a0175a080dcb85af67029bae57d1553393..9e800c05a79587f0519c088b212decacdddc6694 100644 --- a/rl_algo_impls/shared/actor/actor.py +++ b/rl_algo_impls/shared/actor/actor.py @@ -31,12 +31,13 @@ class Actor(nn.Module, ABC): def action_shape(self) -> Tuple[int, ...]: ... - def pi_forward( - self, distribution: Distribution, actions: Optional[torch.Tensor] = None - ) -> PiForward: - logp_a = None - entropy = None - if actions is not None: - logp_a = distribution.log_prob(actions) - entropy = distribution.entropy() - return PiForward(distribution, logp_a, entropy) + +def pi_forward( + distribution: Distribution, actions: Optional[torch.Tensor] = None +) -> PiForward: + logp_a = None + entropy = None + if actions is not None: + logp_a = distribution.log_prob(actions) + entropy = distribution.entropy() + return PiForward(distribution, logp_a, entropy) diff --git a/rl_algo_impls/shared/actor/categorical.py b/rl_algo_impls/shared/actor/categorical.py index 6392ead5472148894f550b9c1208bb3670551db1..a91aeed2d020053a4e7c618df9d31c0fa82cd7ce 100644 --- a/rl_algo_impls/shared/actor/categorical.py +++ b/rl_algo_impls/shared/actor/categorical.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn from torch.distributions import Categorical -from rl_algo_impls.shared.actor import Actor, PiForward -from rl_algo_impls.shared.module.module import mlp +from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward +from rl_algo_impls.shared.module.utils import mlp class MaskedCategorical(Categorical): @@ -57,7 +57,7 @@ class CategoricalActorHead(Actor): ) -> PiForward: logits = self._fc(obs) pi = MaskedCategorical(logits=logits, mask=action_masks) - return self.pi_forward(pi, actions) + return pi_forward(pi, actions) @property def action_shape(self) -> Tuple[int, ...]: diff --git a/rl_algo_impls/shared/actor/gaussian.py b/rl_algo_impls/shared/actor/gaussian.py index 3867477ed7009442eb8d465c51b3420d97c99342..14cbc4739ed75510e34cbd26508ce2e937714214 100644 --- a/rl_algo_impls/shared/actor/gaussian.py +++ b/rl_algo_impls/shared/actor/gaussian.py @@ -4,8 +4,8 @@ import torch import torch.nn as nn from torch.distributions import Distribution, Normal -from rl_algo_impls.shared.actor.actor import Actor, PiForward -from rl_algo_impls.shared.module.module import mlp +from rl_algo_impls.shared.actor.actor import Actor, PiForward, pi_forward +from rl_algo_impls.shared.module.utils import mlp class GaussianDistribution(Normal): @@ -54,7 +54,7 @@ class GaussianActorHead(Actor): not action_masks ), f"{self.__class__.__name__} does not support action_masks" pi = self._distribution(obs) - return self.pi_forward(pi, actions) + return pi_forward(pi, actions) @property def action_shape(self) -> Tuple[int, ...]: diff --git a/rl_algo_impls/shared/actor/gridnet.py b/rl_algo_impls/shared/actor/gridnet.py index a6746428ccd9be1156de6f613b7ee365e9e01cfd..fd0d5a6db760f35ae67c1fff78e159f46dd64092 100644 --- a/rl_algo_impls/shared/actor/gridnet.py +++ b/rl_algo_impls/shared/actor/gridnet.py @@ -6,10 +6,10 @@ import torch.nn as nn from numpy.typing import NDArray from torch.distributions import Distribution, constraints -from rl_algo_impls.shared.actor import Actor, PiForward +from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward from rl_algo_impls.shared.actor.categorical import MaskedCategorical from rl_algo_impls.shared.encoder import EncoderOutDim -from rl_algo_impls.shared.module.module import mlp +from rl_algo_impls.shared.module.utils import mlp class GridnetDistribution(Distribution): @@ -25,7 +25,7 @@ class GridnetDistribution(Distribution): self.action_vec = action_vec masks = masks.view(-1, masks.shape[-1]) - split_masks = torch.split(masks[:, 1:], action_vec.tolist(), dim=1) + split_masks = torch.split(masks, action_vec.tolist(), dim=1) grid_logits = logits.reshape(-1, action_vec.sum()) split_logits = torch.split(grid_logits, action_vec.tolist(), dim=1) @@ -101,7 +101,7 @@ class GridnetActorHead(Actor): ), f"No mask case unhandled in {self.__class__.__name__}" logits = self._fc(obs) pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks) - return self.pi_forward(pi, actions) + return pi_forward(pi, actions) @property def action_shape(self) -> Tuple[int, ...]: diff --git a/rl_algo_impls/shared/actor/gridnet_decoder.py b/rl_algo_impls/shared/actor/gridnet_decoder.py index 21a83e92a84737ad10b4fd6d20fc3ea5d8f5edb7..efb823a6bc69ccde17700d9d74f9e6855756d51d 100644 --- a/rl_algo_impls/shared/actor/gridnet_decoder.py +++ b/rl_algo_impls/shared/actor/gridnet_decoder.py @@ -5,11 +5,10 @@ import torch import torch.nn as nn from numpy.typing import NDArray -from rl_algo_impls.shared.actor import Actor, PiForward -from rl_algo_impls.shared.actor.categorical import MaskedCategorical +from rl_algo_impls.shared.actor import Actor, PiForward, pi_forward from rl_algo_impls.shared.actor.gridnet import GridnetDistribution from rl_algo_impls.shared.encoder import EncoderOutDim -from rl_algo_impls.shared.module.module import layer_init +from rl_algo_impls.shared.module.utils import layer_init class Transpose(nn.Module): @@ -73,7 +72,7 @@ class GridnetDecoder(Actor): ), f"No mask case unhandled in {self.__class__.__name__}" logits = self.deconv(obs) pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks) - return self.pi_forward(pi, actions) + return pi_forward(pi, actions) @property def action_shape(self) -> Tuple[int, ...]: diff --git a/rl_algo_impls/shared/actor/make_actor.py b/rl_algo_impls/shared/actor/make_actor.py index 831e1f800dd46cb2db056bfb686aa10820c82666..ced98f4afd2756d1825aced5d169c94d9402e837 100644 --- a/rl_algo_impls/shared/actor/make_actor.py +++ b/rl_algo_impls/shared/actor/make_actor.py @@ -1,4 +1,4 @@ -from typing import Tuple, Type +from typing import Optional, Tuple, Type import gym import torch.nn as nn @@ -27,6 +27,7 @@ def actor_head( full_std: bool = True, squash_output: bool = False, actor_head_style: str = "single", + action_plane_space: Optional[bool] = None, ) -> Actor: assert not use_sde or isinstance( action_space, Box @@ -73,18 +74,20 @@ def actor_head( init_layers_orthogonal=init_layers_orthogonal, ) elif actor_head_style == "gridnet": + assert isinstance(action_plane_space, MultiDiscrete) return GridnetActorHead( - action_space.nvec[0], # type: ignore - action_space.nvec[1:], # type: ignore + len(action_space.nvec) // len(action_plane_space.nvec), # type: ignore + action_plane_space.nvec, # type: ignore in_dim=in_dim, hidden_sizes=hidden_sizes, activation=activation, init_layers_orthogonal=init_layers_orthogonal, ) elif actor_head_style == "gridnet_decoder": + assert isinstance(action_plane_space, MultiDiscrete) return GridnetDecoder( - action_space.nvec[0], # type: ignore - action_space.nvec[1:], # type: ignore + len(action_space.nvec) // len(action_plane_space.nvec), # type: ignore + action_plane_space.nvec, # type: ignore in_dim=in_dim, activation=activation, init_layers_orthogonal=init_layers_orthogonal, diff --git a/rl_algo_impls/shared/actor/multi_discrete.py b/rl_algo_impls/shared/actor/multi_discrete.py index 26a60d6c90f2e0ac244f57432ec426493ebeefdf..9fcb0ffbc1d540866de0be780ee7336e8938bbd0 100644 --- a/rl_algo_impls/shared/actor/multi_discrete.py +++ b/rl_algo_impls/shared/actor/multi_discrete.py @@ -6,10 +6,10 @@ import torch.nn as nn from numpy.typing import NDArray from torch.distributions import Distribution, constraints -from rl_algo_impls.shared.actor.actor import Actor, PiForward +from rl_algo_impls.shared.actor.actor import Actor, PiForward, pi_forward from rl_algo_impls.shared.actor.categorical import MaskedCategorical from rl_algo_impls.shared.encoder import EncoderOutDim -from rl_algo_impls.shared.module.module import mlp +from rl_algo_impls.shared.module.utils import mlp class MultiCategorical(Distribution): @@ -94,7 +94,7 @@ class MultiDiscreteActorHead(Actor): ) -> PiForward: logits = self._fc(obs) pi = MultiCategorical(self.nvec, logits=logits, masks=action_masks) - return self.pi_forward(pi, actions) + return pi_forward(pi, actions) @property def action_shape(self) -> Tuple[int, ...]: diff --git a/rl_algo_impls/shared/actor/state_dependent_noise.py b/rl_algo_impls/shared/actor/state_dependent_noise.py index 333c2549d511537e02edb655f74f912cf054b6b2..2d8ee190f8acc949b73f8cc53181c5ec8967f484 100644 --- a/rl_algo_impls/shared/actor/state_dependent_noise.py +++ b/rl_algo_impls/shared/actor/state_dependent_noise.py @@ -5,7 +5,7 @@ import torch.nn as nn from torch.distributions import Distribution, Normal from rl_algo_impls.shared.actor.actor import Actor, PiForward -from rl_algo_impls.shared.module.module import mlp +from rl_algo_impls.shared.module.utils import mlp class TanhBijector: @@ -172,7 +172,7 @@ class StateDependentNoiseActorHead(Actor): not action_masks ), f"{self.__class__.__name__} does not support action_masks" pi = self._distribution(obs) - return self.pi_forward(pi, actions) + return pi_forward(pi, actions, self.bijector) def sample_weights(self, batch_size: int = 1) -> None: std = self._get_std() @@ -185,16 +185,15 @@ class StateDependentNoiseActorHead(Actor): def action_shape(self) -> Tuple[int, ...]: return (self.act_dim,) - def pi_forward( - self, distribution: Distribution, actions: Optional[torch.Tensor] = None - ) -> PiForward: - logp_a = None - entropy = None - if actions is not None: - logp_a = distribution.log_prob(actions) - entropy = ( - -logp_a - if self.bijector - else sum_independent_dims(distribution.entropy()) - ) - return PiForward(distribution, logp_a, entropy) + +def pi_forward( + distribution: Distribution, + actions: Optional[torch.Tensor] = None, + bijector: Optional[TanhBijector] = None, +) -> PiForward: + logp_a = None + entropy = None + if actions is not None: + logp_a = distribution.log_prob(actions) + entropy = -logp_a if bijector else sum_independent_dims(distribution.entropy()) + return PiForward(distribution, logp_a, entropy) diff --git a/rl_algo_impls/shared/algorithm.py b/rl_algo_impls/shared/algorithm.py index f70160aaaeb6a0fb92aaaef473fd0b665999d2f9..8ced97d5689d2729d9e41f0f3035a9e2455d7315 100644 --- a/rl_algo_impls/shared/algorithm.py +++ b/rl_algo_impls/shared/algorithm.py @@ -1,11 +1,11 @@ +from abc import ABC, abstractmethod +from typing import List, Optional, TypeVar + import gym import torch - -from abc import ABC, abstractmethod from torch.utils.tensorboard.writer import SummaryWriter -from typing import Optional, TypeVar -from rl_algo_impls.shared.callbacks.callback import Callback +from rl_algo_impls.shared.callbacks import Callback from rl_algo_impls.shared.policy.policy import Policy from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv @@ -32,7 +32,7 @@ class Algorithm(ABC): def learn( self: AlgorithmSelf, train_timesteps: int, - callback: Optional[Callback] = None, + callbacks: Optional[List[Callback]] = None, total_timesteps: Optional[int] = None, start_timesteps: int = 0, ) -> AlgorithmSelf: diff --git a/rl_algo_impls/shared/callbacks/__init__.py b/rl_algo_impls/shared/callbacks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..45e23640fc0ab424598d0fdca84330952f690942 --- /dev/null +++ b/rl_algo_impls/shared/callbacks/__init__.py @@ -0,0 +1 @@ +from rl_algo_impls.shared.callbacks.callback import Callback diff --git a/rl_algo_impls/shared/callbacks/eval_callback.py b/rl_algo_impls/shared/callbacks/eval_callback.py index 04b8ee1b24667c3767e57820ac0dfd9d4b01b1ce..b662aba09d159443a574c5fd14e827b1e3f6bf1c 100644 --- a/rl_algo_impls/shared/callbacks/eval_callback.py +++ b/rl_algo_impls/shared/callbacks/eval_callback.py @@ -1,12 +1,13 @@ import itertools import os +import shutil from time import perf_counter from typing import Dict, List, Optional, Union import numpy as np from torch.utils.tensorboard.writer import SummaryWriter -from rl_algo_impls.shared.callbacks.callback import Callback +from rl_algo_impls.shared.callbacks import Callback from rl_algo_impls.shared.policy.policy import Policy from rl_algo_impls.shared.stats import Episode, EpisodeAccumulator, EpisodesStats from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker @@ -80,6 +81,7 @@ def evaluate( print_returns: bool = True, ignore_first_episode: bool = False, additional_keys_to_log: Optional[List[str]] = None, + score_function: str = "mean-std", ) -> EpisodesStats: policy.sync_normalization(env) policy.eval() @@ -93,18 +95,21 @@ def evaluate( ) obs = env.reset() - action_masker = find_action_masker(env) + get_action_mask = getattr(env, "get_action_mask", None) while not episodes.is_done(): act = policy.act( obs, deterministic=deterministic, - action_masks=action_masker.action_masks() if action_masker else None, + action_masks=get_action_mask() if get_action_mask else None, ) obs, rew, done, info = env.step(act) episodes.step(rew, done, info) if render: env.render() - stats = EpisodesStats(episodes.episodes) + stats = EpisodesStats( + episodes.episodes, + score_function=score_function, + ) if print_returns: print(stats) return stats @@ -127,6 +132,8 @@ class EvalCallback(Callback): max_video_length: int = 3600, ignore_first_episode: bool = False, additional_keys_to_log: Optional[List[str]] = None, + score_function: str = "mean-std", + wandb_enabled: bool = False, ) -> None: super().__init__() self.policy = policy @@ -151,6 +158,8 @@ class EvalCallback(Callback): self.best_video_base_path = None self.ignore_first_episode = ignore_first_episode self.additional_keys_to_log = additional_keys_to_log + self.score_function = score_function + self.wandb_enabled = wandb_enabled def on_step(self, timesteps_elapsed: int = 1) -> bool: super().on_step(timesteps_elapsed) @@ -170,6 +179,7 @@ class EvalCallback(Callback): print_returns=print_returns or False, ignore_first_episode=self.ignore_first_episode, additional_keys_to_log=self.additional_keys_to_log, + score_function=self.score_function, ) end_time = perf_counter() self.tb_writer.add_scalar( @@ -189,6 +199,15 @@ class EvalCallback(Callback): assert self.best_model_path self.policy.save(self.best_model_path) print("Saved best model") + if self.wandb_enabled: + import wandb + + best_model_name = os.path.split(self.best_model_path)[-1] + shutil.make_archive( + os.path.join(wandb.run.dir, best_model_name), # type: ignore + "zip", + self.best_model_path, + ) self.best.write_to_tensorboard( self.tb_writer, "best_eval", self.timesteps_elapsed ) @@ -208,6 +227,7 @@ class EvalCallback(Callback): 1, deterministic=self.deterministic, print_returns=False, + score_function=self.score_function, ) print(f"Saved best video: {video_stats}") diff --git a/rl_algo_impls/shared/callbacks/microrts_reward_decay_callback.py b/rl_algo_impls/shared/callbacks/microrts_reward_decay_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..2d2f0446cd23309c2b9526f70ada5dd6d6e106e4 --- /dev/null +++ b/rl_algo_impls/shared/callbacks/microrts_reward_decay_callback.py @@ -0,0 +1,36 @@ +import numpy as np + +from rl_algo_impls.runner.config import Config +from rl_algo_impls.shared.callbacks import Callback +from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv + + +class MicrortsRewardDecayCallback(Callback): + def __init__( + self, + config: Config, + env: VecEnv, + start_timesteps: int = 0, + ) -> None: + super().__init__() + from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv + + unwrapped = env.unwrapped + assert isinstance(unwrapped, MicroRTSGridModeVecEnv) + self.microrts_env = unwrapped + self.base_reward_weights = self.microrts_env.reward_weight + + self.total_train_timesteps = config.n_timesteps + self.timesteps_elapsed = start_timesteps + + def on_step(self, timesteps_elapsed: int = 1) -> bool: + super().on_step(timesteps_elapsed) + + progress = self.timesteps_elapsed / self.total_train_timesteps + # Decay all rewards except WinLoss + reward_weights = self.base_reward_weights * np.array( + [1] + [1 - progress] * (len(self.base_reward_weights) - 1) + ) + self.microrts_env.reward_weight = reward_weights + + return True diff --git a/rl_algo_impls/shared/callbacks/optimize_callback.py b/rl_algo_impls/shared/callbacks/optimize_callback.py index 75f1cbcb79b30e905f78ecf9fc38c00f79bf9207..9af15f32f861a2a4cc136f5855d5490d5b95f983 100644 --- a/rl_algo_impls/shared/callbacks/optimize_callback.py +++ b/rl_algo_impls/shared/callbacks/optimize_callback.py @@ -5,7 +5,7 @@ from time import perf_counter from torch.utils.tensorboard.writer import SummaryWriter from typing import NamedTuple, Union -from rl_algo_impls.shared.callbacks.callback import Callback +from rl_algo_impls.shared.callbacks import Callback from rl_algo_impls.shared.callbacks.eval_callback import evaluate from rl_algo_impls.shared.policy.policy import Policy from rl_algo_impls.shared.stats import EpisodesStats diff --git a/rl_algo_impls/shared/callbacks/self_play_callback.py b/rl_algo_impls/shared/callbacks/self_play_callback.py new file mode 100644 index 0000000000000000000000000000000000000000..2457e09665196de3960daf99a59f3b4e1383b0d3 --- /dev/null +++ b/rl_algo_impls/shared/callbacks/self_play_callback.py @@ -0,0 +1,34 @@ +from typing import Callable + +from rl_algo_impls.shared.callbacks import Callback +from rl_algo_impls.shared.policy.policy import Policy +from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper + + +class SelfPlayCallback(Callback): + def __init__( + self, + policy: Policy, + policy_factory: Callable[[], Policy], + selfPlayWrapper: SelfPlayWrapper, + ) -> None: + super().__init__() + self.policy = policy + self.policy_factory = policy_factory + self.selfPlayWrapper = selfPlayWrapper + self.checkpoint_policy() + + def on_step(self, timesteps_elapsed: int = 1) -> bool: + super().on_step(timesteps_elapsed) + if ( + self.timesteps_elapsed + >= self.last_checkpoint_step + self.selfPlayWrapper.save_steps + ): + self.checkpoint_policy() + return True + + def checkpoint_policy(self): + self.selfPlayWrapper.checkpoint_policy( + self.policy_factory().load_from(self.policy) + ) + self.last_checkpoint_step = self.timesteps_elapsed diff --git a/rl_algo_impls/shared/encoder/cnn.py b/rl_algo_impls/shared/encoder/cnn.py index b4e324bdedc4d4ce8d82231e6b2b218cdca9b5ba..88e0c6a2a469c62180e3e521d69a077f31bfc753 100644 --- a/rl_algo_impls/shared/encoder/cnn.py +++ b/rl_algo_impls/shared/encoder/cnn.py @@ -6,7 +6,7 @@ import numpy as np import torch import torch.nn as nn -from rl_algo_impls.shared.module.module import layer_init +from rl_algo_impls.shared.module.utils import layer_init EncoderOutDim = Union[int, Tuple[int, ...]] diff --git a/rl_algo_impls/shared/encoder/encoder.py b/rl_algo_impls/shared/encoder/encoder.py index 39dd5ba192f593c79ebe470ead65fae0ed83dd47..cb9106270bac7b54843e269a0a2a4825876ca791 100644 --- a/rl_algo_impls/shared/encoder/encoder.py +++ b/rl_algo_impls/shared/encoder/encoder.py @@ -12,7 +12,7 @@ from rl_algo_impls.shared.encoder.gridnet_encoder import GridnetEncoder from rl_algo_impls.shared.encoder.impala_cnn import ImpalaCnn from rl_algo_impls.shared.encoder.microrts_cnn import MicrortsCnn from rl_algo_impls.shared.encoder.nature_cnn import NatureCnn -from rl_algo_impls.shared.module.module import layer_init +from rl_algo_impls.shared.module.utils import layer_init CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnEncoder]] = { "nature": NatureCnn, diff --git a/rl_algo_impls/shared/encoder/gridnet_encoder.py b/rl_algo_impls/shared/encoder/gridnet_encoder.py index 5930894e98992c90d78768f68844ba9abe9644b8..9388ce575b12e19414d926454638c7fa91a5424b 100644 --- a/rl_algo_impls/shared/encoder/gridnet_encoder.py +++ b/rl_algo_impls/shared/encoder/gridnet_encoder.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn from rl_algo_impls.shared.encoder.cnn import CnnEncoder, EncoderOutDim -from rl_algo_impls.shared.module.module import layer_init +from rl_algo_impls.shared.module.utils import layer_init class GridnetEncoder(CnnEncoder): diff --git a/rl_algo_impls/shared/encoder/impala_cnn.py b/rl_algo_impls/shared/encoder/impala_cnn.py index d14a8a51776792eea647ccc727f77fa9c0991e2f..41c5a42eb5b18e26a4960f57036ca47ec2346d66 100644 --- a/rl_algo_impls/shared/encoder/impala_cnn.py +++ b/rl_algo_impls/shared/encoder/impala_cnn.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder -from rl_algo_impls.shared.module.module import layer_init +from rl_algo_impls.shared.module.utils import layer_init class ResidualBlock(nn.Module): diff --git a/rl_algo_impls/shared/encoder/microrts_cnn.py b/rl_algo_impls/shared/encoder/microrts_cnn.py index 29f18af52350308923bf9c51db1d25e4cbb49601..0f61a23458f9b38ebc2b8371620220ae2cdef29e 100644 --- a/rl_algo_impls/shared/encoder/microrts_cnn.py +++ b/rl_algo_impls/shared/encoder/microrts_cnn.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder -from rl_algo_impls.shared.module.module import layer_init +from rl_algo_impls.shared.module.utils import layer_init class MicrortsCnn(FlattenedCnnEncoder): diff --git a/rl_algo_impls/shared/encoder/nature_cnn.py b/rl_algo_impls/shared/encoder/nature_cnn.py index 21a77f9cdb0a4b1029b10a150ec81990bbfaff3a..f031729b012f73c1e9917503fadcc31e10c4ab80 100644 --- a/rl_algo_impls/shared/encoder/nature_cnn.py +++ b/rl_algo_impls/shared/encoder/nature_cnn.py @@ -4,7 +4,7 @@ import gym import torch.nn as nn from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder -from rl_algo_impls.shared.module.module import layer_init +from rl_algo_impls.shared.module.utils import layer_init class NatureCnn(FlattenedCnnEncoder): diff --git a/rl_algo_impls/shared/gae.py b/rl_algo_impls/shared/gae.py index 7b5cbacc2d82e7c262c91b52afa4a6929f4c439e..f1b82fbed1a5b880e5ccf858a9ea257bfcc81c65 100644 --- a/rl_algo_impls/shared/gae.py +++ b/rl_algo_impls/shared/gae.py @@ -3,7 +3,7 @@ import torch from typing import NamedTuple, Sequence -from rl_algo_impls.shared.policy.on_policy import OnPolicy +from rl_algo_impls.shared.policy.actor_critic import OnPolicy from rl_algo_impls.shared.trajectory import Trajectory from rl_algo_impls.wrappers.vectorable_wrapper import VecEnvObs diff --git a/rl_algo_impls/shared/module/module.py b/rl_algo_impls/shared/module/utils.py similarity index 100% rename from rl_algo_impls/shared/module/module.py rename to rl_algo_impls/shared/module/utils.py diff --git a/rl_algo_impls/shared/policy/on_policy.py b/rl_algo_impls/shared/policy/actor_critic.py similarity index 55% rename from rl_algo_impls/shared/policy/on_policy.py rename to rl_algo_impls/shared/policy/actor_critic.py index 4484c053eda3a17a2a575961cad235ba09e5bff7..a77fdf0bf3c05617611af5f63ed94c9bc1bb3733 100644 --- a/rl_algo_impls/shared/policy/on_policy.py +++ b/rl_algo_impls/shared/policy/actor_critic.py @@ -4,12 +4,14 @@ from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar import gym import numpy as np import torch -from gym.spaces import Box, Discrete, Space +from gym.spaces import Box, Space -from rl_algo_impls.shared.actor import PiForward, actor_head -from rl_algo_impls.shared.encoder import Encoder -from rl_algo_impls.shared.policy.critic import CriticHead -from rl_algo_impls.shared.policy.policy import ACTIVATION, Policy +from rl_algo_impls.shared.policy.actor_critic_network import ( + ConnectedTrioActorCriticNetwork, + SeparateActorCriticNetwork, + UNetActorCriticNetwork, +) +from rl_algo_impls.shared.policy.policy import Policy from rl_algo_impls.wrappers.vectorable_wrapper import ( VecEnv, VecEnvObs, @@ -52,21 +54,6 @@ def clamp_actions( return actions -def default_hidden_sizes(obs_space: Space) -> Sequence[int]: - if isinstance(obs_space, Box): - if len(obs_space.shape) == 3: - # By default feature extractor to output has no hidden layers - return [] - elif len(obs_space.shape) == 1: - return [64, 64] - else: - raise ValueError(f"Unsupported observation space: {obs_space}") - elif isinstance(obs_space, Discrete): - return [64] - else: - raise ValueError(f"Unsupported observation space: {obs_space}") - - class OnPolicy(Policy): @abstractmethod def value(self, obs: VecEnvObs) -> np.ndarray: @@ -106,78 +93,59 @@ class ActorCritic(OnPolicy): observation_space = single_observation_space(env) action_space = single_action_space(env) + action_plane_space = getattr(env, "action_plane_space", None) - pi_hidden_sizes = ( - pi_hidden_sizes - if pi_hidden_sizes is not None - else default_hidden_sizes(observation_space) - ) - v_hidden_sizes = ( - v_hidden_sizes - if v_hidden_sizes is not None - else default_hidden_sizes(observation_space) - ) - - activation = ACTIVATION[activation_fn] self.action_space = action_space self.squash_output = squash_output - self.share_features_extractor = share_features_extractor - self._feature_extractor = Encoder( - observation_space, - activation, - init_layers_orthogonal=init_layers_orthogonal, - cnn_flatten_dim=cnn_flatten_dim, - cnn_style=cnn_style, - cnn_layers_init_orthogonal=cnn_layers_init_orthogonal, - impala_channels=impala_channels, - ) - self._pi = actor_head( - self.action_space, - self._feature_extractor.out_dim, - tuple(pi_hidden_sizes), - init_layers_orthogonal, - activation, - log_std_init=log_std_init, - use_sde=use_sde, - full_std=full_std, - squash_output=squash_output, - actor_head_style=actor_head_style, - ) - if not share_features_extractor: - self._v_feature_extractor = Encoder( + if actor_head_style == "unet": + self.network = UNetActorCriticNetwork( observation_space, - activation, + action_space, + action_plane_space, + v_hidden_sizes=v_hidden_sizes, init_layers_orthogonal=init_layers_orthogonal, + activation_fn=activation_fn, + cnn_layers_init_orthogonal=cnn_layers_init_orthogonal, + ) + elif share_features_extractor: + self.network = ConnectedTrioActorCriticNetwork( + observation_space, + action_space, + pi_hidden_sizes=pi_hidden_sizes, + v_hidden_sizes=v_hidden_sizes, + init_layers_orthogonal=init_layers_orthogonal, + activation_fn=activation_fn, + log_std_init=log_std_init, + use_sde=use_sde, + full_std=full_std, + squash_output=squash_output, cnn_flatten_dim=cnn_flatten_dim, cnn_style=cnn_style, cnn_layers_init_orthogonal=cnn_layers_init_orthogonal, + impala_channels=impala_channels, + actor_head_style=actor_head_style, + action_plane_space=action_plane_space, ) - critic_in_dim = self._v_feature_extractor.out_dim else: - self._v_feature_extractor = None - critic_in_dim = self._feature_extractor.out_dim - self._v = CriticHead( - in_dim=critic_in_dim, - hidden_sizes=v_hidden_sizes, - activation=activation, - init_layers_orthogonal=init_layers_orthogonal, - ) - - def _pi_forward( - self, - obs: torch.Tensor, - action_masks: Optional[torch.Tensor], - action: Optional[torch.Tensor] = None, - ) -> Tuple[PiForward, torch.Tensor]: - p_fe = self._feature_extractor(obs) - pi_forward = self._pi(p_fe, actions=action, action_masks=action_masks) - - return pi_forward, p_fe - - def _v_forward(self, obs: torch.Tensor, p_fc: torch.Tensor) -> torch.Tensor: - v_fe = self._v_feature_extractor(obs) if self._v_feature_extractor else p_fc - return self._v(v_fe) + self.network = SeparateActorCriticNetwork( + observation_space, + action_space, + pi_hidden_sizes=pi_hidden_sizes, + v_hidden_sizes=v_hidden_sizes, + init_layers_orthogonal=init_layers_orthogonal, + activation_fn=activation_fn, + log_std_init=log_std_init, + use_sde=use_sde, + full_std=full_std, + squash_output=squash_output, + cnn_flatten_dim=cnn_flatten_dim, + cnn_style=cnn_style, + cnn_layers_init_orthogonal=cnn_layers_init_orthogonal, + impala_channels=impala_channels, + actor_head_style=actor_head_style, + action_plane_space=action_plane_space, + ) def forward( self, @@ -185,8 +153,7 @@ class ActorCritic(OnPolicy): action: torch.Tensor, action_masks: Optional[torch.Tensor] = None, ) -> ACForward: - (_, logp_a, entropy), p_fc = self._pi_forward(obs, action_masks, action=action) - v = self._v_forward(obs, p_fc) + (_, logp_a, entropy), v = self.network(obs, action, action_masks=action_masks) assert logp_a is not None assert entropy is not None @@ -195,24 +162,17 @@ class ActorCritic(OnPolicy): def value(self, obs: VecEnvObs) -> np.ndarray: o = self._as_tensor(obs) with torch.no_grad(): - fe = ( - self._v_feature_extractor(o) - if self._v_feature_extractor - else self._feature_extractor(o) - ) - v = self._v(fe) + v = self.network.value(o) return v.cpu().numpy() def step(self, obs: VecEnvObs, action_masks: Optional[np.ndarray] = None) -> Step: o = self._as_tensor(obs) a_masks = self._as_tensor(action_masks) if action_masks is not None else None with torch.no_grad(): - (pi, _, _), p_fc = self._pi_forward(o, action_masks=a_masks) + (pi, _, _), v = self.network.distribution_and_value(o, action_masks=a_masks) a = pi.sample() logp_a = pi.log_prob(a) - v = self._v_forward(o, p_fc) - a_np = a.cpu().numpy() clamped_a_np = clamp_actions(a_np, self.action_space, self.squash_output) return Step(a_np, v.cpu().numpy(), logp_a.cpu().numpy(), clamped_a_np) @@ -231,7 +191,9 @@ class ActorCritic(OnPolicy): self._as_tensor(action_masks) if action_masks is not None else None ) with torch.no_grad(): - (pi, _, _), _ = self._pi_forward(o, action_masks=a_masks) + (pi, _, _), _ = self.network.distribution_and_value( + o, action_masks=a_masks + ) a = pi.mode return clamp_actions(a.cpu().numpy(), self.action_space, self.squash_output) @@ -239,11 +201,16 @@ class ActorCritic(OnPolicy): super().load(path) self.reset_noise() + def load_from(self: ActorCriticSelf, policy: ActorCriticSelf) -> ActorCriticSelf: + super().load_from(policy) + self.reset_noise() + return self + def reset_noise(self, batch_size: Optional[int] = None) -> None: - self._pi.sample_weights( + self.network.reset_noise( batch_size=batch_size if batch_size else self.env.num_envs ) @property def action_shape(self) -> Tuple[int, ...]: - return self._pi.action_shape + return self.network.action_shape diff --git a/rl_algo_impls/shared/policy/actor_critic_network/__init__.py b/rl_algo_impls/shared/policy/actor_critic_network/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..47a2ba0d8155a1f505cf82f030d0e17d28345c53 --- /dev/null +++ b/rl_algo_impls/shared/policy/actor_critic_network/__init__.py @@ -0,0 +1,11 @@ +from rl_algo_impls.shared.policy.actor_critic_network.connected_trio import ( + ConnectedTrioActorCriticNetwork, +) +from rl_algo_impls.shared.policy.actor_critic_network.network import ( + ActorCriticNetwork, + default_hidden_sizes, +) +from rl_algo_impls.shared.policy.actor_critic_network.separate_actor_critic import ( + SeparateActorCriticNetwork, +) +from rl_algo_impls.shared.policy.actor_critic_network.unet import UNetActorCriticNetwork diff --git a/rl_algo_impls/shared/policy/actor_critic_network/connected_trio.py b/rl_algo_impls/shared/policy/actor_critic_network/connected_trio.py new file mode 100644 index 0000000000000000000000000000000000000000..9b954ac8bafe96f42df51500f5ff4533e3fdadec --- /dev/null +++ b/rl_algo_impls/shared/policy/actor_critic_network/connected_trio.py @@ -0,0 +1,118 @@ +from typing import Optional, Sequence, Tuple + +import torch +from gym.spaces import Space + +from rl_algo_impls.shared.actor import actor_head +from rl_algo_impls.shared.encoder import Encoder +from rl_algo_impls.shared.policy.actor_critic_network.network import ( + ACNForward, + ActorCriticNetwork, + default_hidden_sizes, +) +from rl_algo_impls.shared.policy.critic import CriticHead +from rl_algo_impls.shared.policy.policy import ACTIVATION + + +class ConnectedTrioActorCriticNetwork(ActorCriticNetwork): + """Encode (feature extractor), decoder (actor head), critic head networks""" + + def __init__( + self, + observation_space: Space, + action_space: Space, + pi_hidden_sizes: Optional[Sequence[int]] = None, + v_hidden_sizes: Optional[Sequence[int]] = None, + init_layers_orthogonal: bool = True, + activation_fn: str = "tanh", + log_std_init: float = -0.5, + use_sde: bool = False, + full_std: bool = True, + squash_output: bool = False, + cnn_flatten_dim: int = 512, + cnn_style: str = "nature", + cnn_layers_init_orthogonal: Optional[bool] = None, + impala_channels: Sequence[int] = (16, 32, 32), + actor_head_style: str = "single", + action_plane_space: Optional[Space] = None, + ) -> None: + super().__init__() + + pi_hidden_sizes = ( + pi_hidden_sizes + if pi_hidden_sizes is not None + else default_hidden_sizes(observation_space) + ) + v_hidden_sizes = ( + v_hidden_sizes + if v_hidden_sizes is not None + else default_hidden_sizes(observation_space) + ) + + activation = ACTIVATION[activation_fn] + self._feature_extractor = Encoder( + observation_space, + activation, + init_layers_orthogonal=init_layers_orthogonal, + cnn_flatten_dim=cnn_flatten_dim, + cnn_style=cnn_style, + cnn_layers_init_orthogonal=cnn_layers_init_orthogonal, + impala_channels=impala_channels, + ) + self._pi = actor_head( + action_space, + self._feature_extractor.out_dim, + tuple(pi_hidden_sizes), + init_layers_orthogonal, + activation, + log_std_init=log_std_init, + use_sde=use_sde, + full_std=full_std, + squash_output=squash_output, + actor_head_style=actor_head_style, + action_plane_space=action_plane_space, + ) + + self._v = CriticHead( + in_dim=self._feature_extractor.out_dim, + hidden_sizes=v_hidden_sizes, + activation=activation, + init_layers_orthogonal=init_layers_orthogonal, + ) + + def forward( + self, + obs: torch.Tensor, + action: torch.Tensor, + action_masks: Optional[torch.Tensor] = None, + ) -> ACNForward: + return self._distribution_and_value( + obs, action=action, action_masks=action_masks + ) + + def distribution_and_value( + self, obs: torch.Tensor, action_masks: Optional[torch.Tensor] = None + ) -> ACNForward: + return self._distribution_and_value(obs, action_masks=action_masks) + + def _distribution_and_value( + self, + obs: torch.Tensor, + action: Optional[torch.Tensor] = None, + action_masks: Optional[torch.Tensor] = None, + ) -> ACNForward: + encoded = self._feature_extractor(obs) + pi_forward = self._pi(encoded, actions=action, action_masks=action_masks) + v = self._v(encoded) + return ACNForward(pi_forward, v) + + def value(self, obs: torch.Tensor) -> torch.Tensor: + encoded = self._feature_extractor(obs) + return self._v(encoded) + + def reset_noise(self, batch_size: int) -> None: + self._pi.sample_weights(batch_size=batch_size) + + @property + def action_shape(self) -> Tuple[int, ...]: + return self._pi.action_shape diff --git a/rl_algo_impls/shared/policy/actor_critic_network/network.py b/rl_algo_impls/shared/policy/actor_critic_network/network.py new file mode 100644 index 0000000000000000000000000000000000000000..4517fa08229dc03c8aebf9d5d81138b643513508 --- /dev/null +++ b/rl_algo_impls/shared/policy/actor_critic_network/network.py @@ -0,0 +1,57 @@ +from abc import ABC, abstractmethod +from typing import NamedTuple, Optional, Sequence, Tuple + +import torch +import torch.nn as nn +from gym.spaces import Box, Discrete, Space + +from rl_algo_impls.shared.actor import PiForward + + +class ACNForward(NamedTuple): + pi_forward: PiForward + v: torch.Tensor + + +class ActorCriticNetwork(nn.Module, ABC): + @abstractmethod + def forward( + self, + obs: torch.Tensor, + action: torch.Tensor, + action_masks: Optional[torch.Tensor] = None, + ) -> ACNForward: + ... + + @abstractmethod + def distribution_and_value( + self, obs: torch.Tensor, action_masks: Optional[torch.Tensor] = None + ) -> ACNForward: + ... + + @abstractmethod + def value(self, obs: torch.Tensor) -> torch.Tensor: + ... + + @abstractmethod + def reset_noise(self, batch_size: Optional[int] = None) -> None: + ... + + @property + def action_shape(self) -> Tuple[int, ...]: + ... + + +def default_hidden_sizes(obs_space: Space) -> Sequence[int]: + if isinstance(obs_space, Box): + if len(obs_space.shape) == 3: # type: ignore + # By default feature extractor to output has no hidden layers + return [] + elif len(obs_space.shape) == 1: # type: ignore + return [64, 64] + else: + raise ValueError(f"Unsupported observation space: {obs_space}") + elif isinstance(obs_space, Discrete): + return [64] + else: + raise ValueError(f"Unsupported observation space: {obs_space}") diff --git a/rl_algo_impls/shared/policy/actor_critic_network/separate_actor_critic.py b/rl_algo_impls/shared/policy/actor_critic_network/separate_actor_critic.py new file mode 100644 index 0000000000000000000000000000000000000000..1c272bcee05863236e5f69c7cb24b3fe71e89971 --- /dev/null +++ b/rl_algo_impls/shared/policy/actor_critic_network/separate_actor_critic.py @@ -0,0 +1,128 @@ +from typing import Optional, Sequence, Tuple + +import torch +import torch.nn as nn +from gym.spaces import Space + +from rl_algo_impls.shared.actor import actor_head +from rl_algo_impls.shared.encoder import Encoder +from rl_algo_impls.shared.policy.actor_critic_network.network import ( + ACNForward, + ActorCriticNetwork, + default_hidden_sizes, +) +from rl_algo_impls.shared.policy.critic import CriticHead +from rl_algo_impls.shared.policy.policy import ACTIVATION + + +class SeparateActorCriticNetwork(ActorCriticNetwork): + def __init__( + self, + observation_space: Space, + action_space: Space, + pi_hidden_sizes: Optional[Sequence[int]] = None, + v_hidden_sizes: Optional[Sequence[int]] = None, + init_layers_orthogonal: bool = True, + activation_fn: str = "tanh", + log_std_init: float = -0.5, + use_sde: bool = False, + full_std: bool = True, + squash_output: bool = False, + cnn_flatten_dim: int = 512, + cnn_style: str = "nature", + cnn_layers_init_orthogonal: Optional[bool] = None, + impala_channels: Sequence[int] = (16, 32, 32), + actor_head_style: str = "single", + action_plane_space: Optional[Space] = None, + ) -> None: + super().__init__() + + pi_hidden_sizes = ( + pi_hidden_sizes + if pi_hidden_sizes is not None + else default_hidden_sizes(observation_space) + ) + v_hidden_sizes = ( + v_hidden_sizes + if v_hidden_sizes is not None + else default_hidden_sizes(observation_space) + ) + + activation = ACTIVATION[activation_fn] + self._feature_extractor = Encoder( + observation_space, + activation, + init_layers_orthogonal=init_layers_orthogonal, + cnn_flatten_dim=cnn_flatten_dim, + cnn_style=cnn_style, + cnn_layers_init_orthogonal=cnn_layers_init_orthogonal, + impala_channels=impala_channels, + ) + self._pi = actor_head( + action_space, + self._feature_extractor.out_dim, + tuple(pi_hidden_sizes), + init_layers_orthogonal, + activation, + log_std_init=log_std_init, + use_sde=use_sde, + full_std=full_std, + squash_output=squash_output, + actor_head_style=actor_head_style, + action_plane_space=action_plane_space, + ) + + v_encoder = Encoder( + observation_space, + activation, + init_layers_orthogonal=init_layers_orthogonal, + cnn_flatten_dim=cnn_flatten_dim, + cnn_style=cnn_style, + cnn_layers_init_orthogonal=cnn_layers_init_orthogonal, + ) + self._v = nn.Sequential( + v_encoder, + CriticHead( + in_dim=v_encoder.out_dim, + hidden_sizes=v_hidden_sizes, + activation=activation, + init_layers_orthogonal=init_layers_orthogonal, + ), + ) + + def forward( + self, + obs: torch.Tensor, + action: torch.Tensor, + action_masks: Optional[torch.Tensor] = None, + ) -> ACNForward: + return self._distribution_and_value( + obs, action=action, action_masks=action_masks + ) + + def distribution_and_value( + self, obs: torch.Tensor, action_masks: Optional[torch.Tensor] = None + ) -> ACNForward: + return self._distribution_and_value(obs, action_masks=action_masks) + + def _distribution_and_value( + self, + obs: torch.Tensor, + action: Optional[torch.Tensor] = None, + action_masks: Optional[torch.Tensor] = None, + ) -> ACNForward: + pi_forward = self._pi( + self._feature_extractor(obs), actions=action, action_masks=action_masks + ) + v = self._v(obs) + return ACNForward(pi_forward, v) + + def value(self, obs: torch.Tensor) -> torch.Tensor: + return self._v(obs) + + def reset_noise(self, batch_size: int) -> None: + self._pi.sample_weights(batch_size=batch_size) + + @property + def action_shape(self) -> Tuple[int, ...]: + return self._pi.action_shape diff --git a/rl_algo_impls/shared/policy/actor_critic_network/unet.py b/rl_algo_impls/shared/policy/actor_critic_network/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..c15dec8996578c98ffb472822795302621e0308d --- /dev/null +++ b/rl_algo_impls/shared/policy/actor_critic_network/unet.py @@ -0,0 +1,196 @@ +from typing import Optional, Sequence, Tuple, Type + +import numpy as np +import torch +import torch.nn as nn +from gym.spaces import MultiDiscrete, Space + +from rl_algo_impls.shared.actor import pi_forward +from rl_algo_impls.shared.actor.gridnet import GridnetDistribution +from rl_algo_impls.shared.actor.gridnet_decoder import Transpose +from rl_algo_impls.shared.module.utils import layer_init +from rl_algo_impls.shared.policy.actor_critic_network.network import ( + ACNForward, + ActorCriticNetwork, + default_hidden_sizes, +) +from rl_algo_impls.shared.policy.critic import CriticHead +from rl_algo_impls.shared.policy.policy import ACTIVATION + + +class UNetActorCriticNetwork(ActorCriticNetwork): + def __init__( + self, + observation_space: Space, + action_space: Space, + action_plane_space: Space, + v_hidden_sizes: Optional[Sequence[int]] = None, + init_layers_orthogonal: bool = True, + activation_fn: str = "tanh", + cnn_layers_init_orthogonal: Optional[bool] = None, + ) -> None: + if cnn_layers_init_orthogonal is None: + cnn_layers_init_orthogonal = True + super().__init__() + assert isinstance(action_space, MultiDiscrete) + assert isinstance(action_plane_space, MultiDiscrete) + self.range_size = np.max(observation_space.high) - np.min(observation_space.low) # type: ignore + self.map_size = len(action_space.nvec) // len(action_plane_space.nvec) # type: ignore + self.action_vec = action_plane_space.nvec # type: ignore + + activation = ACTIVATION[activation_fn] + + def conv_relu( + in_channels: int, out_channels: int, kernel_size: int = 3, padding: int = 1 + ) -> nn.Module: + return nn.Sequential( + layer_init( + nn.Conv2d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=padding, + ), + cnn_layers_init_orthogonal, + ), + activation(), + ) + + def up_conv_relu(in_channels: int, out_channels: int) -> nn.Module: + return nn.Sequential( + layer_init( + nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size=3, + stride=2, + padding=1, + output_padding=1, + ), + cnn_layers_init_orthogonal, + ), + activation(), + ) + + in_channels = observation_space.shape[0] # type: ignore + self.enc1 = conv_relu(in_channels, 32) + self.enc2 = nn.Sequential(max_pool(), conv_relu(32, 64)) + self.enc3 = nn.Sequential(max_pool(), conv_relu(64, 128)) + self.enc4 = nn.Sequential(max_pool(), conv_relu(128, 256)) + self.enc5 = nn.Sequential( + max_pool(), conv_relu(256, 512, kernel_size=1, padding=0) + ) + + self.dec4 = up_conv_relu(512, 256) + self.dec3 = nn.Sequential(conv_relu(512, 256), up_conv_relu(256, 128)) + self.dec2 = nn.Sequential(conv_relu(256, 128), up_conv_relu(128, 64)) + self.dec1 = nn.Sequential(conv_relu(128, 64), up_conv_relu(64, 32)) + self.out = nn.Sequential( + conv_relu(64, 32), + layer_init( + nn.Conv2d(32, self.action_vec.sum(), kernel_size=1, padding=0), + cnn_layers_init_orthogonal, + std=0.01, + ), + Transpose((0, 2, 3, 1)), + ) + + with torch.no_grad(): + cnn_out = torch.flatten( + self.enc5( + self.enc4( + self.enc3( + self.enc2( + self.enc1( + self._preprocess( + torch.as_tensor(observation_space.sample()) + ) + ) + ) + ) + ) + ), + start_dim=1, + ) + + v_hidden_sizes = ( + v_hidden_sizes + if v_hidden_sizes is not None + else default_hidden_sizes(observation_space) + ) + self.critic_head = CriticHead( + in_dim=cnn_out.shape[1:], + hidden_sizes=v_hidden_sizes, + activation=activation, + init_layers_orthogonal=init_layers_orthogonal, + ) + + def _preprocess(self, obs: torch.Tensor) -> torch.Tensor: + if len(obs.shape) == 3: + obs = obs.unsqueeze(0) + return obs.float() / self.range_size + + def forward( + self, + obs: torch.Tensor, + action: torch.Tensor, + action_masks: Optional[torch.Tensor] = None, + ) -> ACNForward: + return self._distribution_and_value( + obs, action=action, action_masks=action_masks + ) + + def distribution_and_value( + self, obs: torch.Tensor, action_masks: Optional[torch.Tensor] = None + ) -> ACNForward: + return self._distribution_and_value(obs, action_masks=action_masks) + + def _distribution_and_value( + self, + obs: torch.Tensor, + action: Optional[torch.Tensor] = None, + action_masks: Optional[torch.Tensor] = None, + ) -> ACNForward: + assert ( + action_masks is not None + ), f"No mask case unhandled in {self.__class__.__name__}" + + obs = self._preprocess(obs) + e1 = self.enc1(obs) + e2 = self.enc2(e1) + e3 = self.enc3(e2) + e4 = self.enc4(e3) + e5 = self.enc5(e4) + + v = self.critic_head(e5) + + d4 = self.dec4(e5) + d3 = self.dec3(torch.cat((d4, e4), dim=1)) + d2 = self.dec2(torch.cat((d3, e3), dim=1)) + d1 = self.dec1(torch.cat((d2, e2), dim=1)) + logits = self.out(torch.cat((d1, e1), dim=1)) + + pi = GridnetDistribution(self.map_size, self.action_vec, logits, action_masks) + + return ACNForward(pi_forward(pi, action), v) + + def value(self, obs: torch.Tensor) -> torch.Tensor: + obs = self._preprocess(obs) + e1 = self.enc1(obs) + e2 = self.enc2(e1) + e3 = self.enc3(e2) + e4 = self.enc4(e3) + e5 = self.enc5(e4) + + return self.critic_head(e5) + + def reset_noise(self, batch_size: Optional[int] = None) -> None: + pass + + @property + def action_shape(self) -> Tuple[int, ...]: + return (self.map_size, len(self.action_vec)) + + +def max_pool() -> nn.MaxPool2d: + return nn.MaxPool2d(3, stride=2, padding=1) diff --git a/rl_algo_impls/shared/policy/critic.py b/rl_algo_impls/shared/policy/critic.py index ffb0752eeeab6de71e24a8ea5f716ae8921fa543..4a47dfcb22e915ccf18b3ae747d731dcc71410eb 100644 --- a/rl_algo_impls/shared/policy/critic.py +++ b/rl_algo_impls/shared/policy/critic.py @@ -5,7 +5,7 @@ import torch import torch.nn as nn from rl_algo_impls.shared.encoder import EncoderOutDim -from rl_algo_impls.shared.module.module import mlp +from rl_algo_impls.shared.module.utils import mlp class CriticHead(nn.Module): diff --git a/rl_algo_impls/shared/policy/policy.py b/rl_algo_impls/shared/policy/policy.py index d84a9a5ec2f386ea540b9a05e733d5c29138fe4c..f5dbd313785ddd4fbb9149a7b38e4a85fb106673 100644 --- a/rl_algo_impls/shared/policy/policy.py +++ b/rl_algo_impls/shared/policy/policy.py @@ -1,13 +1,13 @@ -import numpy as np import os -import torch -import torch.nn as nn - from abc import ABC, abstractmethod from copy import deepcopy +from typing import Dict, Optional, Type, TypeVar, Union + +import numpy as np +import torch +import torch.nn as nn from stable_baselines3.common.vec_env import unwrap_vec_normalize from stable_baselines3.common.vec_env.vec_normalize import VecNormalize -from typing import Dict, Optional, Type, TypeVar, Union from rl_algo_impls.wrappers.normalize import NormalizeObservation, NormalizeReward from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv, VecEnvObs, find_wrapper @@ -82,6 +82,16 @@ class Policy(nn.Module, ABC): if self.norm_reward: self.norm_reward.load(os.path.join(path, NORMALIZE_REWARD_FILENAME)) + def load_from(self: PolicySelf, policy: PolicySelf) -> PolicySelf: + self.load_state_dict(policy.state_dict()) + if self.norm_observation: + assert policy.norm_observation + self.norm_observation.load_from(policy.norm_observation) + if self.norm_reward: + assert policy.norm_reward + self.norm_reward.load_from(policy.norm_reward) + return self + def reset_noise(self) -> None: pass diff --git a/rl_algo_impls/shared/stats.py b/rl_algo_impls/shared/stats.py index fe020c6d930ae4ed9b1cc98bb32fd72a35f0fd90..189a4229650f27a224172ab20cb9401d8999fd21 100644 --- a/rl_algo_impls/shared/stats.py +++ b/rl_algo_impls/shared/stats.py @@ -21,6 +21,7 @@ StatisticSelf = TypeVar("StatisticSelf", bound="Statistic") class Statistic: values: np.ndarray round_digits: int = 2 + score_function: str = "mean-std" @property def mean(self) -> float: @@ -44,8 +45,18 @@ class Statistic: def __len__(self) -> int: return len(self.values) + def score(self) -> float: + if self.score_function == "mean-std": + return self.mean - self.std + elif self.score_function == "mean": + return self.mean + else: + raise NotImplemented( + f"Only mean-std and mean score_functions supported ({self.score_function})" + ) + def _diff(self: StatisticSelf, o: StatisticSelf) -> float: - return (self.mean - self.std) - (o.mean - o.std) + return self.score() - o.score() def __gt__(self: StatisticSelf, o: StatisticSelf) -> bool: return self._diff(o) > 0 @@ -55,9 +66,13 @@ class Statistic: def __repr__(self) -> str: mean = round(self.mean, self.round_digits) - std = round(self.std, self.round_digits) if self.round_digits == 0: mean = int(mean) + if self.score_function == "mean": + return f"{mean}" + + std = round(self.std, self.round_digits) + if self.round_digits == 0: std = int(std) return f"{mean} +/- {std}" @@ -74,16 +89,17 @@ EpisodesStatsSelf = TypeVar("EpisodesStatsSelf", bound="EpisodesStats") class EpisodesStats: - episodes: Sequence[Episode] - simple: bool - score: Statistic - length: Statistic - additional_stats: Dict[str, Statistic] - - def __init__(self, episodes: Sequence[Episode], simple: bool = False) -> None: + def __init__( + self, + episodes: Sequence[Episode], + simple: bool = False, + score_function: str = "mean-std", + ) -> None: self.episodes = episodes self.simple = simple - self.score = Statistic(np.array([e.score for e in episodes])) + self.score = Statistic( + np.array([e.score for e in episodes]), score_function=score_function + ) self.length = Statistic(np.array([e.length for e in episodes]), round_digits=0) additional_values = defaultdict(list) for e in self.episodes: @@ -97,6 +113,7 @@ class EpisodesStats: self.additional_stats = { k: Statistic(np.array(values)) for k, values in additional_values.items() } + self.score_function = score_function def __gt__(self: EpisodesStatsSelf, o: EpisodesStatsSelf) -> bool: return self.score > o.score @@ -105,10 +122,12 @@ class EpisodesStats: return self.score >= o.score def __repr__(self) -> str: - return ( - f"Score: {self.score} ({round(self.score.mean - self.score.std, 2)}) | " - f"Length: {self.length}" - ) + mean = self.score.mean + score = self.score.score() + if mean != score: + return f"Score: {self.score} ({round(score)}) | Length: {self.length}" + else: + return f"Score: {self.score} | Length: {self.length}" def __len__(self) -> int: return len(self.episodes) @@ -129,7 +148,7 @@ class EpisodesStats: { "min": self.score.min, "max": self.score.max, - "result": self.score.mean - self.score.std, + "result": self.score.score(), "n_episodes": len(self.episodes), "length": self.length.mean, } diff --git a/rl_algo_impls/shared/vec_env/make_env.py b/rl_algo_impls/shared/vec_env/make_env.py index f14baf2db982bf7e689313ed55d5d6d88a8eabfb..3d67e7fc5d5574581b4fd4bd062e57b86ab666ce 100644 --- a/rl_algo_impls/shared/vec_env/make_env.py +++ b/rl_algo_impls/shared/vec_env/make_env.py @@ -1,5 +1,5 @@ from dataclasses import asdict -from typing import Optional +from typing import Any, Dict, Optional from torch.utils.tensorboard.writer import SummaryWriter @@ -52,15 +52,21 @@ def make_env( def make_eval_env( config: Config, hparams: EnvHyperparams, - override_n_envs: Optional[int] = None, + override_hparams: Optional[Dict[str, Any]] = None, **kwargs, ) -> VecEnv: kwargs = kwargs.copy() kwargs["training"] = False - if override_n_envs is not None: + env_overrides = config.eval_hyperparams.get("env_overrides") + if env_overrides: hparams_kwargs = asdict(hparams) - hparams_kwargs["n_envs"] = override_n_envs - if override_n_envs == 1: - hparams_kwargs["vec_env_class"] = "sync" + hparams_kwargs.update(env_overrides) hparams = EnvHyperparams(**hparams_kwargs) - return make_env(config, hparams, **kwargs) \ No newline at end of file + if override_hparams: + hparams_kwargs = asdict(hparams) + for k, v in override_hparams.items(): + hparams_kwargs[k] = v + if k == "n_envs" and v == 1: + hparams_kwargs["vec_env_class"] = "sync" + hparams = EnvHyperparams(**hparams_kwargs) + return make_env(config, hparams, **kwargs) diff --git a/rl_algo_impls/shared/vec_env/microrts.py b/rl_algo_impls/shared/vec_env/microrts.py index 8d43dd0e0c299c58074679d671ccd3a15b08d8dd..65a1d5937b3d879ba7ddfbbaaad51374da510e3e 100644 --- a/rl_algo_impls/shared/vec_env/microrts.py +++ b/rl_algo_impls/shared/vec_env/microrts.py @@ -11,6 +11,7 @@ from rl_algo_impls.wrappers.episode_stats_writer import EpisodeStatsWriter from rl_algo_impls.wrappers.hwc_to_chw_observation import HwcToChwObservation from rl_algo_impls.wrappers.is_vector_env import IsVectorEnv from rl_algo_impls.wrappers.microrts_stats_recorder import MicrortsStatsRecorder +from rl_algo_impls.wrappers.self_play_wrapper import SelfPlayWrapper from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv @@ -26,6 +27,7 @@ def make_microrts_env( from gym_microrts import microrts_ai from rl_algo_impls.shared.vec_env.microrts_compat import ( + MicroRTSGridModeSharedMemVecEnvCompat, MicroRTSGridModeVecEnvCompat, ) @@ -47,16 +49,36 @@ def make_microrts_env( _, # normalize_type _, # mask_actions bots, + self_play_kwargs, + selfplay_bots, ) = astuple(hparams) seed = config.seed(training=training) make_kwargs = make_kwargs or {} + self_play_kwargs = self_play_kwargs or {} if "num_selfplay_envs" not in make_kwargs: make_kwargs["num_selfplay_envs"] = 0 if "num_bot_envs" not in make_kwargs: - make_kwargs["num_bot_envs"] = n_envs - make_kwargs["num_selfplay_envs"] + num_selfplay_envs = make_kwargs["num_selfplay_envs"] + if num_selfplay_envs: + num_bot_envs = ( + n_envs + - make_kwargs["num_selfplay_envs"] + + self_play_kwargs.get("num_old_policies", 0) + + (len(selfplay_bots) if selfplay_bots else 0) + ) + else: + num_bot_envs = n_envs + make_kwargs["num_bot_envs"] = num_bot_envs if "reward_weight" in make_kwargs: + # Reward Weights: + # WinLossRewardFunction + # ResourceGatherRewardFunction + # ProduceWorkerRewardFunction + # ProduceBuildingRewardFunction + # AttackRewardFunction + # ProduceCombatUnitRewardFunction make_kwargs["reward_weight"] = np.array(make_kwargs["reward_weight"]) if bots: ai2s = [] @@ -68,19 +90,28 @@ def make_microrts_env( assert ai, f"{ai_name} not in microrts_ai" ai2s.append(ai) else: - ai2s = [microrts_ai.randomAI for _ in make_kwargs["num_bot_envs"]] + ai2s = [microrts_ai.randomAI for _ in range(make_kwargs["num_bot_envs"])] make_kwargs["ai2s"] = ai2s - envs = MicroRTSGridModeVecEnvCompat(**make_kwargs) + if len(make_kwargs.get("map_paths", [])) < 2: + EnvClass = MicroRTSGridModeSharedMemVecEnvCompat + else: + EnvClass = MicroRTSGridModeVecEnvCompat + envs = EnvClass(**make_kwargs) envs = HwcToChwObservation(envs) envs = IsVectorEnv(envs) envs = MicrortsMaskWrapper(envs) + if self_play_kwargs: + if selfplay_bots: + self_play_kwargs["selfplay_bots"] = selfplay_bots + envs = SelfPlayWrapper(envs, config, **self_play_kwargs) + if seed is not None: envs.action_space.seed(seed) envs.observation_space.seed(seed) envs = gym.wrappers.RecordEpisodeStatistics(envs) - envs = MicrortsStatsRecorder(envs, config.algo_hyperparams.get("gamma", 0.99)) + envs = MicrortsStatsRecorder(envs, config.algo_hyperparams.get("gamma", 0.99), bots) if training: assert tb_writer envs = EpisodeStatsWriter( diff --git a/rl_algo_impls/shared/vec_env/microrts_compat.py b/rl_algo_impls/shared/vec_env/microrts_compat.py index c49908771dda12ec666615b8bec62312dc5b0580..b8c7abe7c302f7f042ac02f0a7b286e05268cde0 100644 --- a/rl_algo_impls/shared/vec_env/microrts_compat.py +++ b/rl_algo_impls/shared/vec_env/microrts_compat.py @@ -1,10 +1,9 @@ from typing import TypeVar -import numpy as np -from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv -from jpype.types import JArray, JInt - -from rl_algo_impls.wrappers.vectorable_wrapper import VecEnvStepReturn +from gym_microrts.envs.vec_env import ( + MicroRTSGridModeSharedMemVecEnv, + MicroRTSGridModeVecEnv, +) MicroRTSGridModeVecEnvCompatSelf = TypeVar( "MicroRTSGridModeVecEnvCompatSelf", bound="MicroRTSGridModeVecEnvCompat" @@ -12,38 +11,22 @@ MicroRTSGridModeVecEnvCompatSelf = TypeVar( class MicroRTSGridModeVecEnvCompat(MicroRTSGridModeVecEnv): - def step(self, action: np.ndarray) -> VecEnvStepReturn: - indexed_actions = np.concatenate( - [ - np.expand_dims( - np.stack( - [np.arange(0, action.shape[1]) for i in range(self.num_envs)] - ), - axis=2, - ), - action, - ], - axis=2, - ) - action_mask = np.array(self.vec_client.getMasks(0), dtype=np.bool8).reshape( - indexed_actions.shape[:-1] + (-1,) - ) - valid_action_mask = action_mask[:, :, 0] - valid_actions_counts = valid_action_mask.sum(1) - valid_actions = indexed_actions[valid_action_mask] - valid_actions_idx = 0 - - all_valid_actions = [] - for env_act_cnt in valid_actions_counts: - env_valid_actions = [] - for _ in range(env_act_cnt): - env_valid_actions.append(JArray(JInt)(valid_actions[valid_actions_idx])) - valid_actions_idx += 1 - all_valid_actions.append(JArray(JArray(JInt))(env_valid_actions)) - return super().step(JArray(JArray(JArray(JInt)))(all_valid_actions)) # type: ignore - @property def unwrapped( self: MicroRTSGridModeVecEnvCompatSelf, ) -> MicroRTSGridModeVecEnvCompatSelf: return self + + +MicroRTSGridModeSharedMemVecEnvCompatSelf = TypeVar( + "MicroRTSGridModeSharedMemVecEnvCompatSelf", + bound="MicroRTSGridModeSharedMemVecEnvCompat", +) + + +class MicroRTSGridModeSharedMemVecEnvCompat(MicroRTSGridModeSharedMemVecEnv): + @property + def unwrapped( + self: MicroRTSGridModeSharedMemVecEnvCompatSelf, + ) -> MicroRTSGridModeSharedMemVecEnvCompatSelf: + return self diff --git a/rl_algo_impls/shared/vec_env/procgen.py b/rl_algo_impls/shared/vec_env/procgen.py index d339799b75f8ec8724e3ef6b345a7e290d731c49..023844df643f271ebfeb842fb29179aa1e62f7f3 100644 --- a/rl_algo_impls/shared/vec_env/procgen.py +++ b/rl_algo_impls/shared/vec_env/procgen.py @@ -41,6 +41,8 @@ def make_procgen_env( _, # normalize_type _, # mask_actions _, # bots + _, # self_play_kwargs + _, # selfplay_bots ) = astuple(hparams) seed = config.seed(training=training) diff --git a/rl_algo_impls/shared/vec_env/vec_env.py b/rl_algo_impls/shared/vec_env/vec_env.py index 68079fc08cb3b45afd2de4710e3daa506d911c7b..f1b57bd0c7360908e6934a063957cb82abc2c61a 100644 --- a/rl_algo_impls/shared/vec_env/vec_env.py +++ b/rl_algo_impls/shared/vec_env/vec_env.py @@ -73,6 +73,8 @@ def make_vec_env( normalize_type, mask_actions, _, # bots + _, # self_play_kwargs + _, # selfplay_bots ) = astuple(hparams) import_for_env_id(config.env_id) diff --git a/rl_algo_impls/vpg/policy.py b/rl_algo_impls/vpg/policy.py index 65d29e7d0bc4cb90866cb4c293d4092e313c765a..6827921922fb8851cd6441d05e0c399e0f749b7c 100644 --- a/rl_algo_impls/vpg/policy.py +++ b/rl_algo_impls/vpg/policy.py @@ -6,13 +6,9 @@ import torch.nn as nn from rl_algo_impls.shared.actor import Actor, PiForward, actor_head from rl_algo_impls.shared.encoder import Encoder +from rl_algo_impls.shared.policy.actor_critic import OnPolicy, Step, clamp_actions +from rl_algo_impls.shared.policy.actor_critic_network import default_hidden_sizes from rl_algo_impls.shared.policy.critic import CriticHead -from rl_algo_impls.shared.policy.on_policy import ( - OnPolicy, - Step, - clamp_actions, - default_hidden_sizes, -) from rl_algo_impls.shared.policy.policy import ACTIVATION from rl_algo_impls.wrappers.vectorable_wrapper import ( VecEnv, diff --git a/rl_algo_impls/vpg/vpg.py b/rl_algo_impls/vpg/vpg.py index 9a61a860a4a4be66ed2747bb799fa36903832bae..27fe601d6dfc10eb4a6677126068cc61898acbde 100644 --- a/rl_algo_impls/vpg/vpg.py +++ b/rl_algo_impls/vpg/vpg.py @@ -1,15 +1,16 @@ +import logging +from collections import defaultdict +from dataclasses import asdict, dataclass +from typing import List, Optional, Sequence, TypeVar + import numpy as np import torch import torch.nn as nn - -from collections import defaultdict -from dataclasses import dataclass, asdict from torch.optim import Adam from torch.utils.tensorboard.writer import SummaryWriter -from typing import Optional, Sequence, TypeVar from rl_algo_impls.shared.algorithm import Algorithm -from rl_algo_impls.shared.callbacks.callback import Callback +from rl_algo_impls.shared.callbacks import Callback from rl_algo_impls.shared.gae import compute_rtg_and_advantage_from_trajectories from rl_algo_impls.shared.trajectory import Trajectory, TrajectoryAccumulator from rl_algo_impls.vpg.policy import VPGActorCritic @@ -78,7 +79,7 @@ class VanillaPolicyGradient(Algorithm): def learn( self: VanillaPolicyGradientSelf, total_timesteps: int, - callback: Optional[Callback] = None, + callbacks: Optional[List[Callback]] = None, ) -> VanillaPolicyGradientSelf: timesteps_elapsed = 0 epoch_cnt = 0 @@ -104,8 +105,12 @@ class VanillaPolicyGradient(Algorithm): ] ) ) - if callback: - callback.on_step(timesteps_elapsed=epoch_steps) + if callbacks: + if not all(c.on_step(timesteps_elapsed=epoch_steps) for c in callbacks): + logging.info( + f"Callback terminated training at {timesteps_elapsed} timesteps" + ) + break return self def train(self, trajectories: Sequence[Trajectory]) -> TrainEpochStats: diff --git a/rl_algo_impls/wrappers/action_mask_wrapper.py b/rl_algo_impls/wrappers/action_mask_wrapper.py index fda16444611a9cc6f348638bd193ce36bddb7db0..d2b73649660018a31dd22c97d09e6ae9f74d4735 100644 --- a/rl_algo_impls/wrappers/action_mask_wrapper.py +++ b/rl_algo_impls/wrappers/action_mask_wrapper.py @@ -1,6 +1,7 @@ from typing import Optional, Union import numpy as np +from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv from rl_algo_impls.wrappers.vectorable_wrapper import ( VecEnv, @@ -14,24 +15,19 @@ class IncompleteArrayError(Exception): class SingleActionMaskWrapper(VecotarableWrapper): - def action_masks(self) -> Optional[np.ndarray]: - envs = getattr(self.env.unwrapped, "envs") + def get_action_mask(self) -> Optional[np.ndarray]: + envs = getattr(self.env.unwrapped, "envs", None) # type: ignore assert ( envs ), f"{self.__class__.__name__} expects to wrap synchronous vectorized env" - masks = [getattr(e.unwrapped, "action_mask") for e in envs] + masks = [getattr(e.unwrapped, "action_mask", None) for e in envs] assert all(m is not None for m in masks) - return np.array(masks, dtype=np.bool8) + return np.array(masks, dtype=np.bool_) class MicrortsMaskWrapper(VecotarableWrapper): - def action_masks(self) -> np.ndarray: - microrts_env = self.env.unwrapped # type: ignore - vec_client = getattr(microrts_env, "vec_client") - assert ( - vec_client - ), f"{microrts_env.__class__.__name__} must have vec_client property (as MicroRTSVecEnv does)" - return np.array(vec_client.getMasks(0), dtype=np.bool8) + def get_action_mask(self) -> np.ndarray: + return self.env.get_action_mask().astype(bool) # type: ignore def find_action_masker( diff --git a/rl_algo_impls/wrappers/microrts_stats_recorder.py b/rl_algo_impls/wrappers/microrts_stats_recorder.py index 7e90a845eaf09420f05e82551e55267eac08792d..e680172ab7f6adde193a92808de2e83cf0c19797 100644 --- a/rl_algo_impls/wrappers/microrts_stats_recorder.py +++ b/rl_algo_impls/wrappers/microrts_stats_recorder.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import numpy as np @@ -10,10 +10,19 @@ from rl_algo_impls.wrappers.vectorable_wrapper import ( class MicrortsStatsRecorder(VecotarableWrapper): - def __init__(self, env, gamma: float) -> None: + def __init__( + self, env, gamma: float, bots: Optional[Dict[str, int]] = None + ) -> None: super().__init__(env) self.gamma = gamma self.raw_rewards = [[] for _ in range(self.num_envs)] + self.bots = bots + if self.bots: + self.bot_at_index = [None] * (env.num_envs - sum(self.bots.values())) + for b, n in self.bots.items(): + self.bot_at_index.extend([b] * n) + else: + self.bot_at_index = [None] * env.num_envs def reset(self) -> VecEnvObs: obs = super().reset() @@ -33,4 +42,19 @@ class MicrortsStatsRecorder(VecotarableWrapper): raw_rewards = np.array(self.raw_rewards[idx]).sum(0) raw_names = [str(rf) for rf in self.env.unwrapped.rfs] info["microrts_stats"] = dict(zip(raw_names, raw_rewards)) + + winloss = raw_rewards[raw_names.index("WinLossRewardFunction")] + microrts_results = { + "win": int(winloss == 1), + "draw": int(winloss == 0), + "loss": int(winloss == -1), + } + bot = self.bot_at_index[idx] + if bot: + microrts_results.update( + {f"{k}_{bot}": v for k, v in microrts_results.items()} + ) + + info["microrts_results"] = microrts_results + self.raw_rewards[idx] = [] diff --git a/rl_algo_impls/wrappers/normalize.py b/rl_algo_impls/wrappers/normalize.py index e48288f450b0ec284b405261ddbd22d8ff3bbe10..2186280aff07eac80b4109b920cbe31f90665a6a 100644 --- a/rl_algo_impls/wrappers/normalize.py +++ b/rl_algo_impls/wrappers/normalize.py @@ -1,14 +1,16 @@ +from typing import Tuple, TypeVar + import gym import numpy as np - from numpy.typing import NDArray -from typing import Tuple from rl_algo_impls.wrappers.vectorable_wrapper import ( VecotarableWrapper, single_observation_space, ) +RunningMeanStdSelf = TypeVar("RunningMeanStdSelf", bound="RunningMeanStd") + class RunningMeanStd: def __init__(self, episilon: float = 1e-4, shape: Tuple[int, ...] = ()) -> None: @@ -32,6 +34,30 @@ class RunningMeanStd: self.var = M2 / total_count self.count = total_count + def save(self, path: str) -> None: + np.savez_compressed( + path, + mean=self.mean, + var=self.var, + count=self.count, + ) + + def load(self, path: str) -> None: + data = np.load(path) + self.mean = data["mean"] + self.var = data["var"] + self.count = data["count"] + + def load_from(self: RunningMeanStdSelf, existing: RunningMeanStdSelf) -> None: + self.mean = np.copy(existing.mean) + self.var = np.copy(existing.var) + self.count = np.copy(existing.count) + + +NormalizeObservationSelf = TypeVar( + "NormalizeObservationSelf", bound="NormalizeObservation" +) + class NormalizeObservation(VecotarableWrapper): def __init__( @@ -67,18 +93,18 @@ class NormalizeObservation(VecotarableWrapper): return normalized[0] if not self.is_vector_env else normalized def save(self, path: str) -> None: - np.savez_compressed( - path, - mean=self.rms.mean, - var=self.rms.var, - count=self.rms.count, - ) + self.rms.save(path) def load(self, path: str) -> None: - data = np.load(path) - self.rms.mean = data["mean"] - self.rms.var = data["var"] - self.rms.count = data["count"] + self.rms.load(path) + + def load_from( + self: NormalizeObservationSelf, existing: NormalizeObservationSelf + ) -> None: + self.rms.load_from(existing.rms) + + +NormalizeRewardSelf = TypeVar("NormalizeRewardSelf", bound="NormalizeReward") class NormalizeReward(VecotarableWrapper): @@ -126,15 +152,10 @@ class NormalizeReward(VecotarableWrapper): ) def save(self, path: str) -> None: - np.savez_compressed( - path, - mean=self.rms.mean, - var=self.rms.var, - count=self.rms.count, - ) + self.rms.save(path) def load(self, path: str) -> None: - data = np.load(path) - self.rms.mean = data["mean"] - self.rms.var = data["var"] - self.rms.count = data["count"] + self.rms.load(path) + + def load_from(self: NormalizeRewardSelf, existing: NormalizeRewardSelf) -> None: + self.rms.load_from(existing.rms) diff --git a/rl_algo_impls/wrappers/self_play_wrapper.py b/rl_algo_impls/wrappers/self_play_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f43ea8fdb3793c5841609eefa5356cfb2c1372d4 --- /dev/null +++ b/rl_algo_impls/wrappers/self_play_wrapper.py @@ -0,0 +1,144 @@ +import copy +import random +from collections import deque +from typing import Any, Deque, Dict, List, Optional + +import numpy as np + +from rl_algo_impls.runner.config import Config +from rl_algo_impls.shared.policy.policy import Policy +from rl_algo_impls.wrappers.action_mask_wrapper import find_action_masker +from rl_algo_impls.wrappers.vectorable_wrapper import ( + VecEnvObs, + VecEnvStepReturn, + VecotarableWrapper, +) + + +class SelfPlayWrapper(VecotarableWrapper): + next_obs: VecEnvObs + next_action_masks: Optional[np.ndarray] + + def __init__( + self, + env, + config: Config, + num_old_policies: int = 0, + save_steps: int = 20_000, + swap_steps: int = 10_000, + window: int = 10, + swap_window_size: int = 2, + selfplay_bots: Optional[Dict[str, Any]] = None, + bot_always_player_2: bool = False, + ) -> None: + super().__init__(env) + assert num_old_policies % 2 == 0, f"num_old_policies must be even" + assert ( + num_old_policies % swap_window_size == 0 + ), f"num_old_policies must be a multiple of swap_window_size" + + self.config = config + self.num_old_policies = num_old_policies + self.save_steps = save_steps + self.swap_steps = swap_steps + self.swap_window_size = swap_window_size + self.selfplay_bots = selfplay_bots + self.bot_always_player_2 = bot_always_player_2 + + self.policies: Deque[Policy] = deque(maxlen=window) + self.policy_assignments: List[Optional[Policy]] = [None] * env.num_envs + self.steps_since_swap = np.zeros(env.num_envs) + + self.selfplay_policies: Dict[str, Policy] = {} + + self.num_envs = env.num_envs - num_old_policies + + if self.selfplay_bots: + self.num_envs -= sum(self.selfplay_bots.values()) + self.initialize_selfplay_bots() + + def get_action_mask(self) -> Optional[np.ndarray]: + return self.env.get_action_mask()[self.learner_indexes()] + + def learner_indexes(self) -> List[int]: + return [p is None for p in self.policy_assignments] + + def checkpoint_policy(self, copied_policy: Policy) -> None: + copied_policy.train(False) + self.policies.append(copied_policy) + + if all(p is None for p in self.policy_assignments[: 2 * self.num_old_policies]): + for i in range(self.num_old_policies): + # Switch between player 1 and 2 + self.policy_assignments[ + 2 * i + (i % 2 if not self.bot_always_player_2 else 1) + ] = copied_policy + + def swap_policy(self, idx: int, swap_window_size: int = 1) -> None: + policy = random.choice(self.policies) + idx = idx // 2 * 2 + for j in range(swap_window_size * 2): + if self.policy_assignments[idx + j]: + self.policy_assignments[idx + j] = policy + self.steps_since_swap[idx : idx + swap_window_size * 2] = np.zeros( + swap_window_size * 2 + ) + + def initialize_selfplay_bots(self) -> None: + if not self.selfplay_bots: + return + from rl_algo_impls.runner.running_utils import get_device, make_policy + + env = self.env # Type: ignore + device = get_device(self.config, env) + start_idx = 2 * self.num_old_policies + for model_path, n in self.selfplay_bots.items(): + policy = make_policy( + self.config.algo, + env, + device, + load_path=model_path, + **self.config.policy_hyperparams, + ).eval() + self.selfplay_policies["model_path"] = policy + for idx in range(start_idx, start_idx + 2 * n, 2): + bot_idx = ( + (idx + 1) if self.bot_always_player_2 else (idx + idx // 2 % 2) + ) + self.policy_assignments[bot_idx] = policy + start_idx += 2 * n + + def step(self, actions: np.ndarray) -> VecEnvStepReturn: + env = self.env # type: ignore + all_actions = np.zeros((env.num_envs,) + actions.shape[1:], dtype=actions.dtype) + orig_learner_indexes = self.learner_indexes() + + all_actions[orig_learner_indexes] = actions + for policy in set(p for p in self.policy_assignments if p): + policy_indexes = [policy == p for p in self.policy_assignments] + if any(policy_indexes): + all_actions[policy_indexes] = policy.act( + self.next_obs[policy_indexes], + deterministic=False, + action_masks=self.next_action_masks[policy_indexes] + if self.next_action_masks is not None + else None, + ) + self.next_obs, rew, done, info = env.step(all_actions) + self.next_action_masks = self.env.get_action_mask() + + rew = rew[orig_learner_indexes] + info = [i for i, b in zip(info, orig_learner_indexes) if b] + + self.steps_since_swap += 1 + for idx in range(0, self.num_old_policies * 2, 2 * self.swap_window_size): + if self.steps_since_swap[idx] > self.swap_steps: + self.swap_policy(idx, self.swap_window_size) + + new_learner_indexes = self.learner_indexes() + return self.next_obs[new_learner_indexes], rew, done[new_learner_indexes], info + + def reset(self) -> VecEnvObs: + self.next_obs = super().reset() + self.next_action_masks = self.env.get_action_mask() + return self.next_obs[self.learner_indexes()] diff --git a/rl_algo_impls/wrappers/vec_episode_recorder.py b/rl_algo_impls/wrappers/vec_episode_recorder.py index d86907ab71c3c9930a04fc9c1c6d51fbc083cf54..89c3d5713be22fe43f12cf1ac4043670b92fcb8a 100644 --- a/rl_algo_impls/wrappers/vec_episode_recorder.py +++ b/rl_algo_impls/wrappers/vec_episode_recorder.py @@ -1,21 +1,24 @@ import numpy as np - from gym.wrappers.monitoring.video_recorder import VideoRecorder from rl_algo_impls.wrappers.vectorable_wrapper import ( - VecotarableWrapper, VecEnvObs, VecEnvStepReturn, + VecotarableWrapper, ) class VecEpisodeRecorder(VecotarableWrapper): - def __init__(self, env, base_path: str, max_video_length: int = 3600): + def __init__( + self, env, base_path: str, max_video_length: int = 3600, num_episodes: int = 1 + ): super().__init__(env) self.base_path = base_path self.max_video_length = max_video_length + self.num_episodes = num_episodes self.video_recorder = None self.recorded_frames = 0 + self.num_completed = 0 def step(self, actions: np.ndarray) -> VecEnvStepReturn: obs, rew, dones, infos = self.env.step(actions) @@ -23,13 +26,21 @@ class VecEpisodeRecorder(VecotarableWrapper): if self.video_recorder: self.video_recorder.capture_frame() self.recorded_frames += 1 + if dones[0]: + self.num_completed += 1 if dones[0] and infos[0].get("episode"): episode_info = { k: v.item() if hasattr(v, "item") else v for k, v in infos[0]["episode"].items() } - self.video_recorder.metadata["episode"] = episode_info - if dones[0] or self.recorded_frames > self.max_video_length: + + if "episodes" not in self.video_recorder.metadata: + self.video_recorder.metadata["episodes"] = [] + self.video_recorder.metadata["episodes"].append(episode_info) + if ( + self.num_completed == self.num_episodes + or self.recorded_frames > self.max_video_length + ): self._close_video_recorder() return obs, rew, dones, infos diff --git a/selfplay_enjoy.py b/selfplay_enjoy.py new file mode 100644 index 0000000000000000000000000000000000000000..a681b272dd53e31b2fd2cb811f1480a423e3064f --- /dev/null +++ b/selfplay_enjoy.py @@ -0,0 +1,4 @@ +from rl_algo_impls.selfplay_enjoy import selfplay_enjoy + +if __name__ == "__main__": + selfplay_enjoy() diff --git a/tests/shared/policy/test_on_policy.py b/tests/shared/policy/test_actor_critic.py similarity index 89% rename from tests/shared/policy/test_on_policy.py rename to tests/shared/policy/test_actor_critic.py index b54300e1378b59b4ade73046cba5f97063e9561e..cdb4bdcae4054ab0daa2047d5f4831dd5a2c9b5a 100644 --- a/tests/shared/policy/test_on_policy.py +++ b/tests/shared/policy/test_actor_critic.py @@ -3,7 +3,7 @@ import pytest import gym.spaces import numpy as np -from rl_algo_impls.shared.policy.on_policy import clamp_actions +from rl_algo_impls.shared.policy.actor_critic import clamp_actions def test_clamp_actions():