agaralon commited on
Commit
a986605
·
unverified ·
1 Parent(s): 9118e6e

Parameter counts and some explanation

Browse files
Files changed (2) hide show
  1. README.md +44 -3
  2. count_params.py +28 -0
README.md CHANGED
@@ -1,3 +1,44 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language: en
3
+ tags:
4
+ - machine-learning
5
+ - reinforcement-learning
6
+ - sokoban
7
+ - planning
8
+ license: apache-2.0
9
+ ---
10
+
11
+ # Trained learned planners
12
+
13
+ This repository contains the trained networks from the paper ["Planning behavior in a recurrent neural network that
14
+ plays Sokoban"](https://openreview.net/forum?id=T9sB3S2hok), presented at the ICML 2024 Mechanistic Interpretability
15
+ Workshop.
16
+
17
+ To load and use the NNs, please refer to the [learned-planner
18
+ repository](http://github.com/alignmentresearch/learned-planner), and possibly to the [training code
19
+ ](https://github.com/AlignmentResearch/train-learned-planner).
20
+
21
+ # Model details
22
+
23
+ **Hyperparameters:** see `model/*/cp_*/cfg.json` for the hyperparameters that were used to train a particular run.
24
+
25
+
26
+ ## Parameter counts:
27
+
28
+ - DRC(3, 3): 1,285,125 (1.29M)
29
+ - DRC(1, 1): 987,525 (0.99M)
30
+ - ResNet: 3,068,421 (3.07M)
31
+
32
+ # Citation
33
+
34
+ If you use these neural networks, please cite our work:
35
+
36
+ ```bibtex
37
+ @inproceedings{TODO: add your citation here,
38
+ title={Planning behavior in a recurrent neural network that plays Sokoban},
39
+ author={Your Authors},
40
+ booktitle={ICML 2024 Mechanistic Interpretability Workshop},
41
+ year={2024},
42
+ url={https://openreview.net/forum?id=T9sB3S2hok}
43
+ }
44
+ ```
count_params.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import farconf
6
+ from cleanba.config import Args
7
+ from cleanba.environments import SokobanConfig
8
+
9
+ soko_env = SokobanConfig(
10
+ max_episode_steps=100, num_envs=1, dim_room=(10, 10), num_boxes=1, asynchronous=False, tinyworld_obs=True
11
+ ).make()
12
+
13
+
14
+ def parameter_count(root: Path) -> str:
15
+ model_dir = os.listdir(root)[0]
16
+ cp_dir = os.listdir(root / model_dir)[0]
17
+
18
+ with open(root / model_dir / cp_dir / "cfg.json", "r") as f:
19
+ cfg = json.load(f)
20
+
21
+ args = farconf.from_dict(cfg["cfg"], Args)
22
+ num = args.net.count_params(soko_env)
23
+ return f"{num:,} ({num/1_000_000:.2f}M)"
24
+
25
+
26
+ print("- DRC(3, 3): ", parameter_count(Path("drc33")))
27
+ print("- DRC(1, 1): ", parameter_count(Path("drc11")))
28
+ print("- ResNet: ", parameter_count(Path("resnet")))