David Quarel commited on
Commit ·
c5fd14a
1
Parent(s): a4216e4
Add README.md and train.yaml
Browse files- README.md +35 -0
- train.yaml +49 -0
README.md
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# jaxgmg2_shared_init
|
| 2 |
+
|
| 3 |
+
A collection of RL agent checkpoints studying the effect of shared initialization. Two base models (run IDs 19 and 27 from `jaxgmg2_3phase_optim_state`) are each used as a shared starting point, then independently continued from checkpoint 0 (fresh optimizer state) with α=1.0 across 10 different random seeds each.
|
| 4 |
+
|
| 5 |
+
## Training Configuration
|
| 6 |
+
|
| 7 |
+
- **Environment**: JaxGMG open maze, cheese at any location, 9600 levels
|
| 8 |
+
- **Algorithm**: REINFORCE with value function baseline
|
| 9 |
+
- **Alpha (α)**: 1.0
|
| 10 |
+
- **Discount rate (γ)**: 0.98
|
| 11 |
+
- **Learning rate**: 5e-5
|
| 12 |
+
- **Total env steps**: 1,351,680,000 (~1.35B, 21k gradient steps)
|
| 13 |
+
- **Rollout steps**: 64
|
| 14 |
+
- **Base models**: `jaxgmg2_3phase_optim_state/al_1.0_g_0.98_id_19_seed_981019` and `...id_27_seed_981027`
|
| 15 |
+
- **Resume optimizer**: No (fresh optimizer at checkpoint 0)
|
| 16 |
+
- **Seeds per base model**: 30–39
|
| 17 |
+
- **Optimizer state saved**: Yes
|
| 18 |
+
|
| 19 |
+
## Naming Schema
|
| 20 |
+
|
| 21 |
+
Checkpoints are named `al_1.0_g_0.98_id_{run_id}_shared_init_seed_{seed}`.
|
| 22 |
+
|
| 23 |
+
## Reproduced with
|
| 24 |
+
|
| 25 |
+
See `train.yaml` in this repository. Run with:
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
make run projects/rl/experiments/shared_init/jobs/train.yaml
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
from the [timaeus monorepo](https://github.com/timaeus-research/timaeus).
|
| 32 |
+
|
| 33 |
+
## WandB
|
| 34 |
+
|
| 35 |
+
Project: `jaxgmg2_shared_init`
|
train.yaml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
parameters:
|
| 2 |
+
project_name: jaxgmg2_shared_init
|
| 3 |
+
action: rl
|
| 4 |
+
rl_action: train
|
| 5 |
+
|
| 6 |
+
# Learning
|
| 7 |
+
lr: 5e-5
|
| 8 |
+
alpha: 1.0
|
| 9 |
+
discount_rate: 0.98
|
| 10 |
+
cheese_loc: any
|
| 11 |
+
env_layout: open
|
| 12 |
+
|
| 13 |
+
# Training scale
|
| 14 |
+
num_total_env_steps: 1_351_680_000
|
| 15 |
+
num_levels: 9600
|
| 16 |
+
grad_acc_per_chunk: 4
|
| 17 |
+
num_rollout_steps: 64
|
| 18 |
+
|
| 19 |
+
# Resume from checkpoint 0 (shared initialisation, fresh optimizer)
|
| 20 |
+
resume_id: 0
|
| 21 |
+
resume_optim: false
|
| 22 |
+
|
| 23 |
+
# Checkpointing
|
| 24 |
+
ckpt_dir: jaxgmg2_shared_init
|
| 25 |
+
f_str_ckpt: "al_1.0_g_0.98_id_{run_id}_shared_init_seed_{seed}"
|
| 26 |
+
eval_schedule: "0:1,250:2,500:5,2000:10"
|
| 27 |
+
log_optimizer_state: true
|
| 28 |
+
|
| 29 |
+
# Logging
|
| 30 |
+
use_wandb: true
|
| 31 |
+
use_hf: true
|
| 32 |
+
wandb_project: jaxgmg2_shared_init
|
| 33 |
+
|
| 34 |
+
sweep:
|
| 35 |
+
- - resume: jaxgmg2_3phase_optim_state/al_1.0_g_0.98_id_19_seed_981019
|
| 36 |
+
run_id: 19
|
| 37 |
+
- resume: jaxgmg2_3phase_optim_state/al_1.0_g_0.98_id_27_seed_981027
|
| 38 |
+
run_id: 27
|
| 39 |
+
|
| 40 |
+
- - seed: 30
|
| 41 |
+
- seed: 31
|
| 42 |
+
- seed: 32
|
| 43 |
+
- seed: 33
|
| 44 |
+
- seed: 34
|
| 45 |
+
- seed: 35
|
| 46 |
+
- seed: 36
|
| 47 |
+
- seed: 37
|
| 48 |
+
- seed: 38
|
| 49 |
+
- seed: 39
|