David Quarel commited on
Commit
c5fd14a
·
1 Parent(s): a4216e4

Add README.md and train.yaml

Browse files
Files changed (2) hide show
  1. README.md +35 -0
  2. train.yaml +49 -0
README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # jaxgmg2_shared_init
2
+
3
+ A collection of RL agent checkpoints studying the effect of shared initialization. Two base models (run IDs 19 and 27 from `jaxgmg2_3phase_optim_state`) are each used as a shared starting point, then independently continued from checkpoint 0 (fresh optimizer state) with α=1.0 across 10 different random seeds each.
4
+
5
+ ## Training Configuration
6
+
7
+ - **Environment**: JaxGMG open maze, cheese at any location, 9600 levels
8
+ - **Algorithm**: REINFORCE with value function baseline
9
+ - **Alpha (α)**: 1.0
10
+ - **Discount rate (γ)**: 0.98
11
+ - **Learning rate**: 5e-5
12
+ - **Total env steps**: 1,351,680,000 (~1.35B, 21k gradient steps)
13
+ - **Rollout steps**: 64
14
+ - **Base models**: `jaxgmg2_3phase_optim_state/al_1.0_g_0.98_id_19_seed_981019` and `...id_27_seed_981027`
15
+ - **Resume optimizer**: No (fresh optimizer at checkpoint 0)
16
+ - **Seeds per base model**: 30–39
17
+ - **Optimizer state saved**: Yes
18
+
19
+ ## Naming Schema
20
+
21
+ Checkpoints are named `al_1.0_g_0.98_id_{run_id}_shared_init_seed_{seed}`.
22
+
23
+ ## Reproduced with
24
+
25
+ See `train.yaml` in this repository. Run with:
26
+
27
+ ```bash
28
+ make run projects/rl/experiments/shared_init/jobs/train.yaml
29
+ ```
30
+
31
+ from the [timaeus monorepo](https://github.com/timaeus-research/timaeus).
32
+
33
+ ## WandB
34
+
35
+ Project: `jaxgmg2_shared_init`
train.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ parameters:
2
+ project_name: jaxgmg2_shared_init
3
+ action: rl
4
+ rl_action: train
5
+
6
+ # Learning
7
+ lr: 5e-5
8
+ alpha: 1.0
9
+ discount_rate: 0.98
10
+ cheese_loc: any
11
+ env_layout: open
12
+
13
+ # Training scale
14
+ num_total_env_steps: 1_351_680_000
15
+ num_levels: 9600
16
+ grad_acc_per_chunk: 4
17
+ num_rollout_steps: 64
18
+
19
+ # Resume from checkpoint 0 (shared initialisation, fresh optimizer)
20
+ resume_id: 0
21
+ resume_optim: false
22
+
23
+ # Checkpointing
24
+ ckpt_dir: jaxgmg2_shared_init
25
+ f_str_ckpt: "al_1.0_g_0.98_id_{run_id}_shared_init_seed_{seed}"
26
+ eval_schedule: "0:1,250:2,500:5,2000:10"
27
+ log_optimizer_state: true
28
+
29
+ # Logging
30
+ use_wandb: true
31
+ use_hf: true
32
+ wandb_project: jaxgmg2_shared_init
33
+
34
+ sweep:
35
+ - - resume: jaxgmg2_3phase_optim_state/al_1.0_g_0.98_id_19_seed_981019
36
+ run_id: 19
37
+ - resume: jaxgmg2_3phase_optim_state/al_1.0_g_0.98_id_27_seed_981027
38
+ run_id: 27
39
+
40
+ - - seed: 30
41
+ - seed: 31
42
+ - seed: 32
43
+ - seed: 33
44
+ - seed: 34
45
+ - seed: 35
46
+ - seed: 36
47
+ - seed: 37
48
+ - seed: 38
49
+ - seed: 39