ksridhar commited on
Commit
5d5751b
1 Parent(s): 504aa44

Upload folder using huggingface_hub

Browse files
Files changed (44) hide show
  1. .gitattributes +1 -0
  2. .summary/0/events.out.tfevents.1718465455.koa03 +3 -0
  3. README.md +56 -0
  4. checkpoint_p0/best_000820672_1680736256_reward_1255.320.pth +3 -0
  5. checkpoint_p0/checkpoint_000976608_2000093184.pth +3 -0
  6. checkpoint_p0/checkpoint_000976624_2000125952.pth +3 -0
  7. checkpoint_p0/milestones/checkpoint_000025328_51871744.pth +3 -0
  8. checkpoint_p0/milestones/checkpoint_000052832_108199936.pth +3 -0
  9. checkpoint_p0/milestones/checkpoint_000080576_165019648.pth +3 -0
  10. checkpoint_p0/milestones/checkpoint_000108128_221446144.pth +3 -0
  11. checkpoint_p0/milestones/checkpoint_000136192_278921216.pth +3 -0
  12. checkpoint_p0/milestones/checkpoint_000164160_336199680.pth +3 -0
  13. checkpoint_p0/milestones/checkpoint_000191744_392691712.pth +3 -0
  14. checkpoint_p0/milestones/checkpoint_000220096_450756608.pth +3 -0
  15. checkpoint_p0/milestones/checkpoint_000248000_507904000.pth +3 -0
  16. checkpoint_p0/milestones/checkpoint_000276096_565444608.pth +3 -0
  17. checkpoint_p0/milestones/checkpoint_000303856_622297088.pth +3 -0
  18. checkpoint_p0/milestones/checkpoint_000331888_679706624.pth +3 -0
  19. checkpoint_p0/milestones/checkpoint_000359664_736591872.pth +3 -0
  20. checkpoint_p0/milestones/checkpoint_000387616_793837568.pth +3 -0
  21. checkpoint_p0/milestones/checkpoint_000415744_851443712.pth +3 -0
  22. checkpoint_p0/milestones/checkpoint_000444096_909508608.pth +3 -0
  23. checkpoint_p0/milestones/checkpoint_000472512_967704576.pth +3 -0
  24. checkpoint_p0/milestones/checkpoint_000500608_1025245184.pth +3 -0
  25. checkpoint_p0/milestones/checkpoint_000528672_1082720256.pth +3 -0
  26. checkpoint_p0/milestones/checkpoint_000556928_1140588544.pth +3 -0
  27. checkpoint_p0/milestones/checkpoint_000585408_1198915584.pth +3 -0
  28. checkpoint_p0/milestones/checkpoint_000614144_1257766912.pth +3 -0
  29. checkpoint_p0/milestones/checkpoint_000642240_1315307520.pth +3 -0
  30. checkpoint_p0/milestones/checkpoint_000670720_1373634560.pth +3 -0
  31. checkpoint_p0/milestones/checkpoint_000698816_1431175168.pth +3 -0
  32. checkpoint_p0/milestones/checkpoint_000726976_1488846848.pth +3 -0
  33. checkpoint_p0/milestones/checkpoint_000755328_1546911744.pth +3 -0
  34. checkpoint_p0/milestones/checkpoint_000783136_1603862528.pth +3 -0
  35. checkpoint_p0/milestones/checkpoint_000811584_1662124032.pth +3 -0
  36. checkpoint_p0/milestones/checkpoint_000839936_1720188928.pth +3 -0
  37. checkpoint_p0/milestones/checkpoint_000868144_1777958912.pth +3 -0
  38. checkpoint_p0/milestones/checkpoint_000896240_1835499520.pth +3 -0
  39. checkpoint_p0/milestones/checkpoint_000924672_1893728256.pth +3 -0
  40. checkpoint_p0/milestones/checkpoint_000953008_1951760384.pth +3 -0
  41. config.json +167 -0
  42. git.diff +712 -0
  43. replay.mp4 +3 -0
  44. sf_log.txt +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ replay.mp4 filter=lfs diff=lfs merge=lfs -text
.summary/0/events.out.tfevents.1718465455.koa03 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e94cf53b8f6360348849b631b4c38a3b6f7f36c0e9a04db91a35f5b7b9ccd542
3
+ size 18285797
README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: sample-factory
3
+ tags:
4
+ - deep-reinforcement-learning
5
+ - reinforcement-learning
6
+ - sample-factory
7
+ model-index:
8
+ - name: APPO
9
+ results:
10
+ - task:
11
+ type: reinforcement-learning
12
+ name: reinforcement-learning
13
+ dataset:
14
+ name: atari_airraid
15
+ type: atari_airraid
16
+ metrics:
17
+ - type: mean_reward
18
+ value: 465.00 +/- 182.76
19
+ name: mean_reward
20
+ verified: false
21
+ ---
22
+
23
+ A(n) **APPO** model trained on the **atari_airraid** environment.
24
+
25
+ This model was trained using Sample-Factory 2.0: https://github.com/alex-petrenko/sample-factory.
26
+ Documentation for how to use Sample-Factory can be found at https://www.samplefactory.dev/
27
+
28
+
29
+ ## Downloading the model
30
+
31
+ After installing Sample-Factory, download the model with:
32
+ ```
33
+ python -m sample_factory.huggingface.load_from_hub -r ksridhar/atari_2B_atari_airraid_1111
34
+ ```
35
+
36
+
37
+ ## Using the model
38
+
39
+ To run the model after download, use the `enjoy` script corresponding to this environment:
40
+ ```
41
+ python -m <path.to.enjoy.module> --algo=APPO --env=atari_airraid --train_dir=./train_dir --experiment=atari_2B_atari_airraid_1111
42
+ ```
43
+
44
+
45
+ You can also upload models to the Hugging Face Hub using the same script with the `--push_to_hub` flag.
46
+ See https://www.samplefactory.dev/10-huggingface/huggingface/ for more details
47
+
48
+ ## Training with this model
49
+
50
+ To continue training with this model, use the `train` script corresponding to this environment:
51
+ ```
52
+ python -m <path.to.train.module> --algo=APPO --env=atari_airraid --train_dir=./train_dir --experiment=atari_2B_atari_airraid_1111 --restart_behavior=resume --train_for_env_steps=10000000000
53
+ ```
54
+
55
+ Note, you may have to adjust `--train_for_env_steps` to a suitably high number as the experiment will resume at the number of steps it concluded at.
56
+
checkpoint_p0/best_000820672_1680736256_reward_1255.320.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04ea5891f52bccc613ae420f2383c694c8a4583d64cce037fe061a08240982b3
3
+ size 20722280
checkpoint_p0/checkpoint_000976608_2000093184.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f5b7f2c104719944467b95f32e263dff2734fcfce0a7940f053de310ca23fb9
3
+ size 20722628
checkpoint_p0/checkpoint_000976624_2000125952.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e83c5aba7700b73f5f86d2763f21eadcaee8e1daabfd945d07def7a62c2fdfab
3
+ size 20722628
checkpoint_p0/milestones/checkpoint_000025328_51871744.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c70ea6c746b1099395d74f65b51a553989d2373949cda55bdb936f6d339f163a
3
+ size 20723568
checkpoint_p0/milestones/checkpoint_000052832_108199936.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90153c065eaf8a16774217a7ce5a67685dd5fc0af1d6f8ce55b475a50deea4fe
3
+ size 20723626
checkpoint_p0/milestones/checkpoint_000080576_165019648.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48af7f9d07dc3e13cd2d01cff53a5d4c142afb6e5d5cb46c3f8fcbf78c529f52
3
+ size 20723626
checkpoint_p0/milestones/checkpoint_000108128_221446144.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed2dbe5193dcea57fbc0d9dce3e6d3bd959bb251e232176daa539bfc93a24530
3
+ size 20723626
checkpoint_p0/milestones/checkpoint_000136192_278921216.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba07de6c1b32632354f3d77b63a47c28e412bfed43ca5ef82c3e33fa2616785e
3
+ size 20723626
checkpoint_p0/milestones/checkpoint_000164160_336199680.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf9e64c7241e64da17332f5779f522e2efd38d87cf5847eeb914f1b86bb90eeb
3
+ size 20723626
checkpoint_p0/milestones/checkpoint_000191744_392691712.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fcce6b301593984a3c055b3fbe8f21f3823b1b3233c9a405a5c2c62168f8d45f
3
+ size 20723626
checkpoint_p0/milestones/checkpoint_000220096_450756608.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:60e5d3544717f7983b4abe5a622035a4949d20a3a07b0a6d2732d54e15004567
3
+ size 20723626
checkpoint_p0/milestones/checkpoint_000248000_507904000.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04bef3f4845d5afb30cb49f9bc48dd421b1ae398e2e249b56dfeda5b1f089906
3
+ size 20723626
checkpoint_p0/milestones/checkpoint_000276096_565444608.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4eb1a15333413a30b47229e6e8e1236096c8fb60dc20560f0762f5d5a7ee491
3
+ size 20723626
checkpoint_p0/milestones/checkpoint_000303856_622297088.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8ed4b042c87e40ee5edcc31499c946d83b3206d68b572009c7f0b9b1bb9e51e8
3
+ size 20723626
checkpoint_p0/milestones/checkpoint_000331888_679706624.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:506f2656ecd9f241a5175470d7502dea2df70e4b5f9c0fb911219fe041c906e8
3
+ size 20723626
checkpoint_p0/milestones/checkpoint_000359664_736591872.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e56a0ed00e07cb870607d3ac6727ae2d1c952cfff4156fbabf41954a816271fe
3
+ size 20723626
checkpoint_p0/milestones/checkpoint_000387616_793837568.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c4206c3d4735f6aada46266d53b372afef0ecf8aee1516538fdde4128e73ca35
3
+ size 20723626
checkpoint_p0/milestones/checkpoint_000415744_851443712.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cdbeb1c1d4facca0691907133f028f7a395d7a4631e024bb962fd5eae9e0473e
3
+ size 20723626
checkpoint_p0/milestones/checkpoint_000444096_909508608.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:908e6f59f7f70604da8a49c0effa07c915a4c8181b2229bbf5060448eab34237
3
+ size 20723626
checkpoint_p0/milestones/checkpoint_000472512_967704576.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3acfe6effa0eb87d9be6c64ded5e46f73548ab35f2b9324282a9fab72952cdd9
3
+ size 20723626
checkpoint_p0/milestones/checkpoint_000500608_1025245184.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c0562aa2dd28df15165bbc10ac172196e667e69694225bb419790a0204872b89
3
+ size 20723684
checkpoint_p0/milestones/checkpoint_000528672_1082720256.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a359cb85026316b46cec52c2b3072bb3f4c6d0c8f8c04b81652d53c320962ab
3
+ size 20723684
checkpoint_p0/milestones/checkpoint_000556928_1140588544.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67e509d0cba38cab4cb931ce1b06b3c0dcd997802b1cd0e2cf62bf42bba8c964
3
+ size 20723684
checkpoint_p0/milestones/checkpoint_000585408_1198915584.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:123f1825e68d6007a227a2809108a82997c6140e78cd4b441124d5de77909c72
3
+ size 20723684
checkpoint_p0/milestones/checkpoint_000614144_1257766912.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58fed7f6964ad4469525aa4840aab5182a39bd3fa981c5d0c85131980dff948b
3
+ size 20723684
checkpoint_p0/milestones/checkpoint_000642240_1315307520.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c2b24a591716ea07f740d151b7b912eb8da99d438b7421c1c90e458950ff079
3
+ size 20723684
checkpoint_p0/milestones/checkpoint_000670720_1373634560.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06d76f4a0427517572114f56ad297b0820be092421150f56f0b045289b88fba0
3
+ size 20723684
checkpoint_p0/milestones/checkpoint_000698816_1431175168.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:67eedd32ef45079fb6063fa17b3573d4cfa3756ff2b850d2ed9a7cfa3ff2e9f8
3
+ size 20723684
checkpoint_p0/milestones/checkpoint_000726976_1488846848.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f799ebc4f68d484712629e50e5e50b61a5cf1c228487f35565bef8fad96877ef
3
+ size 20723684
checkpoint_p0/milestones/checkpoint_000755328_1546911744.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:acf482fd0d35a9e9466747e0902a0113179aa3153d6457d9069b62d5006a9af1
3
+ size 20723684
checkpoint_p0/milestones/checkpoint_000783136_1603862528.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a72d7d19a447c00346e260ca9efe1737505c1bdd124bc713d2fb5f600d758db2
3
+ size 20723684
checkpoint_p0/milestones/checkpoint_000811584_1662124032.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d129fb12044a642bcf9cce81c2a72773e839a07e7d83b400c28ffa6f6e77dda
3
+ size 20723684
checkpoint_p0/milestones/checkpoint_000839936_1720188928.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:04aa797855e355aa77fc037b122b93c36ef03add6458a5964950c787ec16b547
3
+ size 20723684
checkpoint_p0/milestones/checkpoint_000868144_1777958912.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f2e77ad1f2d084ca4b8a0560788be4a2c00b2215895a53f0b1ce94ee419dc2f
3
+ size 20723684
checkpoint_p0/milestones/checkpoint_000896240_1835499520.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d45adf6d55ca09b5f89e4178e6eae4c7844bf2e42f0c19985040c7e348a4f6a
3
+ size 20723684
checkpoint_p0/milestones/checkpoint_000924672_1893728256.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8057f3255e306fda08291c8e367d5adc5bf1baea5f0da09fb14f76e3d035666
3
+ size 20723684
checkpoint_p0/milestones/checkpoint_000953008_1951760384.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a07fc48118e729c1125f70c78b65c102fff6c9714e8cc37b8455fca3da50fa3
3
+ size 20723684
config.json ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "help": false,
3
+ "algo": "APPO",
4
+ "env": "atari_airraid",
5
+ "experiment": "atari_2B_atari_airraid_1111",
6
+ "train_dir": "train_dir",
7
+ "restart_behavior": "resume",
8
+ "device": "gpu",
9
+ "seed": 1111,
10
+ "num_policies": 1,
11
+ "async_rl": true,
12
+ "serial_mode": false,
13
+ "batched_sampling": true,
14
+ "num_batches_to_accumulate": 2,
15
+ "worker_num_splits": 1,
16
+ "policy_workers_per_policy": 1,
17
+ "max_policy_lag": 1000,
18
+ "num_workers": 4,
19
+ "num_envs_per_worker": 1,
20
+ "batch_size": 1024,
21
+ "num_batches_per_epoch": 8,
22
+ "num_epochs": 2,
23
+ "rollout": 64,
24
+ "recurrence": 1,
25
+ "shuffle_minibatches": false,
26
+ "gamma": 0.99,
27
+ "reward_scale": 1.0,
28
+ "reward_clip": 1000.0,
29
+ "value_bootstrap": false,
30
+ "normalize_returns": true,
31
+ "exploration_loss_coeff": 0.0004677351413,
32
+ "value_loss_coeff": 0.5,
33
+ "kl_loss_coeff": 0.0,
34
+ "exploration_loss": "entropy",
35
+ "gae_lambda": 0.95,
36
+ "ppo_clip_ratio": 0.1,
37
+ "ppo_clip_value": 1.0,
38
+ "with_vtrace": false,
39
+ "vtrace_rho": 1.0,
40
+ "vtrace_c": 1.0,
41
+ "optimizer": "adam",
42
+ "adam_eps": 1e-05,
43
+ "adam_beta1": 0.9,
44
+ "adam_beta2": 0.999,
45
+ "max_grad_norm": 0.0,
46
+ "learning_rate": 0.0003033891184,
47
+ "lr_schedule": "linear_decay",
48
+ "lr_schedule_kl_threshold": 0.008,
49
+ "lr_adaptive_min": 1e-06,
50
+ "lr_adaptive_max": 0.01,
51
+ "obs_subtract_mean": 0.0,
52
+ "obs_scale": 255.0,
53
+ "normalize_input": true,
54
+ "normalize_input_keys": [
55
+ "obs"
56
+ ],
57
+ "decorrelate_experience_max_seconds": 1,
58
+ "decorrelate_envs_on_one_worker": true,
59
+ "actor_worker_gpus": [],
60
+ "set_workers_cpu_affinity": true,
61
+ "force_envs_single_thread": false,
62
+ "default_niceness": 0,
63
+ "log_to_file": true,
64
+ "experiment_summaries_interval": 3,
65
+ "flush_summaries_interval": 30,
66
+ "stats_avg": 100,
67
+ "summaries_use_frameskip": true,
68
+ "heartbeat_interval": 20,
69
+ "heartbeat_reporting_interval": 180,
70
+ "train_for_env_steps": 2000000000,
71
+ "train_for_seconds": 3600000,
72
+ "save_every_sec": 120,
73
+ "keep_checkpoints": 2,
74
+ "load_checkpoint_kind": "latest",
75
+ "save_milestones_sec": 1200,
76
+ "save_best_every_sec": 5,
77
+ "save_best_metric": "reward",
78
+ "save_best_after": 100000,
79
+ "benchmark": false,
80
+ "encoder_mlp_layers": [
81
+ 512,
82
+ 512
83
+ ],
84
+ "encoder_conv_architecture": "convnet_atari",
85
+ "encoder_conv_mlp_layers": [
86
+ 512
87
+ ],
88
+ "use_rnn": false,
89
+ "rnn_size": 512,
90
+ "rnn_type": "gru",
91
+ "rnn_num_layers": 1,
92
+ "decoder_mlp_layers": [],
93
+ "nonlinearity": "relu",
94
+ "policy_initialization": "orthogonal",
95
+ "policy_init_gain": 1.0,
96
+ "actor_critic_share_weights": true,
97
+ "adaptive_stddev": false,
98
+ "continuous_tanh_scale": 0.0,
99
+ "initial_stddev": 1.0,
100
+ "use_env_info_cache": false,
101
+ "env_gpu_actions": false,
102
+ "env_gpu_observations": true,
103
+ "env_frameskip": 4,
104
+ "env_framestack": 4,
105
+ "pixel_format": "CHW",
106
+ "use_record_episode_statistics": true,
107
+ "episode_counter": false,
108
+ "with_wandb": false,
109
+ "wandb_user": null,
110
+ "wandb_project": "sample_factory",
111
+ "wandb_group": null,
112
+ "wandb_job_type": "SF",
113
+ "wandb_tags": [],
114
+ "with_pbt": false,
115
+ "pbt_mix_policies_in_one_env": true,
116
+ "pbt_period_env_steps": 5000000,
117
+ "pbt_start_mutation": 20000000,
118
+ "pbt_replace_fraction": 0.3,
119
+ "pbt_mutation_rate": 0.15,
120
+ "pbt_replace_reward_gap": 0.1,
121
+ "pbt_replace_reward_gap_absolute": 1e-06,
122
+ "pbt_optimize_gamma": false,
123
+ "pbt_target_objective": "true_objective",
124
+ "pbt_perturb_min": 1.1,
125
+ "pbt_perturb_max": 1.5,
126
+ "env_agents": 512,
127
+ "command_line": "--seed=1111 --experiment=atari_2B_atari_airraid_1111 --env=atari_airraid --train_for_seconds=3600000 --algo=APPO --gamma=0.99 --num_workers=4 --num_envs_per_worker=1 --worker_num_splits=1 --env_agents=512 --benchmark=False --max_grad_norm=0.0 --decorrelate_experience_max_seconds=1 --encoder_conv_architecture=convnet_atari --encoder_conv_mlp_layers 512 --nonlinearity=relu --num_policies=1 --normalize_input=True --normalize_input_keys obs --normalize_returns=True --async_rl=True --batched_sampling=True --train_for_env_steps=2000000000 --save_milestones_sec=1200 --train_dir train_dir --rollout 64 --exploration_loss_coeff 0.0004677351413 --num_epochs 2 --batch_size 1024 --num_batches_per_epoch 8 --learning_rate 0.0003033891184",
128
+ "cli_args": {
129
+ "algo": "APPO",
130
+ "env": "atari_airraid",
131
+ "experiment": "atari_2B_atari_airraid_1111",
132
+ "train_dir": "train_dir",
133
+ "seed": 1111,
134
+ "num_policies": 1,
135
+ "async_rl": true,
136
+ "batched_sampling": true,
137
+ "worker_num_splits": 1,
138
+ "num_workers": 4,
139
+ "num_envs_per_worker": 1,
140
+ "batch_size": 1024,
141
+ "num_batches_per_epoch": 8,
142
+ "num_epochs": 2,
143
+ "rollout": 64,
144
+ "gamma": 0.99,
145
+ "normalize_returns": true,
146
+ "exploration_loss_coeff": 0.0004677351413,
147
+ "max_grad_norm": 0.0,
148
+ "learning_rate": 0.0003033891184,
149
+ "normalize_input": true,
150
+ "normalize_input_keys": [
151
+ "obs"
152
+ ],
153
+ "decorrelate_experience_max_seconds": 1,
154
+ "train_for_env_steps": 2000000000,
155
+ "train_for_seconds": 3600000,
156
+ "save_milestones_sec": 1200,
157
+ "benchmark": false,
158
+ "encoder_conv_architecture": "convnet_atari",
159
+ "encoder_conv_mlp_layers": [
160
+ 512
161
+ ],
162
+ "nonlinearity": "relu",
163
+ "env_agents": 512
164
+ },
165
+ "git_hash": "e259c57b8c7aa9c7f541e9efd1316f8e6f97a6db",
166
+ "git_repo_name": "https://github.com/kaustubhsridhar/jat_regent.git"
167
+ }
git.diff ADDED
@@ -0,0 +1,712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/README.md b/README.md
2
+ index e51a12b..a6e1ca1 100644
3
+ --- a/README.md
4
+ +++ b/README.md
5
+ @@ -21,6 +21,21 @@ conda activate jat
6
+ pip install -e .[dev]
7
+ ```
8
+
9
+ +## REGENT fork of sample-factory: Installation
10
+ +Following [this install ink](https://www.samplefactory.dev/01-get-started/installation/) but for the fork:
11
+ +```shell
12
+ +git clone https://github.com/kaustubhsridhar/sample-factory.git
13
+ +cd sample-factory
14
+ +pip install -e .[dev,mujoco,atari,envpool,vizdoom]
15
+ +```
16
+ +
17
+ +# Regent fork of sample-factory: Train Unseen Env Policies and Generate Datasets
18
+ +Train policies using envpool's atari:
19
+ +```shell
20
+ +bash scripts_sample-factory/train_unseen_atari.sh
21
+ +```
22
+ +Note that the training command inside the above script was obtained from the config files of Ed Beeching's Atari 57 models on Huggingface. An example is [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/blob/main/cfg.json#L124). See my discussion [here](https://huggingface.co/edbeeching/atari_2B_atari_mspacman_1111/discussions/2).
23
+ +
24
+ ## PREV Installation
25
+
26
+ To get started with JAT, follow these steps:
27
+ @@ -155,12 +170,21 @@ python -u scripts_jat_regent/eval_RandP.py --task ${TASK} &> outputs/RandP/${TAS
28
+ ```
29
+
30
+ ### REGENT Analyze data
31
+ +Necessary:
32
+ ```shell
33
+ -python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt &
34
+ -
35
+ python -u examples_regent/analyze_rows_tokenized.py &> examples_regent/analyze_rows_tokenized.txt &
36
+ +```
37
+
38
+ +Already ran and output dict in code:
39
+ +```shell
40
+ python -u examples_regent/get_dim_all_vector_tasks.py &> examples_regent/get_dim_all_vector_tasks.txt &
41
+ +
42
+ +python -u examples_regent/count_rows_to_consider.py &> examples_regent/count_rows_to_consider.txt &
43
+ +```
44
+ +
45
+ +Optional:
46
+ +```shell
47
+ +python -u examples_regent/compare_datasets.py &> examples_regent/compare_datasets.txt &
48
+ ```
49
+
50
+ ## PREV Dataset
51
+ diff --git a/jat_regent/RandP.py b/jat_regent/RandP.py
52
+ deleted file mode 100644
53
+ index b2bd8bf..0000000
54
+ --- a/jat_regent/RandP.py
55
+ +++ /dev/null
56
+ @@ -1,38 +0,0 @@
57
+ -import warnings
58
+ -from dataclasses import dataclass
59
+ -from typing import List, Optional, Tuple, Union
60
+ -
61
+ -import numpy as np
62
+ -import torch
63
+ -import torch.nn.functional as F
64
+ -from gymnasium import spaces
65
+ -from torch import BoolTensor, FloatTensor, LongTensor, Tensor, nn
66
+ -from transformers import GPTNeoModel, GPTNeoPreTrainedModel
67
+ -from transformers.modeling_outputs import ModelOutput
68
+ -from transformers.models.vit.modeling_vit import ViTPatchEmbeddings
69
+ -
70
+ -from jat.configuration_jat import JatConfig
71
+ -from jat.processing_jat import JatProcessor
72
+ -
73
+ -
74
+ -class RandP():
75
+ - def __init__(self, dataset) -> None:
76
+ - self.steps = 0
77
+ - # create an index for retrieval in vector obs envs (OR) collect all images in Atari
78
+ -
79
+ - def reset_rl(self):
80
+ - self.steps = 0
81
+ -
82
+ - def get_next_action(
83
+ - self,
84
+ - processor: JatProcessor,
85
+ - continuous_observation: Optional[List[float]] = None,
86
+ - discrete_observation: Optional[List[int]] = None,
87
+ - text_observation: Optional[str] = None,
88
+ - image_observation: Optional[np.ndarray] = None,
89
+ - action_space: Union[spaces.Box, spaces.Discrete] = None,
90
+ - reward: Optional[float] = None,
91
+ - deterministic: bool = False,
92
+ - context_window: Optional[int] = None,
93
+ - ):
94
+ - pass
95
+
96
+ diff --git a/jat_regent/modelling_jat_regent.py b/jat_regent/modelling_jat_regent.py
97
+ deleted file mode 100644
98
+ index e69de29..0000000
99
+ diff --git a/jat_regent/utils.py b/jat_regent/utils.py
100
+ index 56bfb44..36f6cca 100644
101
+ --- a/jat_regent/utils.py
102
+ +++ b/jat_regent/utils.py
103
+ @@ -8,23 +8,35 @@ from tqdm import tqdm
104
+ from autofaiss import build_index
105
+
106
+
107
+ +UNSEEN_TASK_NAMES = { # Total -- atari: 57, metaworld: 50, babyai: 39, mujoco: 11
108
+ +
109
+ +}
110
+ +
111
+ def myprint(str):
112
+ - # check if first character of string is a newline character
113
+ - if str[0] == '\n':
114
+ - str_without_newline = str[1:]
115
+ + # check if first characters of string are newline character
116
+ + num_newlines = 0
117
+ + while str[num_newlines] == '\n':
118
+ print()
119
+ - else:
120
+ - str_without_newline = str
121
+ + num_newlines += 1
122
+ + str_without_newline = str[num_newlines:]
123
+ print(f'{datetime.now().strftime("%Y-%m-%d %H:%M:%S")}: {str_without_newline}')
124
+
125
+ def is_png_img(item):
126
+ return isinstance(item, PngImagePlugin.PngImageFile)
127
+
128
+ +def get_last_row_for_1M_states(task):
129
+ + last_row_idx = {'atari-alien': 14134, 'atari-amidar': 14319, 'atari-assault': 14427, 'atari-asterix': 14456, 'atari-asteroids': 14348, 'atari-atlantis': 14325, 'atari-bankheist': 14167, 'atari-battlezone': 13981, 'atari-beamrider': 13442, 'atari-berzerk': 13534, 'atari-bowling': 14110, 'atari-boxing': 14542, 'atari-breakout': 13474, 'atari-centipede': 14196, 'atari-choppercommand': 13397, 'atari-crazyclimber': 14026, 'atari-defender': 13504, 'atari-demonattack': 13499, 'atari-doubledunk': 14292, 'atari-enduro': 13260, 'atari-fishingderby': 14073, 'atari-freeway': 14016, 'atari-frostbite': 14075, 'atari-gopher': 13143, 'atari-gravitar': 14405, 'atari-hero': 14044, 'atari-icehockey': 14017, 'atari-jamesbond': 12678, 'atari-kangaroo': 14248, 'atari-krull': 14204, 'atari-kungfumaster': 14030, 'atari-montezumarevenge': 14219, 'atari-mspacman': 14120, 'atari-namethisgame': 13575, 'atari-phoenix': 13539, 'atari-pitfall': 14287, 'atari-pong': 14151, 'atari-privateeye': 14105, 'atari-qbert': 14026, 'atari-riverraid': 14275, 'atari-roadrunner': 14127, 'atari-robotank': 14079, 'atari-seaquest': 14097, 'atari-skiing': 14708, 'atari-solaris': 14199, 'atari-spaceinvaders': 12652, 'atari-stargunner': 13822, 'atari-surround': 13840, 'atari-tennis': 14062, 'atari-timepilot': 13896, 'atari-tutankham': 13121, 'atari-upndown': 13504, 'atari-venture': 14260, 'atari-videopinball': 14272, 'atari-wizardofwor': 13920, 'atari-yarsrevenge': 13981, 'atari-zaxxon': 13833, 'babyai-action-obj-door': 95000, 'babyai-blocked-unlock-pickup': 29279, 'babyai-boss-level-no-unlock': 12087, 'babyai-boss-level': 12101, 'babyai-find-obj-s5': 32974, 'babyai-go-to-door': 95000, 'babyai-go-to-imp-unlock': 9286, 'babyai-go-to-local': 95000, 'babyai-go-to-obj-door': 95000, 'babyai-go-to-obj': 95000, 'babyai-go-to-red-ball-grey': 95000, 'babyai-go-to-red-ball-no-dists': 95000, 'babyai-go-to-red-ball': 95000, 'babyai-go-to-red-blue-ball': 95000, 'babyai-go-to-seq': 13744, 'babyai-go-to': 18974, 'babyai-key-corridor': 9014, 'babyai-mini-boss-level': 38119, 'babyai-move-two-across-s8n9': 24505, 'babyai-one-room-s8': 95000, 'babyai-open-door': 95000, 'babyai-open-doors-order-n4': 95000, 'babyai-open-red-door': 95000, 'babyai-open-two-doors': 73291, 'babyai-open': 32559, 'babyai-pickup-above': 34084, 'babyai-pickup-dist': 89640, 'babyai-pickup-loc': 95000, 'babyai-pickup': 18670, 'babyai-put-next-local': 83187, 'babyai-put-next': 56986, 'babyai-synth-loc': 21605, 'babyai-synth-seq': 13049, 'babyai-synth': 19409, 'babyai-unblock-pickup': 17881, 'babyai-unlock-local': 71186, 'babyai-unlock-pickup': 50883, 'babyai-unlock-to-unlock': 23062, 'babyai-unlock': 11734, 'metaworld-assembly': 10000, 'metaworld-basketball': 10000, 'metaworld-bin-picking': 10000, 'metaworld-box-close': 10000, 'metaworld-button-press-topdown-wall': 10000, 'metaworld-button-press-topdown': 10000, 'metaworld-button-press-wall': 10000, 'metaworld-button-press': 10000, 'metaworld-coffee-button': 10000, 'metaworld-coffee-pull': 10000, 'metaworld-coffee-push': 10000, 'metaworld-dial-turn': 10000, 'metaworld-disassemble': 10000, 'metaworld-door-close': 10000, 'metaworld-door-lock': 10000, 'metaworld-door-open': 10000, 'metaworld-door-unlock': 10000, 'metaworld-drawer-close': 10000, 'metaworld-drawer-open': 10000, 'metaworld-faucet-close': 10000, 'metaworld-faucet-open': 10000, 'metaworld-hammer': 10000, 'metaworld-hand-insert': 10000, 'metaworld-handle-press-side': 10000, 'metaworld-handle-press': 10000, 'metaworld-handle-pull-side': 10000, 'metaworld-handle-pull': 10000, 'metaworld-lever-pull': 10000, 'metaworld-peg-insert-side': 10000, 'metaworld-peg-unplug-side': 10000, 'metaworld-pick-out-of-hole': 10000, 'metaworld-pick-place-wall': 10000, 'metaworld-pick-place': 10000, 'metaworld-plate-slide-back-side': 10000, 'metaworld-plate-slide-back': 10000, 'metaworld-plate-slide-side': 10000, 'metaworld-plate-slide': 10000, 'metaworld-push-back': 10000, 'metaworld-push-wall': 10000, 'metaworld-push': 10000, 'metaworld-reach-wall': 10000, 'metaworld-reach': 10000, 'metaworld-shelf-place': 10000, 'metaworld-soccer': 10000, 'metaworld-stick-pull': 10000, 'metaworld-stick-push': 10000, 'metaworld-sweep-into': 10000, 'metaworld-sweep': 10000, 'metaworld-window-close': 10000, 'metaworld-window-open': 10000, 'mujoco-ant': 4023, 'mujoco-doublependulum': 4002, 'mujoco-halfcheetah': 4000, 'mujoco-hopper': 4931, 'mujoco-humanoid': 4119, 'mujoco-pendulum': 4959, 'mujoco-pusher': 9000, 'mujoco-reacher': 9000, 'mujoco-standup': 4000, 'mujoco-swimmer': 4000, 'mujoco-walker': 4101}
130
+ + return last_row_idx[task]
131
+ +
132
+ +def get_last_row_for_100k_states(task):
133
+ + last_row_idx = {'atari-alien': 3135, 'atari-amidar': 3142, 'atari-assault': 3132, 'atari-asterix': 3181, 'atari-asteroids': 3127, 'atari-atlantis': 3128, 'atari-bankheist': 3156, 'atari-battlezone': 3136, 'atari-beamrider': 3131, 'atari-berzerk': 3127, 'atari-bowling': 3148, 'atari-boxing': 3227, 'atari-breakout': 3128, 'atari-centipede': 3176, 'atari-choppercommand': 3144, 'atari-crazyclimber': 3134, 'atari-defender': 3127, 'atari-demonattack': 3127, 'atari-doubledunk': 3175, 'atari-enduro': 3126, 'atari-fishingderby': 3155, 'atari-freeway': 3131, 'atari-frostbite': 3146, 'atari-gopher': 3128, 'atari-gravitar': 3202, 'atari-hero': 3144, 'atari-icehockey': 3138, 'atari-jamesbond': 3131, 'atari-kangaroo': 3160, 'atari-krull': 3162, 'atari-kungfumaster': 3143, 'atari-montezumarevenge': 3168, 'atari-mspacman': 3143, 'atari-namethisgame': 3131, 'atari-phoenix': 3127, 'atari-pitfall': 3131, 'atari-pong': 3160, 'atari-privateeye': 3158, 'atari-qbert': 3136, 'atari-riverraid': 3157, 'atari-roadrunner': 3150, 'atari-robotank': 3133, 'atari-seaquest': 3138, 'atari-skiing': 3271, 'atari-solaris': 3129, 'atari-spaceinvaders': 3128, 'atari-stargunner': 3129, 'atari-surround': 3143, 'atari-tennis': 3129, 'atari-timepilot': 3132, 'atari-tutankham': 3127, 'atari-upndown': 3127, 'atari-venture': 3148, 'atari-videopinball': 3130, 'atari-wizardofwor': 3138, 'atari-yarsrevenge': 3129, 'atari-zaxxon': 3133, 'babyai-action-obj-door': 15923, 'babyai-blocked-unlock-pickup': 2919, 'babyai-boss-level-no-unlock': 1217, 'babyai-boss-level': 1159, 'babyai-find-obj-s5': 3345, 'babyai-go-to-door': 18875, 'babyai-go-to-imp-unlock': 923, 'babyai-go-to-local': 18724, 'babyai-go-to-obj-door': 16472, 'babyai-go-to-obj': 20197, 'babyai-go-to-red-ball-grey': 16953, 'babyai-go-to-red-ball-no-dists': 20165, 'babyai-go-to-red-ball': 18730, 'babyai-go-to-red-blue-ball': 16934, 'babyai-go-to-seq': 1439, 'babyai-go-to': 1964, 'babyai-key-corridor': 900, 'babyai-mini-boss-level': 3789, 'babyai-move-two-across-s8n9': 2462, 'babyai-one-room-s8': 16994, 'babyai-open-door': 13565, 'babyai-open-doors-order-n4': 9706, 'babyai-open-red-door': 21185, 'babyai-open-two-doors': 7348, 'babyai-open': 3331, 'babyai-pickup-above': 3392, 'babyai-pickup-dist': 19693, 'babyai-pickup-loc': 16405, 'babyai-pickup': 1806, 'babyai-put-next-local': 8303, 'babyai-put-next': 5703, 'babyai-synth-loc': 2183, 'babyai-synth-seq': 1316, 'babyai-synth': 1964, 'babyai-unblock-pickup': 1886, 'babyai-unlock-local': 7118, 'babyai-unlock-pickup': 5107, 'babyai-unlock-to-unlock': 2309, 'babyai-unlock': 1177, 'metaworld-assembly': 1000, 'metaworld-basketball': 1000, 'metaworld-bin-picking': 1000, 'metaworld-box-close': 1000, 'metaworld-button-press-topdown-wall': 1000, 'metaworld-button-press-topdown': 1000, 'metaworld-button-press-wall': 1000, 'metaworld-button-press': 1000, 'metaworld-coffee-button': 1000, 'metaworld-coffee-pull': 1000, 'metaworld-coffee-push': 1000, 'metaworld-dial-turn': 1000, 'metaworld-disassemble': 1000, 'metaworld-door-close': 1000, 'metaworld-door-lock': 1000, 'metaworld-door-open': 1000, 'metaworld-door-unlock': 1000, 'metaworld-drawer-close': 1000, 'metaworld-drawer-open': 1000, 'metaworld-faucet-close': 1000, 'metaworld-faucet-open': 1000, 'metaworld-hammer': 1000, 'metaworld-hand-insert': 1000, 'metaworld-handle-press-side': 1000, 'metaworld-handle-press': 1000, 'metaworld-handle-pull-side': 1000, 'metaworld-handle-pull': 1000, 'metaworld-lever-pull': 1000, 'metaworld-peg-insert-side': 1000, 'metaworld-peg-unplug-side': 1000, 'metaworld-pick-out-of-hole': 1000, 'metaworld-pick-place-wall': 1000, 'metaworld-pick-place': 1000, 'metaworld-plate-slide-back-side': 1000, 'metaworld-plate-slide-back': 1000, 'metaworld-plate-slide-side': 1000, 'metaworld-plate-slide': 1000, 'metaworld-push-back': 1000, 'metaworld-push-wall': 1000, 'metaworld-push': 1000, 'metaworld-reach-wall': 1000, 'metaworld-reach': 1000, 'metaworld-shelf-place': 1000, 'metaworld-soccer': 1000, 'metaworld-stick-pull': 1000, 'metaworld-stick-push': 1000, 'metaworld-sweep-into': 1000, 'metaworld-sweep': 1000, 'metaworld-window-close': 1000, 'metaworld-window-open': 1000, 'mujoco-ant': 401, 'mujoco-doublependulum': 401, 'mujoco-halfcheetah': 400, 'mujoco-hopper': 491, 'mujoco-humanoid': 415, 'mujoco-pendulum': 495, 'mujoco-pusher': 1000, 'mujoco-reacher': 2000, 'mujoco-standup': 400, 'mujoco-swimmer': 400, 'mujoco-walker': 407}
134
+ + return last_row_idx[task]
135
+ +
136
+ def get_obs_dim(task):
137
+ assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco")
138
+
139
+ all_obs_dims={'babyai-action-obj-door': 212, 'babyai-blocked-unlock-pickup': 212, 'babyai-boss-level-no-unlock': 212, 'babyai-boss-level': 212, 'babyai-find-obj-s5': 212, 'babyai-go-to-door': 212, 'babyai-go-to-imp-unlock': 212, 'babyai-go-to-local': 212, 'babyai-go-to-obj-door': 212, 'babyai-go-to-obj': 212, 'babyai-go-to-red-ball-grey': 212, 'babyai-go-to-red-ball-no-dists': 212, 'babyai-go-to-red-ball': 212, 'babyai-go-to-red-blue-ball': 212, 'babyai-go-to-seq': 212, 'babyai-go-to': 212, 'babyai-key-corridor': 212, 'babyai-mini-boss-level': 212, 'babyai-move-two-across-s8n9': 212, 'babyai-one-room-s8': 212, 'babyai-open-door': 212, 'babyai-open-doors-order-n4': 212, 'babyai-open-red-door': 212, 'babyai-open-two-doors': 212, 'babyai-open': 212, 'babyai-pickup-above': 212, 'babyai-pickup-dist': 212, 'babyai-pickup-loc': 212, 'babyai-pickup': 212, 'babyai-put-next-local': 212, 'babyai-put-next': 212, 'babyai-synth-loc': 212, 'babyai-synth-seq': 212, 'babyai-synth': 212, 'babyai-unblock-pickup': 212, 'babyai-unlock-local': 212, 'babyai-unlock-pickup': 212, 'babyai-unlock-to-unlock': 212, 'babyai-unlock': 212, 'metaworld-assembly': 39, 'metaworld-basketball': 39, 'metaworld-bin-picking': 39, 'metaworld-box-close': 39, 'metaworld-button-press-topdown-wall': 39, 'metaworld-button-press-topdown': 39, 'metaworld-button-press-wall': 39, 'metaworld-button-press': 39, 'metaworld-coffee-button': 39, 'metaworld-coffee-pull': 39, 'metaworld-coffee-push': 39, 'metaworld-dial-turn': 39, 'metaworld-disassemble': 39, 'metaworld-door-close': 39, 'metaworld-door-lock': 39, 'metaworld-door-open': 39, 'metaworld-door-unlock': 39, 'metaworld-drawer-close': 39, 'metaworld-drawer-open': 39, 'metaworld-faucet-close': 39, 'metaworld-faucet-open': 39, 'metaworld-hammer': 39, 'metaworld-hand-insert': 39, 'metaworld-handle-press-side': 39, 'metaworld-handle-press': 39, 'metaworld-handle-pull-side': 39, 'metaworld-handle-pull': 39, 'metaworld-lever-pull': 39, 'metaworld-peg-insert-side': 39, 'metaworld-peg-unplug-side': 39, 'metaworld-pick-out-of-hole': 39, 'metaworld-pick-place-wall': 39, 'metaworld-pick-place': 39, 'metaworld-plate-slide-back-side': 39, 'metaworld-plate-slide-back': 39, 'metaworld-plate-slide-side': 39, 'metaworld-plate-slide': 39, 'metaworld-push-back': 39, 'metaworld-push-wall': 39, 'metaworld-push': 39, 'metaworld-reach-wall': 39, 'metaworld-reach': 39, 'metaworld-shelf-place': 39, 'metaworld-soccer': 39, 'metaworld-stick-pull': 39, 'metaworld-stick-push': 39, 'metaworld-sweep-into': 39, 'metaworld-sweep': 39, 'metaworld-window-close': 39, 'metaworld-window-open': 39, 'mujoco-ant': 27, 'mujoco-doublependulum': 11, 'mujoco-halfcheetah': 17, 'mujoco-hopper': 11, 'mujoco-humanoid': 376, 'mujoco-pendulum': 4, 'mujoco-pusher': 23, 'mujoco-reacher': 11, 'mujoco-standup': 376, 'mujoco-swimmer': 8, 'mujoco-walker': 17}
140
+ - return all_obs_dims[task]
141
+ + return (all_obs_dims[task],)
142
+
143
+ def get_act_dim(task):
144
+ assert task.startswith("babyai") or task.startswith("metaworld") or task.startswith("mujoco")
145
+ @@ -36,141 +48,188 @@ def get_act_dim(task):
146
+ elif task.startswith("mujoco"):
147
+ all_act_dims={'mujoco-ant': 8, 'mujoco-doublependulum': 1, 'mujoco-halfcheetah': 6, 'mujoco-hopper': 3, 'mujoco-humanoid': 17, 'mujoco-pendulum': 1, 'mujoco-pusher': 7, 'mujoco-reacher': 2, 'mujoco-standup': 17, 'mujoco-swimmer': 2, 'mujoco-walker': 6}
148
+ return all_act_dims[task]
149
+ -
150
+ -def process_row_atari(attn_mask, row_of_obs, task):
151
+ - """
152
+ - Example for selection with bools:
153
+ - >>> a = np.array([0,1,2,3,4,5])
154
+ - >>> b = np.array([1,0,0,0,0,1]).astype(bool)
155
+ - >>> a[b]
156
+ - array([0, 5])
157
+ - """
158
+ - attn_mask = np.array(attn_mask).astype(bool)
159
+
160
+ - row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs])
161
+ - row_of_obs = row_of_obs[attn_mask]
162
+ +def get_task_info(task):
163
+ + rew_key = 'rewards'
164
+ + attn_key = 'attention_mask'
165
+ + if task.startswith("atari"):
166
+ + obs_key = 'image_observations'
167
+ + act_key = 'discrete_actions'
168
+ + B = 32 # half of 54
169
+ + obs_dim = (3, 4*84, 84)
170
+ + elif task.startswith("babyai"):
171
+ + obs_key = 'discrete_observations' # also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset)
172
+ + act_key = 'discrete_actions'
173
+ + B = 256 # half of 512
174
+ + obs_dim = get_obs_dim(task)
175
+ + elif task.startswith("metaworld") or task.startswith("mujoco"):
176
+ + obs_key = 'continuous_observations'
177
+ + act_key = 'continuous_actions'
178
+ + B = 256
179
+ + obs_dim = get_obs_dim(task)
180
+ +
181
+ + return rew_key, attn_key, obs_key, act_key, B, obs_dim
182
+ +
183
+ +def process_row_of_obs_atari_full_without_mask(row_of_obs):
184
+ +
185
+ + if not isinstance(row_of_obs, torch.Tensor):
186
+ + row_of_obs = torch.stack([to_tensor(np.array(img)) for img in row_of_obs])
187
+ row_of_obs = row_of_obs * 0.5 + 0.5 # denormalize from [-1, 1] to [0, 1]
188
+ - assert row_of_obs.shape == (sum(attn_mask), 84, 4, 84)
189
+ + assert row_of_obs.shape == (len(row_of_obs), 84, 4, 84)
190
+ row_of_obs = row_of_obs.permute(0, 2, 1, 3) # (*, 4, 84, 84)
191
+ - row_of_obs = row_of_obs.reshape(sum(attn_mask), 4*84, 84) # put side-by-side
192
+ + row_of_obs = row_of_obs.reshape(len(row_of_obs), 4*84, 84) # put side-by-side
193
+ row_of_obs = row_of_obs.unsqueeze(1).repeat(1, 3, 1, 1) # repeat for 3 channels
194
+ - assert row_of_obs.shape == (sum(attn_mask), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension
195
+ -
196
+ - return attn_mask, row_of_obs
197
+ + assert row_of_obs.shape == (len(row_of_obs), 3, 4*84, 84) # sum(attn_mask) is the batch size dimension
198
+ +
199
+ + return row_of_obs
200
+
201
+ -def process_row_vector(attn_mask, row_of_obs, task, return_numpy=False):
202
+ - attn_mask = np.array(attn_mask).astype(bool)
203
+ +def collect_all_atari_data(dataset, all_row_idxs=None):
204
+ + if all_row_idxs is None:
205
+ + all_row_idxs = list(range(len(dataset['train'])))
206
+
207
+ - row_of_obs = np.array(row_of_obs)
208
+ - if not return_numpy:
209
+ - row_of_obs = torch.tensor(row_of_obs)
210
+ - row_of_obs = row_of_obs[attn_mask]
211
+ - assert row_of_obs.shape == (sum(attn_mask), get_obs_dim(task))
212
+ -
213
+ - return attn_mask, row_of_obs
214
+ -
215
+ -def retrieve_atari(row_of_obs, # query: (row_B, 3, 4*84, 84)
216
+ - dataset, # to retrieve from
217
+ - all_rows_to_consider, # rows to consider
218
+ - num_to_retrieve, # top-k
219
+ + all_rows_of_obs = []
220
+ + all_attn_masks = []
221
+ + for row_idx in tqdm(all_row_idxs):
222
+ + datarow = dataset['train'][row_idx]
223
+ + row_of_obs = process_row_of_obs_atari_full_without_mask(datarow['image_observations'])
224
+ + attn_mask = np.array(datarow['attention_mask']).astype(bool)
225
+ + all_rows_of_obs.append(row_of_obs) # appending tensor
226
+ + all_attn_masks.append(attn_mask) # appending np array
227
+ + all_rows_of_obs = torch.stack(all_rows_of_obs, dim=0) # stacking tensors
228
+ + all_attn_masks = np.stack(all_attn_masks, axis=0) # concatenating np arrays
229
+ + assert (all_rows_of_obs.shape == (len(all_row_idxs), 32, 3, 4*84, 84) and
230
+ + all_attn_masks.shape == (len(all_row_idxs), 32))
231
+ + return all_attn_masks, all_rows_of_obs
232
+ +
233
+ +def collect_all_data(dataset, task, obs_key):
234
+ + last_row_idx = get_last_row_for_100k_states(task)
235
+ + all_row_idxs = list(range(last_row_idx))
236
+ + if task.startswith("atari"):
237
+ + myprint("Collecting all Atari images and Atari attention masks...")
238
+ + all_attn_masks_OG, all_rows_of_obs_OG = collect_all_atari_data(dataset, all_row_idxs)
239
+ + else:
240
+ + datarows = dataset['train'][all_row_idxs]
241
+ + all_rows_of_obs_OG = np.array(datarows[obs_key])
242
+ + all_attn_masks_OG = np.array(datarows['attention_mask']).astype(bool)
243
+ + return all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs
244
+ +
245
+ +def collect_subset(all_rows_of_obs_OG,
246
+ + all_attn_masks_OG,
247
+ + all_rows_to_consider,
248
+ + kwargs
249
+ + ):
250
+ + """
251
+ + Function to collect subset of data given all_rows_to_consider, reshape it, create all_indices and return.
252
+ + Used in both retrieve_atari() and retrieve_vector() --> build_index_vector().
253
+ + """
254
+ + myprint(f'\n\n\n' + ('-'*100) + f'Collecting subset...')
255
+ + # read kwargs
256
+ + B, task, obs_dim = kwargs['B'], kwargs['task'], kwargs['obs_dim']
257
+ +
258
+ + # take subset based on all_rows_to_consider
259
+ + myprint(f'Taking subset of data based on all_rows_to_consider...')
260
+ + all_processed_rows_of_obs = all_rows_of_obs_OG[all_rows_to_consider]
261
+ + all_attn_masks = all_attn_masks_OG[all_rows_to_consider]
262
+ + assert (all_processed_rows_of_obs.shape == (len(all_rows_to_consider), B, *obs_dim) and
263
+ + all_attn_masks.shape == (len(all_rows_to_consider), B))
264
+ +
265
+ + # reshape
266
+ + myprint(f'Reshaping data...')
267
+ + all_attn_masks = all_attn_masks.reshape(-1)
268
+ + all_processed_rows_of_obs = all_processed_rows_of_obs.reshape(-1, *obs_dim)
269
+ + all_processed_rows_of_obs = all_processed_rows_of_obs[all_attn_masks]
270
+ + assert (all_attn_masks.shape == (len(all_rows_to_consider) * B,) and
271
+ + all_processed_rows_of_obs.shape == (np.sum(all_attn_masks), *obs_dim))
272
+ +
273
+ + # collect indices of data
274
+ + myprint(f'Collecting indices of data...')
275
+ + all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)])
276
+ + all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s
277
+ + assert all_indices.shape == (np.sum(all_attn_masks), 2)
278
+ +
279
+ + myprint(f'{all_indices.shape=}, {all_processed_rows_of_obs.shape=}')
280
+ + myprint(('-'*100) + '\n\n\n')
281
+ + return all_indices, all_processed_rows_of_obs
282
+ +
283
+ +def retrieve_atari(row_of_obs, # query: (xbdim, 3, 4*84, 84) / (xdim *obs_dim)
284
+ + all_processed_rows_of_obs,
285
+ + all_indices,
286
+ + num_to_retrieve,
287
+ kwargs
288
+ - ):
289
+ + ):
290
+ + """
291
+ + Retrieval for Atari with images, ssim distance, and on GPU.
292
+ + """
293
+ assert isinstance(row_of_obs, torch.Tensor)
294
+
295
+ # read kwargs # Note: B = len of row
296
+ - B, attn_key, obs_key, device, task, batch_size_retrieval = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval']
297
+ + B, device, batch_size_retrieval = kwargs['B'], kwargs['device'], kwargs['batch_size_retrieval']
298
+
299
+ # batch size of row_of_obs which can be <= B since we process before calling this function
300
+ - row_B = row_of_obs.shape[0]
301
+ -
302
+ + xbdim = row_of_obs.shape[0]
303
+ +
304
+ + # collect subset of data that we can retrieve from
305
+ + ydim = all_processed_rows_of_obs.shape[0]
306
+ +
307
+ # first argument for ssim
308
+ - repeated_row_og = row_of_obs.repeat_interleave(B, dim=0).to(device)
309
+ - assert repeated_row_og.shape == (row_B*B, 3, 4*84, 84)
310
+ + xbatch = row_of_obs.repeat_interleave(batch_size_retrieval, dim=0).to(device)
311
+ + assert xbatch.shape == (xbdim * batch_size_retrieval, 3, 4*84, 84)
312
+
313
+ - # iterate over all other rows
314
+ + # iterate over data that we can retrieve from in batches
315
+ all_ssim = []
316
+ - all_indices = []
317
+ - total = 0
318
+ - for other_row_idx in tqdm(all_rows_to_consider):
319
+ - other_attn_mask, other_row_of_obs = process_row_atari(dataset['train'][other_row_idx][attn_key], dataset['train'][other_row_idx][obs_key])
320
+ -
321
+ - # batch size of other_row_of_obs
322
+ - other_row_B = other_row_of_obs.shape[0]
323
+ - total += other_row_B
324
+ -
325
+ - # first argument for ssim: RECHECK
326
+ - if other_row_B < B: # when other row has less observations than expected
327
+ - repeated_row = row_of_obs.repeat_interleave(other_row_B, dim=0).to(device)
328
+ - elif other_row_B == B: # otherwise just use the one created before the for loop
329
+ - repeated_row = repeated_row_og
330
+ - assert repeated_row.shape == (row_B*other_row_B, 3, 4*84, 84)
331
+ -
332
+ + for j in range(0, ydim, batch_size_retrieval):
333
+ # second argument for ssim
334
+ - repeated_other_row = other_row_of_obs.repeat(row_B, 1, 1, 1).to(device)
335
+ - assert repeated_other_row.shape == (row_B*other_row_B, 3, 4*84, 84)
336
+ + ybatch = all_processed_rows_of_obs[j:j+batch_size_retrieval]
337
+ + ybdim = ybatch.shape[0]
338
+ + ybatch = ybatch.repeat(xbdim, 1, 1, 1).to(device)
339
+ + assert ybatch.shape == (ybdim * xbdim, 3, 4*84, 84)
340
+ +
341
+ + if ybdim < batch_size_retrieval: # for last batch
342
+ + xbatch = row_of_obs.repeat_interleave(ybdim, dim=0).to(device)
343
+ + assert xbatch.shape == (xbdim * ybdim, 3, 4*84, 84)
344
+
345
+ # compare via ssim and updated all_ssim
346
+ - ssim_score = ssim(repeated_row, repeated_other_row, data_range=1.0, size_average=False)
347
+ - ssim_score = ssim_score.reshape(row_B, other_row_B)
348
+ + ssim_score = ssim(xbatch, ybatch, data_range=1.0, size_average=False)
349
+ + ssim_score = ssim_score.reshape(xbdim, ybdim)
350
+ all_ssim.append(ssim_score)
351
+
352
+ - # update all_indices
353
+ - all_indices.extend([[other_row_idx, i] for i in range(other_row_B)])
354
+ -
355
+ # concat
356
+ all_ssim = torch.cat(all_ssim, dim=1)
357
+ - assert all_ssim.shape == (row_B, total)
358
+ + assert all_ssim.shape == (xbdim, ydim)
359
+
360
+ - all_indices = np.array(all_indices)
361
+ - assert all_indices.shape == (total, 2)
362
+ + assert all_indices.shape == (ydim, 2)
363
+
364
+ # get top-k indices
365
+ topk_values, topk_indices = torch.topk(all_ssim, num_to_retrieve, dim=1, largest=True)
366
+ topk_indices = topk_indices.cpu().numpy()
367
+ - assert topk_indices.shape == (row_B, num_to_retrieve)
368
+ + assert topk_indices.shape == (xbdim, num_to_retrieve)
369
+
370
+ # convert topk indices to indices in the dataset
371
+ - retrieved_indices = np.array(all_indices[topk_indices])
372
+ - assert retrieved_indices.shape == (row_B, num_to_retrieve, 2)
373
+ -
374
+ - # pad the above to expected B
375
+ - if row_B < B:
376
+ - retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0)
377
+ - assert retrieved_indices.shape == (B, num_to_retrieve, 2)
378
+ + retrieved_indices = all_indices[topk_indices]
379
+ + assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2)
380
+
381
+ return retrieved_indices
382
+
383
+ -def build_index_vector(all_rows_of_obs_og,
384
+ - all_attn_masks_og,
385
+ +def build_index_vector(all_rows_of_obs_OG,
386
+ + all_attn_masks_OG,
387
+ all_rows_to_consider,
388
+ kwargs
389
+ - ):
390
+ + ):
391
+ + """
392
+ + Builds FAISS index for vector observation environments.
393
+ + """
394
+ # read kwargs # Note: B = len of row
395
+ - B, attn_key, obs_key, device, task, batch_size_retrieval, nb_cores_autofaiss = kwargs['B'], kwargs['attn_key'], kwargs['obs_key'], kwargs['device'], kwargs['task'], kwargs['batch_size_retrieval'], kwargs['nb_cores_autofaiss']
396
+ - obs_dim = get_obs_dim(task)
397
+ + nb_cores_autofaiss = kwargs['nb_cores_autofaiss']
398
+
399
+ - # take subset based on all_rows_to_consider
400
+ - myprint(f'Taking subset')
401
+ - all_rows_of_obs = all_rows_of_obs_og[all_rows_to_consider]
402
+ - all_attn_masks = all_attn_masks_og[all_rows_to_consider]
403
+ - assert (all_rows_of_obs.shape == (len(all_rows_to_consider), B, obs_dim) and
404
+ - all_attn_masks.shape == (len(all_rows_to_consider), B))
405
+ -
406
+ - # reshape
407
+ - all_attn_masks = all_attn_masks.reshape(-1)
408
+ - all_rows_of_obs = all_rows_of_obs.reshape(-1, obs_dim)
409
+ - all_rows_of_obs = all_rows_of_obs[all_attn_masks]
410
+ - assert all_rows_of_obs.shape == (np.sum(all_attn_masks), obs_dim)
411
+ + # take subset based on all_rows_to_consider, reshape, and save indices of data
412
+ + all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG, all_attn_masks_OG, all_rows_to_consider, kwargs)
413
+
414
+ - # save indices of data to retrieve from
415
+ - myprint(f'Saving indices of data to retrieve from')
416
+ - all_indices = np.array([[row_idx, i] for row_idx in all_rows_to_consider for i in range(B)])
417
+ - all_indices = all_indices[all_attn_masks] # this is fine because all attn masks have 0s that only come after 1s
418
+ - assert all_indices.shape == (np.sum(all_attn_masks), 2)
419
+ + # make sure input to build_index is float, otherwise you will get reading temp file error
420
+ + all_processed_rows_of_obs = all_processed_rows_of_obs.astype(float)
421
+
422
+ # build index
423
+ - myprint(f'Building index...')
424
+ - knn_index, knn_index_infos = build_index(embeddings=all_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader!
425
+ + myprint(('-'*100) + 'Building index...')
426
+ + knn_index, knn_index_infos = build_index(embeddings=all_processed_rows_of_obs, # Note: embeddings have to be float to avoid errors in autofaiss / embedding_reader!
427
+ save_on_disk=False,
428
+ min_nearest_neighbors_to_retrieve=20, # default: 20
429
+ max_index_query_time_ms=10, # default: 10
430
+ @@ -179,34 +238,32 @@ def build_index_vector(all_rows_of_obs_og,
431
+ metric_type='l2',
432
+ nb_cores=nb_cores_autofaiss, # default: None # "The number of cores to use, by default will use all cores" as seen in https://criteo.github.io/autofaiss/getting_started/quantization.html#the-build-index-command
433
+ )
434
+ + myprint(('-'*100) + '\n\n\n')
435
+
436
+ - return knn_index, all_indices
437
+ + return all_indices, knn_index
438
+
439
+ -def retrieve_vector(row_of_obs, # query: (row_B, dim)
440
+ - dataset, # to retrieve from
441
+ - all_rows_to_consider, # rows to consider
442
+ - num_to_retrieve, # top-k
443
+ +def retrieve_vector(row_of_obs, # query: (xbdim, *obs_dim)
444
+ + knn_index,
445
+ + all_indices,
446
+ + num_to_retrieve,
447
+ kwargs
448
+ - ):
449
+ + ):
450
+ + """
451
+ + Retrieval for vector observation environments.
452
+ + """
453
+ assert isinstance(row_of_obs, np.ndarray)
454
+
455
+ # read few kwargs
456
+ B = kwargs['B']
457
+
458
+ # batch size of row_of_obs which can be <= B since we process before calling this function
459
+ - row_B = row_of_obs.shape[0]
460
+ + xbdim = row_of_obs.shape[0]
461
+
462
+ - # read dataset_tuple
463
+ - all_rows_of_obs, all_attn_masks = dataset
464
+ -
465
+ - # create index and all_indices
466
+ - knn_index, all_indices = build_index_vector(all_rows_of_obs, all_attn_masks, all_rows_to_consider, kwargs)
467
+ -
468
+ # retrieve
469
+ myprint(f'Retrieving...')
470
+ topk_indices, _ = knn_index.search(row_of_obs, 10 * num_to_retrieve)
471
+ topk_indices = topk_indices.astype(int)
472
+ - assert topk_indices.shape == (row_B, 10 * num_to_retrieve)
473
+ + assert topk_indices.shape == (xbdim, 10 * num_to_retrieve)
474
+
475
+ # remove -1s and crop to num_to_retrieve
476
+ try:
477
+ @@ -219,16 +276,10 @@ def retrieve_vector(row_of_obs, # query: (row_B, dim)
478
+ print(f'-------------------------------------------------------------------------------------------------------------------------------------------')
479
+ print(f'Leaving some -1s in topk_indices and continuing')
480
+ topk_indices = np.array([indices[:num_to_retrieve] for indices in topk_indices])
481
+ - assert topk_indices.shape == (row_B, num_to_retrieve)
482
+ + assert topk_indices.shape == (xbdim, num_to_retrieve)
483
+
484
+ # convert topk indices to indices in the dataset
485
+ retrieved_indices = all_indices[topk_indices]
486
+ - assert retrieved_indices.shape == (row_B, num_to_retrieve, 2)
487
+ -
488
+ - # pad the above to expected B
489
+ - if row_B < B:
490
+ - retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-row_B, num_to_retrieve, 2), dtype=int)], axis=0)
491
+ - assert retrieved_indices.shape == (B, num_to_retrieve, 2)
492
+ + assert retrieved_indices.shape == (xbdim, num_to_retrieve, 2)
493
+
494
+ - myprint(f'Returning')
495
+ return retrieved_indices
496
+
497
+ diff --git a/scripts_regent/eval_RandP.py b/scripts_regent/eval_RandP.py
498
+ index 07e545c..146b347 100755
499
+ --- a/scripts_regent/eval_RandP.py
500
+ +++ b/scripts_regent/eval_RandP.py
501
+ @@ -15,9 +15,10 @@ from transformers import AutoModelForCausalLM, AutoProcessor, HfArgumentParser
502
+
503
+ from jat.eval.rl import TASK_NAME_TO_ENV_ID, make
504
+ from jat.utils import normalize, push_to_hub, save_video_grid
505
+ -from jat_regent.RandP import RandP
506
+ +from jat_regent.modeling_RandP import RandP
507
+ from datasets import load_from_disk
508
+ from datasets.config import HF_DATASETS_CACHE
509
+ +from jat_regent.utils import myprint
510
+
511
+
512
+ @dataclass
513
+ @@ -70,6 +71,7 @@ def eval_rl(model, processor, task, eval_args):
514
+ scores = []
515
+ frames = []
516
+ for episode in tqdm(range(eval_args.num_episodes), desc=task, unit="episode", leave=False):
517
+ + myprint(('-'*100) + f'{episode=}')
518
+ observation, _ = env.reset()
519
+ reward = None
520
+ rewards = []
521
+ @@ -96,6 +98,7 @@ def eval_rl(model, processor, task, eval_args):
522
+ frames.append(np.array(env.render(), dtype=np.uint8))
523
+
524
+ scores.append(sum(rewards))
525
+ + myprint(('-'*100) + '\n\n\n')
526
+ env.close()
527
+
528
+ raw_mean, raw_std = np.mean(scores), np.std(scores)
529
+ @@ -145,7 +148,9 @@ def main():
530
+ tasks.extend([env_id for env_id in TASK_NAME_TO_ENV_ID.keys() if env_id.startswith(domain)])
531
+
532
+ device = torch.device("cpu") if eval_args.use_cpu else get_default_device()
533
+ - processor = None
534
+ + processor = AutoProcessor.from_pretrained(
535
+ + 'jat-project/jat', cache_dir=None, trust_remote_code=True
536
+ + )
537
+
538
+ evaluations = {}
539
+ video_list = []
540
+ @@ -153,14 +158,18 @@ def main():
541
+
542
+ for task in tqdm(tasks, desc="Evaluation", unit="task", leave=True):
543
+ if task in TASK_NAME_TO_ENV_ID.keys():
544
+ + myprint(('-'*100) + f'{task=}')
545
+ dataset = load_from_disk(f'{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}')
546
+ - model = RandP(dataset)
547
+ + model = RandP(task,
548
+ + dataset,
549
+ + device,)
550
+ scores, frames, fps = eval_rl(model, processor, task, eval_args)
551
+ evaluations[task] = scores
552
+ # Save the video
553
+ if eval_args.save_video:
554
+ video_list.append(frames)
555
+ input_fps.append(fps)
556
+ + myprint(('-'*100) + '\n\n\n')
557
+ else:
558
+ warnings.warn(f"Task {task} is not supported.")
559
+
560
+ diff --git a/scripts_regent/offline_retrieval_jat_regent.py b/scripts_regent/offline_retrieval_jat_regent.py
561
+ index c83d259..aad678a 100644
562
+ --- a/scripts_regent/offline_retrieval_jat_regent.py
563
+ +++ b/scripts_regent/offline_retrieval_jat_regent.py
564
+ @@ -8,7 +8,7 @@ import time
565
+ from datetime import datetime
566
+ from datasets import load_from_disk
567
+ from datasets.config import HF_DATASETS_CACHE
568
+ -from jat_regent.utils import myprint, process_row_atari, process_row_vector, retrieve_atari, retrieve_vector
569
+ +from jat_regent.utils import myprint, get_task_info, collect_all_data, process_row_of_obs_atari_full_without_mask, retrieve_atari, retrieve_vector, collect_subset, build_index_vector
570
+ import logging
571
+ logging.basicConfig(level=logging.DEBUG)
572
+
573
+ @@ -17,7 +17,8 @@ def main():
574
+ parser = argparse.ArgumentParser(description='Build RAAGENT sequence indices')
575
+ parser.add_argument('--task', type=str, default='atari-alien', help='Task name')
576
+ parser.add_argument('--num_to_retrieve', type=int, default=100, help='Number of states/windows to retrieve')
577
+ - parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector observation environments')
578
+ + parser.add_argument('--nb_cores_autofaiss', type=int, default=16, help='Number of cores to use for faiss in vector obs envs')
579
+ + parser.add_argument('--batch_size_retrieval', type=int, default=1024, help='Batch size for retrieval in atari')
580
+ args = parser.parse_args()
581
+
582
+ # load dataset, map, device, for task
583
+ @@ -25,77 +26,83 @@ def main():
584
+ dataset_path = f"{HF_DATASETS_CACHE}/jat-project/jat-dataset-tokenized/{task}"
585
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
586
+
587
+ - rew_key = 'rewards'
588
+ - attn_key = 'attention_mask'
589
+ - if task.startswith("atari"):
590
+ - obs_key = 'image_observations'
591
+ - act_key = 'discrete_actions'
592
+ - len_row_tokenized_known = 32 # half of 54
593
+ - process_row_fn = process_row_atari
594
+ - retrieve_fn = retrieve_atari
595
+ - elif task.startswith("babyai"):
596
+ - obs_key = 'discrete_observations'# also has 'text_observations' only for raw dataset not for tokenized dataset (as it is combined into discrete_observation in tokenized dataset)
597
+ - act_key = 'discrete_actions'
598
+ - len_row_tokenized_known = 256 # half of 512
599
+ - process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True)
600
+ - retrieve_fn = retrieve_vector
601
+ - elif task.startswith("metaworld") or task.startswith("mujoco"):
602
+ - obs_key = 'continuous_observations'
603
+ - act_key = 'continuous_actions'
604
+ - len_row_tokenized_known = 256
605
+ - process_row_fn = lambda attn_mask, row_of_obs, task: process_row_vector(attn_mask, row_of_obs, task, return_numpy=True)
606
+ - retrieve_fn = retrieve_vector
607
+ + rew_key, attn_key, obs_key, act_key, B, obs_dim = get_task_info(task)
608
+
609
+ dataset = load_from_disk(dataset_path)
610
+ with open(f"{dataset_path}/map_from_rows_to_episodes_for_tokenized.json", 'r') as f:
611
+ map_from_rows_to_episodes_for_tokenized = json.load(f)
612
+
613
+ # setup kwargs
614
+ - len_dataset = len(dataset['train'])
615
+ - B = len_row_tokenized_known
616
+ kwargs = {'B': B,
617
+ - 'attn_key':attn_key,
618
+ - 'obs_key':obs_key,
619
+ - 'device':device,
620
+ - 'task':task,
621
+ - 'batch_size_retrieval':None,
622
+ - 'nb_cores_autofaiss':None if task.startswith("atari") else args.nb_cores_autofaiss,
623
+ - }
624
+ + 'obs_dim': obs_dim,
625
+ + 'attn_key': attn_key,
626
+ + 'obs_key': obs_key,
627
+ + 'device': device,
628
+ + 'task': task,
629
+ + 'batch_size_retrieval': args.batch_size_retrieval,
630
+ + 'nb_cores_autofaiss': None if task.startswith("atari") else args.nb_cores_autofaiss,
631
+ + }
632
+
633
+ # collect all observations in a single array (this takes some time) for vector observation environments
634
+ - if not task.startswith("atari"):
635
+ - myprint("Collecting all observations/attn_masks in a single array")
636
+ - all_rows_of_obs = np.array(dataset['train'][obs_key])
637
+ - all_attn_masks = np.array(dataset['train'][attn_key]).astype(bool)
638
+ + myprint("Collecting all observations/attn_masks")
639
+ + all_rows_of_obs_OG, all_attn_masks_OG, all_row_idxs = collect_all_data(dataset, task, obs_key)
640
+
641
+ # iterate over rows
642
+ all_retrieved_indices = []
643
+ - for row_idx in range(len_dataset):
644
+ - myprint(f"\nProcessing row {row_idx}/{len_dataset}")
645
+ + for row_idx in all_row_idxs:
646
+ + myprint(f"\nProcessing row {row_idx}/{len(all_row_idxs)}")
647
+ current_ep = map_from_rows_to_episodes_for_tokenized[str(row_idx)]
648
+
649
+ - attn_mask, row_of_obs = process_row_fn(dataset['train'][row_idx][attn_key], dataset['train'][row_idx][obs_key], task)
650
+ + # get row_of_obs and attn_mask
651
+ + datarow = dataset['train'][row_idx]
652
+ + attn_mask = np.array(datarow[attn_key]).astype(bool)
653
+ + if task.startswith("atari"):
654
+ + row_of_obs = process_row_of_obs_atari_full_without_mask(datarow[obs_key])
655
+ + else:
656
+ + row_of_obs = np.array(datarow[obs_key])
657
+ + row_of_obs = row_of_obs[attn_mask]
658
+ + assert row_of_obs.shape == (np.sum(attn_mask), *obs_dim)
659
+
660
+ # compare with rows from all but the current episode
661
+ - all_other_rows = [idx for idx in range(len_dataset) if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep]
662
+ + all_other_row_idxs = [idx for idx in all_row_idxs if map_from_rows_to_episodes_for_tokenized[str(idx)] != current_ep]
663
+
664
+ # do the retrieval
665
+ - retrieved_indices = retrieve_fn(row_of_obs=row_of_obs,
666
+ - dataset=dataset if task.startswith("atari") else (all_rows_of_obs, all_attn_masks),
667
+ - all_rows_to_consider=all_other_rows,
668
+ - num_to_retrieve=args.num_to_retrieve,
669
+ - kwargs=kwargs,
670
+ - )
671
+ + if task.startswith("atari"):
672
+ + all_indices, all_processed_rows_of_obs = collect_subset(all_rows_of_obs_OG=all_rows_of_obs_OG,
673
+ + all_attn_masks_OG=all_attn_masks_OG,
674
+ + all_rows_to_consider=all_row_idxs,
675
+ + kwargs=kwargs)
676
+ + retrieved_indices = retrieve_atari(row_of_obs=row_of_obs,
677
+ + all_processed_rows_of_obs=all_processed_rows_of_obs,
678
+ + all_indices=all_indices,
679
+ + num_to_retrieve=args.num_to_retrieve,
680
+ + kwargs=kwargs)
681
+ + else:
682
+ + all_indices, knn_index = build_index_vector(all_rows_of_obs_OG=all_rows_of_obs_OG,
683
+ + all_attn_masks_OG=all_attn_masks_OG,
684
+ + all_rows_to_consider=all_other_row_idxs,
685
+ + kwargs=kwargs)
686
+ + retrieved_indices = retrieve_vector(row_of_obs=row_of_obs,
687
+ + knn_index=knn_index,
688
+ + all_indices=all_indices,
689
+ + num_to_retrieve=args.num_to_retrieve,
690
+ + kwargs=kwargs)
691
+ +
692
+ + # pad the above to expected B
693
+ + xbdim = row_of_obs.shape[0]
694
+ + if xbdim < B:
695
+ + retrieved_indices = np.concatenate([retrieved_indices, np.zeros((B-xbdim, args.num_to_retrieve, 2), dtype=int)], axis=0)
696
+ + assert retrieved_indices.shape == (B, args.num_to_retrieve, 2)
697
+
698
+ # collect retrieved indices
699
+ all_retrieved_indices.append(retrieved_indices)
700
+
701
+ # concat
702
+ all_retrieved_indices = np.stack(all_retrieved_indices, axis=0)
703
+ - assert all_retrieved_indices.shape == (len_dataset, B, args.num_to_retrieve, 2)
704
+ + assert all_retrieved_indices.shape == (len(all_row_idxs), B, args.num_to_retrieve, 2)
705
+
706
+ # save arrays as bin for easy memmap access and faster loading
707
+ - all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len_dataset}_{B}_{args.num_to_retrieve}_2.bin")
708
+ + all_retrieved_indices.tofile(f"{dataset_path}/retrieved_indices_{len(all_row_idxs)}_{B}_{args.num_to_retrieve}_2.bin")
709
+
710
+ if __name__ == "__main__":
711
+ main()
712
+
replay.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07be059a9429a473b5b17baa3708722b914b49c1e81c2c57e350ea5acb4339b7
3
+ size 1295354
sf_log.txt ADDED
The diff for this file is too large to render. See raw diff