qgallouedec HF Staff commited on
Commit
ab0c658
·
1 Parent(s): ad413f1

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ replay.mp4 filter=lfs diff=lfs merge=lfs -text
.summary/0/events.out.tfevents.1688740849.qgallouedec-MS-7C84 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6ff0eaa13758cff09bdd058ae223b8b01908d06e46c5955d72f79ec1d5483e6
3
+ size 700655
README.md ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: sample-factory
3
+ tags:
4
+ - deep-reinforcement-learning
5
+ - reinforcement-learning
6
+ - sample-factory
7
+ model-index:
8
+ - name: APPO
9
+ results:
10
+ - task:
11
+ type: reinforcement-learning
12
+ name: reinforcement-learning
13
+ dataset:
14
+ name: bin-picking-v2
15
+ type: bin-picking-v2
16
+ metrics:
17
+ - type: mean_reward
18
+ value: 168.59 +/- 207.20
19
+ name: mean_reward
20
+ verified: false
21
+ ---
22
+
23
+ A(n) **APPO** model trained on the **bin-picking-v2** environment.
24
+
25
+ This model was trained using Sample-Factory 2.0: https://github.com/alex-petrenko/sample-factory.
26
+ Documentation for how to use Sample-Factory can be found at https://www.samplefactory.dev/
27
+
28
+
29
+ ## Downloading the model
30
+
31
+ After installing Sample-Factory, download the model with:
32
+ ```
33
+ python -m sample_factory.huggingface.load_from_hub -r qgallouedec/bin-picking-v2
34
+ ```
35
+
36
+
37
+ ## Using the model
38
+
39
+ To run the model after download, use the `enjoy` script corresponding to this environment:
40
+ ```
41
+ python -m enjoy --algo=APPO --env=bin-picking-v2 --train_dir=./train_dir --experiment=bin-picking-v2
42
+ ```
43
+
44
+
45
+ You can also upload models to the Hugging Face Hub using the same script with the `--push_to_hub` flag.
46
+ See https://www.samplefactory.dev/10-huggingface/huggingface/ for more details
47
+
48
+ ## Training with this model
49
+
50
+ To continue training with this model, use the `train` script corresponding to this environment:
51
+ ```
52
+ python -m train --algo=APPO --env=bin-picking-v2 --train_dir=./train_dir --experiment=bin-picking-v2 --restart_behavior=resume --train_for_env_steps=10000000000
53
+ ```
54
+
55
+ Note, you may have to adjust `--train_for_env_steps` to a suitably high number as the experiment will resume at the number of steps it concluded at.
56
+
checkpoint_p0/best_000018160_9297920_reward_202.428.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:83cb11421e93c30a4d5668d31b2d615e4a6526c4f88f6ac2ecd1ed5631765ebd
3
+ size 98239
checkpoint_p0/checkpoint_000019296_9879552.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cf8931a2ed9a5502496b448578eb87f863bb3703cafccfab6fc9171810e860e4
3
+ size 98567
checkpoint_p0/checkpoint_000019544_10006528.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f3aad58d66e16763d6e8c9e7dcdc53df4f7c7c3f9fb5e8bcd847343ff9ab654
3
+ size 98567
config.json ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "help": false,
3
+ "algo": "APPO",
4
+ "env": "bin-picking-v2",
5
+ "experiment": "bin-picking-v2",
6
+ "train_dir": "/home/qgallouedec/data/gia/data/envs/metaworld/train_dir",
7
+ "restart_behavior": "restart",
8
+ "device": "cpu",
9
+ "seed": null,
10
+ "num_policies": 1,
11
+ "async_rl": false,
12
+ "serial_mode": false,
13
+ "batched_sampling": false,
14
+ "num_batches_to_accumulate": 2,
15
+ "worker_num_splits": 2,
16
+ "policy_workers_per_policy": 1,
17
+ "max_policy_lag": 1000,
18
+ "num_workers": 8,
19
+ "num_envs_per_worker": 8,
20
+ "batch_size": 1024,
21
+ "num_batches_per_epoch": 4,
22
+ "num_epochs": 2,
23
+ "rollout": 64,
24
+ "recurrence": 1,
25
+ "shuffle_minibatches": false,
26
+ "gamma": 0.99,
27
+ "reward_scale": 0.1,
28
+ "reward_clip": 1000.0,
29
+ "value_bootstrap": true,
30
+ "normalize_returns": true,
31
+ "exploration_loss_coeff": 0.0,
32
+ "value_loss_coeff": 1.3,
33
+ "kl_loss_coeff": 0.1,
34
+ "exploration_loss": "entropy",
35
+ "gae_lambda": 0.95,
36
+ "ppo_clip_ratio": 0.2,
37
+ "ppo_clip_value": 1.0,
38
+ "with_vtrace": false,
39
+ "vtrace_rho": 1.0,
40
+ "vtrace_c": 1.0,
41
+ "optimizer": "adam",
42
+ "adam_eps": 1e-06,
43
+ "adam_beta1": 0.9,
44
+ "adam_beta2": 0.999,
45
+ "max_grad_norm": 3.5,
46
+ "learning_rate": 0.00295,
47
+ "lr_schedule": "linear_decay",
48
+ "lr_schedule_kl_threshold": 0.008,
49
+ "lr_adaptive_min": 1e-06,
50
+ "lr_adaptive_max": 0.01,
51
+ "obs_subtract_mean": 0.0,
52
+ "obs_scale": 1.0,
53
+ "normalize_input": true,
54
+ "normalize_input_keys": null,
55
+ "decorrelate_experience_max_seconds": 0,
56
+ "decorrelate_envs_on_one_worker": true,
57
+ "actor_worker_gpus": [],
58
+ "set_workers_cpu_affinity": true,
59
+ "force_envs_single_thread": false,
60
+ "default_niceness": 0,
61
+ "log_to_file": true,
62
+ "experiment_summaries_interval": 3,
63
+ "flush_summaries_interval": 30,
64
+ "stats_avg": 100,
65
+ "summaries_use_frameskip": true,
66
+ "heartbeat_interval": 20,
67
+ "heartbeat_reporting_interval": 180,
68
+ "train_for_env_steps": 10000000,
69
+ "train_for_seconds": 10000000000,
70
+ "save_every_sec": 15,
71
+ "keep_checkpoints": 2,
72
+ "load_checkpoint_kind": "latest",
73
+ "save_milestones_sec": -1,
74
+ "save_best_every_sec": 5,
75
+ "save_best_metric": "reward",
76
+ "save_best_after": 100000,
77
+ "benchmark": false,
78
+ "encoder_mlp_layers": [
79
+ 64,
80
+ 64
81
+ ],
82
+ "encoder_conv_architecture": "convnet_simple",
83
+ "encoder_conv_mlp_layers": [
84
+ 512
85
+ ],
86
+ "use_rnn": false,
87
+ "rnn_size": 512,
88
+ "rnn_type": "gru",
89
+ "rnn_num_layers": 1,
90
+ "decoder_mlp_layers": [],
91
+ "nonlinearity": "tanh",
92
+ "policy_initialization": "torch_default",
93
+ "policy_init_gain": 1.0,
94
+ "actor_critic_share_weights": true,
95
+ "adaptive_stddev": false,
96
+ "continuous_tanh_scale": 0.0,
97
+ "initial_stddev": 1.0,
98
+ "use_env_info_cache": false,
99
+ "env_gpu_actions": false,
100
+ "env_gpu_observations": true,
101
+ "env_frameskip": 1,
102
+ "env_framestack": 1,
103
+ "pixel_format": "CHW",
104
+ "use_record_episode_statistics": false,
105
+ "with_wandb": true,
106
+ "wandb_user": "qgallouedec",
107
+ "wandb_project": "sample_facotry_metaworld",
108
+ "wandb_group": null,
109
+ "wandb_job_type": "SF",
110
+ "wandb_tags": [],
111
+ "with_pbt": false,
112
+ "pbt_mix_policies_in_one_env": true,
113
+ "pbt_period_env_steps": 5000000,
114
+ "pbt_start_mutation": 20000000,
115
+ "pbt_replace_fraction": 0.3,
116
+ "pbt_mutation_rate": 0.15,
117
+ "pbt_replace_reward_gap": 0.1,
118
+ "pbt_replace_reward_gap_absolute": 1e-06,
119
+ "pbt_optimize_gamma": false,
120
+ "pbt_target_objective": "true_objective",
121
+ "pbt_perturb_min": 1.1,
122
+ "pbt_perturb_max": 1.5,
123
+ "command_line": "--env bin-picking-v2 --experiment bin-picking-v2 --with_wandb True --wandb_user qgallouedec --wandb_project sample_facotry_metaworld",
124
+ "cli_args": {
125
+ "env": "bin-picking-v2",
126
+ "experiment": "bin-picking-v2",
127
+ "with_wandb": true,
128
+ "wandb_user": "qgallouedec",
129
+ "wandb_project": "sample_facotry_metaworld"
130
+ },
131
+ "git_hash": "aed90d9e164e44f91bab1d70c09fac4dee064031",
132
+ "git_repo_name": "https://github.com/huggingface/gia",
133
+ "wandb_unique_id": "bin-picking-v2_20230707_164044_957676"
134
+ }
git.diff ADDED
@@ -0,0 +1,714 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ diff --git a/data/envs/download_expert_scores.py b/data/envs/download_expert_scores.py
2
+ index 4c3f06b..88b6c45 100644
3
+ --- a/data/envs/download_expert_scores.py
4
+ +++ b/data/envs/download_expert_scores.py
5
+ @@ -12,162 +12,162 @@ from tqdm import tqdm
6
+
7
+
8
+ ENV_NAMES = [
9
+ - "atari-alien",
10
+ - "atari-amidar",
11
+ - "atari-assault",
12
+ - "atari-asterix",
13
+ - "atari-asteroids",
14
+ - "atari-atlantis",
15
+ - "atari-bankheist",
16
+ - "atari-battlezone",
17
+ - "atari-beamrider",
18
+ - "atari-berzerk",
19
+ - "atari-bowling",
20
+ - "atari-boxing",
21
+ - "atari-breakout",
22
+ - "atari-centipede",
23
+ - "atari-choppercommand",
24
+ - "atari-crazyclimber",
25
+ - "atari-defender",
26
+ - "atari-demonattack",
27
+ - "atari-doubledunk",
28
+ - "atari-enduro",
29
+ - "atari-fishingderby",
30
+ - "atari-freeway",
31
+ - "atari-frostbite",
32
+ - "atari-gopher",
33
+ - "atari-gravitar",
34
+ - "atari-hero",
35
+ - "atari-icehockey",
36
+ - "atari-jamesbond",
37
+ - "atari-kangaroo",
38
+ - "atari-krull",
39
+ - "atari-kungfumaster",
40
+ - "atari-montezumarevenge",
41
+ - "atari-mspacman",
42
+ - "atari-namethisgame",
43
+ - "atari-phoenix",
44
+ - "atari-pitfall",
45
+ - "atari-pong",
46
+ - "atari-privateeye",
47
+ - "atari-qbert",
48
+ - "atari-riverraid",
49
+ - "atari-roadrunner",
50
+ - "atari-robotank",
51
+ - "atari-seaquest",
52
+ - "atari-skiing",
53
+ - "atari-solaris",
54
+ - "atari-spaceinvaders",
55
+ - "atari-stargunner",
56
+ - # "atari-surround", # Not in the dataset
57
+ - "atari-tennis",
58
+ - "atari-timepilot",
59
+ - "atari-tutankham",
60
+ - "atari-upndown",
61
+ - "atari-venture",
62
+ - "atari-videopinball",
63
+ - "atari-wizardofwor",
64
+ - "atari-yarsrevenge",
65
+ - "atari-zaxxon",
66
+ - "babyai-action-obj-door",
67
+ - "babyai-blocked-unlock-pickup",
68
+ - "babyai-boss-level-no-unlock",
69
+ - "babyai-boss-level",
70
+ - "babyai-find-obj-s5",
71
+ - "babyai-go-to-door",
72
+ - # "babyai-go-to-imp-unlock", # Not in the dataset
73
+ - "babyai-go-to-local",
74
+ - "babyai-go-to-obj-door",
75
+ - "babyai-go-to-obj",
76
+ - "babyai-go-to-red-ball-grey",
77
+ - "babyai-go-to-red-ball-no-dists",
78
+ - "babyai-go-to-red-ball",
79
+ - "babyai-go-to-red-blue-ball",
80
+ - "babyai-go-to-seq",
81
+ - "babyai-go-to",
82
+ - "babyai-key-corridor",
83
+ - "babyai-key-in-box",
84
+ - "babyai-mini-boss-level",
85
+ - "babyai-move-two-across",
86
+ - "babyai-one-room-s8",
87
+ - "babyai-open-door",
88
+ - "babyai-open-doors-order",
89
+ - "babyai-open-red-door",
90
+ - "babyai-open-two-doors",
91
+ - "babyai-open",
92
+ - "babyai-pickup-above",
93
+ - "babyai-pickup-dist",
94
+ - "babyai-pickup-loc",
95
+ - "babyai-pickup",
96
+ - "babyai-synth-loc",
97
+ - "babyai-synth-seq",
98
+ - "babyai-synth",
99
+ - "babyai-unblock-pickup",
100
+ - "babyai-unlock-local",
101
+ - "babyai-unlock-pickup",
102
+ - # "babyai-unlock-to-unlock", # Not in the dataset
103
+ - # "babyai-unlock", # Not in the dataset
104
+ + # "atari-alien",
105
+ + # "atari-amidar",
106
+ + # "atari-assault",
107
+ + # "atari-asterix",
108
+ + # "atari-asteroids",
109
+ + # "atari-atlantis",
110
+ + # "atari-bankheist",
111
+ + # "atari-battlezone",
112
+ + # "atari-beamrider",
113
+ + # "atari-berzerk",
114
+ + # "atari-bowling",
115
+ + # "atari-boxing",
116
+ + # "atari-breakout",
117
+ + # "atari-centipede",
118
+ + # "atari-choppercommand",
119
+ + # "atari-crazyclimber",
120
+ + # "atari-defender",
121
+ + # "atari-demonattack",
122
+ + # "atari-doubledunk",
123
+ + # "atari-enduro",
124
+ + # "atari-fishingderby",
125
+ + # "atari-freeway",
126
+ + # "atari-frostbite",
127
+ + # "atari-gopher",
128
+ + # "atari-gravitar",
129
+ + # "atari-hero",
130
+ + # "atari-icehockey",
131
+ + # "atari-jamesbond",
132
+ + # "atari-kangaroo",
133
+ + # "atari-krull",
134
+ + # "atari-kungfumaster",
135
+ + # "atari-montezumarevenge",
136
+ + # "atari-mspacman",
137
+ + # "atari-namethisgame",
138
+ + # "atari-phoenix",
139
+ + # "atari-pitfall",
140
+ + # "atari-pong",
141
+ + # "atari-privateeye",
142
+ + # "atari-qbert",
143
+ + # "atari-riverraid",
144
+ + # "atari-roadrunner",
145
+ + # "atari-robotank",
146
+ + # "atari-seaquest",
147
+ + # "atari-skiing",
148
+ + # "atari-solaris",
149
+ + # "atari-spaceinvaders",
150
+ + # "atari-stargunner",
151
+ + # # "atari-surround", # Not in the dataset
152
+ + # "atari-tennis",
153
+ + # "atari-timepilot",
154
+ + # "atari-tutankham",
155
+ + # "atari-upndown",
156
+ + # "atari-venture",
157
+ + # "atari-videopinball",
158
+ + # "atari-wizardofwor",
159
+ + # "atari-yarsrevenge",
160
+ + # "atari-zaxxon",
161
+ + # "babyai-action-obj-door",
162
+ + # "babyai-blocked-unlock-pickup",
163
+ + # "babyai-boss-level-no-unlock",
164
+ + # "babyai-boss-level",
165
+ + # "babyai-find-obj-s5",
166
+ + # "babyai-go-to-door",
167
+ + # # "babyai-go-to-imp-unlock", # Not in the dataset
168
+ + # "babyai-go-to-local",
169
+ + # "babyai-go-to-obj-door",
170
+ + # "babyai-go-to-obj",
171
+ + # "babyai-go-to-red-ball-grey",
172
+ + # "babyai-go-to-red-ball-no-dists",
173
+ + # "babyai-go-to-red-ball",
174
+ + # "babyai-go-to-red-blue-ball",
175
+ + # "babyai-go-to-seq",
176
+ + # "babyai-go-to",
177
+ + # "babyai-key-corridor",
178
+ + # "babyai-key-in-box",
179
+ + # "babyai-mini-boss-level",
180
+ + # "babyai-move-two-across",
181
+ + # "babyai-one-room-s8",
182
+ + # "babyai-open-door",
183
+ + # "babyai-open-doors-order",
184
+ + # "babyai-open-red-door",
185
+ + # "babyai-open-two-doors",
186
+ + # "babyai-open",
187
+ + # "babyai-pickup-above",
188
+ + # "babyai-pickup-dist",
189
+ + # "babyai-pickup-loc",
190
+ + # "babyai-pickup",
191
+ + # "babyai-synth-loc",
192
+ + # "babyai-synth-seq",
193
+ + # "babyai-synth",
194
+ + # "babyai-unblock-pickup",
195
+ + # "babyai-unlock-local",
196
+ + # "babyai-unlock-pickup",
197
+ + # # "babyai-unlock-to-unlock", # Not in the dataset
198
+ + # # "babyai-unlock", # Not in the dataset
199
+ "metaworld-assembly",
200
+ - "metaworld-basketball",
201
+ - "metaworld-bin-picking",
202
+ - "metaworld-box-close",
203
+ - "metaworld-button-press-topdown-wall",
204
+ - "metaworld-button-press-topdown",
205
+ - "metaworld-button-press-wall",
206
+ - "metaworld-button-press",
207
+ - "metaworld-coffee-button",
208
+ - "metaworld-coffee-pull",
209
+ - "metaworld-coffee-push",
210
+ - "metaworld-dial-turn",
211
+ - "metaworld-disassemble",
212
+ - "metaworld-door-close",
213
+ - "metaworld-door-lock",
214
+ - "metaworld-door-open",
215
+ - "metaworld-door-unlock",
216
+ - "metaworld-drawer-close",
217
+ - "metaworld-drawer-open",
218
+ - "metaworld-faucet-close",
219
+ - "metaworld-faucet-open",
220
+ - "metaworld-hammer",
221
+ - "metaworld-hand-insert",
222
+ - "metaworld-handle-press-side",
223
+ - "metaworld-handle-press",
224
+ - "metaworld-handle-pull-side",
225
+ - "metaworld-handle-pull",
226
+ - "metaworld-lever-pull",
227
+ - "metaworld-peg-insert-side",
228
+ - "metaworld-peg-unplug-side",
229
+ - "metaworld-pick-out-of-hole",
230
+ - "metaworld-pick-place-wall",
231
+ - "metaworld-pick-place",
232
+ - "metaworld-plate-slide-back-side",
233
+ - "metaworld-plate-slide-back",
234
+ - "metaworld-plate-slide-side",
235
+ - "metaworld-plate-slide",
236
+ - "metaworld-push-back",
237
+ - "metaworld-push-wall",
238
+ - "metaworld-push",
239
+ - "metaworld-reach-wall",
240
+ - "metaworld-reach",
241
+ - "metaworld-shelf-place",
242
+ - "metaworld-soccer",
243
+ - "metaworld-stick-pull",
244
+ - "metaworld-stick-push",
245
+ - "metaworld-sweep-into",
246
+ - "metaworld-sweep",
247
+ - "metaworld-window-close",
248
+ - "metaworld-window-open",
249
+ - "mujoco-ant",
250
+ - "mujoco-doublependulum",
251
+ - "mujoco-halfcheetah",
252
+ - "mujoco-hopper",
253
+ + # "metaworld-basketball",
254
+ + # "metaworld-bin-picking",
255
+ + # "metaworld-box-close",
256
+ + # "metaworld-button-press-topdown-wall",
257
+ + # "metaworld-button-press-topdown",
258
+ + # "metaworld-button-press-wall",
259
+ + # "metaworld-button-press",
260
+ + # "metaworld-coffee-button",
261
+ + # "metaworld-coffee-pull",
262
+ + # "metaworld-coffee-push",
263
+ + # "metaworld-dial-turn",
264
+ + # "metaworld-disassemble",
265
+ + # "metaworld-door-close",
266
+ + # "metaworld-door-lock",
267
+ + # "metaworld-door-open",
268
+ + # "metaworld-door-unlock",
269
+ + # "metaworld-drawer-close",
270
+ + # "metaworld-drawer-open",
271
+ + # "metaworld-faucet-close",
272
+ + # "metaworld-faucet-open",
273
+ + # "metaworld-hammer",
274
+ + # "metaworld-hand-insert",
275
+ + # "metaworld-handle-press-side",
276
+ + # "metaworld-handle-press",
277
+ + # "metaworld-handle-pull-side",
278
+ + # "metaworld-handle-pull",
279
+ + # "metaworld-lever-pull",
280
+ + # "metaworld-peg-insert-side",
281
+ + # "metaworld-peg-unplug-side",
282
+ + # "metaworld-pick-out-of-hole",
283
+ + # "metaworld-pick-place-wall",
284
+ + # "metaworld-pick-place",
285
+ + # "metaworld-plate-slide-back-side",
286
+ + # "metaworld-plate-slide-back",
287
+ + # "metaworld-plate-slide-side",
288
+ + # "metaworld-plate-slide",
289
+ + # "metaworld-push-back",
290
+ + # "metaworld-push-wall",
291
+ + # "metaworld-push",
292
+ + # "metaworld-reach-wall",
293
+ + # "metaworld-reach",
294
+ + # "metaworld-shelf-place",
295
+ + # "metaworld-soccer",
296
+ + # "metaworld-stick-pull",
297
+ + # "metaworld-stick-push",
298
+ + # "metaworld-sweep-into",
299
+ + # "metaworld-sweep",
300
+ + # "metaworld-window-close",
301
+ + # "metaworld-window-open",
302
+ + # "mujoco-ant",
303
+ + # "mujoco-doublependulum",
304
+ + # "mujoco-halfcheetah",
305
+ + # "mujoco-hopper",
306
+ # "mujoco-humanoid", # Not in the dataset
307
+ - "mujoco-pendulum",
308
+ - # "mujoco-pusher", # Not in the dataset
309
+ - "mujoco-reacher",
310
+ + # "mujoco-pendulum",
311
+ + # # "mujoco-pusher", # Not in the dataset
312
+ + # "mujoco-reacher",
313
+ # "mujoco-standup", # Not in the dataset
314
+ - "mujoco-swimmer",
315
+ - "mujoco-walker",
316
+ + # "mujoco-swimmer",
317
+ + # "mujoco-walker",
318
+ ]
319
+
320
+
321
+ diff --git a/data/envs/metaworld/generate_dataset.py b/data/envs/metaworld/generate_dataset.py
322
+ index e21b237..c2b1907 100644
323
+ --- a/data/envs/metaworld/generate_dataset.py
324
+ +++ b/data/envs/metaworld/generate_dataset.py
325
+ @@ -142,7 +142,8 @@ def create_dataset(cfg: Config, dataset_size: int = 100_000, split: str = "train
326
+
327
+ # Actions shape should be [num_agents, num_actions] even if it's [1, 1]
328
+ actions = preprocess_actions(env_info, actions)
329
+ -
330
+ + # Clamp actions to be in the range of the action space
331
+ + actions = np.clip(actions, env.action_space.low, env.action_space.high)
332
+ rnn_states = policy_outputs["new_rnn_states"]
333
+ dataset["continuous_observations"][-1].append(observations["obs"].cpu().numpy()[0])
334
+ dataset["continuous_actions"][-1].append(actions[0])
335
+ diff --git a/data/envs/metaworld/generate_dataset_all.sh b/data/envs/metaworld/generate_dataset_all.sh
336
+ index cfdae2f..5db8c4b 100755
337
+ --- a/data/envs/metaworld/generate_dataset_all.sh
338
+ +++ b/data/envs/metaworld/generate_dataset_all.sh
339
+ @@ -2,58 +2,58 @@
340
+
341
+ ENVS=(
342
+ assembly
343
+ - basketball
344
+ - bin-picking
345
+ - box-close
346
+ - button-press-topdown
347
+ - button-press-topdown-wall
348
+ - button-press
349
+ - button-press-wall
350
+ - coffee-button
351
+ - coffee-pull
352
+ - coffee-push
353
+ - dial-turn
354
+ - disassemble
355
+ - door-close
356
+ - door-lock
357
+ - door-open
358
+ - door-unlock
359
+ - drawer-close
360
+ - drawer-open
361
+ - faucet-close
362
+ - faucet-open
363
+ - hammer
364
+ - hand-insert
365
+ - handle-press-side
366
+ - handle-press
367
+ - handle-pull-side
368
+ - handle-pull
369
+ - lever-pull
370
+ - peg-insert-side
371
+ - peg-unplug-side
372
+ - pick-out-of-hole
373
+ - pick-place
374
+ - pick-place-wall
375
+ - plate-slide-back-side
376
+ - plate-slide-back
377
+ - plate-slide-side
378
+ - plate-slide
379
+ - push-back
380
+ - push
381
+ - push-wall
382
+ - reach
383
+ - reach-wall
384
+ - shelf-place
385
+ - soccer
386
+ - stick-pull
387
+ - stick-push
388
+ - sweep-into
389
+ - sweep
390
+ - window-close
391
+ - window-open
392
+ + # basketball
393
+ + # bin-picking
394
+ + # box-close
395
+ + # button-press-topdown
396
+ + # button-press-topdown-wall
397
+ + # button-press
398
+ + # button-press-wall
399
+ + # coffee-button
400
+ + # coffee-pull
401
+ + # coffee-push
402
+ + # dial-turn
403
+ + # disassemble
404
+ + # door-close
405
+ + # door-lock
406
+ + # door-open
407
+ + # door-unlock
408
+ + # drawer-close
409
+ + # drawer-open
410
+ + # faucet-close
411
+ + # faucet-open
412
+ + # hammer
413
+ + # hand-insert
414
+ + # handle-press-side
415
+ + # handle-press
416
+ + # handle-pull-side
417
+ + # handle-pull
418
+ + # lever-pull
419
+ + # peg-insert-side
420
+ + # peg-unplug-side
421
+ + # pick-out-of-hole
422
+ + # pick-place
423
+ + # pick-place-wall
424
+ + # plate-slide-back-side
425
+ + # plate-slide-back
426
+ + # plate-slide-side
427
+ + # plate-slide
428
+ + # push-back
429
+ + # push
430
+ + # push-wall
431
+ + # reach
432
+ + # reach-wall
433
+ + # shelf-place
434
+ + # soccer
435
+ + # stick-pull
436
+ + # stick-push
437
+ + # sweep-into
438
+ + # sweep
439
+ + # window-close
440
+ + # window-open
441
+ )
442
+
443
+ for ENV in "${ENVS[@]}"; do
444
+ - python -m sample_factory.huggingface.load_from_hub -r qgallouedec/sample-factory-$ENV-v2
445
+ - python generate_dataset.py --env $ENV-v2 --experiment sample-factory-$ENV-v2 --train_dir=./train_dir
446
+ + python -m sample_factory.huggingface.load_from_hub -r qgallouedec/$ENV-v2
447
+ + python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir
448
+ done
449
+ diff --git a/data/envs/metaworld/push_all.sh b/data/envs/metaworld/push_all.sh
450
+ index 9d71467..5b05c6d 100755
451
+ --- a/data/envs/metaworld/push_all.sh
452
+ +++ b/data/envs/metaworld/push_all.sh
453
+ @@ -2,57 +2,57 @@
454
+
455
+ ENVS=(
456
+ assembly
457
+ - basketball
458
+ - bin-picking
459
+ - box-close
460
+ - button-press-topdown
461
+ - button-press-topdown-wall
462
+ - button-press
463
+ - button-press-wall
464
+ - coffee-button
465
+ - coffee-pull
466
+ - coffee-push
467
+ - dial-turn
468
+ - disassemble
469
+ - door-close
470
+ - door-lock
471
+ - door-open
472
+ - door-unlock
473
+ - drawer-close
474
+ - drawer-open
475
+ - faucet-close
476
+ - faucet-open
477
+ - hammer
478
+ - hand-insert
479
+ - handle-press-side
480
+ - handle-press
481
+ - handle-pull-side
482
+ - handle-pull
483
+ - lever-pull
484
+ - peg-insert-side
485
+ - peg-unplug-side
486
+ - pick-out-of-hole
487
+ - pick-place
488
+ - pick-place-wall
489
+ - plate-slide-back-side
490
+ - plate-slide-back
491
+ - plate-slide-side
492
+ - plate-slide
493
+ - push-back
494
+ - push
495
+ - push-wall
496
+ - reach
497
+ - reach-wall
498
+ - shelf-place
499
+ - soccer
500
+ - stick-pull
501
+ - stick-push
502
+ - sweep-into
503
+ - sweep
504
+ - window-close
505
+ - window-open
506
+ + # basketball
507
+ + # bin-picking
508
+ + # box-close
509
+ + # button-press-topdown
510
+ + # button-press-topdown-wall
511
+ + # button-press
512
+ + # button-press-wall
513
+ + # coffee-button
514
+ + # coffee-pull
515
+ + # coffee-push
516
+ + # dial-turn
517
+ + # disassemble
518
+ + # door-close
519
+ + # door-lock
520
+ + # door-open
521
+ + # door-unlock
522
+ + # drawer-close
523
+ + # drawer-open
524
+ + # faucet-close
525
+ + # faucet-open
526
+ + # hammer
527
+ + # hand-insert
528
+ + # handle-press-side
529
+ + # handle-press
530
+ + # handle-pull-side
531
+ + # handle-pull
532
+ + # lever-pull
533
+ + # peg-insert-side
534
+ + # peg-unplug-side
535
+ + # pick-out-of-hole
536
+ + # pick-place
537
+ + # pick-place-wall
538
+ + # plate-slide-back-side
539
+ + # plate-slide-back
540
+ + # plate-slide-side
541
+ + # plate-slide
542
+ + # push-back
543
+ + # push
544
+ + # push-wall
545
+ + # reach
546
+ + # reach-wall
547
+ + # shelf-place
548
+ + # soccer
549
+ + # stick-pull
550
+ + # stick-push
551
+ + # sweep-into
552
+ + # sweep
553
+ + # window-close
554
+ + # window-open
555
+ )
556
+
557
+ for ENV in "${ENVS[@]}"; do
558
+ - python enjoy.py --algo=APPO --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=qgallouedec/sample-factory-$ENV-v2 --save_video --no_render --enjoy_script=enjoy --train_script=train --load_checkpoint_kind best
559
+ + python enjoy.py --algo=APPO --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=qgallouedec/$ENV-v2 --save_video --no_render --enjoy_script=enjoy --train_script=train --load_checkpoint_kind best
560
+ done
561
+ diff --git a/data/envs/metaworld/train.py b/data/envs/metaworld/train.py
562
+ index 46dc581..c72f289 100644
563
+ --- a/data/envs/metaworld/train.py
564
+ +++ b/data/envs/metaworld/train.py
565
+ @@ -79,7 +79,7 @@ def override_defaults(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
566
+ num_workers=8,
567
+ num_envs_per_worker=8,
568
+ worker_num_splits=2,
569
+ - train_for_env_steps=100_000_000,
570
+ + train_for_env_steps=10_000_000,
571
+ encoder_mlp_layers=[64, 64],
572
+ env_frameskip=1,
573
+ nonlinearity="tanh",
574
+ diff --git a/data/envs/metaworld/train_all.sh b/data/envs/metaworld/train_all.sh
575
+ index dbf328a..1b3c4c8 100755
576
+ --- a/data/envs/metaworld/train_all.sh
577
+ +++ b/data/envs/metaworld/train_all.sh
578
+ @@ -1,7 +1,7 @@
579
+ #!/bin/bash
580
+
581
+ ENVS=(
582
+ - assembly
583
+ + # assembly
584
+ basketball
585
+ bin-picking
586
+ box-close
587
+ diff --git a/gia/eval/callback.py b/gia/eval/callback.py
588
+ index 5c3a080..4b6198f 100644
589
+ --- a/gia/eval/callback.py
590
+ +++ b/gia/eval/callback.py
591
+ @@ -2,10 +2,10 @@ import glob
592
+ import json
593
+ import subprocess
594
+
595
+ -import wandb
596
+ from accelerate import Accelerator
597
+ from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
598
+
599
+ +import wandb
600
+ from gia.config import Arguments
601
+ from gia.eval.utils import is_slurm_available
602
+
603
+ diff --git a/gia/eval/evaluator.py b/gia/eval/evaluator.py
604
+ index 91b645c..3e2cae7 100644
605
+ --- a/gia/eval/evaluator.py
606
+ +++ b/gia/eval/evaluator.py
607
+ @@ -1,3 +1,5 @@
608
+ +from typing import Optional
609
+ +
610
+ import torch
611
+
612
+ from gia.config.arguments import Arguments
613
+ @@ -5,11 +7,12 @@ from gia.model import GiaModel
614
+
615
+
616
+ class Evaluator:
617
+ - def __init__(self, args: Arguments, task: str) -> None:
618
+ + def __init__(self, args: Arguments, task: str, mean_random: Optional[float] = None) -> None:
619
+ self.args = args
620
+ self.task = task
621
+ + self.mean_random = mean_random
622
+
623
+ - @torch.no_grad()
624
+ + @torch.inference_mode()
625
+ def evaluate(self, model: GiaModel) -> float:
626
+ return self._evaluate(model)
627
+
628
+ diff --git a/gia/eval/mappings.py b/gia/eval/mappings.py
629
+ deleted file mode 100644
630
+ index e7ba9d3..0000000
631
+ --- a/gia/eval/mappings.py
632
+ +++ /dev/null
633
+ @@ -1,11 +0,0 @@
634
+ -TASK_TO_ENV_MAPPING = {
635
+ - "mujoco-ant": "Ant-v4",
636
+ - "mujoco-halfcheetah": "HalfCheetah-v4",
637
+ - "mujoco-hopper": "Hopper-v4",
638
+ - "mujoco-doublependulum": "InvertedDoublePendulum-v4",
639
+ - "mujoco-pendulum": "InvertedPendulum-v4",
640
+ - "mujoco-reacher": "Reacher-v4",
641
+ - "mujoco-swimmer": "Swimmer-v4",
642
+ - "mujoco-walker": "Walker2d-v4",
643
+ - # Atari etc...
644
+ -}
645
+ diff --git a/gia/eval/rl/__init__.py b/gia/eval/rl/__init__.py
646
+ index 36d890b..da5e0c7 100644
647
+ --- a/gia/eval/rl/__init__.py
648
+ +++ b/gia/eval/rl/__init__.py
649
+ @@ -1,4 +1,5 @@
650
+ +from .envs.core import make
651
+ from .gym_evaluator import GymEvaluator
652
+
653
+
654
+ -__all__ = ["GymEvaluator"]
655
+ +__all__ = ["GymEvaluator", "make"]
656
+ diff --git a/gia/eval/rl/gia_agent.py b/gia/eval/rl/gia_agent.py
657
+ index f0d0b9b..04b9637 100644
658
+ --- a/gia/eval/rl/gia_agent.py
659
+ +++ b/gia/eval/rl/gia_agent.py
660
+ @@ -75,6 +75,11 @@ class GiaAgent:
661
+ ) -> Tuple[Tuple[Tensor, Tensor], ...]:
662
+ return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values)
663
+
664
+ + def set_model(self, model: GiaModel) -> None:
665
+ + self.model = model
666
+ + self.device = next(model.parameters()).device
667
+ + self._max_length = self.model.config.max_position_embeddings
668
+ +
669
+ def reset(self, num_envs: int = 1) -> None:
670
+ if self.prompter is not None:
671
+ prompts = self.prompter.generate_prompts(num_envs)
672
+ diff --git a/gia/eval/rl/gym_evaluator.py b/gia/eval/rl/gym_evaluator.py
673
+ index f8531ee..754c05d 100644
674
+ --- a/gia/eval/rl/gym_evaluator.py
675
+ +++ b/gia/eval/rl/gym_evaluator.py
676
+ @@ -1,7 +1,7 @@
677
+ import gym
678
+ from gym.vector.vector_env import VectorEnv
679
+
680
+ -from gia.eval.mappings import TASK_TO_ENV_MAPPING
681
+ +# from gia.eval.rl.envs.mappings import TASK_TO_ENV_MAPPING
682
+ from gia.eval.rl.rl_evaluator import RLEvaluator
683
+
684
+
685
+ diff --git a/gia/eval/rl/rl_evaluator.py b/gia/eval/rl/rl_evaluator.py
686
+ index c5cc423..91189f3 100644
687
+ --- a/gia/eval/rl/rl_evaluator.py
688
+ +++ b/gia/eval/rl/rl_evaluator.py
689
+ @@ -8,6 +8,10 @@ from gia.eval.rl.gia_agent import GiaAgent
690
+
691
+
692
+ class RLEvaluator(Evaluator):
693
+ + def __init__(self, args, task):
694
+ + super().__init__(args, task)
695
+ + self.agent = GiaAgent()
696
+ +
697
+ def _build_env(self) -> VectorEnv: # TODO: maybe just a gym.Env ?
698
+ raise NotImplementedError
699
+
700
+ diff --git a/gia/eval/rl/scores_dict.json b/gia/eval/rl/scores_dict.json
701
+ index 1b8ebee..ff7d030 100644
702
+ --- a/gia/eval/rl/scores_dict.json
703
+ +++ b/gia/eval/rl/scores_dict.json
704
+ @@ -929,8 +929,8 @@
705
+ },
706
+ "metaworld-assembly": {
707
+ "expert": {
708
+ - "mean": 311.29314618777823,
709
+ - "std": 75.04282151450695
710
+ + "mean": 3523.81468486244,
711
+ + "std": 63.22745220327798
712
+ },
713
+ "random": {
714
+ "mean": 220.65601680730813,
replay.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d7d292cb6c8b454365d8a689263ddb710b0149da82dafb6442f0524ee8a0c486
3
+ size 2343193
sf_log.txt ADDED
The diff for this file is too large to render. See raw diff