dquarel commited on
Commit
77cc104
·
1 Parent(s): ed6e3a3

Update train.yaml with full alpha/discount_rate grid, add train_missing.yaml

Browse files
Files changed (3) hide show
  1. README.md +22 -8
  2. train.yaml +11 -2
  3. train_missing.yaml +32 -0
README.md CHANGED
@@ -1,22 +1,30 @@
1
  # jaxgmg2_3phase_unique
2
 
3
- 15 RL agent checkpoints trained on the JaxGMG maze environment with alpha=0.6 and discount_rate=0.98.
4
- First batch of the "3-phase" training runs, without optimizer state saved (see jaxgmg2_3phase_optim_state for the version with optimizer state).
 
5
 
6
  **WandB:** https://wandb.ai/devinterp/jaxgmg2_3phase_unique
7
 
8
  ## Sweep
9
 
10
- run_id sweep: 0-14. Seed is derived from run_id via:
 
 
 
 
 
 
 
 
 
 
11
  `seed = int(discount_rate*100)*10000 + int(alpha*10)*100 + run_id`
12
- e.g. run_id=0 -> seed=980600, run_id=14 -> seed=980614.
13
 
14
  ## Shared Hyperparams
15
 
16
  ```
17
  rl_action=train
18
- alpha=0.6
19
- discount_rate=0.98
20
  lr=5e-05
21
  num_total_env_steps=10000000000
22
  num_rollout_steps=64
@@ -39,14 +47,20 @@ use_hf=True
39
 
40
  ## Naming Schema
41
 
42
- Checkpoints are named `al_0.6_g_0.98_id_{run_id}_seed_{seed}`.
43
 
44
  ## Reproduced with
45
 
46
  See `train.yaml` in this repository. Run with:
47
 
48
  ```bash
49
- make run projects/rl/experiments/al_0.6_g_0.98/jobs/train_unique.yaml
 
 
 
 
 
 
50
  ```
51
 
52
  from the [timaeus monorepo](https://github.com/timaeus-research/timaeus).
 
1
  # jaxgmg2_3phase_unique
2
 
3
+ 224 RL agent checkpoints trained on the JaxGMG maze environment across a grid of
4
+ alpha and discount_rate values. Without optimizer state saved (see
5
+ jaxgmg2_3phase_optim_state for the version with optimizer state).
6
 
7
  **WandB:** https://wandb.ai/devinterp/jaxgmg2_3phase_unique
8
 
9
  ## Sweep
10
 
11
+ Grid over alpha x discount_rate x run_id (0-14):
12
+
13
+ ```
14
+ alpha: {0.4, 0.5, 0.6, 0.7, 1.0}
15
+ discount_rate: {0.97, 0.98, 0.99}
16
+ run_id: 0-14
17
+ ```
18
+
19
+ 5 x 3 x 15 = 225 combinations. 1 run missing: `al_0.7_g_0.98_id_14_seed_980714`.
20
+
21
+ Seed is derived from run_id via:
22
  `seed = int(discount_rate*100)*10000 + int(alpha*10)*100 + run_id`
 
23
 
24
  ## Shared Hyperparams
25
 
26
  ```
27
  rl_action=train
 
 
28
  lr=5e-05
29
  num_total_env_steps=10000000000
30
  num_rollout_steps=64
 
47
 
48
  ## Naming Schema
49
 
50
+ Checkpoints are named `al_{alpha}_g_{discount_rate}_id_{run_id}_seed_{seed}`.
51
 
52
  ## Reproduced with
53
 
54
  See `train.yaml` in this repository. Run with:
55
 
56
  ```bash
57
+ timaeus run train.yaml
58
+ ```
59
+
60
+ To fill the 1 missing run:
61
+
62
+ ```bash
63
+ timaeus run train_missing.yaml
64
  ```
65
 
66
  from the [timaeus monorepo](https://github.com/timaeus-research/timaeus).
train.yaml CHANGED
@@ -4,12 +4,11 @@ parameters:
4
  rl_action: train
5
 
6
  lr: 5e-5
7
- alpha: 0.6
8
- discount_rate: 0.98
9
  cheese_loc: any
10
  env_layout: open
11
  mask_type: first_episode
12
  use_prev_action: false
 
13
 
14
  num_total_env_steps: 10_000_000_000
15
  num_levels: 9600
@@ -28,6 +27,16 @@ parameters:
28
  ntfy: david_jaxgmg
29
 
30
  sweep:
 
 
 
 
 
 
 
 
 
 
31
  - - run_id: 0
32
  - run_id: 1
33
  - run_id: 2
 
4
  rl_action: train
5
 
6
  lr: 5e-5
 
 
7
  cheese_loc: any
8
  env_layout: open
9
  mask_type: first_episode
10
  use_prev_action: false
11
+ log_optimizer_state: false
12
 
13
  num_total_env_steps: 10_000_000_000
14
  num_levels: 9600
 
27
  ntfy: david_jaxgmg
28
 
29
  sweep:
30
+ - - alpha: 0.4
31
+ - alpha: 0.5
32
+ - alpha: 0.6
33
+ - alpha: 0.7
34
+ - alpha: 1.0
35
+
36
+ - - discount_rate: 0.97
37
+ - discount_rate: 0.98
38
+ - discount_rate: 0.99
39
+
40
  - - run_id: 0
41
  - run_id: 1
42
  - run_id: 2
train_missing.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ parameters:
2
+ project_name: jaxgmg2_3phase_unique
3
+ action: rl
4
+ rl_action: train
5
+
6
+ lr: 5e-5
7
+ alpha: 0.7
8
+ discount_rate: 0.98
9
+ cheese_loc: any
10
+ env_layout: open
11
+ mask_type: first_episode
12
+ use_prev_action: false
13
+ log_optimizer_state: false
14
+
15
+ num_total_env_steps: 10_000_000_000
16
+ num_levels: 9600
17
+ grad_acc_per_chunk: 5
18
+ num_rollout_steps: 64
19
+
20
+ seed_formula: "{int(discount_rate*100):02d}{int(alpha*10):02d}{run_id:02d}"
21
+ ckpt_dir: jaxgmg2_3phase_unique
22
+ f_str_ckpt: "al_{alpha}_g_{discount_rate}_id_{run_id}_seed_{seed}"
23
+ eval_schedule: "0:1,250:2,500:5,2000:10"
24
+
25
+ wandb_project: jaxgmg2_3phase_unique
26
+ use_wandb: true
27
+ use_hf: true
28
+ no_tqdm: true
29
+ ntfy: david_jaxgmg
30
+
31
+ sweep:
32
+ - - run_id: 14