sgoodfriend
commited on
Commit
•
85e4a43
1
Parent(s):
0ca6846
PPO playing CarRacing-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/fbc943f151b95afc4905a67a3835fb6b18c6a5e4
Browse files- README.md +16 -14
- benchmark_publish.py +4 -5
- benchmarks/benchmark_test.sh +32 -0
- benchmarks/colab_pybullet.sh +1 -1
- benchmarks/train_loop.sh +1 -3
- colab_requirements.txt +2 -1
- compare_runs.py +179 -0
- dqn/policy.py +1 -1
- dqn/q_net.py +1 -1
- huggingface_publish.py +5 -0
- hyperparams/dqn.yml +5 -17
- hyperparams/ppo.yml +5 -12
- lambda_labs/benchmark.sh +1 -2
- lambda_labs/lambda_requirements.txt +2 -1
- poetry.lock +83 -28
- ppo/policy.py +11 -16
- pyproject.toml +3 -0
- replay.meta.json +1 -1
- replay.mp4 +0 -0
- runner/running_utils.py +2 -0
- saved_models/ppo-CarRacing-v0-S1-best/model.pth +3 -0
- shared/policy/on_policy.py +16 -1
- shared/policy/policy.py +0 -1
- train.py +8 -9
- vpg/policy.py +8 -2
README.md
CHANGED
@@ -10,7 +10,7 @@ model-index:
|
|
10 |
results:
|
11 |
- metrics:
|
12 |
- type: mean_reward
|
13 |
-
value:
|
14 |
name: mean_reward
|
15 |
task:
|
16 |
type: reinforcement-learning
|
@@ -23,17 +23,17 @@ model-index:
|
|
23 |
|
24 |
This is a trained model of a **PPO** agent playing **CarRacing-v0** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
|
25 |
|
26 |
-
All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/
|
27 |
|
28 |
## Training Results
|
29 |
|
30 |
-
This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [
|
31 |
|
32 |
| algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
|
33 |
|:-------|:-------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
|
34 |
-
| ppo | CarRacing-v0 |
|
35 |
-
| ppo | CarRacing-v0 |
|
36 |
-
| ppo | CarRacing-v0 |
|
37 |
|
38 |
|
39 |
### Prerequisites: Weights & Biases (WandB)
|
@@ -53,10 +53,10 @@ login`.
|
|
53 |
Note: While the model state dictionary and hyperaparameters are saved, the latest
|
54 |
implementation could be sufficiently different to not be able to reproduce similar
|
55 |
results. You might need to checkout the commit the agent was trained on:
|
56 |
-
[
|
57 |
```
|
58 |
# Downloads the model, sets hyperparameters, and runs agent for 3 episodes
|
59 |
-
python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/
|
60 |
```
|
61 |
|
62 |
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
@@ -68,11 +68,11 @@ notebook.
|
|
68 |
|
69 |
## Training
|
70 |
If you want the highest chance to reproduce these results, you'll want to checkout the
|
71 |
-
commit the agent was trained on: [
|
72 |
training is deterministic, different hardware will give different results.
|
73 |
|
74 |
```
|
75 |
-
python train.py --algo ppo --env CarRacing-v0 --seed
|
76 |
```
|
77 |
|
78 |
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
@@ -83,7 +83,7 @@ notebook.
|
|
83 |
|
84 |
|
85 |
## Benchmarking (with Lambda Labs instance)
|
86 |
-
This and other models from https://api.wandb.ai/links/sgoodfriend/
|
87 |
Labs instance. In a Lambda Labs instance terminal:
|
88 |
```
|
89 |
git clone git@github.com:sgoodfriend/rl-algo-impls.git
|
@@ -127,16 +127,18 @@ n_timesteps: 4000000
|
|
127 |
policy_hyperparams:
|
128 |
activation_fn: relu
|
129 |
cnn_feature_dim: 256
|
|
|
|
|
130 |
init_layers_orthogonal: false
|
131 |
log_std_init: -2
|
132 |
share_features_extractor: false
|
133 |
use_sde: true
|
134 |
-
seed:
|
135 |
use_deterministic_algorithms: true
|
136 |
wandb_entity: null
|
137 |
wandb_project_name: rl-algo-impls-benchmarks
|
138 |
wandb_tags:
|
139 |
-
-
|
140 |
-
-
|
141 |
|
142 |
```
|
|
|
10 |
results:
|
11 |
- metrics:
|
12 |
- type: mean_reward
|
13 |
+
value: 865.72 +/- 58.15
|
14 |
name: mean_reward
|
15 |
task:
|
16 |
type: reinforcement-learning
|
|
|
23 |
|
24 |
This is a trained model of a **PPO** agent playing **CarRacing-v0** using the [/sgoodfriend/rl-algo-impls](https://github.com/sgoodfriend/rl-algo-impls) repo.
|
25 |
|
26 |
+
All models trained at this commit can be found at https://api.wandb.ai/links/sgoodfriend/448odm37.
|
27 |
|
28 |
## Training Results
|
29 |
|
30 |
+
This model was trained from 3 trainings of **PPO** agents using different initial seeds. These agents were trained by checking out [fbc943f](https://github.com/sgoodfriend/rl-algo-impls/tree/fbc943f151b95afc4905a67a3835fb6b18c6a5e4). The best and last models were kept from each training. This submission has loaded the best models from each training, reevaluates them, and selects the best model from these latest evaluations (mean - std).
|
31 |
|
32 |
| algo | env | seed | reward_mean | reward_std | eval_episodes | best | wandb_url |
|
33 |
|:-------|:-------------|-------:|--------------:|-------------:|----------------:|:-------|:-----------------------------------------------------------------------------|
|
34 |
+
| ppo | CarRacing-v0 | 1 | 865.725 | 58.1454 | 16 | * | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/8vyb0q44) |
|
35 |
+
| ppo | CarRacing-v0 | 2 | 693.464 | 236.712 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/a3ld38qf) |
|
36 |
+
| ppo | CarRacing-v0 | 3 | 815.26 | 141.502 | 16 | | [wandb](https://wandb.ai/sgoodfriend/rl-algo-impls-benchmarks/runs/zah43or2) |
|
37 |
|
38 |
|
39 |
### Prerequisites: Weights & Biases (WandB)
|
|
|
53 |
Note: While the model state dictionary and hyperaparameters are saved, the latest
|
54 |
implementation could be sufficiently different to not be able to reproduce similar
|
55 |
results. You might need to checkout the commit the agent was trained on:
|
56 |
+
[fbc943f](https://github.com/sgoodfriend/rl-algo-impls/tree/fbc943f151b95afc4905a67a3835fb6b18c6a5e4).
|
57 |
```
|
58 |
# Downloads the model, sets hyperparameters, and runs agent for 3 episodes
|
59 |
+
python enjoy.py --wandb-run-path=sgoodfriend/rl-algo-impls-benchmarks/8vyb0q44
|
60 |
```
|
61 |
|
62 |
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
|
|
68 |
|
69 |
## Training
|
70 |
If you want the highest chance to reproduce these results, you'll want to checkout the
|
71 |
+
commit the agent was trained on: [fbc943f](https://github.com/sgoodfriend/rl-algo-impls/tree/fbc943f151b95afc4905a67a3835fb6b18c6a5e4). While
|
72 |
training is deterministic, different hardware will give different results.
|
73 |
|
74 |
```
|
75 |
+
python train.py --algo ppo --env CarRacing-v0 --seed 1
|
76 |
```
|
77 |
|
78 |
Setup hasn't been completely worked out yet, so you might be best served by using Google
|
|
|
83 |
|
84 |
|
85 |
## Benchmarking (with Lambda Labs instance)
|
86 |
+
This and other models from https://api.wandb.ai/links/sgoodfriend/448odm37 were generated by running a script on a Lambda
|
87 |
Labs instance. In a Lambda Labs instance terminal:
|
88 |
```
|
89 |
git clone git@github.com:sgoodfriend/rl-algo-impls.git
|
|
|
127 |
policy_hyperparams:
|
128 |
activation_fn: relu
|
129 |
cnn_feature_dim: 256
|
130 |
+
hidden_sizes:
|
131 |
+
- 256
|
132 |
init_layers_orthogonal: false
|
133 |
log_std_init: -2
|
134 |
share_features_extractor: false
|
135 |
use_sde: true
|
136 |
+
seed: 1
|
137 |
use_deterministic_algorithms: true
|
138 |
wandb_entity: null
|
139 |
wandb_project_name: rl-algo-impls-benchmarks
|
140 |
wandb_tags:
|
141 |
+
- benchmark_fbc943f
|
142 |
+
- host_150-230-44-105
|
143 |
|
144 |
```
|
benchmark_publish.py
CHANGED
@@ -44,11 +44,10 @@ if __name__ == "__main__":
|
|
44 |
default=3,
|
45 |
help="How many publish jobs can run in parallel",
|
46 |
)
|
47 |
-
parser.set_defaults(
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
)
|
52 |
args = parser.parse_args()
|
53 |
print(args)
|
54 |
|
|
|
44 |
default=3,
|
45 |
help="How many publish jobs can run in parallel",
|
46 |
)
|
47 |
+
# parser.set_defaults(
|
48 |
+
# wandb_tags=["benchmark_5598ebc", "host_192-9-145-26"],
|
49 |
+
# wandb_report_url="https://api.wandb.ai/links/sgoodfriend/6p2sjqtn",
|
50 |
+
# )
|
|
|
51 |
args = parser.parse_args()
|
52 |
print(args)
|
53 |
|
benchmarks/benchmark_test.sh
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
source benchmarks/train_loop.sh
|
2 |
+
|
3 |
+
export WANDB_PROJECT_NAME="rl-algo-impls"
|
4 |
+
|
5 |
+
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
|
6 |
+
|
7 |
+
ALGOS=(
|
8 |
+
# "vpg"
|
9 |
+
"dqn"
|
10 |
+
# "ppo"
|
11 |
+
)
|
12 |
+
ENVS=(
|
13 |
+
# Basic
|
14 |
+
"CartPole-v1"
|
15 |
+
"MountainCar-v0"
|
16 |
+
# "MountainCarContinuous-v0"
|
17 |
+
"Acrobot-v1"
|
18 |
+
"LunarLander-v2"
|
19 |
+
# # PyBullet
|
20 |
+
# "HalfCheetahBulletEnv-v0"
|
21 |
+
# "AntBulletEnv-v0"
|
22 |
+
# "HopperBulletEnv-v0"
|
23 |
+
# "Walker2DBulletEnv-v0"
|
24 |
+
# # CarRacing
|
25 |
+
# "CarRacing-v0"
|
26 |
+
# Atari
|
27 |
+
"PongNoFrameskip-v4"
|
28 |
+
"BreakoutNoFrameskip-v4"
|
29 |
+
"SpaceInvadersNoFrameskip-v4"
|
30 |
+
"QbertNoFrameskip-v4"
|
31 |
+
)
|
32 |
+
train_loop "${ALGOS[*]}" "${ENVS[*]}" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
benchmarks/colab_pybullet.sh
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
source benchmarks/train_loop.sh
|
2 |
ALGOS="ppo"
|
3 |
-
ENVS="HalfCheetahBulletEnv-v0 AntBulletEnv-v0
|
4 |
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
|
5 |
train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
|
|
1 |
source benchmarks/train_loop.sh
|
2 |
ALGOS="ppo"
|
3 |
+
ENVS="HalfCheetahBulletEnv-v0 AntBulletEnv-v0 HopperBulletEnv-v0 Walker2DBulletEnv-v0"
|
4 |
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-3}"
|
5 |
train_loop $ALGOS "$ENVS" | xargs -I CMD -P $BENCHMARK_MAX_PROCS bash -c CMD
|
benchmarks/train_loop.sh
CHANGED
@@ -4,13 +4,11 @@ train_loop () {
|
|
4 |
local env
|
5 |
local seed
|
6 |
local WANDB_PROJECT_NAME="${WANDB_PROJECT_NAME:-rl-algo-impls-benchmarks}"
|
7 |
-
local args=()
|
8 |
-
(( VIRTUAL_DISPLAY == 1)) && args+=("--virtual-display")
|
9 |
local SEEDS="${SEEDS:-1 2 3}"
|
10 |
for algo in $(echo $1); do
|
11 |
for env in $(echo $2); do
|
12 |
for seed in $SEEDS; do
|
13 |
-
echo python train.py --algo $algo --env $env --seed $seed --pool-size 1 --wandb-tags $WANDB_TAGS --wandb-project-name $WANDB_PROJECT_NAME
|
14 |
done
|
15 |
done
|
16 |
done
|
|
|
4 |
local env
|
5 |
local seed
|
6 |
local WANDB_PROJECT_NAME="${WANDB_PROJECT_NAME:-rl-algo-impls-benchmarks}"
|
|
|
|
|
7 |
local SEEDS="${SEEDS:-1 2 3}"
|
8 |
for algo in $(echo $1); do
|
9 |
for env in $(echo $2); do
|
10 |
for seed in $SEEDS; do
|
11 |
+
echo python train.py --algo $algo --env $env --seed $seed --pool-size 1 --wandb-tags $WANDB_TAGS --wandb-project-name $WANDB_PROJECT_NAME
|
12 |
done
|
13 |
done
|
14 |
done
|
colab_requirements.txt
CHANGED
@@ -6,4 +6,5 @@ wandb >= 0.13.9, < 0.14
|
|
6 |
pyvirtualdisplay == 3.0
|
7 |
pybullet >= 3.2.5, < 3.3
|
8 |
tabulate >= 0.9.0, < 0.10
|
9 |
-
huggingface-hub >= 0.12.0, < 0.13
|
|
|
|
6 |
pyvirtualdisplay == 3.0
|
7 |
pybullet >= 3.2.5, < 3.3
|
8 |
tabulate >= 0.9.0, < 0.10
|
9 |
+
huggingface-hub >= 0.12.0, < 0.13
|
10 |
+
numexpr >= 2.8.4, < 2.9
|
compare_runs.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import itertools
|
3 |
+
import numpy as np
|
4 |
+
import pandas as pd
|
5 |
+
import wandb
|
6 |
+
import wandb.apis.public
|
7 |
+
|
8 |
+
from collections import defaultdict
|
9 |
+
from dataclasses import dataclass
|
10 |
+
from typing import Dict, Iterable, List, TypeVar
|
11 |
+
|
12 |
+
from benchmark_publish import RunGroup
|
13 |
+
|
14 |
+
|
15 |
+
@dataclass
|
16 |
+
class Comparison:
|
17 |
+
control_values: List[float]
|
18 |
+
experiment_values: List[float]
|
19 |
+
|
20 |
+
def mean_diff_percentage(self) -> float:
|
21 |
+
return self._diff_percentage(
|
22 |
+
np.mean(self.control_values).item(), np.mean(self.experiment_values).item()
|
23 |
+
)
|
24 |
+
|
25 |
+
def median_diff_percentage(self) -> float:
|
26 |
+
return self._diff_percentage(
|
27 |
+
np.median(self.control_values).item(),
|
28 |
+
np.median(self.experiment_values).item(),
|
29 |
+
)
|
30 |
+
|
31 |
+
def _diff_percentage(self, c: float, e: float) -> float:
|
32 |
+
if c == e:
|
33 |
+
return 0
|
34 |
+
elif c == 0:
|
35 |
+
return float("inf") if e > 0 else float("-inf")
|
36 |
+
return 100 * (e - c) / c
|
37 |
+
|
38 |
+
def score(self) -> float:
|
39 |
+
return (
|
40 |
+
np.sum(
|
41 |
+
np.sign((self.mean_diff_percentage(), self.median_diff_percentage()))
|
42 |
+
).item()
|
43 |
+
/ 2
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
RunGroupRunsSelf = TypeVar("RunGroupRunsSelf", bound="RunGroupRuns")
|
48 |
+
|
49 |
+
|
50 |
+
class RunGroupRuns:
|
51 |
+
def __init__(
|
52 |
+
self,
|
53 |
+
run_group: RunGroup,
|
54 |
+
control: List[str],
|
55 |
+
experiment: List[str],
|
56 |
+
summary_stats: List[str] = ["best_eval", "eval", "train_rolling"],
|
57 |
+
summary_metrics: List[str] = ["mean", "result"],
|
58 |
+
) -> None:
|
59 |
+
self.algo = run_group.algo
|
60 |
+
self.env = run_group.env_id
|
61 |
+
self.control = set(control)
|
62 |
+
self.experiment = set(experiment)
|
63 |
+
|
64 |
+
self.summary_stats = summary_stats
|
65 |
+
self.summary_metrics = summary_metrics
|
66 |
+
|
67 |
+
self.control_runs = []
|
68 |
+
self.experiment_runs = []
|
69 |
+
|
70 |
+
def add_run(self, run: wandb.apis.public.Run) -> None:
|
71 |
+
wandb_tags = set(run.config.get("wandb_tags", []))
|
72 |
+
if self.control & wandb_tags:
|
73 |
+
self.control_runs.append(run)
|
74 |
+
elif self.experiment & wandb_tags:
|
75 |
+
self.experiment_runs.append(run)
|
76 |
+
|
77 |
+
def comparisons_by_metric(self) -> Dict[str, Comparison]:
|
78 |
+
c_by_m = {}
|
79 |
+
for metric in (
|
80 |
+
f"{s}_{m}"
|
81 |
+
for s, m in itertools.product(self.summary_stats, self.summary_metrics)
|
82 |
+
):
|
83 |
+
c_by_m[metric] = Comparison(
|
84 |
+
[c.summary[metric] for c in self.control_runs],
|
85 |
+
[e.summary[metric] for e in self.experiment_runs],
|
86 |
+
)
|
87 |
+
return c_by_m
|
88 |
+
|
89 |
+
@staticmethod
|
90 |
+
def data_frame(rows: Iterable[RunGroupRunsSelf]) -> pd.DataFrame:
|
91 |
+
results = defaultdict(list)
|
92 |
+
for r in rows:
|
93 |
+
results["algo"].append(r.algo)
|
94 |
+
results["env"].append(r.env)
|
95 |
+
results["control"].append(r.control)
|
96 |
+
results["expierment"].append(r.experiment)
|
97 |
+
c_by_m = r.comparisons_by_metric()
|
98 |
+
results["score"].append(
|
99 |
+
sum(m.score() for m in c_by_m.values()) / len(c_by_m)
|
100 |
+
)
|
101 |
+
for m, c in c_by_m.items():
|
102 |
+
results[f"{m}_mean"].append(c.mean_diff_percentage())
|
103 |
+
results[f"{m}_median"].append(c.median_diff_percentage())
|
104 |
+
return pd.DataFrame(results)
|
105 |
+
|
106 |
+
|
107 |
+
if __name__ == "__main__":
|
108 |
+
parser = argparse.ArgumentParser()
|
109 |
+
parser.add_argument(
|
110 |
+
"-p",
|
111 |
+
"--wandb-project-name",
|
112 |
+
type=str,
|
113 |
+
default="rl-algo-impls-benchmarks",
|
114 |
+
help="WandB project name to load runs from",
|
115 |
+
)
|
116 |
+
parser.add_argument(
|
117 |
+
"--wandb-entity",
|
118 |
+
type=str,
|
119 |
+
default=None,
|
120 |
+
help="WandB team. None uses default entity",
|
121 |
+
)
|
122 |
+
parser.add_argument(
|
123 |
+
"-n",
|
124 |
+
"--wandb-hostname-tag",
|
125 |
+
type=str,
|
126 |
+
help="WandB tag for hostname (i.e. host_192-9-145-26)",
|
127 |
+
)
|
128 |
+
parser.add_argument(
|
129 |
+
"-c",
|
130 |
+
"--wandb-control-tag",
|
131 |
+
type=str,
|
132 |
+
nargs="+",
|
133 |
+
help="WandB tag for control commit (i.e. benchmark_5598ebc)",
|
134 |
+
)
|
135 |
+
parser.add_argument(
|
136 |
+
"-e",
|
137 |
+
"--wandb-experiment-tag",
|
138 |
+
type=str,
|
139 |
+
nargs="+",
|
140 |
+
help="WandB tag for experiment commit (i.e. benchmark_5540e1f)",
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--exclude_envs",
|
144 |
+
type=str,
|
145 |
+
nargs="*",
|
146 |
+
help="Environments to exclude from comparison",
|
147 |
+
)
|
148 |
+
parser.set_defaults(
|
149 |
+
wandb_hostname_tag="host_192-9-145-26",
|
150 |
+
wandb_control_tag=["benchmark_e4d1ed6", "benchmark_5598ebc"],
|
151 |
+
wandb_experiment_tag=["benchmark_680043d", "benchmark_5540e1f"],
|
152 |
+
exclude_envs=["CarRacing-v0"]
|
153 |
+
)
|
154 |
+
args = parser.parse_args()
|
155 |
+
print(args)
|
156 |
+
|
157 |
+
api = wandb.Api()
|
158 |
+
all_runs = api.runs(
|
159 |
+
path=f"{args.wandb_entity or api.default_entity}/{args.wandb_project_name}",
|
160 |
+
order="+created_at",
|
161 |
+
)
|
162 |
+
|
163 |
+
runs_by_run_group: Dict[RunGroup, RunGroupRuns] = {}
|
164 |
+
for r in all_runs:
|
165 |
+
wandb_tags = r.config.get("wandb_tags", [])
|
166 |
+
if not wandb_tags or not args.wandb_hostname_tag in wandb_tags:
|
167 |
+
continue
|
168 |
+
rg = RunGroup(r.config["algo"], r.config["env"])
|
169 |
+
if args.exclude_envs and rg.env_id in args.exclude_envs:
|
170 |
+
continue
|
171 |
+
if rg not in runs_by_run_group:
|
172 |
+
runs_by_run_group[rg] = RunGroupRuns(
|
173 |
+
rg, args.wandb_control_tag, args.wandb_experiment_tag
|
174 |
+
)
|
175 |
+
runs_by_run_group[rg].add_run(r)
|
176 |
+
df = RunGroupRuns.data_frame(runs_by_run_group.values()).round(decimals=2)
|
177 |
+
print(f"**Total Score: {sum(df.score)}**")
|
178 |
+
df.loc["mean"] = df.mean(numeric_only=True)
|
179 |
+
print(df.to_markdown())
|
dqn/policy.py
CHANGED
@@ -15,7 +15,7 @@ class DQNPolicy(Policy):
|
|
15 |
def __init__(
|
16 |
self,
|
17 |
env: VecEnv,
|
18 |
-
hidden_sizes: Sequence[int],
|
19 |
**kwargs,
|
20 |
) -> None:
|
21 |
super().__init__(env, **kwargs)
|
|
|
15 |
def __init__(
|
16 |
self,
|
17 |
env: VecEnv,
|
18 |
+
hidden_sizes: Sequence[int] = [],
|
19 |
**kwargs,
|
20 |
) -> None:
|
21 |
super().__init__(env, **kwargs)
|
dqn/q_net.py
CHANGED
@@ -13,7 +13,7 @@ class QNetwork(nn.Module):
|
|
13 |
self,
|
14 |
observation_space: gym.Space,
|
15 |
action_space: gym.Space,
|
16 |
-
hidden_sizes: Sequence[int],
|
17 |
activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
|
18 |
) -> None:
|
19 |
super().__init__()
|
|
|
13 |
self,
|
14 |
observation_space: gym.Space,
|
15 |
action_space: gym.Space,
|
16 |
+
hidden_sizes: Sequence[int] = [],
|
17 |
activation: Type[nn.Module] = nn.ReLU, # Used by stable-baselines3
|
18 |
) -> None:
|
19 |
super().__init__()
|
huggingface_publish.py
CHANGED
@@ -14,6 +14,8 @@ from typing import List, Optional
|
|
14 |
|
15 |
from huggingface_hub.hf_api import HfApi, upload_folder
|
16 |
from huggingface_hub.repocard import metadata_save
|
|
|
|
|
17 |
from publish.markdown_format import EvalTableData, model_card_text
|
18 |
from runner.evaluate import EvalArgs, evaluate_model
|
19 |
from runner.env import make_eval_env
|
@@ -27,6 +29,9 @@ def publish(
|
|
27 |
huggingface_user: Optional[str] = None,
|
28 |
huggingface_token: Optional[str] = None,
|
29 |
) -> None:
|
|
|
|
|
|
|
30 |
api = wandb.Api()
|
31 |
runs = [api.run(rp) for rp in wandb_run_paths]
|
32 |
algo = runs[0].config["algo"]
|
|
|
14 |
|
15 |
from huggingface_hub.hf_api import HfApi, upload_folder
|
16 |
from huggingface_hub.repocard import metadata_save
|
17 |
+
from pyvirtualdisplay.display import Display
|
18 |
+
|
19 |
from publish.markdown_format import EvalTableData, model_card_text
|
20 |
from runner.evaluate import EvalArgs, evaluate_model
|
21 |
from runner.env import make_eval_env
|
|
|
29 |
huggingface_user: Optional[str] = None,
|
30 |
huggingface_token: Optional[str] = None,
|
31 |
) -> None:
|
32 |
+
virtual_display = Display(visible=False, size=(1400, 900))
|
33 |
+
virtual_display.start()
|
34 |
+
|
35 |
api = wandb.Api()
|
36 |
runs = [api.run(rp) for rp in wandb_run_paths]
|
37 |
algo = runs[0].config["algo"]
|
hyperparams/dqn.yml
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
CartPole-v1: &cartpole-defaults
|
2 |
n_timesteps: !!float 5e4
|
3 |
env_hyperparams:
|
4 |
-
n_envs: 1
|
5 |
rolling_length: 50
|
6 |
policy_hyperparams:
|
7 |
hidden_sizes: [256, 256]
|
@@ -18,8 +17,6 @@ CartPole-v1: &cartpole-defaults
|
|
18 |
exploration_final_eps: 0.04
|
19 |
eval_params:
|
20 |
step_freq: !!float 1e4
|
21 |
-
n_episodes: 10
|
22 |
-
save_best: true
|
23 |
|
24 |
CartPole-v0:
|
25 |
<<: *cartpole-defaults
|
@@ -46,7 +43,7 @@ MountainCar-v0:
|
|
46 |
Acrobot-v1:
|
47 |
n_timesteps: !!float 1e5
|
48 |
env_hyperparams:
|
49 |
-
rolling_length:
|
50 |
policy_hyperparams:
|
51 |
hidden_sizes: [256, 256]
|
52 |
algo_hyperparams:
|
@@ -64,7 +61,7 @@ Acrobot-v1:
|
|
64 |
LunarLander-v2:
|
65 |
n_timesteps: !!float 5e5
|
66 |
env_hyperparams:
|
67 |
-
rolling_length:
|
68 |
policy_hyperparams:
|
69 |
hidden_sizes: [256, 256]
|
70 |
algo_hyperparams:
|
@@ -81,19 +78,15 @@ LunarLander-v2:
|
|
81 |
max_grad_norm: 0.5
|
82 |
eval_params:
|
83 |
step_freq: 25_000
|
84 |
-
n_episodes: 10
|
85 |
-
save_best: true
|
86 |
|
87 |
-
|
88 |
n_timesteps: !!float 1e7
|
89 |
env_hyperparams:
|
90 |
frame_stack: 4
|
91 |
no_reward_timeout_steps: 1_000
|
|
|
92 |
n_envs: 8
|
93 |
vec_env_class: "subproc"
|
94 |
-
rolling_length: 20
|
95 |
-
policy_hyperparams:
|
96 |
-
hidden_sizes: [512]
|
97 |
algo_hyperparams:
|
98 |
buffer_size: 100000
|
99 |
learning_rate: !!float 1e-4
|
@@ -105,12 +98,7 @@ SpaceInvadersNoFrameskip-v4: &atari-defaults
|
|
105 |
exploration_fraction: 0.1
|
106 |
exploration_final_eps: 0.01
|
107 |
eval_params:
|
108 |
-
|
109 |
-
n_episodes: 10
|
110 |
-
save_best: true
|
111 |
-
|
112 |
-
BreakoutNoFrameskip-v4:
|
113 |
-
<<: *atari-defaults
|
114 |
|
115 |
PongNoFrameskip-v4:
|
116 |
<<: *atari-defaults
|
|
|
1 |
CartPole-v1: &cartpole-defaults
|
2 |
n_timesteps: !!float 5e4
|
3 |
env_hyperparams:
|
|
|
4 |
rolling_length: 50
|
5 |
policy_hyperparams:
|
6 |
hidden_sizes: [256, 256]
|
|
|
17 |
exploration_final_eps: 0.04
|
18 |
eval_params:
|
19 |
step_freq: !!float 1e4
|
|
|
|
|
20 |
|
21 |
CartPole-v0:
|
22 |
<<: *cartpole-defaults
|
|
|
43 |
Acrobot-v1:
|
44 |
n_timesteps: !!float 1e5
|
45 |
env_hyperparams:
|
46 |
+
rolling_length: 50
|
47 |
policy_hyperparams:
|
48 |
hidden_sizes: [256, 256]
|
49 |
algo_hyperparams:
|
|
|
61 |
LunarLander-v2:
|
62 |
n_timesteps: !!float 5e5
|
63 |
env_hyperparams:
|
64 |
+
rolling_length: 50
|
65 |
policy_hyperparams:
|
66 |
hidden_sizes: [256, 256]
|
67 |
algo_hyperparams:
|
|
|
78 |
max_grad_norm: 0.5
|
79 |
eval_params:
|
80 |
step_freq: 25_000
|
|
|
|
|
81 |
|
82 |
+
atari: &atari-defaults
|
83 |
n_timesteps: !!float 1e7
|
84 |
env_hyperparams:
|
85 |
frame_stack: 4
|
86 |
no_reward_timeout_steps: 1_000
|
87 |
+
no_reward_fire_steps: 500
|
88 |
n_envs: 8
|
89 |
vec_env_class: "subproc"
|
|
|
|
|
|
|
90 |
algo_hyperparams:
|
91 |
buffer_size: 100000
|
92 |
learning_rate: !!float 1e-4
|
|
|
98 |
exploration_fraction: 0.1
|
99 |
exploration_final_eps: 0.01
|
100 |
eval_params:
|
101 |
+
deterministic: false
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
PongNoFrameskip-v4:
|
104 |
<<: *atari-defaults
|
hyperparams/ppo.yml
CHANGED
@@ -15,8 +15,6 @@ CartPole-v1: &cartpole-defaults
|
|
15 |
clip_range_decay: linear
|
16 |
eval_params:
|
17 |
step_freq: !!float 2.5e4
|
18 |
-
n_episodes: 10
|
19 |
-
save_best: true
|
20 |
|
21 |
CartPole-v0:
|
22 |
<<: *cartpole-defaults
|
@@ -39,9 +37,10 @@ MountainCarContinuous-v0:
|
|
39 |
env_hyperparams:
|
40 |
normalize: true
|
41 |
n_envs: 4
|
42 |
-
policy_hyperparams:
|
43 |
-
|
44 |
-
|
|
|
45 |
algo_hyperparams:
|
46 |
n_steps: 512
|
47 |
batch_size: 256
|
@@ -53,11 +52,8 @@ MountainCarContinuous-v0:
|
|
53 |
gae_lambda: 0.9
|
54 |
max_grad_norm: 5
|
55 |
vf_coef: 0.19
|
56 |
-
# use_sde: true
|
57 |
eval_params:
|
58 |
step_freq: 5000
|
59 |
-
n_episodes: 10
|
60 |
-
save_best: true
|
61 |
|
62 |
Acrobot-v1:
|
63 |
n_timesteps: !!float 1e6
|
@@ -84,10 +80,6 @@ LunarLander-v2:
|
|
84 |
ent_coef: 0.01
|
85 |
ent_coef_decay: linear
|
86 |
normalize_advantage: false
|
87 |
-
eval_params:
|
88 |
-
step_freq: !!float 5e4
|
89 |
-
n_episodes: 10
|
90 |
-
save_best: true
|
91 |
|
92 |
CarRacing-v0:
|
93 |
n_timesteps: !!float 4e6
|
@@ -101,6 +93,7 @@ CarRacing-v0:
|
|
101 |
activation_fn: relu
|
102 |
share_features_extractor: false
|
103 |
cnn_feature_dim: 256
|
|
|
104 |
algo_hyperparams:
|
105 |
n_steps: 512
|
106 |
batch_size: 128
|
|
|
15 |
clip_range_decay: linear
|
16 |
eval_params:
|
17 |
step_freq: !!float 2.5e4
|
|
|
|
|
18 |
|
19 |
CartPole-v0:
|
20 |
<<: *cartpole-defaults
|
|
|
37 |
env_hyperparams:
|
38 |
normalize: true
|
39 |
n_envs: 4
|
40 |
+
# policy_hyperparams:
|
41 |
+
# init_layers_orthogonal: false
|
42 |
+
# log_std_init: -3.29
|
43 |
+
# use_sde: true
|
44 |
algo_hyperparams:
|
45 |
n_steps: 512
|
46 |
batch_size: 256
|
|
|
52 |
gae_lambda: 0.9
|
53 |
max_grad_norm: 5
|
54 |
vf_coef: 0.19
|
|
|
55 |
eval_params:
|
56 |
step_freq: 5000
|
|
|
|
|
57 |
|
58 |
Acrobot-v1:
|
59 |
n_timesteps: !!float 1e6
|
|
|
80 |
ent_coef: 0.01
|
81 |
ent_coef_decay: linear
|
82 |
normalize_advantage: false
|
|
|
|
|
|
|
|
|
83 |
|
84 |
CarRacing-v0:
|
85 |
n_timesteps: !!float 4e6
|
|
|
93 |
activation_fn: relu
|
94 |
share_features_extractor: false
|
95 |
cnn_feature_dim: 256
|
96 |
+
hidden_sizes: [256]
|
97 |
algo_hyperparams:
|
98 |
n_steps: 512
|
99 |
batch_size: 128
|
lambda_labs/benchmark.sh
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
source benchmarks/train_loop.sh
|
2 |
|
3 |
# export WANDB_PROJECT_NAME="rl-algo-impls"
|
4 |
-
export VIRTUAL_DISPLAY=1
|
5 |
|
6 |
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-6}"
|
7 |
|
@@ -20,8 +19,8 @@ ENVS=(
|
|
20 |
# PyBullet
|
21 |
"HalfCheetahBulletEnv-v0"
|
22 |
"AntBulletEnv-v0"
|
23 |
-
"Walker2DBulletEnv-v0"
|
24 |
"HopperBulletEnv-v0"
|
|
|
25 |
# CarRacing
|
26 |
"CarRacing-v0"
|
27 |
# Atari
|
|
|
1 |
source benchmarks/train_loop.sh
|
2 |
|
3 |
# export WANDB_PROJECT_NAME="rl-algo-impls"
|
|
|
4 |
|
5 |
BENCHMARK_MAX_PROCS="${BENCHMARK_MAX_PROCS:-6}"
|
6 |
|
|
|
19 |
# PyBullet
|
20 |
"HalfCheetahBulletEnv-v0"
|
21 |
"AntBulletEnv-v0"
|
|
|
22 |
"HopperBulletEnv-v0"
|
23 |
+
"Walker2DBulletEnv-v0"
|
24 |
# CarRacing
|
25 |
"CarRacing-v0"
|
26 |
# Atari
|
lambda_labs/lambda_requirements.txt
CHANGED
@@ -8,4 +8,5 @@ wandb >= 0.13.9, < 0.14
|
|
8 |
pyvirtualdisplay == 3.0
|
9 |
pybullet >= 3.2.5, < 3.3
|
10 |
tabulate >= 0.9.0, < 0.10
|
11 |
-
huggingface-hub >= 0.12.0, < 0.13
|
|
|
|
8 |
pyvirtualdisplay == 3.0
|
9 |
pybullet >= 3.2.5, < 3.3
|
10 |
tabulate >= 0.9.0, < 0.10
|
11 |
+
huggingface-hub >= 0.12.0, < 0.13
|
12 |
+
numexpr >= 2.8.4, < 2.9
|
poetry.lock
CHANGED
@@ -787,47 +787,47 @@ files = [
|
|
787 |
|
788 |
[[package]]
|
789 |
name = "cryptography"
|
790 |
-
version = "39.0.
|
791 |
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
|
792 |
category = "main"
|
793 |
optional = false
|
794 |
python-versions = ">=3.6"
|
795 |
files = [
|
796 |
-
{file = "cryptography-39.0.
|
797 |
-
{file = "cryptography-39.0.
|
798 |
-
{file = "cryptography-39.0.
|
799 |
-
{file = "cryptography-39.0.
|
800 |
-
{file = "cryptography-39.0.
|
801 |
-
{file = "cryptography-39.0.
|
802 |
-
{file = "cryptography-39.0.
|
803 |
-
{file = "cryptography-39.0.
|
804 |
-
{file = "cryptography-39.0.
|
805 |
-
{file = "cryptography-39.0.
|
806 |
-
{file = "cryptography-39.0.
|
807 |
-
{file = "cryptography-39.0.
|
808 |
-
{file = "cryptography-39.0.
|
809 |
-
{file = "cryptography-39.0.
|
810 |
-
{file = "cryptography-39.0.
|
811 |
-
{file = "cryptography-39.0.
|
812 |
-
{file = "cryptography-39.0.
|
813 |
-
{file = "cryptography-39.0.
|
814 |
-
{file = "cryptography-39.0.
|
815 |
-
{file = "cryptography-39.0.
|
816 |
-
{file = "cryptography-39.0.
|
817 |
-
{file = "cryptography-39.0.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:e0a05aee6a82d944f9b4edd6a001178787d1546ec7c6223ee9a848a7ade92e39"},
|
818 |
-
{file = "cryptography-39.0.0.tar.gz", hash = "sha256:f964c7dcf7802d133e8dbd1565914fa0194f9d683d82411989889ecd701e8adf"},
|
819 |
]
|
820 |
|
821 |
[package.dependencies]
|
822 |
cffi = ">=1.12"
|
823 |
|
824 |
[package.extras]
|
825 |
-
docs = ["sphinx (>=
|
826 |
docstest = ["pyenchant (>=1.6.11)", "sphinxcontrib-spelling (>=4.0.1)", "twine (>=1.12.0)"]
|
827 |
-
pep8test = ["black", "ruff"]
|
828 |
sdist = ["setuptools-rust (>=0.11.4)"]
|
829 |
ssh = ["bcrypt (>=3.1.5)"]
|
830 |
-
test = ["hypothesis (>=1.11.4,!=3.79.2)", "iso8601", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-subtests", "pytest-xdist", "pytz"]
|
|
|
|
|
831 |
|
832 |
[[package]]
|
833 |
name = "cycler"
|
@@ -2250,6 +2250,49 @@ jupyter-server = ">=1.8,<3"
|
|
2250 |
[package.extras]
|
2251 |
test = ["pytest", "pytest-console-scripts", "pytest-tornasync"]
|
2252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2253 |
[[package]]
|
2254 |
name = "numpy"
|
2255 |
version = "1.24.1"
|
@@ -3002,6 +3045,18 @@ files = [
|
|
3002 |
{file = "pytz-2022.7.tar.gz", hash = "sha256:7ccfae7b4b2c067464a6733c6261673fdb8fd1be905460396b97a073e9fa683a"},
|
3003 |
]
|
3004 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3005 |
[[package]]
|
3006 |
name = "pywin32"
|
3007 |
version = "305"
|
@@ -4198,4 +4253,4 @@ testing = ["flake8 (<5)", "func-timeout", "jaraco.functools", "jaraco.itertools"
|
|
4198 |
[metadata]
|
4199 |
lock-version = "2.0"
|
4200 |
python-versions = "~3.10"
|
4201 |
-
content-hash = "
|
|
|
787 |
|
788 |
[[package]]
|
789 |
name = "cryptography"
|
790 |
+
version = "39.0.1"
|
791 |
description = "cryptography is a package which provides cryptographic recipes and primitives to Python developers."
|
792 |
category = "main"
|
793 |
optional = false
|
794 |
python-versions = ">=3.6"
|
795 |
files = [
|
796 |
+
{file = "cryptography-39.0.1-cp36-abi3-macosx_10_12_universal2.whl", hash = "sha256:6687ef6d0a6497e2b58e7c5b852b53f62142cfa7cd1555795758934da363a965"},
|
797 |
+
{file = "cryptography-39.0.1-cp36-abi3-macosx_10_12_x86_64.whl", hash = "sha256:706843b48f9a3f9b9911979761c91541e3d90db1ca905fd63fee540a217698bc"},
|
798 |
+
{file = "cryptography-39.0.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:5d2d8b87a490bfcd407ed9d49093793d0f75198a35e6eb1a923ce1ee86c62b41"},
|
799 |
+
{file = "cryptography-39.0.1-cp36-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:83e17b26de248c33f3acffb922748151d71827d6021d98c70e6c1a25ddd78505"},
|
800 |
+
{file = "cryptography-39.0.1-cp36-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e124352fd3db36a9d4a21c1aa27fd5d051e621845cb87fb851c08f4f75ce8be6"},
|
801 |
+
{file = "cryptography-39.0.1-cp36-abi3-manylinux_2_24_x86_64.whl", hash = "sha256:5aa67414fcdfa22cf052e640cb5ddc461924a045cacf325cd164e65312d99502"},
|
802 |
+
{file = "cryptography-39.0.1-cp36-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:35f7c7d015d474f4011e859e93e789c87d21f6f4880ebdc29896a60403328f1f"},
|
803 |
+
{file = "cryptography-39.0.1-cp36-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f24077a3b5298a5a06a8e0536e3ea9ec60e4c7ac486755e5fb6e6ea9b3500106"},
|
804 |
+
{file = "cryptography-39.0.1-cp36-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:f0c64d1bd842ca2633e74a1a28033d139368ad959872533b1bab8c80e8240a0c"},
|
805 |
+
{file = "cryptography-39.0.1-cp36-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:0f8da300b5c8af9f98111ffd512910bc792b4c77392a9523624680f7956a99d4"},
|
806 |
+
{file = "cryptography-39.0.1-cp36-abi3-win32.whl", hash = "sha256:fe913f20024eb2cb2f323e42a64bdf2911bb9738a15dba7d3cce48151034e3a8"},
|
807 |
+
{file = "cryptography-39.0.1-cp36-abi3-win_amd64.whl", hash = "sha256:ced4e447ae29ca194449a3f1ce132ded8fcab06971ef5f618605aacaa612beac"},
|
808 |
+
{file = "cryptography-39.0.1-pp38-pypy38_pp73-macosx_10_12_x86_64.whl", hash = "sha256:807ce09d4434881ca3a7594733669bd834f5b2c6d5c7e36f8c00f691887042ad"},
|
809 |
+
{file = "cryptography-39.0.1-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:96f1157a7c08b5b189b16b47bc9db2332269d6680a196341bf30046330d15388"},
|
810 |
+
{file = "cryptography-39.0.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:e422abdec8b5fa8462aa016786680720d78bdce7a30c652b7fadf83a4ba35336"},
|
811 |
+
{file = "cryptography-39.0.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.manylinux_2_24_aarch64.whl", hash = "sha256:b0afd054cd42f3d213bf82c629efb1ee5f22eba35bf0eec88ea9ea7304f511a2"},
|
812 |
+
{file = "cryptography-39.0.1-pp39-pypy39_pp73-manylinux_2_24_x86_64.whl", hash = "sha256:6f8ba7f0328b79f08bdacc3e4e66fb4d7aab0c3584e0bd41328dce5262e26b2e"},
|
813 |
+
{file = "cryptography-39.0.1-pp39-pypy39_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:ef8b72fa70b348724ff1218267e7f7375b8de4e8194d1636ee60510aae104cd0"},
|
814 |
+
{file = "cryptography-39.0.1-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:aec5a6c9864be7df2240c382740fcf3b96928c46604eaa7f3091f58b878c0bb6"},
|
815 |
+
{file = "cryptography-39.0.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:fdd188c8a6ef8769f148f88f859884507b954cc64db6b52f66ef199bb9ad660a"},
|
816 |
+
{file = "cryptography-39.0.1.tar.gz", hash = "sha256:d1f6198ee6d9148405e49887803907fe8962a23e6c6f83ea7d98f1c0de375695"},
|
|
|
|
|
817 |
]
|
818 |
|
819 |
[package.dependencies]
|
820 |
cffi = ">=1.12"
|
821 |
|
822 |
[package.extras]
|
823 |
+
docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=1.1.1)"]
|
824 |
docstest = ["pyenchant (>=1.6.11)", "sphinxcontrib-spelling (>=4.0.1)", "twine (>=1.12.0)"]
|
825 |
+
pep8test = ["black", "check-manifest", "mypy", "ruff", "types-pytz", "types-requests"]
|
826 |
sdist = ["setuptools-rust (>=0.11.4)"]
|
827 |
ssh = ["bcrypt (>=3.1.5)"]
|
828 |
+
test = ["hypothesis (>=1.11.4,!=3.79.2)", "iso8601", "pretend", "pytest (>=6.2.0)", "pytest-benchmark", "pytest-cov", "pytest-shard (>=0.1.2)", "pytest-subtests", "pytest-xdist", "pytz"]
|
829 |
+
test-randomorder = ["pytest-randomly"]
|
830 |
+
tox = ["tox"]
|
831 |
|
832 |
[[package]]
|
833 |
name = "cycler"
|
|
|
2250 |
[package.extras]
|
2251 |
test = ["pytest", "pytest-console-scripts", "pytest-tornasync"]
|
2252 |
|
2253 |
+
[[package]]
|
2254 |
+
name = "numexpr"
|
2255 |
+
version = "2.8.4"
|
2256 |
+
description = "Fast numerical expression evaluator for NumPy"
|
2257 |
+
category = "main"
|
2258 |
+
optional = false
|
2259 |
+
python-versions = ">=3.7"
|
2260 |
+
files = [
|
2261 |
+
{file = "numexpr-2.8.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a75967d46b6bd56455dd32da6285e5ffabe155d0ee61eef685bbfb8dafb2e484"},
|
2262 |
+
{file = "numexpr-2.8.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:db93cf1842f068247de631bfc8af20118bf1f9447cd929b531595a5e0efc9346"},
|
2263 |
+
{file = "numexpr-2.8.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7bca95f4473b444428061d4cda8e59ac564dc7dc6a1dea3015af9805c6bc2946"},
|
2264 |
+
{file = "numexpr-2.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9e34931089a6bafc77aaae21f37ad6594b98aa1085bb8b45d5b3cd038c3c17d9"},
|
2265 |
+
{file = "numexpr-2.8.4-cp310-cp310-win32.whl", hash = "sha256:f3a920bfac2645017110b87ddbe364c9c7a742870a4d2f6120b8786c25dc6db3"},
|
2266 |
+
{file = "numexpr-2.8.4-cp310-cp310-win_amd64.whl", hash = "sha256:6931b1e9d4f629f43c14b21d44f3f77997298bea43790cfcdb4dd98804f90783"},
|
2267 |
+
{file = "numexpr-2.8.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:9400781553541f414f82eac056f2b4c965373650df9694286b9bd7e8d413f8d8"},
|
2268 |
+
{file = "numexpr-2.8.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6ee9db7598dd4001138b482342b96d78110dd77cefc051ec75af3295604dde6a"},
|
2269 |
+
{file = "numexpr-2.8.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ff5835e8af9a212e8480003d731aad1727aaea909926fd009e8ae6a1cba7f141"},
|
2270 |
+
{file = "numexpr-2.8.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:655d84eb09adfee3c09ecf4a89a512225da153fdb7de13c447404b7d0523a9a7"},
|
2271 |
+
{file = "numexpr-2.8.4-cp311-cp311-win32.whl", hash = "sha256:5538b30199bfc68886d2be18fcef3abd11d9271767a7a69ff3688defe782800a"},
|
2272 |
+
{file = "numexpr-2.8.4-cp311-cp311-win_amd64.whl", hash = "sha256:3f039321d1c17962c33079987b675fb251b273dbec0f51aac0934e932446ccc3"},
|
2273 |
+
{file = "numexpr-2.8.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c867cc36cf815a3ec9122029874e00d8fbcef65035c4a5901e9b120dd5d626a2"},
|
2274 |
+
{file = "numexpr-2.8.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:059546e8f6283ccdb47c683101a890844f667fa6d56258d48ae2ecf1b3875957"},
|
2275 |
+
{file = "numexpr-2.8.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:845a6aa0ed3e2a53239b89c1ebfa8cf052d3cc6e053c72805e8153300078c0b1"},
|
2276 |
+
{file = "numexpr-2.8.4-cp37-cp37m-win32.whl", hash = "sha256:a38664e699526cb1687aefd9069e2b5b9387da7feac4545de446141f1ef86f46"},
|
2277 |
+
{file = "numexpr-2.8.4-cp37-cp37m-win_amd64.whl", hash = "sha256:eaec59e9bf70ff05615c34a8b8d6c7bd042bd9f55465d7b495ea5436f45319d0"},
|
2278 |
+
{file = "numexpr-2.8.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b318541bf3d8326682ebada087ba0050549a16d8b3fa260dd2585d73a83d20a7"},
|
2279 |
+
{file = "numexpr-2.8.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b076db98ca65eeaf9bd224576e3ac84c05e451c0bd85b13664b7e5f7b62e2c70"},
|
2280 |
+
{file = "numexpr-2.8.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90f12cc851240f7911a47c91aaf223dba753e98e46dff3017282e633602e76a7"},
|
2281 |
+
{file = "numexpr-2.8.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c368aa35ae9b18840e78b05f929d3a7b3abccdba9630a878c7db74ca2368339"},
|
2282 |
+
{file = "numexpr-2.8.4-cp38-cp38-win32.whl", hash = "sha256:b96334fc1748e9ec4f93d5fadb1044089d73fb08208fdb8382ed77c893f0be01"},
|
2283 |
+
{file = "numexpr-2.8.4-cp38-cp38-win_amd64.whl", hash = "sha256:a6d2d7740ae83ba5f3531e83afc4b626daa71df1ef903970947903345c37bd03"},
|
2284 |
+
{file = "numexpr-2.8.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:77898fdf3da6bb96aa8a4759a8231d763a75d848b2f2e5c5279dad0b243c8dfe"},
|
2285 |
+
{file = "numexpr-2.8.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:df35324666b693f13a016bc7957de7cc4d8801b746b81060b671bf78a52b9037"},
|
2286 |
+
{file = "numexpr-2.8.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:17ac9cfe6d0078c5fc06ba1c1bbd20b8783f28c6f475bbabd3cad53683075cab"},
|
2287 |
+
{file = "numexpr-2.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:df3a1f6b24214a1ab826e9c1c99edf1686c8e307547a9aef33910d586f626d01"},
|
2288 |
+
{file = "numexpr-2.8.4-cp39-cp39-win32.whl", hash = "sha256:7d71add384adc9119568d7e9ffa8a35b195decae81e0abf54a2b7779852f0637"},
|
2289 |
+
{file = "numexpr-2.8.4-cp39-cp39-win_amd64.whl", hash = "sha256:9f096d707290a6a00b6ffdaf581ee37331109fb7b6c8744e9ded7c779a48e517"},
|
2290 |
+
{file = "numexpr-2.8.4.tar.gz", hash = "sha256:d5432537418d18691b9115d615d6daa17ee8275baef3edf1afbbf8bc69806147"},
|
2291 |
+
]
|
2292 |
+
|
2293 |
+
[package.dependencies]
|
2294 |
+
numpy = ">=1.13.3"
|
2295 |
+
|
2296 |
[[package]]
|
2297 |
name = "numpy"
|
2298 |
version = "1.24.1"
|
|
|
3045 |
{file = "pytz-2022.7.tar.gz", hash = "sha256:7ccfae7b4b2c067464a6733c6261673fdb8fd1be905460396b97a073e9fa683a"},
|
3046 |
]
|
3047 |
|
3048 |
+
[[package]]
|
3049 |
+
name = "pyvirtualdisplay"
|
3050 |
+
version = "3.0"
|
3051 |
+
description = "python wrapper for Xvfb, Xephyr and Xvnc"
|
3052 |
+
category = "main"
|
3053 |
+
optional = false
|
3054 |
+
python-versions = "*"
|
3055 |
+
files = [
|
3056 |
+
{file = "PyVirtualDisplay-3.0-py3-none-any.whl", hash = "sha256:40d4b8dfe4b8de8552e28eb367647f311f88a130bf837fe910e7f180d5477f0e"},
|
3057 |
+
{file = "PyVirtualDisplay-3.0.tar.gz", hash = "sha256:09755bc3ceb6eb725fb07eca5425f43f2358d3bf08e00d2a9b792a1aedd16159"},
|
3058 |
+
]
|
3059 |
+
|
3060 |
[[package]]
|
3061 |
name = "pywin32"
|
3062 |
version = "305"
|
|
|
4253 |
[metadata]
|
4254 |
lock-version = "2.0"
|
4255 |
python-versions = "~3.10"
|
4256 |
+
content-hash = "8301ee1f2321a6c23370a61466fd3b45096291c2cc63326bbe4701774edf1d94"
|
ppo/policy.py
CHANGED
@@ -2,7 +2,7 @@ from stable_baselines3.common.vec_env.base_vec_env import VecEnv
|
|
2 |
from typing import Optional, Sequence
|
3 |
|
4 |
from gym.spaces import Box, Discrete
|
5 |
-
from shared.policy.on_policy import ActorCritic
|
6 |
|
7 |
|
8 |
class PPOActorCritic(ActorCritic):
|
@@ -13,21 +13,16 @@ class PPOActorCritic(ActorCritic):
|
|
13 |
v_hidden_sizes: Optional[Sequence[int]] = None,
|
14 |
**kwargs,
|
15 |
) -> None:
|
16 |
-
|
17 |
-
|
18 |
-
if
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
else
|
25 |
-
|
26 |
-
elif isinstance(obs_space, Discrete):
|
27 |
-
pi_hidden_sizes = pi_hidden_sizes or [64]
|
28 |
-
v_hidden_sizes = v_hidden_sizes or [64]
|
29 |
-
else:
|
30 |
-
raise ValueError(f"Unsupported observation space: {obs_space}")
|
31 |
super().__init__(
|
32 |
env,
|
33 |
pi_hidden_sizes,
|
|
|
2 |
from typing import Optional, Sequence
|
3 |
|
4 |
from gym.spaces import Box, Discrete
|
5 |
+
from shared.policy.on_policy import ActorCritic, default_hidden_sizes
|
6 |
|
7 |
|
8 |
class PPOActorCritic(ActorCritic):
|
|
|
13 |
v_hidden_sizes: Optional[Sequence[int]] = None,
|
14 |
**kwargs,
|
15 |
) -> None:
|
16 |
+
pi_hidden_sizes = (
|
17 |
+
pi_hidden_sizes
|
18 |
+
if pi_hidden_sizes is not None
|
19 |
+
else default_hidden_sizes(env.observation_space)
|
20 |
+
)
|
21 |
+
v_hidden_sizes = (
|
22 |
+
v_hidden_sizes
|
23 |
+
if v_hidden_sizes is not None
|
24 |
+
else default_hidden_sizes(env.observation_space)
|
25 |
+
)
|
|
|
|
|
|
|
|
|
|
|
26 |
super().__init__(
|
27 |
env,
|
28 |
pi_hidden_sizes,
|
pyproject.toml
CHANGED
@@ -23,6 +23,9 @@ torch-tb-profiler = "^0.4.1"
|
|
23 |
jupyter = "^1.0.0"
|
24 |
tabulate = "^0.9.0"
|
25 |
huggingface-hub = "^0.12.0"
|
|
|
|
|
|
|
26 |
|
27 |
[build-system]
|
28 |
requires = ["poetry-core"]
|
|
|
23 |
jupyter = "^1.0.0"
|
24 |
tabulate = "^0.9.0"
|
25 |
huggingface-hub = "^0.12.0"
|
26 |
+
cryptography = "39.0.1"
|
27 |
+
pyvirtualdisplay = "^3.0"
|
28 |
+
numexpr = "^2.8.4"
|
29 |
|
30 |
[build-system]
|
31 |
requires = ["poetry-core"]
|
replay.meta.json
CHANGED
@@ -1 +1 @@
|
|
1 |
-
{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 5.1.2 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with clang version 14.0.6\\nconfiguration: --prefix=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl --cc=arm64-apple-darwin20.0.0-clang --cxx=arm64-apple-darwin20.0.0-clang++ --nm=arm64-apple-darwin20.0.0-nm --ar=arm64-apple-darwin20.0.0-ar --disable-doc --disable-openssl --enable-demuxer=dash --enable-hardcoded-tables --enable-libfreetype --enable-libfontconfig --enable-libopenh264 --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/x86_64-apple-darwin13.4.0-clang --enable-neon --enable-gnutls --enable-libmp3lame --enable-libvpx --enable-pthreads --enable-gpl --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-pic --enable-shared --disable-static --enable-version3 --enable-zlib --pkg-config=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/pkg-config\\nlibavutil 57. 28.100 / 57. 28.100\\nlibavcodec 59. 37.100 / 59. 37.100\\nlibavformat 59. 27.100 / 59. 27.100\\nlibavdevice 59. 7.100 / 59. 7.100\\nlibavfilter 8. 44.100 / 8. 44.100\\nlibswscale 6. 7.100 / 6. 7.100\\nlibswresample 4. 7.100 / 4. 7.100\\nlibpostproc 56. 6.100 / 56. 6.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "600x400", "-pix_fmt", "rgb24", "-framerate", "50", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "50", "/var/folders/9g/my5557_91xddp6lx00nkzly80000gn/T/
|
|
|
1 |
+
{"content_type": "video/mp4", "encoder_version": {"backend": "ffmpeg", "version": "b'ffmpeg version 5.1.2 Copyright (c) 2000-2022 the FFmpeg developers\\nbuilt with clang version 14.0.6\\nconfiguration: --prefix=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_h_env_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_placehold_pl --cc=arm64-apple-darwin20.0.0-clang --cxx=arm64-apple-darwin20.0.0-clang++ --nm=arm64-apple-darwin20.0.0-nm --ar=arm64-apple-darwin20.0.0-ar --disable-doc --disable-openssl --enable-demuxer=dash --enable-hardcoded-tables --enable-libfreetype --enable-libfontconfig --enable-libopenh264 --enable-cross-compile --arch=arm64 --target-os=darwin --cross-prefix=arm64-apple-darwin20.0.0- --host-cc=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/x86_64-apple-darwin13.4.0-clang --enable-neon --enable-gnutls --enable-libmp3lame --enable-libvpx --enable-pthreads --enable-gpl --enable-libx264 --enable-libx265 --enable-libaom --enable-libsvtav1 --enable-libxml2 --enable-pic --enable-shared --disable-static --enable-version3 --enable-zlib --pkg-config=/Users/runner/miniforge3/conda-bld/ffmpeg_1671040513231/_build_env/bin/pkg-config\\nlibavutil 57. 28.100 / 57. 28.100\\nlibavcodec 59. 37.100 / 59. 37.100\\nlibavformat 59. 27.100 / 59. 27.100\\nlibavdevice 59. 7.100 / 59. 7.100\\nlibavfilter 8. 44.100 / 8. 44.100\\nlibswscale 6. 7.100 / 6. 7.100\\nlibswresample 4. 7.100 / 4. 7.100\\nlibpostproc 56. 6.100 / 56. 6.100\\n'", "cmdline": ["ffmpeg", "-nostats", "-loglevel", "error", "-y", "-f", "rawvideo", "-s:v", "600x400", "-pix_fmt", "rgb24", "-framerate", "50", "-i", "-", "-vf", "scale=trunc(iw/2)*2:trunc(ih/2)*2", "-vcodec", "libx264", "-pix_fmt", "yuv420p", "-r", "50", "/var/folders/9g/my5557_91xddp6lx00nkzly80000gn/T/tmpcfp_i4ww/ppo-CarRacing-v0/replay.mp4"]}, "episode": {"r": 910.102783203125, "l": 899, "t": 11.729042}}
|
replay.mp4
CHANGED
Binary files a/replay.mp4 and b/replay.mp4 differ
|
|
runner/running_utils.py
CHANGED
@@ -119,6 +119,8 @@ def set_seeds(seed: Optional[int], use_deterministic_algorithms: bool) -> None:
|
|
119 |
torch.backends.cudnn.benchmark = False
|
120 |
torch.use_deterministic_algorithms(use_deterministic_algorithms)
|
121 |
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
|
|
|
|
122 |
|
123 |
|
124 |
def make_policy(
|
|
|
119 |
torch.backends.cudnn.benchmark = False
|
120 |
torch.use_deterministic_algorithms(use_deterministic_algorithms)
|
121 |
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
122 |
+
# Stop warning and it would introduce stochasticity if I was using TF
|
123 |
+
os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
|
124 |
|
125 |
|
126 |
def make_policy(
|
saved_models/ppo-CarRacing-v0-S1-best/model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e252b9b261d9afa502ed7c5079655ec0e34990e64aaa8ec133e91ce72aa5fb34
|
3 |
+
size 2737400
|
shared/policy/on_policy.py
CHANGED
@@ -2,7 +2,7 @@ import gym
|
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
|
5 |
-
from gym.spaces import Box
|
6 |
from pathlib import Path
|
7 |
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
|
8 |
from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
|
@@ -47,6 +47,21 @@ def clamp_actions(
|
|
47 |
return actions
|
48 |
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
class ActorCritic(Policy):
|
51 |
def __init__(
|
52 |
self,
|
|
|
2 |
import numpy as np
|
3 |
import torch
|
4 |
|
5 |
+
from gym.spaces import Box, Discrete, Space
|
6 |
from pathlib import Path
|
7 |
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs
|
8 |
from typing import NamedTuple, Optional, Sequence, Tuple, TypeVar
|
|
|
47 |
return actions
|
48 |
|
49 |
|
50 |
+
def default_hidden_sizes(obs_space: Space) -> Sequence[int]:
|
51 |
+
if isinstance(obs_space, Box):
|
52 |
+
if len(obs_space.shape) == 3:
|
53 |
+
# By default feature extractor to output has no hidden layers
|
54 |
+
return []
|
55 |
+
elif len(obs_space.shape) == 1:
|
56 |
+
return [64, 64]
|
57 |
+
else:
|
58 |
+
raise ValueError(f"Unsupported observation space: {obs_space}")
|
59 |
+
elif isinstance(obs_space, Discrete):
|
60 |
+
return [64]
|
61 |
+
else:
|
62 |
+
raise ValueError(f"Unsupported observation space: {obs_space}")
|
63 |
+
|
64 |
+
|
65 |
class ActorCritic(Policy):
|
66 |
def __init__(
|
67 |
self,
|
shared/policy/policy.py
CHANGED
@@ -51,7 +51,6 @@ class Policy(nn.Module, ABC):
|
|
51 |
os.path.join(path, MODEL_FILENAME),
|
52 |
)
|
53 |
|
54 |
-
@abstractmethod
|
55 |
def load(self, path: str) -> None:
|
56 |
# VecNormalize load occurs in env.py
|
57 |
self.load_state_dict(
|
|
|
51 |
os.path.join(path, MODEL_FILENAME),
|
52 |
)
|
53 |
|
|
|
54 |
def load(self, path: str) -> None:
|
55 |
# VecNormalize load occurs in env.py
|
56 |
self.load_state_dict(
|
train.py
CHANGED
@@ -45,21 +45,20 @@ if __name__ == "__main__":
|
|
45 |
parser.add_argument(
|
46 |
"--pool-size", type=int, default=1, help="Simultaneous training jobs to run"
|
47 |
)
|
48 |
-
parser.
|
49 |
-
"
|
50 |
-
|
51 |
-
|
|
|
52 |
)
|
53 |
-
parser.set_defaults(algo="ppo", env="CartPole-v1", seed=1)
|
54 |
args = parser.parse_args()
|
55 |
print(args)
|
56 |
|
57 |
-
if args.
|
58 |
-
from pyvirtualdisplay import Display
|
59 |
|
60 |
-
virtual_display = Display(visible=
|
61 |
virtual_display.start()
|
62 |
-
delattr(args, "virtual_display")
|
63 |
|
64 |
# pool_size isn't a TrainArg so must be removed from args
|
65 |
pool_size = args.pool_size
|
|
|
45 |
parser.add_argument(
|
46 |
"--pool-size", type=int, default=1, help="Simultaneous training jobs to run"
|
47 |
)
|
48 |
+
parser.set_defaults(
|
49 |
+
algo="ppo",
|
50 |
+
env="MountainCarContinuous-v0",
|
51 |
+
seed=[1, 2, 3],
|
52 |
+
pool_size=3,
|
53 |
)
|
|
|
54 |
args = parser.parse_args()
|
55 |
print(args)
|
56 |
|
57 |
+
if args.pool_size == 1:
|
58 |
+
from pyvirtualdisplay.display import Display
|
59 |
|
60 |
+
virtual_display = Display(visible=False, size=(1400, 900))
|
61 |
virtual_display.start()
|
|
|
62 |
|
63 |
# pool_size isn't a TrainArg so must be removed from args
|
64 |
pool_size = args.pool_size
|
vpg/policy.py
CHANGED
@@ -15,7 +15,7 @@ from shared.policy.actor import (
|
|
15 |
actor_head,
|
16 |
)
|
17 |
from shared.policy.critic import CriticHead
|
18 |
-
from shared.policy.on_policy import Step, clamp_actions
|
19 |
from shared.policy.policy import ACTIVATION, Policy
|
20 |
|
21 |
PI_FILE_NAME = "pi.pt"
|
@@ -37,7 +37,7 @@ class VPGActorCritic(Policy):
|
|
37 |
def __init__(
|
38 |
self,
|
39 |
env: VecEnv,
|
40 |
-
hidden_sizes: Sequence[int],
|
41 |
init_layers_orthogonal: bool = True,
|
42 |
activation_fn: str = "tanh",
|
43 |
log_std_init: float = -0.5,
|
@@ -53,6 +53,12 @@ class VPGActorCritic(Policy):
|
|
53 |
self.use_sde = use_sde
|
54 |
self.squash_output = squash_output
|
55 |
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
pi_feature_extractor = FeatureExtractor(
|
57 |
obs_space, activation, init_layers_orthogonal=init_layers_orthogonal
|
58 |
)
|
|
|
15 |
actor_head,
|
16 |
)
|
17 |
from shared.policy.critic import CriticHead
|
18 |
+
from shared.policy.on_policy import Step, clamp_actions, default_hidden_sizes
|
19 |
from shared.policy.policy import ACTIVATION, Policy
|
20 |
|
21 |
PI_FILE_NAME = "pi.pt"
|
|
|
37 |
def __init__(
|
38 |
self,
|
39 |
env: VecEnv,
|
40 |
+
hidden_sizes: Optional[Sequence[int]] = None,
|
41 |
init_layers_orthogonal: bool = True,
|
42 |
activation_fn: str = "tanh",
|
43 |
log_std_init: float = -0.5,
|
|
|
53 |
self.use_sde = use_sde
|
54 |
self.squash_output = squash_output
|
55 |
|
56 |
+
hidden_sizes = (
|
57 |
+
hidden_sizes
|
58 |
+
if hidden_sizes is not None
|
59 |
+
else default_hidden_sizes(obs_space)
|
60 |
+
)
|
61 |
+
|
62 |
pi_feature_extractor = FeatureExtractor(
|
63 |
obs_space, activation, init_layers_orthogonal=init_layers_orthogonal
|
64 |
)
|