sgoodfriend commited on
Commit
85e4a43
1 Parent(s): 0ca6846

PPO playing CarRacing-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/fbc943f151b95afc4905a67a3835fb6b18c6a5e4

Browse files
README.md CHANGED
@@ -10,7 +10,7 @@ model-index:
10
  results:
11
  - metrics:
12
  - type: mean_reward
13
- value: 621.48 +/- 140.74
14
  name: mean_reward
15
  task:
16
  type: reinforcement-learning
@@ -23,17 +23,17 @@ model-index:
23
 
24
  This is a trained model of a **PPO** agent playing **CarRacing-v0** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
25
 
26
- All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/6p2sjqtn.
27
 
28
  ## Training Results
29
 
30
- This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
31
 
32
  | algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
33
  |:-------|:-------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
34
- | ppo | CarRacing-v0 | 4 | 635.901 | 267.357 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/1af869b6) |
35
- | ppo | CarRacing-v0 | 5 | 621.48 | 140.74 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/2isthqpm) |
36
- | ppo | CarRacing-v0 | 6 | 663.161 | 184.276 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/rhymy2k8) |
37
 
38
 
39
  ### Prerequisites: Weights & Biases (WandB)
@@ -53,10 +53,10 @@ login`.
53
  Note: While the model state dictionary and hyperaparameters are saved, the latest
54
  implementation could be sufficiently different to not be able to reproduce similar
55
  results. You might need to checkout the commit the agent was trained on:
56
- [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c).
57
  ```
58
  # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
59
- python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/2isthqpm
60
  ```
61
 
62
  Setup hasn't been completely worked out yet, so you might be best served by using Google
@@ -68,11 +68,11 @@ notebook.
68
 
69
  ## Training
70
  If you want the highest chance to reproduce these results, you'll want to checkout the
71
- commit the agent was trained on: [5598ebc](https://github.com/sgoodfriend/rl-algo-impls/tree/5598ebc4b03054f16eebe76792486ba7bcacfc5c). While
72
  training is deterministic, different hardware will give different results.
73
 
74
  ```
75
- python train.py --algo ppo --env CarRacing-v0 --seed 5
76
  ```
77
 
78
  Setup hasn't been completely worked out yet, so you might be best served by using Google
@@ -83,7 +83,7 @@ notebook.
83
 
84
 
85
  ## Benchmarking (with Lambda Labs instance)
86
- This and other models from https://api.wandb.ai/links/sgoodfriend/6p2sjqtn were generated by running a script on a Lambda
87
  Labs instance. In a Lambda Labs instance terminal:
88
  ```
89
  git clone git@github.com:sgoodfriend/rl-algo-impls.git
@@ -127,16 +127,18 @@ n_timesteps: 4000000
127
  policy_hyperparams:
128
  activation_fn: relu
129
  cnn_feature_dim: 256
 
 
130
  init_layers_orthogonal: false
131
  log_std_init: -2
132
  share_features_extractor: false
133
  use_sde: true
134
- seed: 5
135
  use_deterministic_algorithms: true
136
  wandb_entity: null
137
  wandb_project_name: rl-algo-impls-benchmarks
138
  wandb_tags:
139
- - benchmark_5598ebc
140
- - host_192-9-145-26
141
 
142
  ```
 
10
  results:
11
  - metrics:
12
  - type: mean_reward
13
+ value: 865.72 +/- 58.15
14
  name: mean_reward
15
  task:
16
  type: reinforcement-learning
 
23
 
24
  This is a trained model of a **PPO** agent playing **CarRacing-v0** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
25
 
26
+ All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/448odm37.
27
 
28
  ## Training Results
29
 
30
+ This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [fbc943f](https://github.com/sgoodfriend/rl-algo-impls/tree/fbc943f151b95afc4905a67a3835fb6b18c6a5e4). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
31
 
32
  | algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
33
  |:-------|:-------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
34
+ | ppo | CarRacing-v0 | 1 | 865.725 | 58.1454 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/8vyb0q44) |
35
+ | ppo | CarRacing-v0 | 2 | 693.464 | 236.712 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/a3ld38qf) |
36
+ | ppo | CarRacing-v0 | 3 | 815.26 | 141.502 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/zah43or2) |
37
 
38
 
39
  ### Prerequisites: Weights & Biases (WandB)
 
53
  Note: While the model state dictionary and hyperaparameters are saved, the latest
54
  implementation could be sufficiently different to not be able to reproduce similar
55
  results. You might need to checkout the commit the agent was trained on:
56
+ [fbc943f](https://github.com/sgoodfriend/rl-algo-impls/tree/fbc943f151b95afc4905a67a3835fb6b18c6a5e4).
57
  ```
58
  # Downloads the model, sets hyperparameters, and runs agent for 3 episodes
59
+ python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/8vyb0q44
60
  ```
61
 
62
  Setup hasn't been completely worked out yet, so you might be best served by using Google
 
68
 
69
  ## Training
70
  If you want the highest chance to reproduce these results, you'll want to checkout the
71
+ commit the agent was trained on: [fbc943f](https://github.com/sgoodfriend/rl-algo-impls/tree/fbc943f151b95afc4905a67a3835fb6b18c6a5e4). While
72
  training is deterministic, different hardware will give different results.
73
 
74
  ```
75
+ python train.py --algo ppo --env CarRacing-v0 --seed 1
76
  ```
77
 
78
  Setup hasn't been completely worked out yet, so you might be best served by using Google
 
83
 
84
 
85
  ## Benchmarking (with Lambda Labs instance)
86
+ This and other models from https://api.wandb.ai/links/sgoodfriend/448odm37 were generated by running a script on a Lambda
87
  Labs instance. In a Lambda Labs instance terminal:
88
  ```
89
  git clone git@github.com:sgoodfriend/rl-algo-impls.git
 
127
  policy_hyperparams:
128
  activation_fn: relu
129
  cnn_feature_dim: 256
130
+ hidden_sizes:
131
+ - 256
132
  init_layers_orthogonal: false
133
  log_std_init: -2
134
  share_features_extractor: false
135
  use_sde: true
136
+ seed: 1
137
  use_deterministic_algorithms: true
138
  wandb_entity: null
139
  wandb_project_name: rl-algo-impls-benchmarks
140
  wandb_tags:
141
+ - benchmark_fbc943f
142
+ - host_150-230-44-105
143
 
144
  ```
benchmark_publish.py CHANGED
@@ -44,11 +44,10 @@ if __name__ == "__main__":
44
  default=3,
45
  help="How many publish jobs can run in parallel",
46
  )
47
- parser.set_defaults(
48
- wandb_tags=["benchmark_5598ebc", "host_192-9-145-26"],
49
- wandb_report_url="https://api.wandb.ai/links/sgoodfriend/6p2sjqtn",
50
- envs=["CartPole-v1", "Acrobot-v1"],
51
- )
52
  args = parser.parse_args()
53
  print(args)
54
 
 
44
  default=3,
45
  help="How many publish jobs can run in parallel",
46
  )
47
+ # parser.set_defaults(
48
+ # wandb_tags=["benchmark_5598ebc", "host_192-9-145-26"],
49
+ # wandb_report_url="https://api.wandb.ai/links/sgoodfriend/6p2sjqtn",
50
+ # )
 
51
  args = parser.parse_args()
52
  print(args)
53
 
benchmarks/benchmark_test.sh ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ source benchmarks/train_loop.sh
2
+
3
+ export WANDB_PROJECT_NAME="rl-algo-impls"
4
+
5
+ BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
6
+
7
+ ALGOS=(
8
+ # "vpg"
9
+ "dqn"
10
+ # "ppo"
11
+ )
12
+ ENVS=(
13
+ # Basic
14
+ "CartPole-v1"
15
+ "MountainCar-v0"
16
+ # "MountainCarContinuous-v0"
17
+ "Acrobot-v1"
18
+ "LunarLander-v2"
19
+ # # PyBullet
20
+ # "HalfCheetahBulletEnv-v0"
21
+ # "AntBulletEnv-v0"
22
+ # "HopperBulletEnv-v0"
23
+ # "Walker2DBulletEnv-v0"
24
+ # # CarRacing
25
+ # "CarRacing-v0"
26
+ # Atari
27
+ "PongNoFrameskip-v4"
28
+ "BreakoutNoFrameskip-v4"
29
+ "SpaceInvadersNoFrameskip-v4"
30
+ "QbertNoFrameskip-v4"
31
+ )
32
+ train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
benchmarks/colab_pybullet.sh CHANGED
@@ -1,5 +1,5 @@
1
  source benchmarks/train_loop.sh
2
  ALGOS="ppo"
3
- ENVS="HalfCheetahBulletEnv-v0 AntBulletEnv-v0 Walker2DBulletEnv-v0 HopperBulletEnv-v0"
4
  BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
  train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
 
1
  source benchmarks/train_loop.sh
2
  ALGOS="ppo"
3
+ ENVS="HalfCheetahBulletEnv-v0 AntBulletEnv-v0 HopperBulletEnv-v0 Walker2DBulletEnv-v0"
4
  BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
5
  train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
benchmarks/train_loop.sh CHANGED
@@ -4,13 +4,11 @@ train_loop () {
4
  local env
5
  local seed
6
  local WANDB_PROJECT_NAME="${WANDB_PROJECT_NAME:-rl-algo-impls-benchmarks}"
7
- local args=()
8
- (( VIRTUAL_DISPLAY == 1)) && args+=("--virtual-display")
9
  local SEEDS="${SEEDS:-1 2 3}"
10
  for algo in $(echo $1); do
11
  for env in $(echo $2); do
12
  for seed in $SEEDS; do
13
- echo python train.py --algo $algo --env $env --seed $seed --pool-size 1 --wandb-tags $WANDB_TAGS --wandb-project-name $WANDB_PROJECT_NAME ${args[@]}
14
  done
15
  done
16
  done
 
4
  local env
5
  local seed
6
  local WANDB_PROJECT_NAME="${WANDB_PROJECT_NAME:-rl-algo-impls-benchmarks}"
 
 
7
  local SEEDS="${SEEDS:-1 2 3}"
8
  for algo in $(echo $1); do
9
  for env in $(echo $2); do
10
  for seed in $SEEDS; do
11
+ echo python train.py --algo $algo --env $env --seed $seed --pool-size 1 --wandb-tags $WANDB_TAGS --wandb-project-name $WANDB_PROJECT_NAME
12
  done
13
  done
14
  done
colab_requirements.txt CHANGED
@@ -6,4 +6,5 @@ wandb >= 0.13.9, < 0.14
6
  pyvirtualdisplay == 3.0
7
  pybullet >= 3.2.5, < 3.3
8
  tabulate >= 0.9.0, < 0.10
9
- huggingface-hub >= 0.12.0, < 0.13
 
 
6
  pyvirtualdisplay == 3.0
7
  pybullet >= 3.2.5, < 3.3
8
  tabulate >= 0.9.0, < 0.10
9
+ huggingface-hub >= 0.12.0, < 0.13
10
+ numexpr >= 2.8.4, < 2.9
compare_runs.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import numpy as np
4
+ import pandas as pd
5
+ import wandb
6
+ import wandb.apis.public
7
+
8
+ from collections import defaultdict
9
+ from dataclasses import dataclass
10
+ from typing import Dict, Iterable, List, TypeVar
11
+
12
+ from benchmark_publish import RunGroup
13
+
14
+
15
+ @dataclass
16
+ class Comparison:
17
+ control_values: List[float]
18
+ experiment_values: List[float]
19
+
20
+ def mean_diff_percentage(self) -> float:
21
+ return self._diff_percentage(
22
+ np.mean(self.control_values).item(), np.mean(self.experiment_values).item()
23
+ )
24
+
25
+ def median_diff_percentage(self) -> float:
26
+ return self._diff_percentage(
27
+ np.median(self.control_values).item(),
28
+ np.median(self.experiment_values).item(),
29
+ )
30
+
31
+ def _diff_percentage(self, c: float, e: float) -> float:
32
+ if c == e:
33
+ return 0
34
+ elif c == 0:
35
+ return float("inf") if e > 0 else float("-inf")
36
+ return 100 * (e - c) / c
37
+
38
+ def score(self) -> float:
39
+ return (
40
+ np.sum(
41
+ np.sign((self.mean_diff_percentage(), self.median_diff_percentage()))
42
+ ).item()
43
+ / 2
44
+ )
45
+
46
+
47
+ RunGroupRunsSelf = TypeVar("RunGroupRunsSelf", bound="RunGroupRuns")
48
+
49
+
50
+ class RunGroupRuns:
51
+ def __init__(
52
+ self,
53
+ run_group: RunGroup,
54
+ control: List[str],
55
+ experiment: List[str],
56
+ summary_stats: List[str] = ["best_eval", "eval", "train_rolling"],
57
+ summary_metrics: List[str] = ["mean", "result"],
58
+ ) -> None:
59
+ self.algo = run_group.algo
60
+ self.env = run_group.env_id
61
+ self.control = set(control)
62
+ self.experiment = set(experiment)
63
+
64
+ self.summary_stats = summary_stats
65
+ self.summary_metrics = summary_metrics
66
+
67
+ self.control_runs = []
68
+ self.experiment_runs = []
69
+
70
+ def add_run(self, run: wandb.apis.public.Run) -> None:
71
+ wandb_tags = set(run.config.get("wandb_tags", []))
72
+ if self.control & wandb_tags:
73
+ self.control_runs.append(run)
74
+ elif self.experiment & wandb_tags:
75
+ self.experiment_runs.append(run)
76
+
77
+ def comparisons_by_metric(self) -> Dict[str, Comparison]:
78
+ c_by_m = {}
79
+ for metric in (
80
+ f"{s}_{m}"
81
+ for s, m in itertools.product(self.summary_stats, self.summary_metrics)
82
+ ):
83
+ c_by_m[metric] = Comparison(
84
+ [c.summary[metric] for c in self.control_runs],
85
+ [e.summary[metric] for e in self.experiment_runs],
86
+ )
87
+ return c_by_m
88
+
89
+ @staticmethod
90
+ def data_frame(rows: Iterable[RunGroupRunsSelf]) -> pd.DataFrame:
91
+ results = defaultdict(list)
92
+ for r in rows:
93
+ results["algo"].append(r.algo)
94
+ results["env"].append(r.env)
95
+ results["control"].append(r.control)
96
+ results["expierment"].append(r.experiment)
97
+ c_by_m = r.comparisons_by_metric()
98
+ results["score"].append(
99
+ sum(m.score() for m in c_by_m.values()) / len(c_by_m)
100
+ )
101
+ for m, c in c_by_m.items():
102
+ results[f"{m}_mean"].append(c.mean_diff_percentage())
103
+ results[f"{m}_median"].append(c.median_diff_percentage())
104
+ return pd.DataFrame(results)
105
+
106
+
107
+ if __name__ == "__main__":
108
+ parser = argparse.ArgumentParser()
109
+ parser.add_argument(
110
+ "-p",
111
+ "--wandb-project-name",
112
+ type=str,
113
+ default="rl-algo-impls-benchmarks",
114
+ help="WandB project name to load runs from",
115
+ )
116
+ parser.add_argument(
117
+ "--wandb-entity",
118
+ type=str,
119
+ default=None,
120
+ help="WandB team. None uses default entity",
121
+ )
122
+ parser.add_argument(
123
+ "-n",
124
+ "--wandb-hostname-tag",
125
+ type=str,
126
+ help="WandB tag for hostname (i.e. host_192-9-145-26)",
127
+ )
128
+ parser.add_argument(
129
+ "-c",
130
+ "--wandb-control-tag",
131
+ type=str,
132
+ nargs="+",
133
+ help="WandB tag for control commit (i.e. benchmark_5598ebc)",
134
+ )
135
+ parser.add_argument(
136
+ "-e",
137
+ "--wandb-experiment-tag",
138
+ type=str,
139
+ nargs="+",
140
+ help="WandB tag for experiment commit (i.e. benchmark_5540e1f)",
141
+ )
142
+ parser.add_argument(
143
+ "--exclude_envs",
144
+ type=str,
145
+ nargs="*",
146
+ help="Environments to exclude from comparison",
147
+ )
148
+ parser.set_defaults(
149
+ wandb_hostname_tag="host_192-9-145-26",
150
+ wandb_control_tag=["benchmark_e4d1ed6", "benchmark_5598ebc"],
151
+ wandb_experiment_tag=["benchmark_680043d", "benchmark_5540e1f"],
152
+ exclude_envs=["CarRacing-v0"]
153
+ )
154
+ args = parser.parse_args()
155
+ print(args)
156
+
157
+ api = wandb.Api()
158
+ all_runs = api.runs(
159
+ path=f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}",
160
+ order="+created_at",
161
+ )
162
+
163
+ runs_by_run_group: Dict[RunGroup, RunGroupRuns] = {}
164
+ for r in all_runs:
165
+ wandb_tags = r.config.get("wandb_tags", [])
166
+ if not wandb_tags or not args.wandb_hostname_tag in wandb_tags:
167
+ continue
168
+ rg = RunGroup(r.config["algo"], r.config["env"])
169
+ if args.exclude_envs and rg.env_id in args.exclude_envs:
170
+ continue
171
+ if rg not in runs_by_run_group:
172
+ runs_by_run_group[rg] = RunGroupRuns(
173
+ rg, args.wandb_control_tag, args.wandb_experiment_tag
174
+ )
175
+ runs_by_run_group[rg].add_run(r)
176
+ df = RunGroupRuns.data_frame(runs_by_run_group.values()).round(decimals=2)
177
+ print(f"**Total Score: {sum(df.score)}**")
178
+ df.loc["mean"] = df.mean(numeric_only=True)
179
+ print(df.to_markdown())
dqn/policy.py CHANGED
@@ -15,7 +15,7 @@ class DQNPolicy(Policy):
15
  def __init__(
16
  self,
17
  env: VecEnv,
18
- hidden_sizes: Sequence[int],
19
  **kwargs,
20
  ) -> None:
21
  super().__init__(env, **kwargs)
 
15
  def __init__(
16
  self,
17
  env: VecEnv,
18
+ hidden_sizes: Sequence[int] = [],
19
  **kwargs,
20
  ) -> None:
21
  super().__init__(env, **kwargs)
dqn/q_net.py CHANGED
@@ -13,7 +13,7 @@ class QNetwork(nn.Module):
13
  self,
14
  observation_space: gym.Space,
15
  action_space: gym.Space,
16
- hidden_sizes: Sequence[int],
17
  activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
18
  ) -> None:
19
  super().__init__()
 
13
  self,
14
  observation_space: gym.Space,
15
  action_space: gym.Space,
16
+ hidden_sizes: Sequence[int] = [],
17
  activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
18
  ) -> None:
19
  super().__init__()
huggingface_publish.py CHANGED
@@ -14,6 +14,8 @@ from typing import List, Optional
14
 
15
  from huggingface_hub.hf_api import HfApi, upload_folder
16
  from huggingface_hub.repocard import metadata_save
 
 
17
  from publish.markdown_format import EvalTableData, model_card_text
18
  from runner.evaluate import EvalArgs, evaluate_model
19
  from runner.env import make_eval_env
@@ -27,6 +29,9 @@ def publish(
27
  huggingface_user: Optional[str] = None,
28
  huggingface_token: Optional[str] = None,
29
  ) -> None:
 
 
 
30
  api = wandb.Api()
31
  runs = [api.run(rp) for rp in wandb_run_paths]
32
  algo = runs[0].config["algo"]
 
14
 
15
  from huggingface_hub.hf_api import HfApi, upload_folder
16
  from huggingface_hub.repocard import metadata_save
17
+ from pyvirtualdisplay.display import Display
18
+
19
  from publish.markdown_format import EvalTableData, model_card_text
20
  from runner.evaluate import EvalArgs, evaluate_model
21
  from runner.env import make_eval_env
 
29
  huggingface_user: Optional[str] = None,
30
  huggingface_token: Optional[str] = None,
31
  ) -> None:
32
+ virtual_display = Display(visible=False, size=(1400, 900))
33
+ virtual_display.start()
34
+
35
  api = wandb.Api()
36
  runs = [api.run(rp) for rp in wandb_run_paths]
37
  algo = runs[0].config["algo"]
hyperparams/dqn.yml CHANGED
@@ -1,7 +1,6 @@
1
  CartPole-v1: &cartpole-defaults
2
  n_timesteps: !!float 5e4
3
  env_hyperparams:
4
- n_envs: 1
5
  rolling_length: 50
6
  policy_hyperparams:
7
  hidden_sizes: [256, 256]
@@ -18,8 +17,6 @@ CartPole-v1: &cartpole-defaults
18
  exploration_final_eps: 0.04
19
  eval_params:
20
  step_freq: !!float 1e4
21
- n_episodes: 10
22
- save_best: true
23
 
24
  CartPole-v0:
25
  <<: *cartpole-defaults
@@ -46,7 +43,7 @@ MountainCar-v0:
46
  Acrobot-v1:
47
  n_timesteps: !!float 1e5
48
  env_hyperparams:
49
- rolling_length: 10
50
  policy_hyperparams:
51
  hidden_sizes: [256, 256]
52
  algo_hyperparams:
@@ -64,7 +61,7 @@ Acrobot-v1:
64
  LunarLander-v2:
65
  n_timesteps: !!float 5e5
66
  env_hyperparams:
67
- rolling_length: 10
68
  policy_hyperparams:
69
  hidden_sizes: [256, 256]
70
  algo_hyperparams:
@@ -81,19 +78,15 @@ LunarLander-v2:
81
  max_grad_norm: 0.5
82
  eval_params:
83
  step_freq: 25_000
84
- n_episodes: 10
85
- save_best: true
86
 
87
- SpaceInvadersNoFrameskip-v4: &atari-defaults
88
  n_timesteps: !!float 1e7
89
  env_hyperparams:
90
  frame_stack: 4
91
  no_reward_timeout_steps: 1_000
 
92
  n_envs: 8
93
  vec_env_class: "subproc"
94
- rolling_length: 20
95
- policy_hyperparams:
96
- hidden_sizes: [512]
97
  algo_hyperparams:
98
  buffer_size: 100000
99
  learning_rate: !!float 1e-4
@@ -105,12 +98,7 @@ SpaceInvadersNoFrameskip-v4: &atari-defaults
105
  exploration_fraction: 0.1
106
  exploration_final_eps: 0.01
107
  eval_params:
108
- step_freq: 100_000
109
- n_episodes: 10
110
- save_best: true
111
-
112
- BreakoutNoFrameskip-v4:
113
- <<: *atari-defaults
114
 
115
  PongNoFrameskip-v4:
116
  <<: *atari-defaults
 
1
  CartPole-v1: &cartpole-defaults
2
  n_timesteps: !!float 5e4
3
  env_hyperparams:
 
4
  rolling_length: 50
5
  policy_hyperparams:
6
  hidden_sizes: [256, 256]
 
17
  exploration_final_eps: 0.04
18
  eval_params:
19
  step_freq: !!float 1e4
 
 
20
 
21
  CartPole-v0:
22
  <<: *cartpole-defaults
 
43
  Acrobot-v1:
44
  n_timesteps: !!float 1e5
45
  env_hyperparams:
46
+ rolling_length: 50
47
  policy_hyperparams:
48
  hidden_sizes: [256, 256]
49
  algo_hyperparams:
 
61
  LunarLander-v2:
62
  n_timesteps: !!float 5e5
63
  env_hyperparams:
64
+ rolling_length: 50
65
  policy_hyperparams:
66
  hidden_sizes: [256, 256]
67
  algo_hyperparams:
 
78
  max_grad_norm: 0.5
79
  eval_params:
80
  step_freq: 25_000
 
 
81
 
82
+ atari: &atari-defaults
83
  n_timesteps: !!float 1e7
84
  env_hyperparams:
85
  frame_stack: 4
86
  no_reward_timeout_steps: 1_000
87
+ no_reward_fire_steps: 500
88
  n_envs: 8
89
  vec_env_class: "subproc"
 
 
 
90
  algo_hyperparams:
91
  buffer_size: 100000
92
  learning_rate: !!float 1e-4
 
98
  exploration_fraction: 0.1
99
  exploration_final_eps: 0.01
100
  eval_params:
101
+ deterministic: false
 
 
 
 
 
102
 
103
  PongNoFrameskip-v4:
104
  <<: *atari-defaults
hyperparams/ppo.yml CHANGED
@@ -15,8 +15,6 @@ CartPole-v1: &cartpole-defaults
15
  clip_range_decay: linear
16
  eval_params:
17
  step_freq: !!float 2.5e4
18
- n_episodes: 10
19
- save_best: true
20
 
21
  CartPole-v0:
22
  <<: *cartpole-defaults
@@ -39,9 +37,10 @@ MountainCarContinuous-v0:
39
  env_hyperparams:
40
  normalize: true
41
  n_envs: 4
42
- policy_hyperparams:
43
- init_layers_orthogonal: false
44
- # log_std_init: -3.29
 
45
  algo_hyperparams:
46
  n_steps: 512
47
  batch_size: 256
@@ -53,11 +52,8 @@ MountainCarContinuous-v0:
53
  gae_lambda: 0.9
54
  max_grad_norm: 5
55
  vf_coef: 0.19
56
- # use_sde: true
57
  eval_params:
58
  step_freq: 5000
59
- n_episodes: 10
60
- save_best: true
61
 
62
  Acrobot-v1:
63
  n_timesteps: !!float 1e6
@@ -84,10 +80,6 @@ LunarLander-v2:
84
  ent_coef: 0.01
85
  ent_coef_decay: linear
86
  normalize_advantage: false
87
- eval_params:
88
- step_freq: !!float 5e4
89
- n_episodes: 10
90
- save_best: true
91
 
92
  CarRacing-v0:
93
  n_timesteps: !!float 4e6
@@ -101,6 +93,7 @@ CarRacing-v0:
101
  activation_fn: relu
102
  share_features_extractor: false
103
  cnn_feature_dim: 256
 
104
  algo_hyperparams:
105
  n_steps: 512
106
  batch_size: 128
 
15
  clip_range_decay: linear
16
  eval_params:
17
  step_freq: !!float 2.5e4
 
 
18
 
19
  CartPole-v0:
20
  <<: *cartpole-defaults
 
37
  env_hyperparams:
38
  normalize: true
39
  n_envs: 4
40
+ # policy_hyperparams:
41
+ # init_layers_orthogonal: false
42
+ # log_std_init: -3.29
43
+ # use_sde: true
44
  algo_hyperparams:
45
  n_steps: 512
46
  batch_size: 256
 
52
  gae_lambda: 0.9
53
  max_grad_norm: 5
54
  vf_coef: 0.19
 
55
  eval_params:
56
  step_freq: 5000
 
 
57
 
58
  Acrobot-v1:
59
  n_timesteps: !!float 1e6
 
80
  ent_coef: 0.01
81
  ent_coef_decay: linear
82
  normalize_advantage: false
 
 
 
 
83
 
84
  CarRacing-v0:
85
  n_timesteps: !!float 4e6
 
93
  activation_fn: relu
94
  share_features_extractor: false
95
  cnn_feature_dim: 256
96
+ hidden_sizes: [256]
97
  algo_hyperparams:
98
  n_steps: 512
99
  batch_size: 128
lambda_labs/benchmark.sh CHANGED
@@ -1,7 +1,6 @@
1
  source benchmarks/train_loop.sh
2
 
3
  # export WANDB_PROJECT_NAME="rl-algo-impls"
4
- export VIRTUAL_DISPLAY=1
5
 
6
  BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-6}"
7
 
@@ -20,8 +19,8 @@ ENVS=(
20
  # PyBullet
21
  "HalfCheetahBulletEnv-v0"
22
  "AntBulletEnv-v0"
23
- "Walker2DBulletEnv-v0"
24
  "HopperBulletEnv-v0"
 
25
  # CarRacing
26
  "CarRacing-v0"
27
  # Atari
 
1
  source benchmarks/train_loop.sh
2
 
3
  # export WANDB_PROJECT_NAME="rl-algo-impls"
 
4
 
5
  BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-6}"
6
 
 
19
  # PyBullet
20
  "HalfCheetahBulletEnv-v0"
21
  "AntBulletEnv-v0"
 
22
  "HopperBulletEnv-v0"
23
+ "Walker2DBulletEnv-v0"
24
  # CarRacing
25
  "CarRacing-v0"
26
  # Atari
lambda_labs/lambda_requirements.txt CHANGED
@@ -8,4 +8,5 @@ wandb >= 0.13.9, < 0.14
8
  pyvirtualdisplay == 3.0
9
  pybullet >= 3.2.5, < 3.3
10
  tabulate >= 0.9.0, < 0.10
11
- huggingface-hub >= 0.12.0, < 0.13
 
 
8
  pyvirtualdisplay == 3.0
9
  pybullet >= 3.2.5, < 3.3
10
  tabulate >= 0.9.0, < 0.10
11
+ huggingface-hub >= 0.12.0, < 0.13
12
+ numexpr >= 2.8.4, < 2.9
poetry.lock CHANGED
@@ -787,47 +787,47 @@ files = [
787
 
788
  [[package]]
789
  name = "cryptography"
790
- version = "39.0.0"
791
  description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
792
  category = "main"
793
  optional = false
794
  python-versions = ">=3.6"
795
  files = [
796
- {file = "cryptography-39.0.0-cp36-abi3-macosx_10_12_universal2.whl", hash = "sha256:c52a1a6f81e738d07f43dab57831c29e57d21c81a942f4602fac7ee21b27f288"},
797
- {file = "cryptography-39.0.0-cp36-abi3-macosx_10_12_x86_64.whl", hash = "sha256:80ee674c08aaef194bc4627b7f2956e5ba7ef29c3cc3ca488cf15854838a8f72"},
798
- {file = "cryptography-39.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:887cbc1ea60786e534b00ba8b04d1095f4272d380ebd5f7a7eb4cc274710fad9"},
799
- {file = "cryptography-39.0.0-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6f97109336df5c178ee7c9c711b264c502b905c2d2a29ace99ed761533a3460f"},
800
- {file = "cryptography-39.0.0-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a6915075c6d3a5e1215eab5d99bcec0da26036ff2102a1038401d6ef5bef25b"},
801
- {file = "cryptography-39.0.0-cp36-abi3-manylinux_2_24_x86_64.whl", hash = "sha256:76c24dd4fd196a80f9f2f5405a778a8ca132f16b10af113474005635fe7e066c"},
802
- {file = "cryptography-39.0.0-cp36-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:bae6c7f4a36a25291b619ad064a30a07110a805d08dc89984f4f441f6c1f3f96"},
803
- {file = "cryptography-39.0.0-cp36-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:875aea1039d78557c7c6b4db2fe0e9d2413439f4676310a5f269dd342ca7a717"},
804
- {file = "cryptography-39.0.0-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:f6c0db08d81ead9576c4d94bbb27aed8d7a430fa27890f39084c2d0e2ec6b0df"},
805
- {file = "cryptography-39.0.0-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:f3ed2d864a2fa1666e749fe52fb8e23d8e06b8012e8bd8147c73797c506e86f1"},
806
- {file = "cryptography-39.0.0-cp36-abi3-win32.whl", hash = "sha256:f671c1bb0d6088e94d61d80c606d65baacc0d374e67bf895148883461cd848de"},
807
- {file = "cryptography-39.0.0-cp36-abi3-win_amd64.whl", hash = "sha256:e324de6972b151f99dc078defe8fb1b0a82c6498e37bff335f5bc6b1e3ab5a1e"},
808
- {file = "cryptography-39.0.0-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:754978da4d0457e7ca176f58c57b1f9de6556591c19b25b8bcce3c77d314f5eb"},
809
- {file = "cryptography-39.0.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1ee1fd0de9851ff32dbbb9362a4d833b579b4a6cc96883e8e6d2ff2a6bc7104f"},
810
- {file = "cryptography-39.0.0-pp38-pypy38_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:fec8b932f51ae245121c4671b4bbc030880f363354b2f0e0bd1366017d891458"},
811
- {file = "cryptography-39.0.0-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:407cec680e811b4fc829de966f88a7c62a596faa250fc1a4b520a0355b9bc190"},
812
- {file = "cryptography-39.0.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:7dacfdeee048814563eaaec7c4743c8aea529fe3dd53127313a792f0dadc1773"},
813
- {file = "cryptography-39.0.0-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:ad04f413436b0781f20c52a661660f1e23bcd89a0e9bb1d6d20822d048cf2856"},
814
- {file = "cryptography-39.0.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:50386acb40fbabbceeb2986332f0287f50f29ccf1497bae31cf5c3e7b4f4b34f"},
815
- {file = "cryptography-39.0.0-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:e5d71c5d5bd5b5c3eebcf7c5c2bb332d62ec68921a8c593bea8c394911a005ce"},
816
- {file = "cryptography-39.0.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:844ad4d7c3850081dffba91cdd91950038ee4ac525c575509a42d3fc806b83c8"},
817
- {file = "cryptography-39.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:e0a05aee6a82d944f9b4edd6a001178787d1546ec7c6223ee9a848a7ade92e39"},
818
- {file = "cryptography-39.0.0.tar.gz", hash = "sha256:f964c7dcf7802d133e8dbd1565914fa0194f9d683d82411989889ecd701e8adf"},
819
  ]
820
 
821
  [package.dependencies]
822
  cffi = ">=1.12"
823
 
824
  [package.extras]
825
- docs = ["sphinx (>=1.6.5,!=1.8.0,!=3.1.0,!=3.1.1,!=5.2.0,!=5.2.0.post0)", "sphinx-rtd-theme"]
826
  docstest = ["pyenchant (>=1.6.11)", "sphinxcontrib-spelling (>=4.0.1)", "twine (>=1.12.0)"]
827
- pep8test = ["black", "ruff"]
828
  sdist = ["setuptools-rust (>=0.11.4)"]
829
  ssh = ["bcrypt (>=3.1.5)"]
830
- test = ["hypothesis (>=1.11.4,!=3.79.2)", "iso8601", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-subtests", "pytest-xdist", "pytz"]
 
 
831
 
832
  [[package]]
833
  name = "cycler"
@@ -2250,6 +2250,49 @@ jupyter-server = ">=1.8,<3"
2250
  [package.extras]
2251
  test = ["pytest", "pytest-console-scripts", "pytest-tornasync"]
2252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2253
  [[package]]
2254
  name = "numpy"
2255
  version = "1.24.1"
@@ -3002,6 +3045,18 @@ files = [
3002
  {file = "pytz-2022.7.tar.gz", hash = "sha256:7ccfae7b4b2c067464a6733c6261673fdb8fd1be905460396b97a073e9fa683a"},
3003
  ]
3004
 
 
 
 
 
 
 
 
 
 
 
 
 
3005
  [[package]]
3006
  name = "pywin32"
3007
  version = "305"
@@ -4198,4 +4253,4 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools"
4198
  [metadata]
4199
  lock-version = "2.0"
4200
  python-versions = "~3.10"
4201
- content-hash = "89d4861857be881d3c6cb591d17fb98396b8c117b24a8d4ce4b6593ac8048670"
 
787
 
788
  [[package]]
789
  name = "cryptography"
790
+ version = "39.0.1"
791
  description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
792
  category = "main"
793
  optional = false
794
  python-versions = ">=3.6"
795
  files = [
796
+ {file = "cryptography-39.0.1-cp36-abi3-macosx_10_12_universal2.whl", hash = "sha256:6687ef6d0a6497e2b58e7c5b852b53f62142cfa7cd1555795758934da363a965"},
797
+ {file = "cryptography-39.0.1-cp36-abi3-macosx_10_12_x86_64.whl", hash = "sha256:706843b48f9a3f9b9911979761c91541e3d90db1ca905fd63fee540a217698bc"},
798
+ {file = "cryptography-39.0.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:5d2d8b87a490bfcd407ed9d49093793d0f75198a35e6eb1a923ce1ee86c62b41"},
799
+ {file = "cryptography-39.0.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83e17b26de248c33f3acffb922748151d71827d6021d98c70e6c1a25ddd78505"},
800
+ {file = "cryptography-39.0.1-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e124352fd3db36a9d4a21c1aa27fd5d051e621845cb87fb851c08f4f75ce8be6"},
801
+ {file = "cryptography-39.0.1-cp36-abi3-manylinux_2_24_x86_64.whl", hash = "sha256:5aa67414fcdfa22cf052e640cb5ddc461924a045cacf325cd164e65312d99502"},
802
+ {file = "cryptography-39.0.1-cp36-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:35f7c7d015d474f4011e859e93e789c87d21f6f4880ebdc29896a60403328f1f"},
803
+ {file = "cryptography-39.0.1-cp36-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f24077a3b5298a5a06a8e0536e3ea9ec60e4c7ac486755e5fb6e6ea9b3500106"},
804
+ {file = "cryptography-39.0.1-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:f0c64d1bd842ca2633e74a1a28033d139368ad959872533b1bab8c80e8240a0c"},
805
+ {file = "cryptography-39.0.1-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:0f8da300b5c8af9f98111ffd512910bc792b4c77392a9523624680f7956a99d4"},
806
+ {file = "cryptography-39.0.1-cp36-abi3-win32.whl", hash = "sha256:fe913f20024eb2cb2f323e42a64bdf2911bb9738a15dba7d3cce48151034e3a8"},
807
+ {file = "cryptography-39.0.1-cp36-abi3-win_amd64.whl", hash = "sha256:ced4e447ae29ca194449a3f1ce132ded8fcab06971ef5f618605aacaa612beac"},
808
+ {file = "cryptography-39.0.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:807ce09d4434881ca3a7594733669bd834f5b2c6d5c7e36f8c00f691887042ad"},
809
+ {file = "cryptography-39.0.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:96f1157a7c08b5b189b16b47bc9db2332269d6680a196341bf30046330d15388"},
810
+ {file = "cryptography-39.0.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e422abdec8b5fa8462aa016786680720d78bdce7a30c652b7fadf83a4ba35336"},
811
+ {file = "cryptography-39.0.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:b0afd054cd42f3d213bf82c629efb1ee5f22eba35bf0eec88ea9ea7304f511a2"},
812
+ {file = "cryptography-39.0.1-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:6f8ba7f0328b79f08bdacc3e4e66fb4d7aab0c3584e0bd41328dce5262e26b2e"},
813
+ {file = "cryptography-39.0.1-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:ef8b72fa70b348724ff1218267e7f7375b8de4e8194d1636ee60510aae104cd0"},
814
+ {file = "cryptography-39.0.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:aec5a6c9864be7df2240c382740fcf3b96928c46604eaa7f3091f58b878c0bb6"},
815
+ {file = "cryptography-39.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fdd188c8a6ef8769f148f88f859884507b954cc64db6b52f66ef199bb9ad660a"},
816
+ {file = "cryptography-39.0.1.tar.gz", hash = "sha256:d1f6198ee6d9148405e49887803907fe8962a23e6c6f83ea7d98f1c0de375695"},
 
 
817
  ]
818
 
819
  [package.dependencies]
820
  cffi = ">=1.12"
821
 
822
  [package.extras]
823
+ docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=1.1.1)"]
824
  docstest = ["pyenchant (>=1.6.11)", "sphinxcontrib-spelling (>=4.0.1)", "twine (>=1.12.0)"]
825
+ pep8test = ["black", "check-manifest", "mypy", "ruff", "types-pytz", "types-requests"]
826
  sdist = ["setuptools-rust (>=0.11.4)"]
827
  ssh = ["bcrypt (>=3.1.5)"]
828
+ test = ["hypothesis (>=1.11.4,!=3.79.2)", "iso8601", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-shard (>=0.1.2)", "pytest-subtests", "pytest-xdist", "pytz"]
829
+ test-randomorder = ["pytest-randomly"]
830
+ tox = ["tox"]
831
 
832
  [[package]]
833
  name = "cycler"
 
2250
  [package.extras]
2251
  test = ["pytest", "pytest-console-scripts", "pytest-tornasync"]
2252
 
2253
+ [[package]]
2254
+ name = "numexpr"
2255
+ version = "2.8.4"
2256
+ description = "Fast numerical expression evaluator for NumPy"
2257
+ category = "main"
2258
+ optional = false
2259
+ python-versions = ">=3.7"
2260
+ files = [
2261
+ {file = "numexpr-2.8.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a75967d46b6bd56455dd32da6285e5ffabe155d0ee61eef685bbfb8dafb2e484"},
2262
+ {file = "numexpr-2.8.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db93cf1842f068247de631bfc8af20118bf1f9447cd929b531595a5e0efc9346"},
2263
+ {file = "numexpr-2.8.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bca95f4473b444428061d4cda8e59ac564dc7dc6a1dea3015af9805c6bc2946"},
2264
+ {file = "numexpr-2.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e34931089a6bafc77aaae21f37ad6594b98aa1085bb8b45d5b3cd038c3c17d9"},
2265
+ {file = "numexpr-2.8.4-cp310-cp310-win32.whl", hash = "sha256:f3a920bfac2645017110b87ddbe364c9c7a742870a4d2f6120b8786c25dc6db3"},
2266
+ {file = "numexpr-2.8.4-cp310-cp310-win_amd64.whl", hash = "sha256:6931b1e9d4f629f43c14b21d44f3f77997298bea43790cfcdb4dd98804f90783"},
2267
+ {file = "numexpr-2.8.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9400781553541f414f82eac056f2b4c965373650df9694286b9bd7e8d413f8d8"},
2268
+ {file = "numexpr-2.8.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6ee9db7598dd4001138b482342b96d78110dd77cefc051ec75af3295604dde6a"},
2269
+ {file = "numexpr-2.8.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ff5835e8af9a212e8480003d731aad1727aaea909926fd009e8ae6a1cba7f141"},
2270
+ {file = "numexpr-2.8.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:655d84eb09adfee3c09ecf4a89a512225da153fdb7de13c447404b7d0523a9a7"},
2271
+ {file = "numexpr-2.8.4-cp311-cp311-win32.whl", hash = "sha256:5538b30199bfc68886d2be18fcef3abd11d9271767a7a69ff3688defe782800a"},
2272
+ {file = "numexpr-2.8.4-cp311-cp311-win_amd64.whl", hash = "sha256:3f039321d1c17962c33079987b675fb251b273dbec0f51aac0934e932446ccc3"},
2273
+ {file = "numexpr-2.8.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c867cc36cf815a3ec9122029874e00d8fbcef65035c4a5901e9b120dd5d626a2"},
2274
+ {file = "numexpr-2.8.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:059546e8f6283ccdb47c683101a890844f667fa6d56258d48ae2ecf1b3875957"},
2275
+ {file = "numexpr-2.8.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:845a6aa0ed3e2a53239b89c1ebfa8cf052d3cc6e053c72805e8153300078c0b1"},
2276
+ {file = "numexpr-2.8.4-cp37-cp37m-win32.whl", hash = "sha256:a38664e699526cb1687aefd9069e2b5b9387da7feac4545de446141f1ef86f46"},
2277
+ {file = "numexpr-2.8.4-cp37-cp37m-win_amd64.whl", hash = "sha256:eaec59e9bf70ff05615c34a8b8d6c7bd042bd9f55465d7b495ea5436f45319d0"},
2278
+ {file = "numexpr-2.8.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b318541bf3d8326682ebada087ba0050549a16d8b3fa260dd2585d73a83d20a7"},
2279
+ {file = "numexpr-2.8.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b076db98ca65eeaf9bd224576e3ac84c05e451c0bd85b13664b7e5f7b62e2c70"},
2280
+ {file = "numexpr-2.8.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90f12cc851240f7911a47c91aaf223dba753e98e46dff3017282e633602e76a7"},
2281
+ {file = "numexpr-2.8.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c368aa35ae9b18840e78b05f929d3a7b3abccdba9630a878c7db74ca2368339"},
2282
+ {file = "numexpr-2.8.4-cp38-cp38-win32.whl", hash = "sha256:b96334fc1748e9ec4f93d5fadb1044089d73fb08208fdb8382ed77c893f0be01"},
2283
+ {file = "numexpr-2.8.4-cp38-cp38-win_amd64.whl", hash = "sha256:a6d2d7740ae83ba5f3531e83afc4b626daa71df1ef903970947903345c37bd03"},
2284
+ {file = "numexpr-2.8.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:77898fdf3da6bb96aa8a4759a8231d763a75d848b2f2e5c5279dad0b243c8dfe"},
2285
+ {file = "numexpr-2.8.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:df35324666b693f13a016bc7957de7cc4d8801b746b81060b671bf78a52b9037"},
2286
+ {file = "numexpr-2.8.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17ac9cfe6d0078c5fc06ba1c1bbd20b8783f28c6f475bbabd3cad53683075cab"},
2287
+ {file = "numexpr-2.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df3a1f6b24214a1ab826e9c1c99edf1686c8e307547a9aef33910d586f626d01"},
2288
+ {file = "numexpr-2.8.4-cp39-cp39-win32.whl", hash = "sha256:7d71add384adc9119568d7e9ffa8a35b195decae81e0abf54a2b7779852f0637"},
2289
+ {file = "numexpr-2.8.4-cp39-cp39-win_amd64.whl", hash = "sha256:9f096d707290a6a00b6ffdaf581ee37331109fb7b6c8744e9ded7c779a48e517"},
2290
+ {file = "numexpr-2.8.4.tar.gz", hash = "sha256:d5432537418d18691b9115d615d6daa17ee8275baef3edf1afbbf8bc69806147"},
2291
+ ]
2292
+
2293
+ [package.dependencies]
2294
+ numpy = ">=1.13.3"
2295
+
2296
  [[package]]
2297
  name = "numpy"
2298
  version = "1.24.1"
 
3045
  {file = "pytz-2022.7.tar.gz", hash = "sha256:7ccfae7b4b2c067464a6733c6261673fdb8fd1be905460396b97a073e9fa683a"},
3046
  ]
3047
 
3048
+ [[package]]
3049
+ name = "pyvirtualdisplay"
3050
+ version = "3.0"
3051
+ description = "python wrapper for Xvfb, Xephyr and Xvnc"
3052
+ category = "main"
3053
+ optional = false
3054
+ python-versions = "*"
3055
+ files = [
3056
+ {file = "PyVirtualDisplay-3.0-py3-none-any.whl", hash = "sha256:40d4b8dfe4b8de8552e28eb367647f311f88a130bf837fe910e7f180d5477f0e"},
3057
+ {file = "PyVirtualDisplay-3.0.tar.gz", hash = "sha256:09755bc3ceb6eb725fb07eca5425f43f2358d3bf08e00d2a9b792a1aedd16159"},
3058
+ ]
3059
+
3060
  [[package]]
3061
  name = "pywin32"
3062
  version = "305"
 
4253
  [metadata]
4254
  lock-version = "2.0"
4255
  python-versions = "~3.10"
4256
+ content-hash = "8301ee1f2321a6c23370a61466fd3b45096291c2cc63326bbe4701774edf1d94"
ppo/policy.py CHANGED
@@ -2,7 +2,7 @@ from stable_baselines3.common.vec_env.base_vec_env import VecEnv
2
  from typing import Optional, Sequence
3
 
4
  from gym.spaces import Box, Discrete
5
- from shared.policy.on_policy import ActorCritic
6
 
7
 
8
  class PPOActorCritic(ActorCritic):
@@ -13,21 +13,16 @@ class PPOActorCritic(ActorCritic):
13
  v_hidden_sizes: Optional[Sequence[int]] = None,
14
  **kwargs,
15
  ) -> None:
16
- obs_space = env.observation_space
17
- if isinstance(obs_space, Box):
18
- if len(obs_space.shape) == 3:
19
- pi_hidden_sizes = pi_hidden_sizes or []
20
- v_hidden_sizes = v_hidden_sizes or []
21
- elif len(obs_space.shape) == 1:
22
- pi_hidden_sizes = pi_hidden_sizes or [64, 64]
23
- v_hidden_sizes = v_hidden_sizes or [64, 64]
24
- else:
25
- raise ValueError(f"Unsupported observation space: {obs_space}")
26
- elif isinstance(obs_space, Discrete):
27
- pi_hidden_sizes = pi_hidden_sizes or [64]
28
- v_hidden_sizes = v_hidden_sizes or [64]
29
- else:
30
- raise ValueError(f"Unsupported observation space: {obs_space}")
31
  super().__init__(
32
  env,
33
  pi_hidden_sizes,
 
2
  from typing import Optional, Sequence
3
 
4
  from gym.spaces import Box, Discrete
5
+ from shared.policy.on_policy import ActorCritic, default_hidden_sizes
6
 
7
 
8
  class PPOActorCritic(ActorCritic):
 
13
  v_hidden_sizes: Optional[Sequence[int]] = None,
14
  **kwargs,
15
  ) -> None:
16
+ pi_hidden_sizes = (
17
+ pi_hidden_sizes
18
+ if pi_hidden_sizes is not None
19
+ else default_hidden_sizes(env.observation_space)
20
+ )
21
+ v_hidden_sizes = (
22
+ v_hidden_sizes
23
+ if v_hidden_sizes is not None
24
+ else default_hidden_sizes(env.observation_space)
25
+ )
 
 
 
 
 
26
  super().__init__(
27
  env,
28
  pi_hidden_sizes,
pyproject.toml CHANGED
@@ -23,6 +23,9 @@ torch-tb-profiler = "^0.4.1"
23
  jupyter = "^1.0.0"
24
  tabulate = "^0.9.0"
25
  huggingface-hub = "^0.12.0"
 
 
 
26
 
27
  [build-system]
28
  requires = ["poetry-core"]
 
23
  jupyter = "^1.0.0"
24
  tabulate = "^0.9.0"
25
  huggingface-hub = "^0.12.0"
26
+ cryptography = "39.0.1"
27
+ pyvirtualdisplay = "^3.0"
28
+ numexpr = "^2.8.4"
29
 
30
  [build-system]
31
  requires = ["poetry-core"]
replay.meta.json CHANGED
@@ -1 +1 @@
1
- {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 5.1.2 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with clang version 14.0.6\\nconfiguration: --prefix=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl --cc=arm64-apple-darwin20.0.0-clang --cxx=arm64-apple-darwin20.0.0-clang++ --nm=arm64-apple-darwin20.0.0-nm --ar=arm64-apple-darwin20.0.0-ar --disable-doc --disable-openssl --enable-demuxer=dash --enable-hardcoded-tables --enable-libfreetype --enable-libfontconfig --enable-libopenh264 --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/x86_64-apple-darwin13.4.0-clang --enable-neon --enable-gnutls --enable-libmp3lame --enable-libvpx --enable-pthreads --enable-gpl --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-pic --enable-shared --disable-static --enable-version3 --enable-zlib --pkg-config=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/pkg-config\\nlibavutil 57. 28.100 / 57. 28.100\\nlibavcodec 59. 37.100 / 59. 37.100\\nlibavformat 59. 27.100 / 59. 27.100\\nlibavdevice 59. 7.100 / 59. 7.100\\nlibavfilter 8. 44.100 / 8. 44.100\\nlibswscale 6. 7.100 / 6. 7.100\\nlibswresample 4. 7.100 / 4. 7.100\\nlibpostproc 56. 6.100 / 56. 6.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "600x400", "-pix_fmt", "rgb24", "-framerate", "50", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "50", "/var/folders/9g/my5557_91xddp6lx00nkzly80000gn/T/tmpdv8g18sc/ppo-CarRacing-v0/replay.mp4"]}, "episode": {"r": 442.7605895996094, "l": 1000, "t": 12.301982}}
 
1
+ {"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 5.1.2 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with clang version 14.0.6\\nconfiguration: --prefix=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl --cc=arm64-apple-darwin20.0.0-clang --cxx=arm64-apple-darwin20.0.0-clang++ --nm=arm64-apple-darwin20.0.0-nm --ar=arm64-apple-darwin20.0.0-ar --disable-doc --disable-openssl --enable-demuxer=dash --enable-hardcoded-tables --enable-libfreetype --enable-libfontconfig --enable-libopenh264 --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/x86_64-apple-darwin13.4.0-clang --enable-neon --enable-gnutls --enable-libmp3lame --enable-libvpx --enable-pthreads --enable-gpl --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-pic --enable-shared --disable-static --enable-version3 --enable-zlib --pkg-config=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/pkg-config\\nlibavutil 57. 28.100 / 57. 28.100\\nlibavcodec 59. 37.100 / 59. 37.100\\nlibavformat 59. 27.100 / 59. 27.100\\nlibavdevice 59. 7.100 / 59. 7.100\\nlibavfilter 8. 44.100 / 8. 44.100\\nlibswscale 6. 7.100 / 6. 7.100\\nlibswresample 4. 7.100 / 4. 7.100\\nlibpostproc 56. 6.100 / 56. 6.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "600x400", "-pix_fmt", "rgb24", "-framerate", "50", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "50", "/var/folders/9g/my5557_91xddp6lx00nkzly80000gn/T/tmpcfp_i4ww/ppo-CarRacing-v0/replay.mp4"]}, "episode": {"r": 910.102783203125, "l": 899, "t": 11.729042}}
replay.mp4 CHANGED
Binary files a/replay.mp4 and b/replay.mp4 differ
 
runner/running_utils.py CHANGED
@@ -119,6 +119,8 @@ def set_seeds(seed: Optional[int], use_deterministic_algorithms: bool) -> None:
119
  torch.backends.cudnn.benchmark = False
120
  torch.use_deterministic_algorithms(use_deterministic_algorithms)
121
  os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
 
 
122
 
123
 
124
  def make_policy(
 
119
  torch.backends.cudnn.benchmark = False
120
  torch.use_deterministic_algorithms(use_deterministic_algorithms)
121
  os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
122
+ # Stop warning and it would introduce stochasticity if I was using TF
123
+ os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
124
 
125
 
126
  def make_policy(
saved_models/ppo-CarRacing-v0-S1-best/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e252b9b261d9afa502ed7c5079655ec0e34990e64aaa8ec133e91ce72aa5fb34
3
+ size 2737400
shared/policy/on_policy.py CHANGED
@@ -2,7 +2,7 @@ import gym
2
  import numpy as np
3
  import torch
4
 
5
- from gym.spaces import Box
6
  from pathlib import Path
7
  from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
8
  from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
@@ -47,6 +47,21 @@ def clamp_actions(
47
  return actions
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  class ActorCritic(Policy):
51
  def __init__(
52
  self,
 
2
  import numpy as np
3
  import torch
4
 
5
+ from gym.spaces import Box, Discrete, Space
6
  from pathlib import Path
7
  from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
8
  from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
 
47
  return actions
48
 
49
 
50
+ def default_hidden_sizes(obs_space: Space) -> Sequence[int]:
51
+ if isinstance(obs_space, Box):
52
+ if len(obs_space.shape) == 3:
53
+ # By default feature extractor to output has no hidden layers
54
+ return []
55
+ elif len(obs_space.shape) == 1:
56
+ return [64, 64]
57
+ else:
58
+ raise ValueError(f"Unsupported observation space: {obs_space}")
59
+ elif isinstance(obs_space, Discrete):
60
+ return [64]
61
+ else:
62
+ raise ValueError(f"Unsupported observation space: {obs_space}")
63
+
64
+
65
  class ActorCritic(Policy):
66
  def __init__(
67
  self,
shared/policy/policy.py CHANGED
@@ -51,7 +51,6 @@ class Policy(nn.Module, ABC):
51
  os.path.join(path, MODEL_FILENAME),
52
  )
53
 
54
- @abstractmethod
55
  def load(self, path: str) -> None:
56
  # VecNormalize load occurs in env.py
57
  self.load_state_dict(
 
51
  os.path.join(path, MODEL_FILENAME),
52
  )
53
 
 
54
  def load(self, path: str) -> None:
55
  # VecNormalize load occurs in env.py
56
  self.load_state_dict(
train.py CHANGED
@@ -45,21 +45,20 @@ if __name__ == "__main__":
45
  parser.add_argument(
46
  "--pool-size", type=int, default=1, help="Simultaneous training jobs to run"
47
  )
48
- parser.add_argument(
49
- "--virtual-display",
50
- action="store_true",
51
- help="Whether to create a virtual display for video rendering",
 
52
  )
53
- parser.set_defaults(algo="ppo", env="CartPole-v1", seed=1)
54
  args = parser.parse_args()
55
  print(args)
56
 
57
- if args.virtual_display:
58
- from pyvirtualdisplay import Display
59
 
60
- virtual_display = Display(visible=0, size=(1400, 900))
61
  virtual_display.start()
62
- delattr(args, "virtual_display")
63
 
64
  # pool_size isn't a TrainArg so must be removed from args
65
  pool_size = args.pool_size
 
45
  parser.add_argument(
46
  "--pool-size", type=int, default=1, help="Simultaneous training jobs to run"
47
  )
48
+ parser.set_defaults(
49
+ algo="ppo",
50
+ env="MountainCarContinuous-v0",
51
+ seed=[1, 2, 3],
52
+ pool_size=3,
53
  )
 
54
  args = parser.parse_args()
55
  print(args)
56
 
57
+ if args.pool_size == 1:
58
+ from pyvirtualdisplay.display import Display
59
 
60
+ virtual_display = Display(visible=False, size=(1400, 900))
61
  virtual_display.start()
 
62
 
63
  # pool_size isn't a TrainArg so must be removed from args
64
  pool_size = args.pool_size
vpg/policy.py CHANGED
@@ -15,7 +15,7 @@ from shared.policy.actor import (
15
  actor_head,
16
  )
17
  from shared.policy.critic import CriticHead
18
- from shared.policy.on_policy import Step, clamp_actions
19
  from shared.policy.policy import ACTIVATION, Policy
20
 
21
  PI_FILE_NAME = "pi.pt"
@@ -37,7 +37,7 @@ class VPGActorCritic(Policy):
37
  def __init__(
38
  self,
39
  env: VecEnv,
40
- hidden_sizes: Sequence[int],
41
  init_layers_orthogonal: bool = True,
42
  activation_fn: str = "tanh",
43
  log_std_init: float = -0.5,
@@ -53,6 +53,12 @@ class VPGActorCritic(Policy):
53
  self.use_sde = use_sde
54
  self.squash_output = squash_output
55
 
 
 
 
 
 
 
56
  pi_feature_extractor = FeatureExtractor(
57
  obs_space, activation, init_layers_orthogonal=init_layers_orthogonal
58
  )
 
15
  actor_head,
16
  )
17
  from shared.policy.critic import CriticHead
18
+ from shared.policy.on_policy import Step, clamp_actions, default_hidden_sizes
19
  from shared.policy.policy import ACTIVATION, Policy
20
 
21
  PI_FILE_NAME = "pi.pt"
 
37
  def __init__(
38
  self,
39
  env: VecEnv,
40
+ hidden_sizes: Optional[Sequence[int]] = None,
41
  init_layers_orthogonal: bool = True,
42
  activation_fn: str = "tanh",
43
  log_std_init: float = -0.5,
 
53
  self.use_sde = use_sde
54
  self.squash_output = squash_output
55
 
56
+ hidden_sizes = (
57
+ hidden_sizes
58
+ if hidden_sizes is not None
59
+ else default_hidden_sizes(obs_space)
60
+ )
61
+
62
  pi_feature_extractor = FeatureExtractor(
63
  obs_space, activation, init_layers_orthogonal=init_layers_orthogonal
64
  )