sgoodfriend commited on
Commit
df508c1
1 Parent(s): 341188c

PPO playing Acrobot-v1 from https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c

Browse files
README.md CHANGED
@@ -42,7 +42,7 @@ By default training goes to a rl-algo-impls project while benchmarks go to
42
  rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
43
  models and the model weights are uploaded to WandB.
44
 
45
- Before doing any of the runs below, you'll need to create a wandb account and run `wandb
46
  login`.
47
 
48
 
@@ -50,7 +50,7 @@ login`.
50
  ## Usage
51
  /sgoodfriend/rl-algo-impls: https://github.com/sgoodfriend/rl-algo-impls
52
 
53
- Note: While the model state dictionary and hyperaparameters are saved, the
54
  implementation could be sufficiently different to not be able to reproduce similar
55
  results. You might need to checkout the commit the agent was trained on:
56
  [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c).
@@ -68,7 +68,8 @@ notebook.
68
 
69
  ## Training
70
  If you want the highest chance to reproduce these results, you'll want to checkout the
71
- commit the agent was trained on: [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c).
 
72
 
73
  ```
74
  python train.py --algo ppo --env Acrobot-v1 --seed 4
 
42
  rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
43
  models and the model weights are uploaded to WandB.
44
 
45
+ Before doing anything below, you'll need to create a wandb account and run `wandb
46
  login`.
47
 
48
 
 
50
  ## Usage
51
  /sgoodfriend/rl-algo-impls: https://github.com/sgoodfriend/rl-algo-impls
52
 
53
+ Note: While the model state dictionary and hyperaparameters are saved, the latest
54
  implementation could be sufficiently different to not be able to reproduce similar
55
  results. You might need to checkout the commit the agent was trained on:
56
  [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c).
 
68
 
69
  ## Training
70
  If you want the highest chance to reproduce these results, you'll want to checkout the
71
+ commit the agent was trained on: [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c). While
72
+ training is deterministic, different hardware will give different results.
73
 
74
  ```
75
  python train.py --algo ppo --env Acrobot-v1 --seed 4
benchmark_publish.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import subprocess
3
+ import wandb
4
+ import wandb.apis.public
5
+
6
+ from collections import defaultdict
7
+ from multiprocessing.pool import ThreadPool
8
+ from typing import List, NamedTuple
9
+
10
+
11
+ class RunGroup(NamedTuple):
12
+ algo: str
13
+ env_id: str
14
+
15
+
16
+ if __name__ == "__main__":
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument(
19
+ "--wandb-project-name",
20
+ type=str,
21
+ default="rl-algo-impls-benchmarks",
22
+ help="WandB project name to load runs from",
23
+ )
24
+ parser.add_argument(
25
+ "--wandb-entity",
26
+ type=str,
27
+ default=None,
28
+ help="WandB team of project. None uses default entity",
29
+ )
30
+ parser.add_argument("--wandb-tags", type=str, nargs="+", help="WandB tags")
31
+ parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
32
+ parser.add_argument(
33
+ "--envs", type=str, nargs="*", help="Optional filter down to these envs"
34
+ )
35
+ parser.add_argument(
36
+ "--huggingface-user",
37
+ type=str,
38
+ default=None,
39
+ help="Huggingface user or team to upload model cards. Defaults to huggingface-cli login user",
40
+ )
41
+ parser.add_argument(
42
+ "--pool-size",
43
+ type=int,
44
+ default=3,
45
+ help="How many publish jobs can run in parallel",
46
+ )
47
+ parser.set_defaults(
48
+ wandb_tags=["benchmark_5598ebc", "host_192-9-145-26"],
49
+ wandb_report_url="https://api.wandb.ai/links/sgoodfriend/6p2sjqtn",
50
+ envs=["CartPole-v1", "Acrobot-v1"],
51
+ )
52
+ args = parser.parse_args()
53
+ print(args)
54
+
55
+ api = wandb.Api()
56
+ all_runs = api.runs(
57
+ f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}"
58
+ )
59
+
60
+ required_tags = set(args.wandb_tags)
61
+ runs: List[wandb.apis.public.Run] = [
62
+ r
63
+ for r in all_runs
64
+ if required_tags.issubset(set(r.config.get("wandb_tags", [])))
65
+ ]
66
+
67
+ runs_paths_by_group = defaultdict(list)
68
+ for r in runs:
69
+ algo = r.config["algo"]
70
+ env = r.config["env"]
71
+ if args.envs and env not in args.envs:
72
+ continue
73
+ run_group = RunGroup(algo, env)
74
+ runs_paths_by_group[run_group].append("/".join(r.path))
75
+
76
+ def run(run_paths: List[str]) -> None:
77
+ publish_args = ["python", "huggingface_publish.py"]
78
+ publish_args.append("--wandb-run-paths")
79
+ publish_args.extend(run_paths)
80
+ publish_args.append("--wandb-report-url")
81
+ publish_args.append(args.wandb_report_url)
82
+ if args.huggingface_user:
83
+ publish_args.append("--huggingface-user")
84
+ publish_args.append(args.huggingface_user)
85
+ subprocess.run(publish_args)
86
+
87
+ tp = ThreadPool(args.pool_size)
88
+ for run_paths in runs_paths_by_group.values():
89
+ tp.apply_async(run, (run_paths,))
90
+ tp.close()
91
+ tp.join()
colab_requirements.txt CHANGED
@@ -4,4 +4,6 @@ gym[box2d] >= 0.21.0, < 0.22
4
  pyglet == 1.5.27
5
  wandb >= 0.13.9, < 0.14
6
  pyvirtualdisplay == 3.0
7
- pybullet >= 3.2.5, < 3.3
 
 
 
4
  pyglet == 1.5.27
5
  wandb >= 0.13.9, < 0.14
6
  pyvirtualdisplay == 3.0
7
+ pybullet >= 3.2.5, < 3.3
8
+ tabulate >= 0.9.0, < 0.10
9
+ huggingface-hub >= 0.12.0, < 0.13
enjoy.py CHANGED
@@ -3,103 +3,28 @@ import os
3
 
4
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
 
6
- import shutil
7
- import yaml
8
-
9
- from dataclasses import dataclass
10
- from typing import Optional
11
-
12
- from runner.env import make_eval_env
13
- from runner.config import Config, RunArgs
14
- from runner.running_utils import (
15
- base_parser,
16
- load_hyperparams,
17
- set_seeds,
18
- get_device,
19
- make_policy,
20
- )
21
- from shared.callbacks.eval_callback import evaluate
22
-
23
-
24
- @dataclass
25
- class EvalArgs(RunArgs):
26
- render: bool = True
27
- best: bool = True
28
- n_envs: int = 1
29
- n_episodes: int = 3
30
- deterministic: Optional[bool] = None
31
- wandb_run_path: Optional[str] = None
32
 
33
 
34
  if __name__ == "__main__":
35
- parser = base_parser()
36
  parser.add_argument("--render", default=True, type=bool)
37
  parser.add_argument("--best", default=True, type=bool)
38
  parser.add_argument("--n_envs", default=1, type=int)
39
  parser.add_argument("--n_episodes", default=3, type=int)
40
- parser.add_argument("--deterministic", default=None, type=bool)
 
 
 
 
41
  parser.add_argument("--wandb-run-path", default=None, type=str)
42
  parser.set_defaults(
43
- wandb_run_path="sgoodfriend/rl-algo-impls/sfi78a3t",
44
  )
 
 
 
45
  args = EvalArgs(**vars(parser.parse_args()))
46
 
47
- if args.wandb_run_path:
48
- import wandb
49
-
50
- api = wandb.Api()
51
- run = api.run(args.wandb_run_path)
52
- hyperparams = run.config
53
-
54
- args.algo = hyperparams["algo"]
55
- args.env = hyperparams["env"]
56
- args.use_deterministic_algorithms = hyperparams.get(
57
- "use_deterministic_algorithms", True
58
- )
59
-
60
- config = Config(args, hyperparams, os.path.dirname(__file__))
61
- model_path = config.model_dir_path(best=args.best, downloaded=True)
62
-
63
- model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
64
- run.file(model_archive_name).download()
65
- if os.path.isdir(model_path):
66
- shutil.rmtree(model_path)
67
- shutil.unpack_archive(model_archive_name, model_path)
68
- os.remove(model_archive_name)
69
- else:
70
- hyperparams = load_hyperparams(args.algo, args.env, os.path.dirname(__file__))
71
-
72
- config = Config(args, hyperparams, os.path.dirname(__file__))
73
- model_path = config.model_dir_path(best=args.best)
74
-
75
- print(args)
76
-
77
- set_seeds(args.seed, args.use_deterministic_algorithms)
78
-
79
- env = make_eval_env(
80
- config,
81
- override_n_envs=args.n_envs,
82
- render=args.render,
83
- normalize_load_path=model_path,
84
- **config.env_hyperparams,
85
- )
86
- device = get_device(config.device, env)
87
- policy = make_policy(
88
- args.algo,
89
- env,
90
- device,
91
- load_path=model_path,
92
- **config.policy_hyperparams,
93
- ).eval()
94
-
95
- if args.deterministic is None:
96
- deterministic = config.eval_params.get("deterministic", True)
97
- else:
98
- deterministic = args.deterministic
99
- evaluate(
100
- env,
101
- policy,
102
- args.n_episodes,
103
- render=args.render,
104
- deterministic=deterministic,
105
- )
 
3
 
4
  os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
5
 
6
+ from runner.evaluate import EvalArgs, evaluate_model
7
+ from runner.running_utils import base_parser
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
  if __name__ == "__main__":
11
+ parser = base_parser(multiple=False)
12
  parser.add_argument("--render", default=True, type=bool)
13
  parser.add_argument("--best", default=True, type=bool)
14
  parser.add_argument("--n_envs", default=1, type=int)
15
  parser.add_argument("--n_episodes", default=3, type=int)
16
+ parser.add_argument("--deterministic-eval", default=None, type=bool)
17
+ parser.add_argument(
18
+ "--no-print-returns", action="store_true", help="Limit printing"
19
+ )
20
+ # wandb-run-path overrides base RunArgs
21
  parser.add_argument("--wandb-run-path", default=None, type=str)
22
  parser.set_defaults(
23
+ algo=["ppo"],
24
  )
25
+ args = parser.parse_args()
26
+ args.algo = args.algo[0]
27
+ args.env = args.env[0]
28
  args = EvalArgs(**vars(parser.parse_args()))
29
 
30
+ evaluate_model(args, os.path.dirname(__file__))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
huggingface_publish.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
4
+
5
+ import argparse
6
+ import requests
7
+ import shutil
8
+ import subprocess
9
+ import tempfile
10
+ import wandb
11
+ import wandb.apis.public
12
+
13
+ from typing import List, Optional
14
+
15
+ from huggingface_hub.hf_api import HfApi, upload_folder
16
+ from huggingface_hub.repocard import metadata_save
17
+ from publish.markdown_format import EvalTableData, model_card_text
18
+ from runner.evaluate import EvalArgs, evaluate_model
19
+ from runner.env import make_eval_env
20
+ from shared.callbacks.eval_callback import evaluate
21
+ from wrappers.vec_episode_recorder import VecEpisodeRecorder
22
+
23
+
24
+ def publish(
25
+ wandb_run_paths: List[str],
26
+ wandb_report_url: str,
27
+ huggingface_user: Optional[str] = None,
28
+ huggingface_token: Optional[str] = None,
29
+ ) -> None:
30
+ api = wandb.Api()
31
+ runs = [api.run(rp) for rp in wandb_run_paths]
32
+ algo = runs[0].config["algo"]
33
+ env = runs[0].config["env"]
34
+ evaluations = [
35
+ evaluate_model(
36
+ EvalArgs(
37
+ algo,
38
+ env,
39
+ seed=r.config.get("seed", None),
40
+ render=False,
41
+ best=True,
42
+ n_envs=None,
43
+ n_episodes=10,
44
+ no_print_returns=True,
45
+ wandb_run_path="/".join(r.path),
46
+ ),
47
+ os.path.dirname(__file__),
48
+ )
49
+ for r in runs
50
+ ]
51
+ run_metadata = requests.get(runs[0].file("wandb-metadata.json").url).json()
52
+ table_data = list(EvalTableData(r, e) for r, e in zip(runs, evaluations))
53
+ best_eval = sorted(
54
+ table_data, key=lambda d: d.evaluation.stats.score, reverse=True
55
+ )[0]
56
+
57
+ with tempfile.TemporaryDirectory() as tmpdirname:
58
+ _, (policy, stats, config) = best_eval
59
+
60
+ repo_name = config.model_name(include_seed=False)
61
+ repo_dir_path = os.path.join(tmpdirname, repo_name)
62
+ # Locally clone this repo to a temp directory
63
+ subprocess.run(["git", "clone", ".", repo_dir_path])
64
+ shutil.rmtree(os.path.join(repo_dir_path, ".git"))
65
+ model_path = config.model_dir_path(best=True, downloaded=True)
66
+ shutil.copytree(
67
+ model_path,
68
+ os.path.join(
69
+ repo_dir_path, "saved_models", config.model_dir_name(best=True)
70
+ ),
71
+ )
72
+
73
+ github_url = "https://github.com/sgoodfriend/rl-algo-impls"
74
+ commit_hash = run_metadata.get("git", {}).get("commit", None)
75
+ card_text = model_card_text(
76
+ algo,
77
+ env,
78
+ github_url,
79
+ commit_hash,
80
+ wandb_report_url,
81
+ table_data,
82
+ best_eval,
83
+ )
84
+ readme_filepath = os.path.join(repo_dir_path, "README.md")
85
+ os.remove(readme_filepath)
86
+ with open(readme_filepath, "w") as f:
87
+ f.write(card_text)
88
+
89
+ metadata = {
90
+ "library_name": "rl-algo-impls",
91
+ "tags": [
92
+ env,
93
+ algo,
94
+ "deep-reinforcement-learning",
95
+ "reinforcement-learning",
96
+ ],
97
+ "model-index": [
98
+ {
99
+ "name": algo,
100
+ "results": [
101
+ {
102
+ "metrics": [
103
+ {
104
+ "type": "mean_reward",
105
+ "value": str(stats.score),
106
+ "name": "mean_reward",
107
+ }
108
+ ],
109
+ "task": {
110
+ "type": "reinforcement-learning",
111
+ "name": "reinforcement-learning",
112
+ },
113
+ "dataset": {
114
+ "name": env,
115
+ "type": env,
116
+ },
117
+ }
118
+ ],
119
+ }
120
+ ],
121
+ }
122
+ metadata_save(readme_filepath, metadata)
123
+
124
+ video_env = VecEpisodeRecorder(
125
+ make_eval_env(
126
+ config,
127
+ override_n_envs=1,
128
+ normalize_load_path=model_path,
129
+ **config.env_hyperparams,
130
+ ),
131
+ os.path.join(repo_dir_path, "replay"),
132
+ max_video_length=3600,
133
+ )
134
+ evaluate(
135
+ video_env,
136
+ policy,
137
+ 1,
138
+ deterministic=config.eval_params.get("deterministic", True),
139
+ )
140
+
141
+ api = HfApi()
142
+ huggingface_user = huggingface_user or api.whoami()["name"]
143
+ huggingface_repo = f"{huggingface_user}/{repo_name}"
144
+ api.create_repo(
145
+ token=huggingface_token,
146
+ repo_id=huggingface_repo,
147
+ private=True,
148
+ exist_ok=True,
149
+ )
150
+ repo_url = upload_folder(
151
+ repo_id=huggingface_repo,
152
+ folder_path=repo_dir_path,
153
+ path_in_repo="",
154
+ commit_message=f"{algo.upper()} playing {env} from {github_url}/tree/{commit_hash}",
155
+ token=huggingface_token,
156
+ )
157
+ print(f"Pushed model to the hub: {repo_url}")
158
+
159
+
160
+ if __name__ == "__main__":
161
+ parser = argparse.ArgumentParser()
162
+ parser.add_argument(
163
+ "--wandb-run-paths",
164
+ type=str,
165
+ nargs="+",
166
+ help="Run paths of the form entity/project/run_id",
167
+ )
168
+ parser.add_argument("--wandb-report-url", type=str, help="Link to WandB report")
169
+ parser.add_argument(
170
+ "--huggingface-user",
171
+ type=str,
172
+ help="Huggingface user or team to upload model cards",
173
+ default=None,
174
+ )
175
+ args = parser.parse_args()
176
+ print(args)
177
+ publish(**vars(args))
lambda_labs/lambda_requirements.txt CHANGED
@@ -6,4 +6,6 @@ gym[box2d] >= 0.21.0, < 0.22
6
  pyglet == 1.5.27
7
  wandb >= 0.13.9, < 0.14
8
  pyvirtualdisplay == 3.0
9
- pybullet >= 3.2.5, < 3.3
 
 
 
6
  pyglet == 1.5.27
7
  wandb >= 0.13.9, < 0.14
8
  pyvirtualdisplay == 3.0
9
+ pybullet >= 3.2.5, < 3.3
10
+ tabulate >= 0.9.0, < 0.10
11
+ huggingface-hub >= 0.12.0, < 0.13
poetry.lock CHANGED
@@ -1217,6 +1217,37 @@ chardet = ["chardet (>=2.2)"]
1217
  genshi = ["genshi"]
1218
  lxml = ["lxml"]
1219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1220
  [[package]]
1221
  name = "idna"
1222
  version = "3.4"
@@ -3687,6 +3718,21 @@ pure-eval = "*"
3687
  [package.extras]
3688
  tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
3689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3690
  [[package]]
3691
  name = "tensorboard"
3692
  version = "2.11.0"
@@ -4152,4 +4198,4 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools"
4152
  [metadata]
4153
  lock-version = "2.0"
4154
  python-versions = "~3.10"
4155
- content-hash = "c017f434016a4a1e42e01a10957b1de2f2596b1720d79d992a3d794ad8760ae3"
 
1217
  genshi = ["genshi"]
1218
  lxml = ["lxml"]
1219
 
1220
+ [[package]]
1221
+ name = "huggingface-hub"
1222
+ version = "0.12.0"
1223
+ description = "Client library to download and publish models, datasets and other repos on the huggingface.co hub"
1224
+ category = "main"
1225
+ optional = false
1226
+ python-versions = ">=3.7.0"
1227
+ files = [
1228
+ {file = "huggingface_hub-0.12.0-py3-none-any.whl", hash = "sha256:93809eabbfb2058a808bddf8b2a70f645de3f9df73ce87ddf5163d4c74b71c0c"},
1229
+ {file = "huggingface_hub-0.12.0.tar.gz", hash = "sha256:da82c9ec8f9d8f976ffd3fd8249d20bb35c2dd3145a9f7ca1106f0ebefd9afa0"},
1230
+ ]
1231
+
1232
+ [package.dependencies]
1233
+ filelock = "*"
1234
+ packaging = ">=20.9"
1235
+ pyyaml = ">=5.1"
1236
+ requests = "*"
1237
+ tqdm = ">=4.42.1"
1238
+ typing-extensions = ">=3.7.4.3"
1239
+
1240
+ [package.extras]
1241
+ all = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "black (==22.3)", "flake8 (>=3.8.3)", "flake8-bugbear", "isort (>=5.5.4)", "jedi", "mypy (==0.982)", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
1242
+ cli = ["InquirerPy (==0.3.4)"]
1243
+ dev = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "black (==22.3)", "flake8 (>=3.8.3)", "flake8-bugbear", "isort (>=5.5.4)", "jedi", "mypy (==0.982)", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile", "types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
1244
+ fastai = ["fastai (>=2.4)", "fastcore (>=1.3.27)", "toml"]
1245
+ quality = ["black (==22.3)", "flake8 (>=3.8.3)", "flake8-bugbear", "isort (>=5.5.4)", "mypy (==0.982)"]
1246
+ tensorflow = ["graphviz", "pydot", "tensorflow"]
1247
+ testing = ["InquirerPy (==0.3.4)", "Jinja2", "Pillow", "isort (>=5.5.4)", "jedi", "pytest", "pytest-cov", "pytest-env", "pytest-xdist", "soundfile"]
1248
+ torch = ["torch"]
1249
+ typing = ["types-PyYAML", "types-requests", "types-simplejson", "types-toml", "types-tqdm", "types-urllib3"]
1250
+
1251
  [[package]]
1252
  name = "idna"
1253
  version = "3.4"
 
3718
  [package.extras]
3719
  tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"]
3720
 
3721
+ [[package]]
3722
+ name = "tabulate"
3723
+ version = "0.9.0"
3724
+ description = "Pretty-print tabular data"
3725
+ category = "main"
3726
+ optional = false
3727
+ python-versions = ">=3.7"
3728
+ files = [
3729
+ {file = "tabulate-0.9.0-py3-none-any.whl", hash = "sha256:024ca478df22e9340661486f85298cff5f6dcdba14f3813e8830015b9ed1948f"},
3730
+ {file = "tabulate-0.9.0.tar.gz", hash = "sha256:0095b12bf5966de529c0feb1fa08671671b3368eec77d7ef7ab114be2c068b3c"},
3731
+ ]
3732
+
3733
+ [package.extras]
3734
+ widechars = ["wcwidth"]
3735
+
3736
  [[package]]
3737
  name = "tensorboard"
3738
  version = "2.11.0"
 
4198
  [metadata]
4199
  lock-version = "2.0"
4200
  python-versions = "~3.10"
4201
+ content-hash = "89d4861857be881d3c6cb591d17fb98396b8c117b24a8d4ce4b6593ac8048670"
publish/markdown_format.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import wandb.apis.public
4
+ import yaml
5
+
6
+ from collections import defaultdict
7
+ from dataclasses import dataclass, asdict
8
+ from typing import Any, Dict, Iterable, List, NamedTuple, Optional, TypeVar
9
+ from urllib.parse import urlparse
10
+
11
+ from runner.evaluate import Evaluation
12
+
13
+ EvaluationRowSelf = TypeVar("EvaluationRowSelf", bound="EvaluationRow")
14
+
15
+
16
+ @dataclass
17
+ class EvaluationRow:
18
+ algo: str
19
+ env: str
20
+ seed: Optional[int]
21
+ reward_mean: float
22
+ reward_std: float
23
+ eval_episodes: int
24
+ best: str
25
+ wandb_url: str
26
+
27
+ @staticmethod
28
+ def data_frame(rows: List[EvaluationRowSelf]) -> pd.DataFrame:
29
+ results = defaultdict(list)
30
+ for r in rows:
31
+ for k, v in asdict(r).items():
32
+ results[k].append(v)
33
+ return pd.DataFrame(results)
34
+
35
+
36
+ class EvalTableData(NamedTuple):
37
+ run: wandb.apis.public.Run
38
+ evaluation: Evaluation
39
+
40
+
41
+ def evaluation_table(table_data: Iterable[EvalTableData]) -> str:
42
+ best_stats = sorted(
43
+ [d.evaluation.stats for d in table_data], key=lambda r: r.score, reverse=True
44
+ )[0]
45
+ table_data = sorted(table_data, key=lambda d: d.evaluation.config.seed() or 0)
46
+ rows = [
47
+ EvaluationRow(
48
+ config.algo,
49
+ config.env_id,
50
+ config.seed(),
51
+ stats.score.mean,
52
+ stats.score.std,
53
+ len(stats),
54
+ "*" if stats == best_stats else "",
55
+ f"[wandb]({r.url})",
56
+ )
57
+ for (r, (_, stats, config)) in table_data
58
+ ]
59
+ df = EvaluationRow.data_frame(rows)
60
+ return df.to_markdown(index=False)
61
+
62
+
63
+ def github_project_link(github_url: str) -> str:
64
+ return f"[{urlparse(github_url).path}]({github_url})"
65
+
66
+
67
+ def header_section(algo: str, env: str, github_url: str, wandb_report_url: str) -> str:
68
+ algo_caps = algo.upper()
69
+ lines = [
70
+ f"# **{algo_caps}** Agent playing **{env}**",
71
+ f"This is a trained model of a **{algo_caps}** agent playing **{env}** using "
72
+ f"the {github_project_link(github_url)} repo.",
73
+ f"All models trained at this commit can be found at {wandb_report_url}.",
74
+ ]
75
+ return "\n\n".join(lines)
76
+
77
+
78
+ def github_tree_link(github_url: str, commit_hash: Optional[str]) -> str:
79
+ if not commit_hash:
80
+ return github_project_link(github_url)
81
+ return f"[{commit_hash[:7]}]({github_url}/tree/{commit_hash})"
82
+
83
+
84
+ def results_section(
85
+ table_data: List[EvalTableData], algo: str, github_url: str, commit_hash: str
86
+ ) -> str:
87
+ # type: ignore
88
+ lines = [
89
+ "## Training Results",
90
+ f"This model was trained from {len(table_data)} trainings of **{algo.upper()}** "
91
+ + "agents using different initial seeds. "
92
+ + f"These agents were trained by checking out "
93
+ + f"{github_tree_link(github_url, commit_hash)}. "
94
+ + "The best and last models were kept from each training. "
95
+ + "This submission has loaded the best models from each training, reevaluates "
96
+ + "them, and selects the best model from these latest evaluations (mean - std).",
97
+ ]
98
+ lines.append(evaluation_table(table_data))
99
+ return "\n\n".join(lines)
100
+
101
+
102
+ def prerequisites_section() -> str:
103
+ return """
104
+ ### Prerequisites: Weights & Biases (WandB)
105
+ Training and benchmarking assumes you have a Weights & Biases project to upload runs to.
106
+ By default training goes to a rl-algo-impls project while benchmarks go to
107
+ rl-algo-impls-benchmarks. During training and benchmarking runs, videos of the best
108
+ models and the model weights are uploaded to WandB.
109
+
110
+ Before doing anything below, you'll need to create a wandb account and run `wandb
111
+ login`.
112
+ """
113
+
114
+
115
+ def usage_section(github_url: str, run_path: str, commit_hash: str) -> str:
116
+ return f"""
117
+ ## Usage
118
+ {urlparse(github_url).path}: {github_url}
119
+
120
+ Note: While the model state dictionary and hyperaparameters are saved, the latest
121
+ implementation could be sufficiently different to not be able to reproduce similar
122
+ results. You might need to checkout the commit the agent was trained on:
123
+ {github_tree_link(github_url, commit_hash)}.
124
+ ```
125
+ # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
126
+ python enjoy.py --wandb-run-path={run_path}
127
+ ```
128
+
129
+ Setup hasn't been completely worked out yet, so you might be best served by using Google
130
+ Colab starting from the
131
+ [colab_enjoy.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_enjoy.ipynb)
132
+ notebook.
133
+ """
134
+
135
+
136
+ def training_setion(
137
+ github_url: str, commit_hash: str, algo: str, env: str, seed: Optional[int]
138
+ ) -> str:
139
+ return f"""
140
+ ## Training
141
+ If you want the highest chance to reproduce these results, you'll want to checkout the
142
+ commit the agent was trained on: {github_tree_link(github_url, commit_hash)}. While
143
+ training is deterministic, different hardware will give different results.
144
+
145
+ ```
146
+ python train.py --algo {algo} --env {env} {'--seed ' + str(seed) if seed is not None else ''}
147
+ ```
148
+
149
+ Setup hasn't been completely worked out yet, so you might be best served by using Google
150
+ Colab starting from the
151
+ [colab_train.ipynb](https://github.com/sgoodfriend/rl-algo-impls/blob/main/colab_train.ipynb)
152
+ notebook.
153
+ """
154
+
155
+
156
+ def benchmarking_section(report_url: str) -> str:
157
+ return f"""
158
+ ## Benchmarking (with Lambda Labs instance)
159
+ This and other models from {report_url} were generated by running a script on a Lambda
160
+ Labs instance. In a Lambda Labs instance terminal:
161
+ ```
162
+ git clone git@github.com:sgoodfriend/rl-algo-impls.git
163
+ cd rl-algo-impls
164
+ bash ./lambda_labs/setup.sh
165
+ wandb login
166
+ bash ./lambda_labs/benchmark.sh
167
+ ```
168
+
169
+ ### Alternative: Google Colab Pro+
170
+ As an alternative,
171
+ [colab_benchmark.ipynb](https://github.com/sgoodfriend/rl-algo-impls/tree/main/benchmarks#:~:text=colab_benchmark.ipynb),
172
+ can be used. However, this requires a Google Colab Pro+ subscription and running across
173
+ 4 separate instances because otherwise running all jobs will exceed the 24-hour limit.
174
+ """
175
+
176
+
177
+ def hyperparams_section(run_config: Dict[str, Any]) -> str:
178
+ return f"""
179
+ ## Hyperparameters
180
+ This isn't exactly the format of hyperparams in {os.path.join("hyperparams",
181
+ run_config["algo"] + ".yml")}, but instead the Wandb Run Config. However, it's very
182
+ close and has some additional data:
183
+ ```
184
+ {yaml.dump(run_config)}
185
+ ```
186
+ """
187
+
188
+
189
+ def model_card_text(
190
+ algo: str,
191
+ env: str,
192
+ github_url: str,
193
+ commit_hash: str,
194
+ wandb_report_url: str,
195
+ table_data: List[EvalTableData],
196
+ best_eval: EvalTableData,
197
+ ) -> str:
198
+ run, (_, _, config) = best_eval
199
+ run_path = "/".join(run.path)
200
+ return "\n\n".join(
201
+ [
202
+ header_section(algo, env, github_url, wandb_report_url),
203
+ results_section(table_data, algo, github_url, commit_hash),
204
+ prerequisites_section(),
205
+ usage_section(github_url, run_path, commit_hash),
206
+ training_setion(github_url, commit_hash, algo, env, config.seed()),
207
+ benchmarking_section(wandb_report_url),
208
+ hyperparams_section(run.config),
209
+ ]
210
+ )
pyproject.toml CHANGED
@@ -21,6 +21,8 @@ wandb = "^0.13.9"
21
  conda-lock = "^1.3.0"
22
  torch-tb-profiler = "^0.4.1"
23
  jupyter = "^1.0.0"
 
 
24
 
25
  [build-system]
26
  requires = ["poetry-core"]
 
21
  conda-lock = "^1.3.0"
22
  torch-tb-profiler = "^0.4.1"
23
  jupyter = "^1.0.0"
24
+ tabulate = "^0.9.0"
25
+ huggingface-hub = "^0.12.0"
26
 
27
  [build-system]
28
  requires = ["poetry-core"]
replay.meta.json CHANGED
@@ -1 +1 @@
1
- {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 5.1.2 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with clang version 14.0.6\\nconfiguration: --prefix=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl --cc=arm64-apple-darwin20.0.0-clang --cxx=arm64-apple-darwin20.0.0-clang++ --nm=arm64-apple-darwin20.0.0-nm --ar=arm64-apple-darwin20.0.0-ar --disable-doc --disable-openssl --enable-demuxer=dash --enable-hardcoded-tables --enable-libfreetype --enable-libfontconfig --enable-libopenh264 --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/x86_64-apple-darwin13.4.0-clang --enable-neon --enable-gnutls --enable-libmp3lame --enable-libvpx --enable-pthreads --enable-gpl --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-pic --enable-shared --disable-static --enable-version3 --enable-zlib --pkg-config=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/pkg-config\\nlibavutil 57. 28.100 / 57. 28.100\\nlibavcodec 59. 37.100 / 59. 37.100\\nlibavformat 59. 27.100 / 59. 27.100\\nlibavdevice 59. 7.100 / 59. 7.100\\nlibavfilter 8. 44.100 / 8. 44.100\\nlibswscale 6. 7.100 / 6. 7.100\\nlibswresample 4. 7.100 / 4. 7.100\\nlibpostproc 56. 6.100 / 56. 6.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "500x500", "-pix_fmt", "rgb24", "-framerate", "30", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "30", "/var/folders/9g/my5557_91xddp6lx00nkzly80000gn/T/tmp7t2v9jcd/ppo-Acrobot-v1/replay.mp4"]}, "episode": {"r": -73.0, "l": 74, "t": 1.272925}}
 
1
+ {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 5.1.2 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with clang version 14.0.6\\nconfiguration: --prefix=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl --cc=arm64-apple-darwin20.0.0-clang --cxx=arm64-apple-darwin20.0.0-clang++ --nm=arm64-apple-darwin20.0.0-nm --ar=arm64-apple-darwin20.0.0-ar --disable-doc --disable-openssl --enable-demuxer=dash --enable-hardcoded-tables --enable-libfreetype --enable-libfontconfig --enable-libopenh264 --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/x86_64-apple-darwin13.4.0-clang --enable-neon --enable-gnutls --enable-libmp3lame --enable-libvpx --enable-pthreads --enable-gpl --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-pic --enable-shared --disable-static --enable-version3 --enable-zlib --pkg-config=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/pkg-config\\nlibavutil 57. 28.100 / 57. 28.100\\nlibavcodec 59. 37.100 / 59. 37.100\\nlibavformat 59. 27.100 / 59. 27.100\\nlibavdevice 59. 7.100 / 59. 7.100\\nlibavfilter 8. 44.100 / 8. 44.100\\nlibswscale 6. 7.100 / 6. 7.100\\nlibswresample 4. 7.100 / 4. 7.100\\nlibpostproc 56. 6.100 / 56. 6.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "500x500", "-pix_fmt", "rgb24", "-framerate", "30", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "30", "/var/folders/9g/my5557_91xddp6lx00nkzly80000gn/T/tmpz2flad47/ppo-Acrobot-v1/replay.mp4"]}, "episode": {"r": -73.0, "l": 74, "t": 1.297341}}
runner/config.py CHANGED
@@ -59,14 +59,17 @@ class Config:
59
  def eval_params(self) -> Dict[str, Any]:
60
  return self.hyperparams.get("eval_params", {})
61
 
 
 
 
 
62
  @property
63
  def env_id(self) -> str:
64
  return self.args.env
65
 
66
- @property
67
- def model_name(self) -> str:
68
- parts = [self.args.algo, self.env_id]
69
- if self.args.seed is not None:
70
  parts.append(f"S{self.args.seed}")
71
  make_kwargs = self.env_hyperparams.get("make_kwargs", {})
72
  if make_kwargs:
@@ -81,7 +84,7 @@ class Config:
81
 
82
  @property
83
  def run_name(self) -> str:
84
- parts = [self.model_name, self.run_id]
85
  return "-".join(parts)
86
 
87
  @property
@@ -97,7 +100,7 @@ class Config:
97
  best: bool = False,
98
  extension: str = "",
99
  ) -> str:
100
- return self.model_name + ("-best" if best else "") + extension
101
 
102
  def model_dir_path(self, best: bool = False, downloaded: bool = False) -> str:
103
  return os.path.join(
@@ -123,8 +126,8 @@ class Config:
123
 
124
  @property
125
  def video_prefix(self) -> str:
126
- return os.path.join(self.videos_dir, self.model_name)
127
 
128
  @property
129
  def best_videos_dir(self) -> str:
130
- return os.path.join(self.videos_dir, f"{self.model_name}-best")
 
59
  def eval_params(self) -> Dict[str, Any]:
60
  return self.hyperparams.get("eval_params", {})
61
 
62
+ @property
63
+ def algo(self) -> str:
64
+ return self.args.algo
65
+
66
  @property
67
  def env_id(self) -> str:
68
  return self.args.env
69
 
70
+ def model_name(self, include_seed: bool = True) -> str:
71
+ parts = [self.algo, self.env_id]
72
+ if include_seed and self.args.seed is not None:
 
73
  parts.append(f"S{self.args.seed}")
74
  make_kwargs = self.env_hyperparams.get("make_kwargs", {})
75
  if make_kwargs:
 
84
 
85
  @property
86
  def run_name(self) -> str:
87
+ parts = [self.model_name(), self.run_id]
88
  return "-".join(parts)
89
 
90
  @property
 
100
  best: bool = False,
101
  extension: str = "",
102
  ) -> str:
103
+ return self.model_name() + ("-best" if best else "") + extension
104
 
105
  def model_dir_path(self, best: bool = False, downloaded: bool = False) -> str:
106
  return os.path.join(
 
126
 
127
  @property
128
  def video_prefix(self) -> str:
129
+ return os.path.join(self.videos_dir, self.model_name())
130
 
131
  @property
132
  def best_videos_dir(self) -> str:
133
+ return os.path.join(self.videos_dir, f"{self.model_name()}-best")
runner/evaluate.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+ from dataclasses import dataclass
5
+ from typing import NamedTuple, Optional
6
+
7
+ from runner.env import make_eval_env
8
+ from runner.config import Config, RunArgs
9
+ from runner.running_utils import (
10
+ load_hyperparams,
11
+ set_seeds,
12
+ get_device,
13
+ make_policy,
14
+ )
15
+ from shared.callbacks.eval_callback import evaluate
16
+ from shared.policy.policy import Policy
17
+ from shared.stats import EpisodesStats
18
+
19
+
20
+ @dataclass
21
+ class EvalArgs(RunArgs):
22
+ render: bool = True
23
+ best: bool = True
24
+ n_envs: Optional[int] = 1
25
+ n_episodes: int = 3
26
+ deterministic_eval: Optional[bool] = None
27
+ no_print_returns: bool = False
28
+ wandb_run_path: Optional[str] = None
29
+
30
+
31
+ class Evaluation(NamedTuple):
32
+ policy: Policy
33
+ stats: EpisodesStats
34
+ config: Config
35
+
36
+
37
+ def evaluate_model(args: EvalArgs, root_dir: str) -> Evaluation:
38
+ if args.wandb_run_path:
39
+ import wandb
40
+
41
+ api = wandb.Api()
42
+ run = api.run(args.wandb_run_path)
43
+ hyperparams = run.config
44
+
45
+ args.algo = hyperparams["algo"]
46
+ args.env = hyperparams["env"]
47
+ args.seed = hyperparams.get("seed", None)
48
+ args.use_deterministic_algorithms = hyperparams.get(
49
+ "use_deterministic_algorithms", True
50
+ )
51
+
52
+ config = Config(args, hyperparams, root_dir)
53
+ model_path = config.model_dir_path(best=args.best, downloaded=True)
54
+
55
+ model_archive_name = config.model_dir_name(best=args.best, extension=".zip")
56
+ run.file(model_archive_name).download()
57
+ if os.path.isdir(model_path):
58
+ shutil.rmtree(model_path)
59
+ shutil.unpack_archive(model_archive_name, model_path)
60
+ os.remove(model_archive_name)
61
+ else:
62
+ hyperparams = load_hyperparams(args.algo, args.env, root_dir)
63
+
64
+ config = Config(args, hyperparams, root_dir)
65
+ model_path = config.model_dir_path(best=args.best)
66
+
67
+ print(args)
68
+
69
+ set_seeds(args.seed, args.use_deterministic_algorithms)
70
+
71
+ env = make_eval_env(
72
+ config,
73
+ override_n_envs=args.n_envs,
74
+ render=args.render,
75
+ normalize_load_path=model_path,
76
+ **config.env_hyperparams,
77
+ )
78
+ device = get_device(config.device, env)
79
+ policy = make_policy(
80
+ args.algo,
81
+ env,
82
+ device,
83
+ load_path=model_path,
84
+ **config.policy_hyperparams,
85
+ ).eval()
86
+
87
+ deterministic = (
88
+ args.deterministic_eval
89
+ if args.deterministic_eval is not None
90
+ else config.eval_params.get("deterministic", True)
91
+ )
92
+ return Evaluation(
93
+ policy,
94
+ evaluate(
95
+ env,
96
+ policy,
97
+ args.n_episodes,
98
+ render=args.render,
99
+ deterministic=deterministic,
100
+ print_returns=not args.no_print_returns,
101
+ ),
102
+ config,
103
+ )
runner/running_utils.py CHANGED
@@ -40,28 +40,28 @@ POLICIES: Dict[str, Type[Policy]] = {
40
  HYPERPARAMS_PATH = "hyperparams"
41
 
42
 
43
- def base_parser() -> argparse.ArgumentParser:
44
  parser = argparse.ArgumentParser()
45
  parser.add_argument(
46
  "--algo",
47
  default="dqn",
48
  type=str,
49
  choices=list(ALGOS.keys()),
50
- nargs="+",
51
  help="Abbreviation(s) of algorithm(s)",
52
  )
53
  parser.add_argument(
54
  "--env",
55
  default="CartPole-v1",
56
  type=str,
57
- nargs="+",
58
  help="Name of environment(s) in gym",
59
  )
60
  parser.add_argument(
61
  "--seed",
62
  default=1,
63
  type=int,
64
- nargs="*",
65
  help="Seeds to run experiment. Unset will do one run with no set seed",
66
  )
67
  parser.add_argument(
 
40
  HYPERPARAMS_PATH = "hyperparams"
41
 
42
 
43
+ def base_parser(multiple: bool = True) -> argparse.ArgumentParser:
44
  parser = argparse.ArgumentParser()
45
  parser.add_argument(
46
  "--algo",
47
  default="dqn",
48
  type=str,
49
  choices=list(ALGOS.keys()),
50
+ nargs="+" if multiple else 1,
51
  help="Abbreviation(s) of algorithm(s)",
52
  )
53
  parser.add_argument(
54
  "--env",
55
  default="CartPole-v1",
56
  type=str,
57
+ nargs="+" if multiple else 1,
58
  help="Name of environment(s) in gym",
59
  )
60
  parser.add_argument(
61
  "--seed",
62
  default=1,
63
  type=int,
64
+ nargs="*" if multiple else "?",
65
  help="Seeds to run experiment. Unset will do one run with no set seed",
66
  )
67
  parser.add_argument(
shared/callbacks/eval_callback.py CHANGED
@@ -22,7 +22,10 @@ class EvaluateAccumulator(EpisodeAccumulator):
22
  self.print_returns = print_returns
23
 
24
  def on_done(self, ep_idx: int, episode: Episode) -> None:
25
- if len(self.completed_episodes_by_env_idx[ep_idx]) >= self.goal_episodes_per_env:
 
 
 
26
  return
27
  self.completed_episodes_by_env_idx[ep_idx].append(episode)
28
  if self.print_returns:
@@ -36,11 +39,14 @@ class EvaluateAccumulator(EpisodeAccumulator):
36
  return sum(len(ce) for ce in self.completed_episodes_by_env_idx)
37
 
38
  @property
39
- def episodes(self) -> bool:
40
- return list(itertools.chain(*self.completed_episodes_by_env_idx))
41
 
42
  def is_done(self) -> bool:
43
- return all(len(ce) == self.goal_episodes_per_env for ce in self.completed_episodes_by_env_idx)
 
 
 
44
 
45
 
46
  def evaluate(
@@ -108,7 +114,7 @@ class EvalCallback(Callback):
108
  def on_step(self, timesteps_elapsed: int = 1) -> bool:
109
  super().on_step(timesteps_elapsed)
110
  if self.timesteps_elapsed // self.step_freq >= len(self.stats):
111
- self.sync_vec_normalize(self.env)
112
  self.evaluate()
113
  return True
114
 
@@ -134,10 +140,12 @@ class EvalCallback(Callback):
134
  assert self.best_model_path
135
  self.policy.save(self.best_model_path)
136
  print("Saved best model")
137
- self.best.write_to_tensorboard(self.tb_writer, "best_eval", self.timesteps_elapsed)
 
 
138
  if strictly_better and self.record_best_videos:
139
  assert self.video_env and self.best_video_dir
140
- self.sync_vec_normalize(self.video_env)
141
  self.best_video_base_path = os.path.join(
142
  self.best_video_dir, str(self.timesteps_elapsed)
143
  )
@@ -159,16 +167,15 @@ class EvalCallback(Callback):
159
 
160
  return eval_stat
161
 
162
- def sync_vec_normalize(self, destination_env: VecEnv) -> None:
163
- if self.policy.vec_normalize is not None:
164
- eval_env_wrapper = destination_env
165
- while isinstance(eval_env_wrapper, VecEnvWrapper):
166
- if isinstance(eval_env_wrapper, VecNormalize):
167
- if hasattr(self.policy.vec_normalize, "obs_rms"):
168
- eval_env_wrapper.obs_rms = deepcopy(
169
- self.policy.vec_normalize.obs_rms
170
- )
171
- eval_env_wrapper.ret_rms = deepcopy(
172
- self.policy.vec_normalize.ret_rms
173
- )
174
- eval_env_wrapper = eval_env_wrapper.venv
 
22
  self.print_returns = print_returns
23
 
24
  def on_done(self, ep_idx: int, episode: Episode) -> None:
25
+ if (
26
+ len(self.completed_episodes_by_env_idx[ep_idx])
27
+ >= self.goal_episodes_per_env
28
+ ):
29
  return
30
  self.completed_episodes_by_env_idx[ep_idx].append(episode)
31
  if self.print_returns:
 
39
  return sum(len(ce) for ce in self.completed_episodes_by_env_idx)
40
 
41
  @property
42
+ def episodes(self) -> List[Episode]:
43
+ return list(itertools.chain(*self.completed_episodes_by_env_idx))
44
 
45
  def is_done(self) -> bool:
46
+ return all(
47
+ len(ce) == self.goal_episodes_per_env
48
+ for ce in self.completed_episodes_by_env_idx
49
+ )
50
 
51
 
52
  def evaluate(
 
114
  def on_step(self, timesteps_elapsed: int = 1) -> bool:
115
  super().on_step(timesteps_elapsed)
116
  if self.timesteps_elapsed // self.step_freq >= len(self.stats):
117
+ sync_vec_normalize(self.policy.vec_normalize, self.env)
118
  self.evaluate()
119
  return True
120
 
 
140
  assert self.best_model_path
141
  self.policy.save(self.best_model_path)
142
  print("Saved best model")
143
+ self.best.write_to_tensorboard(
144
+ self.tb_writer, "best_eval", self.timesteps_elapsed
145
+ )
146
  if strictly_better and self.record_best_videos:
147
  assert self.video_env and self.best_video_dir
148
+ sync_vec_normalize(self.policy.vec_normalize, self.video_env)
149
  self.best_video_base_path = os.path.join(
150
  self.best_video_dir, str(self.timesteps_elapsed)
151
  )
 
167
 
168
  return eval_stat
169
 
170
+
171
+ def sync_vec_normalize(
172
+ origin_vec_normalize: Optional[VecNormalize], destination_env: VecEnv
173
+ ) -> None:
174
+ if origin_vec_normalize is not None:
175
+ eval_env_wrapper = destination_env
176
+ while isinstance(eval_env_wrapper, VecEnvWrapper):
177
+ if isinstance(eval_env_wrapper, VecNormalize):
178
+ if hasattr(origin_vec_normalize, "obs_rms"):
179
+ eval_env_wrapper.obs_rms = deepcopy(origin_vec_normalize.obs_rms)
180
+ eval_env_wrapper.ret_rms = deepcopy(origin_vec_normalize.ret_rms)
181
+ eval_env_wrapper = eval_env_wrapper.venv
 
shared/policy/policy.py CHANGED
@@ -54,7 +54,9 @@ class Policy(nn.Module, ABC):
54
  @abstractmethod
55
  def load(self, path: str) -> None:
56
  # VecNormalize load occurs in env.py
57
- self.load_state_dict(torch.load(os.path.join(path, MODEL_FILENAME)))
 
 
58
 
59
  def reset_noise(self) -> None:
60
  pass
 
54
  @abstractmethod
55
  def load(self, path: str) -> None:
56
  # VecNormalize load occurs in env.py
57
+ self.load_state_dict(
58
+ torch.load(os.path.join(path, MODEL_FILENAME), map_location=self.device)
59
+ )
60
 
61
  def reset_noise(self) -> None:
62
  pass
shared/stats.py CHANGED
@@ -94,6 +94,9 @@ class EpisodesStats:
94
  f"Length: {self.length}"
95
  )
96
 
 
 
 
97
  def _asdict(self) -> dict:
98
  return {
99
  "n_episodes": len(self.episodes),
@@ -147,27 +150,3 @@ class EpisodeAccumulator:
147
 
148
  def stats(self) -> EpisodesStats:
149
  return EpisodesStats(self.episodes)
150
-
151
-
152
- class RolloutStats(EpisodeAccumulator):
153
- def __init__(self, num_envs: int, print_n_episodes: int, tb_writer: SummaryWriter):
154
- super().__init__(num_envs)
155
- self.print_n_episodes = print_n_episodes
156
- self.epochs: List[EpisodesStats] = []
157
- self.tb_writer = tb_writer
158
-
159
- def on_done(self, ep_idx: int, episode: Episode) -> None:
160
- if (
161
- self.print_n_episodes >= 0
162
- and len(self.episodes) % self.print_n_episodes == 0
163
- ):
164
- sample = self.episodes[-self.print_n_episodes :]
165
- epoch = EpisodesStats(sample)
166
- self.epochs.append(epoch)
167
- total_steps = np.sum([e.length for e in self.episodes])
168
- print(
169
- f"Episode: {len(self.episodes)} | "
170
- f"{epoch} | "
171
- f"Total Steps: {total_steps}"
172
- )
173
- epoch.write_to_tensorboard(self.tb_writer, "train", global_step=total_steps)
 
94
  f"Length: {self.length}"
95
  )
96
 
97
+ def __len__(self) -> int:
98
+ return len(self.episodes)
99
+
100
  def _asdict(self) -> dict:
101
  return {
102
  "n_episodes": len(self.episodes),
 
150
 
151
  def stats(self) -> EpisodesStats:
152
  return EpisodesStats(self.episodes)