alexandersoare commited on
Commit
dcf7d64
1 Parent(s): 39bad5d

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,39 +1,11 @@
1
- # Model Card for Diffusion Policy / PushT
2
 
3
- Diffusion Policy (as per [Diffusion Policy: Visuomotor Policy
4
- Learning via Action Diffusion](https://arxiv.org/abs/2303.04137)) trained for the `PushT` environment from [gym-pusht](https://github.com/huggingface/gym-pusht).
5
 
6
- ![demo](demo.gif)
 
 
7
 
8
- ## How to Get Started with the Model
9
 
10
- See the [LeRobot library](https://github.com/huggingface/lerobot) (particularly the [evaluation script](https://github.com/huggingface/lerobot/blob/main/lerobot/scripts/eval.py)) for instructions on how to load and evaluate this model.
11
-
12
- ## Training Details
13
-
14
- TODO commit hash.
15
-
16
- Trained with [LeRobot@d747195](https://github.com/huggingface/lerobot/tree/d747195c5733c4f68d4bfbe62632d6fc1b605712).
17
-
18
- The model was trained using [LeRobot's training script](https://github.com/huggingface/lerobot/blob/d747195c5733c4f68d4bfbe62632d6fc1b605712/lerobot/scripts/train.py) and with the [pusht](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3) dataset.
19
-
20
- Here are the [loss](./train_loss.csv), [evaluation score](./eval_avg_max_reward.csv), [evaluation success rate](./eval_pc_success.csv) (with 50 rollouts) during training.
21
-
22
- ![](training_curves.png)
23
-
24
- This took about 7 hours to train on an Nvida RTX 3090.
25
-
26
- ## Evaluation
27
-
28
- The model was evaluated on the `PushT` environment from [gym-pusht](https://github.com/huggingface/gym-pusht) and compared to a similar model trained with the original [Diffusion Policy code](https://github.com/real-stanford/diffusion_policy). There are two evaluation metrics on a per-episode basis:
29
-
30
- - Maximum overlap with target (seen as `eval/avg_max_reward` in the charts above). This ranges in [0, 1].
31
- - Success: whether or not the maximum overlap is at least 95%.
32
-
33
- Here are the metrics for 500 episodes worth of evaluation. For the succes rate we add an extra row with confidence bounds. This assumes a uniform prior over success probability and computes the beta posterior, then calculates the mean and lower/upper confidence bounds (with a 68.2% confidence interval centered on the mean).
34
-
35
- <blank>|Ours|Theirs
36
- -|-|-
37
- Average max. overlap ratio | 0.959 | 0.957
38
- Success rate for 500 episodes (%) | 63.8 | 64.2
39
- Beta distribution lower/mean/upper (%) | 61.6 / 63.7 / 65.9 | 62.0 / 64.1 / 66.3
 
1
+ This branch contains the model weights obtained from training on the original Diffusion Policy repository.
2
 
3
+ This is the command that was used for training:
 
4
 
5
+ ```bash
6
+ python train.py --config-dir=. --config-name=image_pusht_diffusion_policy_cnn.yaml training.seed=42 logging.name=benchmark
7
+ ```
8
 
9
+ The configuration file `image_pusht_diffusion_policy_cnn.yaml` is included in this branch.
10
 
11
+ The weights were converted with [`convert_weights.py`](convert_weights.py).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.yaml CHANGED
@@ -7,8 +7,8 @@ training:
7
  online_steps_between_rollouts: 1
8
  online_sampling_ratio: 0.5
9
  online_env_seed: ???
10
- eval_freq: 10000
11
- save_freq: 20000
12
  log_freq: 250
13
  save_model: true
14
  batch_size: 64
@@ -45,15 +45,13 @@ training:
45
  - 1.2
46
  - 1.3
47
  - 1.4
48
- n_end_keyframes_dropped: ${policy.horizon} - ${policy.n_action_steps} - ${policy.n_obs_steps}
49
- + 1
50
  eval:
51
  n_episodes: 50
52
  batch_size: 50
53
  use_async_envs: false
54
  wandb:
55
  enable: true
56
- disable_artifact: true
57
  project: lerobot
58
  notes: ''
59
  fps: 10
 
7
  online_steps_between_rollouts: 1
8
  online_sampling_ratio: 0.5
9
  online_env_seed: ???
10
+ eval_freq: 5000
11
+ save_freq: 5000
12
  log_freq: 250
13
  save_model: true
14
  batch_size: 64
 
45
  - 1.2
46
  - 1.3
47
  - 1.4
 
 
48
  eval:
49
  n_episodes: 50
50
  batch_size: 50
51
  use_async_envs: false
52
  wandb:
53
  enable: true
54
+ disable_artifact: false
55
  project: lerobot
56
  notes: ''
57
  fps: 10
convert_weights.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import product
2
+ from pathlib import Path
3
+
4
+ import torch
5
+ from omegaconf import OmegaConf
6
+
7
+ from lerobot.common.datasets.factory import make_dataset
8
+ from lerobot.common.policies.factory import make_policy
9
+ from lerobot.common.utils.utils import init_hydra_config
10
+
11
+ PATH_TO_ORIGINAL_WEIGHTS = "/tmp/dp.pt"
12
+ PATH_TO_CONFIG = "/home/alexander/Projects/lerobot/lerobot/configs/default.yaml"
13
+ PATH_TO_SAVE_NEW_WEIGHTS = "/tmp/dp"
14
+
15
+ cfg = init_hydra_config(PATH_TO_CONFIG)
16
+
17
+ policy = make_policy(cfg, dataset_stats=make_dataset(cfg).stats)
18
+
19
+ state_dict = torch.load(PATH_TO_ORIGINAL_WEIGHTS)
20
+
21
+ # Remove keys based on what they start with.
22
+
23
+ start_removals = ["normalizer.", "obs_encoder.obs_nets.rgb.backbone.nets.0.nets.0"]
24
+
25
+ for to_remove in start_removals:
26
+ for k in list(state_dict.keys()):
27
+ if k.startswith(to_remove):
28
+ del state_dict[k]
29
+
30
+
31
+ # Replace keys based on what they start with.
32
+
33
+ start_replacements = [
34
+ ("obs_encoder.obs_nets.image.backbone.nets", "rgb_encoder.backbone"),
35
+ ("obs_encoder.obs_nets.image.pool", "rgb_encoder.pool"),
36
+ ("obs_encoder.obs_nets.image.nets.3", "rgb_encoder.out"),
37
+ *[(f"model.up_modules.{i}.2.conv.", f"model.up_modules.{i}.2.") for i in range(2)],
38
+ *[(f"model.down_modules.{i}.2.conv.", f"model.down_modules.{i}.2.") for i in range(2)],
39
+ *[
40
+ (f"model.mid_modules.{i}.blocks.{k}.", f"model.mid_modules.{i}.conv{k + 1}.")
41
+ for i, k in product(range(3), range(2))
42
+ ],
43
+ *[
44
+ (f"model.down_modules.{i}.{j}.blocks.{k}.", f"model.down_modules.{i}.{j}.conv{k + 1}.")
45
+ for i, j, k in product(range(3), range(2), range(2))
46
+ ],
47
+ *[
48
+ (f"model.up_modules.{i}.{j}.blocks.{k}.", f"model.up_modules.{i}.{j}.conv{k + 1}.")
49
+ for i, j, k in product(range(3), range(2), range(2))
50
+ ],
51
+ ("model.", "unet.")
52
+ ]
53
+
54
+ for to_replace, replace_with in start_replacements:
55
+ for k in list(state_dict.keys()):
56
+ if k.startswith(to_replace):
57
+ k_ = replace_with + k.removeprefix(to_replace)
58
+ state_dict[k_] = state_dict[k]
59
+ del state_dict[k]
60
+
61
+ missing_keys, unexpected_keys = policy.diffusion.load_state_dict(state_dict, strict=False)
62
+
63
+ unexpected_keys = set(unexpected_keys)
64
+ allowed_unexpected_keys = eval(
65
+ "{'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.downsample.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.1.nets.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.1.pos_x', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.1.nets.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn1.weight', '_dummy_variable', 'mask_generator._dummy_variable', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.downsample.1.bias', 'obs_encoder.obs_nets.image.nets.1.temperature', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn2.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.downsample.1.bias', 'obs_encoder.obs_nets.image.nets.1.pos_y', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.downsample.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.5.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.downsample.1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.bn1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.downsample.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.downsample.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.6.0.downsample.0.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.conv2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.7.1.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.downsample.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.5.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn1.bias', 'obs_encoder.obs_nets.image.nets.0.nets.7.0.conv1.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.1.bn2.weight', 'obs_encoder.obs_nets.image.nets.0.nets.4.0.bn2.bias'}"
66
+ )
67
+ if len(missing_keys) != 0:
68
+ print("MISSING KEYS")
69
+ print(missing_keys)
70
+ if unexpected_keys != allowed_unexpected_keys:
71
+ print("UNEXPECTED KEYS")
72
+ print(unexpected_keys)
73
+
74
+ if len(missing_keys) != 0 or unexpected_keys != allowed_unexpected_keys:
75
+ print("Failed due to mismatch in state dicts.")
76
+ exit()
77
+
78
+ torch.save(policy.state_dict(), "/tmp/policy.pt")
79
+ policy.save_pretrained(PATH_TO_SAVE_NEW_WEIGHTS)
80
+ OmegaConf.save(cfg, Path(PATH_TO_SAVE_NEW_WEIGHTS) / "config.yaml")
image_pusht_diffusion_policy_cnn.yaml ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: diffusion_policy.workspace.train_diffusion_unet_hybrid_workspace.TrainDiffusionUnetHybridWorkspace
2
+ checkpoint:
3
+ save_last_ckpt: true
4
+ save_last_snapshot: false
5
+ topk:
6
+ format_str: epoch={epoch:04d}-test_mean_score={test_mean_score:.3f}.ckpt
7
+ k: 5
8
+ mode: max
9
+ monitor_key: test_mean_score
10
+ dataloader:
11
+ batch_size: 64
12
+ num_workers: 8
13
+ persistent_workers: false
14
+ pin_memory: true
15
+ shuffle: true
16
+ dataset_obs_steps: 2
17
+ ema:
18
+ _target_: diffusion_policy.model.diffusion.ema_model.EMAModel
19
+ inv_gamma: 1.0
20
+ max_value: 0.9999
21
+ min_value: 0.0
22
+ power: 0.75
23
+ update_after_step: 0
24
+ exp_name: default
25
+ horizon: 16
26
+ keypoint_visible_rate: 1.0
27
+ logging:
28
+ group: null
29
+ id: null
30
+ mode: online
31
+ name: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image
32
+ project: diffusion_policy_debug
33
+ resume: true
34
+ tags:
35
+ - train_diffusion_unet_hybrid
36
+ - pusht_image
37
+ - default
38
+ multi_run:
39
+ run_dir: data/outputs/2023.01.16/20.20.06_train_diffusion_unet_hybrid_pusht_image
40
+ wandb_name_base: 2023.01.16-20.20.06_train_diffusion_unet_hybrid_pusht_image
41
+ n_action_steps: 8
42
+ n_latency_steps: 0
43
+ n_obs_steps: 2
44
+ name: train_diffusion_unet_hybrid
45
+ obs_as_global_cond: true
46
+ optimizer:
47
+ _target_: torch.optim.AdamW
48
+ betas:
49
+ - 0.95
50
+ - 0.999
51
+ eps: 1.0e-08
52
+ lr: 0.0001
53
+ weight_decay: 1.0e-06
54
+ past_action_visible: false
55
+ policy:
56
+ _target_: diffusion_policy.policy.diffusion_unet_hybrid_image_policy.DiffusionUnetHybridImagePolicy
57
+ cond_predict_scale: true
58
+ crop_shape:
59
+ - 84
60
+ - 84
61
+ diffusion_step_embed_dim: 128
62
+ down_dims:
63
+ # - 256
64
+ # - 512
65
+ # - 1024
66
+ - 512
67
+ - 1024
68
+ - 2048
69
+ eval_fixed_crop: true
70
+ horizon: 16
71
+ kernel_size: 5
72
+ n_action_steps: 8
73
+ n_groups: 8
74
+ n_obs_steps: 2
75
+ noise_scheduler:
76
+ _target_: diffusers.schedulers.scheduling_ddpm.DDPMScheduler
77
+ beta_end: 0.02
78
+ beta_schedule: squaredcos_cap_v2
79
+ beta_start: 0.0001
80
+ clip_sample: true
81
+ num_train_timesteps: 100
82
+ prediction_type: epsilon
83
+ variance_type: fixed_small
84
+ num_inference_steps: 100
85
+ obs_as_global_cond: true
86
+ obs_encoder_group_norm: true
87
+ shape_meta:
88
+ action:
89
+ shape:
90
+ - 2
91
+ obs:
92
+ agent_pos:
93
+ shape:
94
+ - 2
95
+ type: low_dim
96
+ image:
97
+ shape:
98
+ - 3
99
+ - 96
100
+ - 96
101
+ type: rgb
102
+ shape_meta:
103
+ action:
104
+ shape:
105
+ - 2
106
+ obs:
107
+ agent_pos:
108
+ shape:
109
+ - 2
110
+ type: low_dim
111
+ image:
112
+ shape:
113
+ - 3
114
+ - 96
115
+ - 96
116
+ type: rgb
117
+ task:
118
+ dataset:
119
+ _target_: diffusion_policy.dataset.pusht_image_dataset.PushTImageDataset
120
+ horizon: 16
121
+ max_train_episodes: null
122
+ pad_after: 7
123
+ pad_before: 1
124
+ seed: 42
125
+ val_ratio: 0
126
+ zarr_path: data/pusht/pusht_cchi_v7_replay.zarr
127
+ env_runner:
128
+ _target_: diffusion_policy.env_runner.pusht_image_runner.PushTImageRunner
129
+ fps: 10
130
+ legacy_test: true
131
+ max_steps: 300
132
+ n_action_steps: 8
133
+ n_envs: null
134
+ n_obs_steps: 2
135
+ n_test: 50
136
+ n_test_vis: 4
137
+ n_train: 6
138
+ n_train_vis: 2
139
+ past_action: false
140
+ test_start_seed: 100000
141
+ train_start_seed: 0
142
+ image_shape:
143
+ - 3
144
+ - 96
145
+ - 96
146
+ name: pusht_image
147
+ shape_meta:
148
+ action:
149
+ shape:
150
+ - 2
151
+ obs:
152
+ agent_pos:
153
+ shape:
154
+ - 2
155
+ type: low_dim
156
+ image:
157
+ shape:
158
+ - 3
159
+ - 96
160
+ - 96
161
+ type: rgb
162
+ task_name: pusht_image
163
+ training:
164
+ checkpoint_every: 50
165
+ debug: false
166
+ device: cuda:0
167
+ gradient_accumulate_every: 1
168
+ lr_scheduler: cosine
169
+ lr_warmup_steps: 500
170
+ max_train_steps: null
171
+ max_val_steps: null
172
+ num_epochs: 500
173
+ resume: true
174
+ rollout_every: 50
175
+ sample_every: 5
176
+ seed: 42
177
+ tqdm_interval_sec: 1.0
178
+ use_ema: true
179
+ val_every: 50000000
180
+ val_dataloader:
181
+ batch_size: 64
182
+ num_workers: 8
183
+ persistent_workers: false
184
+ pin_memory: true
185
+ shuffle: false
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:877969d58d12af315d8c672a2328b3984071901b6f71bdf592b6f131056b520f
3
  size 1050862612
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9150bda22091932686db52309233586c2695be418bda16fa7202e497f56bfab8
3
  size 1050862612