Maikou commited on
Commit
b621857
1 Parent(s): 391d2ef

related files and example data

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. configs/aligned_shape_latents/shapevae-256.yaml +46 -0
  2. configs/deploy/clip_aslp_3df+3dc+abo+gso+toy+t10k+obj+sp+pk=256_01_4096_8_ckpt_250000_udt=110M_finetune_500000_deploy.yaml +181 -0
  3. configs/deploy/clip_sp+pk_aslperceiver=256_01_4096_8_udt=03.yaml +180 -0
  4. configs/image_cond_diffuser_asl/image-ASLDM-256.yaml +97 -0
  5. configs/text_cond_diffuser_asl/text-ASLDM-256.yaml +98 -0
  6. example_data/image/car.jpg +0 -0
  7. example_data/surface/surface.npz +3 -0
  8. gradio_cached_dir/example/img_example/airplane.jpg +0 -0
  9. gradio_cached_dir/example/img_example/alita.jpg +0 -0
  10. gradio_cached_dir/example/img_example/bag.jpg +0 -0
  11. gradio_cached_dir/example/img_example/bench.jpg +0 -0
  12. gradio_cached_dir/example/img_example/building.jpg +0 -0
  13. gradio_cached_dir/example/img_example/burger.jpg +0 -0
  14. gradio_cached_dir/example/img_example/car.jpg +0 -0
  15. gradio_cached_dir/example/img_example/loopy.jpg +0 -0
  16. gradio_cached_dir/example/img_example/mario.jpg +0 -0
  17. gradio_cached_dir/example/img_example/ship.jpg +0 -0
  18. michelangelo/__init__.py +1 -0
  19. michelangelo/__pycache__/__init__.cpython-39.pyc +0 -0
  20. michelangelo/data/__init__.py +1 -0
  21. michelangelo/data/__pycache__/__init__.cpython-39.pyc +0 -0
  22. michelangelo/data/__pycache__/asl_webdataset.cpython-39.pyc +0 -0
  23. michelangelo/data/__pycache__/tokenizer.cpython-39.pyc +0 -0
  24. michelangelo/data/__pycache__/transforms.cpython-39.pyc +0 -0
  25. michelangelo/data/__pycache__/utils.cpython-39.pyc +0 -0
  26. michelangelo/data/templates.json +69 -0
  27. michelangelo/data/transforms.py +407 -0
  28. michelangelo/data/utils.py +59 -0
  29. michelangelo/graphics/__init__.py +1 -0
  30. michelangelo/graphics/__pycache__/__init__.cpython-39.pyc +0 -0
  31. michelangelo/graphics/primitives/__init__.py +9 -0
  32. michelangelo/graphics/primitives/__pycache__/__init__.cpython-39.pyc +0 -0
  33. michelangelo/graphics/primitives/__pycache__/extract_texture_map.cpython-39.pyc +0 -0
  34. michelangelo/graphics/primitives/__pycache__/mesh.cpython-39.pyc +0 -0
  35. michelangelo/graphics/primitives/__pycache__/volume.cpython-39.pyc +0 -0
  36. michelangelo/graphics/primitives/mesh.py +114 -0
  37. michelangelo/graphics/primitives/volume.py +21 -0
  38. michelangelo/models/__init__.py +1 -0
  39. michelangelo/models/__pycache__/__init__.cpython-39.pyc +0 -0
  40. michelangelo/models/asl_diffusion/__init__.py +1 -0
  41. michelangelo/models/asl_diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
  42. michelangelo/models/asl_diffusion/__pycache__/asl_udt.cpython-39.pyc +0 -0
  43. michelangelo/models/asl_diffusion/__pycache__/clip_asl_diffuser_pl_module.cpython-39.pyc +0 -0
  44. michelangelo/models/asl_diffusion/__pycache__/inference_utils.cpython-39.pyc +0 -0
  45. michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py +483 -0
  46. michelangelo/models/asl_diffusion/asl_udt.py +104 -0
  47. michelangelo/models/asl_diffusion/base.py +13 -0
  48. michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py +393 -0
  49. michelangelo/models/asl_diffusion/inference_utils.py +80 -0
  50. michelangelo/models/conditional_encoders/__init__.py +3 -0
configs/aligned_shape_latents/shapevae-256.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
3
+ params:
4
+ shape_module_cfg:
5
+ target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
6
+ params:
7
+ num_latents: 256
8
+ embed_dim: 64
9
+ point_feats: 3 # normal
10
+ num_freqs: 8
11
+ include_pi: false
12
+ heads: 12
13
+ width: 768
14
+ num_encoder_layers: 8
15
+ num_decoder_layers: 16
16
+ use_ln_post: true
17
+ init_scale: 0.25
18
+ qkv_bias: false
19
+ use_checkpoint: true
20
+ aligned_module_cfg:
21
+ target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
22
+ params:
23
+ clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
24
+
25
+ loss_cfg:
26
+ target: michelangelo.models.tsal.loss.ContrastKLNearFar
27
+ params:
28
+ contrast_weight: 0.1
29
+ near_weight: 0.1
30
+ kl_weight: 0.001
31
+
32
+ optimizer_cfg:
33
+ optimizer:
34
+ target: torch.optim.AdamW
35
+ params:
36
+ betas: [0.9, 0.99]
37
+ eps: 1.e-6
38
+ weight_decay: 1.e-2
39
+
40
+ scheduler:
41
+ target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
42
+ params:
43
+ warm_up_steps: 5000
44
+ f_start: 1.e-6
45
+ f_min: 1.e-3
46
+ f_max: 1.0
configs/deploy/clip_aslp_3df+3dc+abo+gso+toy+t10k+obj+sp+pk=256_01_4096_8_ckpt_250000_udt=110M_finetune_500000_deploy.yaml ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "0630_clip_aslp_3df+3dc+abo+gso+toy+t10k+obj+sp+pk=256_01_4096_8_ckpt_250000_udt=110M_finetune_500000"
2
+ #wandb:
3
+ # project: "image_diffuser"
4
+ # offline: false
5
+
6
+
7
+ training:
8
+ steps: 500000
9
+ use_amp: true
10
+ ckpt_path: ""
11
+ base_lr: 1.e-4
12
+ gradient_clip_val: 5.0
13
+ gradient_clip_algorithm: "norm"
14
+ every_n_train_steps: 5000
15
+ val_check_interval: 1024
16
+ limit_val_batches: 16
17
+
18
+ dataset:
19
+ target: michelangelo.data.asl_webdataset.MultiAlignedShapeLatentModule
20
+ params:
21
+ batch_size: 38
22
+ num_workers: 4
23
+ val_num_workers: 4
24
+ buffer_size: 256
25
+ return_normal: true
26
+ random_crop: false
27
+ surface_sampling: true
28
+ pc_size: &pc_size 4096
29
+ image_size: 384
30
+ mean: &mean [0.5, 0.5, 0.5]
31
+ std: &std [0.5, 0.5, 0.5]
32
+ cond_stage_key: "image"
33
+
34
+ meta_info:
35
+ 3D-FUTURE:
36
+ render_folder: "/root/workspace/cq_workspace/datasets/3D-FUTURE/renders"
37
+ tar_folder: "/root/workspace/datasets/make_tars/3D-FUTURE"
38
+
39
+ ABO:
40
+ render_folder: "/root/workspace/cq_workspace/datasets/ABO/renders"
41
+ tar_folder: "/root/workspace/datasets/make_tars/ABO"
42
+
43
+ GSO:
44
+ render_folder: "/root/workspace/cq_workspace/datasets/GSO/renders"
45
+ tar_folder: "/root/workspace/datasets/make_tars/GSO"
46
+
47
+ TOYS4K:
48
+ render_folder: "/root/workspace/cq_workspace/datasets/TOYS4K/TOYS4K/renders"
49
+ tar_folder: "/root/workspace/datasets/make_tars/TOYS4K"
50
+
51
+ 3DCaricShop:
52
+ render_folder: "/root/workspace/cq_workspace/datasets/3DCaricShop/renders"
53
+ tar_folder: "/root/workspace/datasets/make_tars/3DCaricShop"
54
+
55
+ Thingi10K:
56
+ render_folder: "/root/workspace/cq_workspace/datasets/Thingi10K/renders"
57
+ tar_folder: "/root/workspace/datasets/make_tars/Thingi10K"
58
+
59
+ shapenet:
60
+ render_folder: "/root/workspace/cq_workspace/datasets/shapenet/renders"
61
+ tar_folder: "/root/workspace/datasets/make_tars/shapenet"
62
+
63
+ pokemon:
64
+ render_folder: "/root/workspace/cq_workspace/datasets/pokemon/renders"
65
+ tar_folder: "/root/workspace/datasets/make_tars/pokemon"
66
+
67
+ objaverse:
68
+ render_folder: "/root/workspace/cq_workspace/datasets/objaverse/renders"
69
+ tar_folder: "/root/workspace/datasets/make_tars/objaverse"
70
+
71
+ model:
72
+ target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser
73
+ params:
74
+ first_stage_config:
75
+ target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
76
+ params:
77
+ shape_module_cfg:
78
+ target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
79
+ params:
80
+ num_latents: &num_latents 256
81
+ embed_dim: &embed_dim 64
82
+ point_feats: 3 # normal
83
+ num_freqs: 8
84
+ include_pi: false
85
+ heads: 12
86
+ width: 768
87
+ num_encoder_layers: 8
88
+ num_decoder_layers: 16
89
+ use_ln_post: true
90
+ init_scale: 0.25
91
+ qkv_bias: false
92
+ use_checkpoint: false
93
+ aligned_module_cfg:
94
+ target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
95
+ params:
96
+ clip_model_version: "/mnt/shadow_cv_training/stevenxxliu/checkpoints/clip/clip-vit-large-patch14"
97
+ # clip_model_version: "/root/workspace/checkpoints/clip/clip-vit-large-patch14"
98
+
99
+ loss_cfg:
100
+ target: torch.nn.Identity
101
+
102
+ cond_stage_config:
103
+ target: michelangelo.models.conditional_encoders.encoder_factory.FrozenCLIPImageGridEmbedder
104
+ params:
105
+ version: "/mnt/shadow_cv_training/stevenxxliu/checkpoints/clip/clip-vit-large-patch14"
106
+ # version: "/root/workspace/checkpoints/clip/clip-vit-large-patch14"
107
+ zero_embedding_radio: 0.1
108
+
109
+ first_stage_key: "surface"
110
+ cond_stage_key: "image"
111
+ scale_by_std: false
112
+
113
+ denoiser_cfg:
114
+ target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser
115
+ params:
116
+ input_channels: *embed_dim
117
+ output_channels: *embed_dim
118
+ n_ctx: *num_latents
119
+ width: 768
120
+ layers: 6 # 2 * 6 + 1 = 13
121
+ heads: 12
122
+ context_dim: 1024
123
+ init_scale: 1.0
124
+ skip_ln: true
125
+ use_checkpoint: true
126
+
127
+ scheduler_cfg:
128
+ guidance_scale: 7.5
129
+ num_inference_steps: 50
130
+ eta: 0.0
131
+
132
+ noise:
133
+ target: diffusers.schedulers.DDPMScheduler
134
+ params:
135
+ num_train_timesteps: 1000
136
+ beta_start: 0.00085
137
+ beta_end: 0.012
138
+ beta_schedule: "scaled_linear"
139
+ variance_type: "fixed_small"
140
+ clip_sample: false
141
+ denoise:
142
+ target: diffusers.schedulers.DDIMScheduler
143
+ params:
144
+ num_train_timesteps: 1000
145
+ beta_start: 0.00085
146
+ beta_end: 0.012
147
+ beta_schedule: "scaled_linear"
148
+ clip_sample: false # clip sample to -1~1
149
+ set_alpha_to_one: false
150
+ steps_offset: 1
151
+
152
+ optimizer_cfg:
153
+ optimizer:
154
+ target: torch.optim.AdamW
155
+ params:
156
+ betas: [0.9, 0.99]
157
+ eps: 1.e-6
158
+ weight_decay: 1.e-2
159
+
160
+ scheduler:
161
+ target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
162
+ params:
163
+ warm_up_steps: 5000
164
+ f_start: 1.e-6
165
+ f_min: 1.e-3
166
+ f_max: 1.0
167
+
168
+ loss_cfg:
169
+ loss_type: "mse"
170
+
171
+ logger:
172
+ target: michelangelo.utils.trainings.mesh_log_callback.ImageConditionalASLDiffuserLogger
173
+ params:
174
+ step_frequency: 2000
175
+ num_samples: 4
176
+ sample_times: 4
177
+ mean: *mean
178
+ std: *std
179
+ bounds: [-1.1, -1.1, -1.1, 1.1, 1.1, 1.1]
180
+ octree_depth: 7
181
+ num_chunks: 10000
configs/deploy/clip_sp+pk_aslperceiver=256_01_4096_8_udt=03.yaml ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: "0428_clip_subsp+pk_sal_perceiver=256_01_4096_8_udt=03"
2
+ #wandb:
3
+ # project: "image_diffuser"
4
+ # offline: false
5
+
6
+ training:
7
+ steps: 500000
8
+ use_amp: true
9
+ ckpt_path: ""
10
+ base_lr: 1.e-4
11
+ gradient_clip_val: 5.0
12
+ gradient_clip_algorithm: "norm"
13
+ every_n_train_steps: 5000
14
+ val_check_interval: 1024
15
+ limit_val_batches: 16
16
+
17
+ # dataset
18
+ dataset:
19
+ target: michelangelo.data.asl_torch_dataset.MultiAlignedShapeImageTextModule
20
+ params:
21
+ batch_size: 38
22
+ num_workers: 4
23
+ val_num_workers: 4
24
+ buffer_size: 256
25
+ return_normal: true
26
+ random_crop: false
27
+ surface_sampling: true
28
+ pc_size: &pc_size 4096
29
+ image_size: 384
30
+ mean: &mean [0.5, 0.5, 0.5]
31
+ std: &std [0.5, 0.5, 0.5]
32
+
33
+ cond_stage_key: "text"
34
+
35
+ meta_info:
36
+ 3D-FUTURE:
37
+ render_folder: "/root/workspace/cq_workspace/datasets/3D-FUTURE/renders"
38
+ tar_folder: "/root/workspace/datasets/make_tars/3D-FUTURE"
39
+
40
+ ABO:
41
+ render_folder: "/root/workspace/cq_workspace/datasets/ABO/renders"
42
+ tar_folder: "/root/workspace/datasets/make_tars/ABO"
43
+
44
+ GSO:
45
+ render_folder: "/root/workspace/cq_workspace/datasets/GSO/renders"
46
+ tar_folder: "/root/workspace/datasets/make_tars/GSO"
47
+
48
+ TOYS4K:
49
+ render_folder: "/root/workspace/cq_workspace/datasets/TOYS4K/TOYS4K/renders"
50
+ tar_folder: "/root/workspace/datasets/make_tars/TOYS4K"
51
+
52
+ 3DCaricShop:
53
+ render_folder: "/root/workspace/cq_workspace/datasets/3DCaricShop/renders"
54
+ tar_folder: "/root/workspace/datasets/make_tars/3DCaricShop"
55
+
56
+ Thingi10K:
57
+ render_folder: "/root/workspace/cq_workspace/datasets/Thingi10K/renders"
58
+ tar_folder: "/root/workspace/datasets/make_tars/Thingi10K"
59
+
60
+ shapenet:
61
+ render_folder: "/root/workspace/cq_workspace/datasets/shapenet/renders"
62
+ tar_folder: "/root/workspace/datasets/make_tars/shapenet"
63
+
64
+ pokemon:
65
+ render_folder: "/root/workspace/cq_workspace/datasets/pokemon/renders"
66
+ tar_folder: "/root/workspace/datasets/make_tars/pokemon"
67
+
68
+ objaverse:
69
+ render_folder: "/root/workspace/cq_workspace/datasets/objaverse/renders"
70
+ tar_folder: "/root/workspace/datasets/make_tars/objaverse"
71
+
72
+ model:
73
+ target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser
74
+ params:
75
+ first_stage_config:
76
+ target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
77
+ params:
78
+ # ckpt_path: "/root/workspace/cq_workspace/michelangelo/experiments/aligned_shape_latents/clip_aslperceiver_sp+pk_01_01/ckpt/ckpt-step=00230000.ckpt"
79
+ shape_module_cfg:
80
+ target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
81
+ params:
82
+ num_latents: &num_latents 256
83
+ embed_dim: &embed_dim 64
84
+ point_feats: 3 # normal
85
+ num_freqs: 8
86
+ include_pi: false
87
+ heads: 12
88
+ width: 768
89
+ num_encoder_layers: 8
90
+ num_decoder_layers: 16
91
+ use_ln_post: true
92
+ init_scale: 0.25
93
+ qkv_bias: false
94
+ use_checkpoint: true
95
+ aligned_module_cfg:
96
+ target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
97
+ params:
98
+ clip_model_version: "/mnt/shadow_cv_training/stevenxxliu/checkpoints/clip/clip-vit-large-patch14"
99
+
100
+ loss_cfg:
101
+ target: torch.nn.Identity
102
+
103
+ cond_stage_config:
104
+ target: michelangelo.models.conditional_encoders.encoder_factory.FrozenAlignedCLIPTextEmbedder
105
+ params:
106
+ version: "/mnt/shadow_cv_training/stevenxxliu/checkpoints/clip/clip-vit-large-patch14"
107
+ zero_embedding_radio: 0.1
108
+ max_length: 77
109
+
110
+ first_stage_key: "surface"
111
+ cond_stage_key: "text"
112
+ scale_by_std: false
113
+
114
+ denoiser_cfg:
115
+ target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser
116
+ params:
117
+ input_channels: *embed_dim
118
+ output_channels: *embed_dim
119
+ n_ctx: *num_latents
120
+ width: 768
121
+ layers: 8 # 2 * 6 + 1 = 13
122
+ heads: 12
123
+ context_dim: 768
124
+ init_scale: 1.0
125
+ skip_ln: true
126
+ use_checkpoint: true
127
+
128
+ scheduler_cfg:
129
+ guidance_scale: 7.5
130
+ num_inference_steps: 50
131
+ eta: 0.0
132
+
133
+ noise:
134
+ target: diffusers.schedulers.DDPMScheduler
135
+ params:
136
+ num_train_timesteps: 1000
137
+ beta_start: 0.00085
138
+ beta_end: 0.012
139
+ beta_schedule: "scaled_linear"
140
+ variance_type: "fixed_small"
141
+ clip_sample: false
142
+ denoise:
143
+ target: diffusers.schedulers.DDIMScheduler
144
+ params:
145
+ num_train_timesteps: 1000
146
+ beta_start: 0.00085
147
+ beta_end: 0.012
148
+ beta_schedule: "scaled_linear"
149
+ clip_sample: false # clip sample to -1~1
150
+ set_alpha_to_one: false
151
+ steps_offset: 1
152
+
153
+ optimizer_cfg:
154
+ optimizer:
155
+ target: torch.optim.AdamW
156
+ params:
157
+ betas: [0.9, 0.99]
158
+ eps: 1.e-6
159
+ weight_decay: 1.e-2
160
+
161
+ scheduler:
162
+ target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
163
+ params:
164
+ warm_up_steps: 5000
165
+ f_start: 1.e-6
166
+ f_min: 1.e-3
167
+ f_max: 1.0
168
+
169
+ loss_cfg:
170
+ loss_type: "mse"
171
+
172
+ logger:
173
+ target: michelangelo.utils.trainings.mesh_log_callback.TextConditionalASLDiffuserLogger
174
+ params:
175
+ step_frequency: 1000
176
+ num_samples: 4
177
+ sample_times: 4
178
+ bounds: [-1.1, -1.1, -1.1, 1.1, 1.1, 1.1]
179
+ octree_depth: 7
180
+ num_chunks: 10000
configs/image_cond_diffuser_asl/image-ASLDM-256.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser
3
+ params:
4
+ first_stage_config:
5
+ target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
6
+ params:
7
+ shape_module_cfg:
8
+ target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
9
+ params:
10
+ num_latents: &num_latents 256
11
+ embed_dim: &embed_dim 64
12
+ point_feats: 3 # normal
13
+ num_freqs: 8
14
+ include_pi: false
15
+ heads: 12
16
+ width: 768
17
+ num_encoder_layers: 8
18
+ num_decoder_layers: 16
19
+ use_ln_post: true
20
+ init_scale: 0.25
21
+ qkv_bias: false
22
+ use_checkpoint: false
23
+ aligned_module_cfg:
24
+ target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
25
+ params:
26
+ clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
27
+
28
+ loss_cfg:
29
+ target: torch.nn.Identity
30
+
31
+ cond_stage_config:
32
+ target: michelangelo.models.conditional_encoders.encoder_factory.FrozenCLIPImageGridEmbedder
33
+ params:
34
+ version: "./checkpoints/clip/clip-vit-large-patch14"
35
+ zero_embedding_radio: 0.1
36
+
37
+ first_stage_key: "surface"
38
+ cond_stage_key: "image"
39
+ scale_by_std: false
40
+
41
+ denoiser_cfg:
42
+ target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser
43
+ params:
44
+ input_channels: *embed_dim
45
+ output_channels: *embed_dim
46
+ n_ctx: *num_latents
47
+ width: 768
48
+ layers: 6 # 2 * 6 + 1 = 13
49
+ heads: 12
50
+ context_dim: 1024
51
+ init_scale: 1.0
52
+ skip_ln: true
53
+ use_checkpoint: true
54
+
55
+ scheduler_cfg:
56
+ guidance_scale: 7.5
57
+ num_inference_steps: 50
58
+ eta: 0.0
59
+
60
+ noise:
61
+ target: diffusers.schedulers.DDPMScheduler
62
+ params:
63
+ num_train_timesteps: 1000
64
+ beta_start: 0.00085
65
+ beta_end: 0.012
66
+ beta_schedule: "scaled_linear"
67
+ variance_type: "fixed_small"
68
+ clip_sample: false
69
+ denoise:
70
+ target: diffusers.schedulers.DDIMScheduler
71
+ params:
72
+ num_train_timesteps: 1000
73
+ beta_start: 0.00085
74
+ beta_end: 0.012
75
+ beta_schedule: "scaled_linear"
76
+ clip_sample: false # clip sample to -1~1
77
+ set_alpha_to_one: false
78
+ steps_offset: 1
79
+
80
+ optimizer_cfg:
81
+ optimizer:
82
+ target: torch.optim.AdamW
83
+ params:
84
+ betas: [0.9, 0.99]
85
+ eps: 1.e-6
86
+ weight_decay: 1.e-2
87
+
88
+ scheduler:
89
+ target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
90
+ params:
91
+ warm_up_steps: 5000
92
+ f_start: 1.e-6
93
+ f_min: 1.e-3
94
+ f_max: 1.0
95
+
96
+ loss_cfg:
97
+ loss_type: "mse"
configs/text_cond_diffuser_asl/text-ASLDM-256.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ target: michelangelo.models.asl_diffusion.clip_asl_diffuser_pl_module.ClipASLDiffuser
3
+ params:
4
+ first_stage_config:
5
+ target: michelangelo.models.tsal.asl_pl_module.AlignedShapeAsLatentPLModule
6
+ params:
7
+ shape_module_cfg:
8
+ target: michelangelo.models.tsal.sal_perceiver.AlignedShapeLatentPerceiver
9
+ params:
10
+ num_latents: &num_latents 256
11
+ embed_dim: &embed_dim 64
12
+ point_feats: 3 # normal
13
+ num_freqs: 8
14
+ include_pi: false
15
+ heads: 12
16
+ width: 768
17
+ num_encoder_layers: 8
18
+ num_decoder_layers: 16
19
+ use_ln_post: true
20
+ init_scale: 0.25
21
+ qkv_bias: false
22
+ use_checkpoint: true
23
+ aligned_module_cfg:
24
+ target: michelangelo.models.tsal.clip_asl_module.CLIPAlignedShapeAsLatentModule
25
+ params:
26
+ clip_model_version: "./checkpoints/clip/clip-vit-large-patch14"
27
+
28
+ loss_cfg:
29
+ target: torch.nn.Identity
30
+
31
+ cond_stage_config:
32
+ target: michelangelo.models.conditional_encoders.encoder_factory.FrozenAlignedCLIPTextEmbedder
33
+ params:
34
+ version: "./checkpoints/clip/clip-vit-large-patch14"
35
+ zero_embedding_radio: 0.1
36
+ max_length: 77
37
+
38
+ first_stage_key: "surface"
39
+ cond_stage_key: "text"
40
+ scale_by_std: false
41
+
42
+ denoiser_cfg:
43
+ target: michelangelo.models.asl_diffusion.asl_udt.ConditionalASLUDTDenoiser
44
+ params:
45
+ input_channels: *embed_dim
46
+ output_channels: *embed_dim
47
+ n_ctx: *num_latents
48
+ width: 768
49
+ layers: 8 # 2 * 6 + 1 = 13
50
+ heads: 12
51
+ context_dim: 768
52
+ init_scale: 1.0
53
+ skip_ln: true
54
+ use_checkpoint: true
55
+
56
+ scheduler_cfg:
57
+ guidance_scale: 7.5
58
+ num_inference_steps: 50
59
+ eta: 0.0
60
+
61
+ noise:
62
+ target: diffusers.schedulers.DDPMScheduler
63
+ params:
64
+ num_train_timesteps: 1000
65
+ beta_start: 0.00085
66
+ beta_end: 0.012
67
+ beta_schedule: "scaled_linear"
68
+ variance_type: "fixed_small"
69
+ clip_sample: false
70
+ denoise:
71
+ target: diffusers.schedulers.DDIMScheduler
72
+ params:
73
+ num_train_timesteps: 1000
74
+ beta_start: 0.00085
75
+ beta_end: 0.012
76
+ beta_schedule: "scaled_linear"
77
+ clip_sample: false # clip sample to -1~1
78
+ set_alpha_to_one: false
79
+ steps_offset: 1
80
+
81
+ optimizer_cfg:
82
+ optimizer:
83
+ target: torch.optim.AdamW
84
+ params:
85
+ betas: [0.9, 0.99]
86
+ eps: 1.e-6
87
+ weight_decay: 1.e-2
88
+
89
+ scheduler:
90
+ target: michelangelo.utils.trainings.lr_scheduler.LambdaWarmUpCosineFactorScheduler
91
+ params:
92
+ warm_up_steps: 5000
93
+ f_start: 1.e-6
94
+ f_min: 1.e-3
95
+ f_max: 1.0
96
+
97
+ loss_cfg:
98
+ loss_type: "mse"
example_data/image/car.jpg ADDED
example_data/surface/surface.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0893e44d82ada683baa656a718beaf6ec19fc28b6816b451f56645530d5bb962
3
+ size 1201024
gradio_cached_dir/example/img_example/airplane.jpg ADDED
gradio_cached_dir/example/img_example/alita.jpg ADDED
gradio_cached_dir/example/img_example/bag.jpg ADDED
gradio_cached_dir/example/img_example/bench.jpg ADDED
gradio_cached_dir/example/img_example/building.jpg ADDED
gradio_cached_dir/example/img_example/burger.jpg ADDED
gradio_cached_dir/example/img_example/car.jpg ADDED
gradio_cached_dir/example/img_example/loopy.jpg ADDED
gradio_cached_dir/example/img_example/mario.jpg ADDED
gradio_cached_dir/example/img_example/ship.jpg ADDED
michelangelo/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
michelangelo/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (176 Bytes). View file
 
michelangelo/data/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
michelangelo/data/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (181 Bytes). View file
 
michelangelo/data/__pycache__/asl_webdataset.cpython-39.pyc ADDED
Binary file (9.43 kB). View file
 
michelangelo/data/__pycache__/tokenizer.cpython-39.pyc ADDED
Binary file (6.48 kB). View file
 
michelangelo/data/__pycache__/transforms.cpython-39.pyc ADDED
Binary file (11.4 kB). View file
 
michelangelo/data/__pycache__/utils.cpython-39.pyc ADDED
Binary file (1.13 kB). View file
 
michelangelo/data/templates.json ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "shape": [
3
+ "a point cloud model of {}.",
4
+ "There is a {} in the scene.",
5
+ "There is the {} in the scene.",
6
+ "a photo of a {} in the scene.",
7
+ "a photo of the {} in the scene.",
8
+ "a photo of one {} in the scene.",
9
+ "itap of a {}.",
10
+ "itap of my {}.",
11
+ "itap of the {}.",
12
+ "a photo of a {}.",
13
+ "a photo of my {}.",
14
+ "a photo of the {}.",
15
+ "a photo of one {}.",
16
+ "a photo of many {}.",
17
+ "a good photo of a {}.",
18
+ "a good photo of the {}.",
19
+ "a bad photo of a {}.",
20
+ "a bad photo of the {}.",
21
+ "a photo of a nice {}.",
22
+ "a photo of the nice {}.",
23
+ "a photo of a cool {}.",
24
+ "a photo of the cool {}.",
25
+ "a photo of a weird {}.",
26
+ "a photo of the weird {}.",
27
+ "a photo of a small {}.",
28
+ "a photo of the small {}.",
29
+ "a photo of a large {}.",
30
+ "a photo of the large {}.",
31
+ "a photo of a clean {}.",
32
+ "a photo of the clean {}.",
33
+ "a photo of a dirty {}.",
34
+ "a photo of the dirty {}.",
35
+ "a bright photo of a {}.",
36
+ "a bright photo of the {}.",
37
+ "a dark photo of a {}.",
38
+ "a dark photo of the {}.",
39
+ "a photo of a hard to see {}.",
40
+ "a photo of the hard to see {}.",
41
+ "a low resolution photo of a {}.",
42
+ "a low resolution photo of the {}.",
43
+ "a cropped photo of a {}.",
44
+ "a cropped photo of the {}.",
45
+ "a close-up photo of a {}.",
46
+ "a close-up photo of the {}.",
47
+ "a jpeg corrupted photo of a {}.",
48
+ "a jpeg corrupted photo of the {}.",
49
+ "a blurry photo of a {}.",
50
+ "a blurry photo of the {}.",
51
+ "a pixelated photo of a {}.",
52
+ "a pixelated photo of the {}.",
53
+ "a black and white photo of the {}.",
54
+ "a black and white photo of a {}",
55
+ "a plastic {}.",
56
+ "the plastic {}.",
57
+ "a toy {}.",
58
+ "the toy {}.",
59
+ "a plushie {}.",
60
+ "the plushie {}.",
61
+ "a cartoon {}.",
62
+ "the cartoon {}.",
63
+ "an embroidered {}.",
64
+ "the embroidered {}.",
65
+ "a painting of the {}.",
66
+ "a painting of a {}."
67
+ ]
68
+
69
+ }
michelangelo/data/transforms.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import time
4
+ import numpy as np
5
+ import warnings
6
+ import random
7
+ from omegaconf.listconfig import ListConfig
8
+ from webdataset import pipelinefilter
9
+ import torch
10
+ import torchvision.transforms.functional as TVF
11
+ from torchvision.transforms import InterpolationMode
12
+ from torchvision.transforms.transforms import _interpolation_modes_from_int
13
+ from typing import Sequence
14
+
15
+ from michelangelo.utils import instantiate_from_config
16
+
17
+
18
+ def _uid_buffer_pick(buf_dict, rng):
19
+ uid_keys = list(buf_dict.keys())
20
+ selected_uid = rng.choice(uid_keys)
21
+ buf = buf_dict[selected_uid]
22
+
23
+ k = rng.randint(0, len(buf) - 1)
24
+ sample = buf[k]
25
+ buf[k] = buf[-1]
26
+ buf.pop()
27
+
28
+ if len(buf) == 0:
29
+ del buf_dict[selected_uid]
30
+
31
+ return sample
32
+
33
+
34
+ def _add_to_buf_dict(buf_dict, sample):
35
+ key = sample["__key__"]
36
+ uid, uid_sample_id = key.split("_")
37
+ if uid not in buf_dict:
38
+ buf_dict[uid] = []
39
+ buf_dict[uid].append(sample)
40
+
41
+ return buf_dict
42
+
43
+
44
+ def _uid_shuffle(data, bufsize=1000, initial=100, rng=None, handler=None):
45
+ """Shuffle the data in the stream.
46
+
47
+ This uses a buffer of size `bufsize`. Shuffling at
48
+ startup is less random; this is traded off against
49
+ yielding samples quickly.
50
+
51
+ data: iterator
52
+ bufsize: buffer size for shuffling
53
+ returns: iterator
54
+ rng: either random module or random.Random instance
55
+
56
+ """
57
+ if rng is None:
58
+ rng = random.Random(int((os.getpid() + time.time()) * 1e9))
59
+ initial = min(initial, bufsize)
60
+ buf_dict = dict()
61
+ current_samples = 0
62
+ for sample in data:
63
+ _add_to_buf_dict(buf_dict, sample)
64
+ current_samples += 1
65
+
66
+ if current_samples < bufsize:
67
+ try:
68
+ _add_to_buf_dict(buf_dict, next(data)) # skipcq: PYL-R1708
69
+ current_samples += 1
70
+ except StopIteration:
71
+ pass
72
+
73
+ if current_samples >= initial:
74
+ current_samples -= 1
75
+ yield _uid_buffer_pick(buf_dict, rng)
76
+
77
+ while current_samples > 0:
78
+ current_samples -= 1
79
+ yield _uid_buffer_pick(buf_dict, rng)
80
+
81
+
82
+ uid_shuffle = pipelinefilter(_uid_shuffle)
83
+
84
+
85
+ class RandomSample(object):
86
+ def __init__(self,
87
+ num_volume_samples: int = 1024,
88
+ num_near_samples: int = 1024):
89
+
90
+ super().__init__()
91
+
92
+ self.num_volume_samples = num_volume_samples
93
+ self.num_near_samples = num_near_samples
94
+
95
+ def __call__(self, sample):
96
+ rng = np.random.default_rng()
97
+
98
+ # 1. sample surface input
99
+ total_surface = sample["surface"]
100
+ ind = rng.choice(total_surface.shape[0], replace=False)
101
+ surface = total_surface[ind]
102
+
103
+ # 2. sample volume/near geometric points
104
+ vol_points = sample["vol_points"]
105
+ vol_label = sample["vol_label"]
106
+ near_points = sample["near_points"]
107
+ near_label = sample["near_label"]
108
+
109
+ ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False)
110
+ vol_points = vol_points[ind]
111
+ vol_label = vol_label[ind]
112
+ vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1)
113
+
114
+ ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False)
115
+ near_points = near_points[ind]
116
+ near_label = near_label[ind]
117
+ near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1)
118
+
119
+ # concat sampled volume and near points
120
+ geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0)
121
+
122
+ sample = {
123
+ "surface": surface,
124
+ "geo_points": geo_points
125
+ }
126
+
127
+ return sample
128
+
129
+
130
+ class SplitRandomSample(object):
131
+ def __init__(self,
132
+ use_surface_sample: bool = False,
133
+ num_surface_samples: int = 4096,
134
+ num_volume_samples: int = 1024,
135
+ num_near_samples: int = 1024):
136
+
137
+ super().__init__()
138
+
139
+ self.use_surface_sample = use_surface_sample
140
+ self.num_surface_samples = num_surface_samples
141
+ self.num_volume_samples = num_volume_samples
142
+ self.num_near_samples = num_near_samples
143
+
144
+ def __call__(self, sample):
145
+
146
+ rng = np.random.default_rng()
147
+
148
+ # 1. sample surface input
149
+ surface = sample["surface"]
150
+
151
+ if self.use_surface_sample:
152
+ replace = surface.shape[0] < self.num_surface_samples
153
+ ind = rng.choice(surface.shape[0], self.num_surface_samples, replace=replace)
154
+ surface = surface[ind]
155
+
156
+ # 2. sample volume/near geometric points
157
+ vol_points = sample["vol_points"]
158
+ vol_label = sample["vol_label"]
159
+ near_points = sample["near_points"]
160
+ near_label = sample["near_label"]
161
+
162
+ ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False)
163
+ vol_points = vol_points[ind]
164
+ vol_label = vol_label[ind]
165
+ vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1)
166
+
167
+ ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False)
168
+ near_points = near_points[ind]
169
+ near_label = near_label[ind]
170
+ near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1)
171
+
172
+ # concat sampled volume and near points
173
+ geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0)
174
+
175
+ sample = {
176
+ "surface": surface,
177
+ "geo_points": geo_points
178
+ }
179
+
180
+ return sample
181
+
182
+
183
+ class FeatureSelection(object):
184
+
185
+ VALID_SURFACE_FEATURE_DIMS = {
186
+ "none": [0, 1, 2], # xyz
187
+ "watertight_normal": [0, 1, 2, 3, 4, 5], # xyz, normal
188
+ "normal": [0, 1, 2, 6, 7, 8]
189
+ }
190
+
191
+ def __init__(self, surface_feature_type: str):
192
+
193
+ self.surface_feature_type = surface_feature_type
194
+ self.surface_dims = self.VALID_SURFACE_FEATURE_DIMS[surface_feature_type]
195
+
196
+ def __call__(self, sample):
197
+ sample["surface"] = sample["surface"][:, self.surface_dims]
198
+ return sample
199
+
200
+
201
+ class AxisScaleTransform(object):
202
+ def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005):
203
+ assert isinstance(interval, (tuple, list, ListConfig))
204
+ self.interval = interval
205
+ self.min_val = interval[0]
206
+ self.max_val = interval[1]
207
+ self.inter_size = interval[1] - interval[0]
208
+ self.jitter = jitter
209
+ self.jitter_scale = jitter_scale
210
+
211
+ def __call__(self, sample):
212
+
213
+ surface = sample["surface"][..., 0:3]
214
+ geo_points = sample["geo_points"][..., 0:3]
215
+
216
+ scaling = torch.rand(1, 3) * self.inter_size + self.min_val
217
+ # print(scaling)
218
+ surface = surface * scaling
219
+ geo_points = geo_points * scaling
220
+
221
+ scale = (1 / torch.abs(surface).max().item()) * 0.999999
222
+ surface *= scale
223
+ geo_points *= scale
224
+
225
+ if self.jitter:
226
+ surface += self.jitter_scale * torch.randn_like(surface)
227
+ surface.clamp_(min=-1.015, max=1.015)
228
+
229
+ sample["surface"][..., 0:3] = surface
230
+ sample["geo_points"][..., 0:3] = geo_points
231
+
232
+ return sample
233
+
234
+
235
+ class ToTensor(object):
236
+
237
+ def __init__(self, tensor_keys=("surface", "geo_points", "tex_points")):
238
+ self.tensor_keys = tensor_keys
239
+
240
+ def __call__(self, sample):
241
+ for key in self.tensor_keys:
242
+ if key not in sample:
243
+ continue
244
+
245
+ sample[key] = torch.tensor(sample[key], dtype=torch.float32)
246
+
247
+ return sample
248
+
249
+
250
+ class AxisScale(object):
251
+ def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005):
252
+ assert isinstance(interval, (tuple, list, ListConfig))
253
+ self.interval = interval
254
+ self.jitter = jitter
255
+ self.jitter_scale = jitter_scale
256
+
257
+ def __call__(self, surface, *args):
258
+ scaling = torch.rand(1, 3) * 0.5 + 0.75
259
+ # print(scaling)
260
+ surface = surface * scaling
261
+ scale = (1 / torch.abs(surface).max().item()) * 0.999999
262
+ surface *= scale
263
+
264
+ args_outputs = []
265
+ for _arg in args:
266
+ _arg = _arg * scaling * scale
267
+ args_outputs.append(_arg)
268
+
269
+ if self.jitter:
270
+ surface += self.jitter_scale * torch.randn_like(surface)
271
+ surface.clamp_(min=-1, max=1)
272
+
273
+ if len(args) == 0:
274
+ return surface
275
+ else:
276
+ return surface, *args_outputs
277
+
278
+
279
+ class RandomResize(torch.nn.Module):
280
+ """Apply randomly Resize with a given probability."""
281
+
282
+ def __init__(
283
+ self,
284
+ size,
285
+ resize_radio=(0.5, 1),
286
+ allow_resize_interpolations=(InterpolationMode.BICUBIC, InterpolationMode.BILINEAR, InterpolationMode.BILINEAR),
287
+ interpolation=InterpolationMode.BICUBIC,
288
+ max_size=None,
289
+ antialias=None,
290
+ ):
291
+ super().__init__()
292
+ if not isinstance(size, (int, Sequence)):
293
+ raise TypeError(f"Size should be int or sequence. Got {type(size)}")
294
+ if isinstance(size, Sequence) and len(size) not in (1, 2):
295
+ raise ValueError("If size is a sequence, it should have 1 or 2 values")
296
+
297
+ self.size = size
298
+ self.max_size = max_size
299
+ # Backward compatibility with integer value
300
+ if isinstance(interpolation, int):
301
+ warnings.warn(
302
+ "Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. "
303
+ "Please use InterpolationMode enum."
304
+ )
305
+ interpolation = _interpolation_modes_from_int(interpolation)
306
+
307
+ self.interpolation = interpolation
308
+ self.antialias = antialias
309
+
310
+ self.resize_radio = resize_radio
311
+ self.allow_resize_interpolations = allow_resize_interpolations
312
+
313
+ def random_resize_params(self):
314
+ radio = torch.rand(1) * (self.resize_radio[1] - self.resize_radio[0]) + self.resize_radio[0]
315
+
316
+ if isinstance(self.size, int):
317
+ size = int(self.size * radio)
318
+ elif isinstance(self.size, Sequence):
319
+ size = list(self.size)
320
+ size = (int(size[0] * radio), int(size[1] * radio))
321
+ else:
322
+ raise RuntimeError()
323
+
324
+ interpolation = self.allow_resize_interpolations[
325
+ torch.randint(low=0, high=len(self.allow_resize_interpolations), size=(1,))
326
+ ]
327
+ return size, interpolation
328
+
329
+ def forward(self, img):
330
+ size, interpolation = self.random_resize_params()
331
+ img = TVF.resize(img, size, interpolation, self.max_size, self.antialias)
332
+ img = TVF.resize(img, self.size, self.interpolation, self.max_size, self.antialias)
333
+ return img
334
+
335
+ def __repr__(self) -> str:
336
+ detail = f"(size={self.size}, interpolation={self.interpolation.value},"
337
+ detail += f"max_size={self.max_size}, antialias={self.antialias}), resize_radio={self.resize_radio}"
338
+ return f"{self.__class__.__name__}{detail}"
339
+
340
+
341
+ class Compose(object):
342
+ """Composes several transforms together. This transform does not support torchscript.
343
+ Please, see the note below.
344
+
345
+ Args:
346
+ transforms (list of ``Transform`` objects): list of transforms to compose.
347
+
348
+ Example:
349
+ >>> transforms.Compose([
350
+ >>> transforms.CenterCrop(10),
351
+ >>> transforms.ToTensor(),
352
+ >>> ])
353
+
354
+ .. note::
355
+ In order to script the transformations, please use ``torch.nn.Sequential`` as below.
356
+
357
+ >>> transforms = torch.nn.Sequential(
358
+ >>> transforms.CenterCrop(10),
359
+ >>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
360
+ >>> )
361
+ >>> scripted_transforms = torch.jit.script(transforms)
362
+
363
+ Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require
364
+ `lambda` functions or ``PIL.Image``.
365
+
366
+ """
367
+
368
+ def __init__(self, transforms):
369
+ self.transforms = transforms
370
+
371
+ def __call__(self, *args):
372
+ for t in self.transforms:
373
+ args = t(*args)
374
+ return args
375
+
376
+ def __repr__(self):
377
+ format_string = self.__class__.__name__ + '('
378
+ for t in self.transforms:
379
+ format_string += '\n'
380
+ format_string += ' {0}'.format(t)
381
+ format_string += '\n)'
382
+ return format_string
383
+
384
+
385
+ def identity(*args, **kwargs):
386
+ if len(args) == 1:
387
+ return args[0]
388
+ else:
389
+ return args
390
+
391
+
392
+ def build_transforms(cfg):
393
+
394
+ if cfg is None:
395
+ return identity
396
+
397
+ transforms = []
398
+
399
+ for transform_name, cfg_instance in cfg.items():
400
+ transform_instance = instantiate_from_config(cfg_instance)
401
+ transforms.append(transform_instance)
402
+ print(f"Build transform: {transform_instance}")
403
+
404
+ transforms = Compose(transforms)
405
+
406
+ return transforms
407
+
michelangelo/data/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+
7
+ def worker_init_fn(_):
8
+ worker_info = torch.utils.data.get_worker_info()
9
+ worker_id = worker_info.id
10
+
11
+ # dataset = worker_info.dataset
12
+ # split_size = dataset.num_records // worker_info.num_workers
13
+ # # reset num_records to the true number to retain reliable length information
14
+ # dataset.sample_ids = dataset.valid_ids[worker_id * split_size:(worker_id + 1) * split_size]
15
+ # current_id = np.random.choice(len(np.random.get_state()[1]), 1)
16
+ # return np.random.seed(np.random.get_state()[1][current_id] + worker_id)
17
+
18
+ return np.random.seed(np.random.get_state()[1][0] + worker_id)
19
+
20
+
21
+ def collation_fn(samples, combine_tensors=True, combine_scalars=True):
22
+ """
23
+
24
+ Args:
25
+ samples (list[dict]):
26
+ combine_tensors:
27
+ combine_scalars:
28
+
29
+ Returns:
30
+
31
+ """
32
+
33
+ result = {}
34
+
35
+ keys = samples[0].keys()
36
+
37
+ for key in keys:
38
+ result[key] = []
39
+
40
+ for sample in samples:
41
+ for key in keys:
42
+ val = sample[key]
43
+ result[key].append(val)
44
+
45
+ for key in keys:
46
+ val_list = result[key]
47
+ if isinstance(val_list[0], (int, float)):
48
+ if combine_scalars:
49
+ result[key] = np.array(result[key])
50
+
51
+ elif isinstance(val_list[0], torch.Tensor):
52
+ if combine_tensors:
53
+ result[key] = torch.stack(val_list)
54
+
55
+ elif isinstance(val_list[0], np.ndarray):
56
+ if combine_tensors:
57
+ result[key] = np.stack(val_list)
58
+
59
+ return result
michelangelo/graphics/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
michelangelo/graphics/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (185 Bytes). View file
 
michelangelo/graphics/primitives/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .volume import generate_dense_grid_points
4
+
5
+ from .mesh import (
6
+ MeshOutput,
7
+ save_obj,
8
+ savemeshtes2
9
+ )
michelangelo/graphics/primitives/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (334 Bytes). View file
 
michelangelo/graphics/primitives/__pycache__/extract_texture_map.cpython-39.pyc ADDED
Binary file (2.46 kB). View file
 
michelangelo/graphics/primitives/__pycache__/mesh.cpython-39.pyc ADDED
Binary file (2.93 kB). View file
 
michelangelo/graphics/primitives/__pycache__/volume.cpython-39.pyc ADDED
Binary file (860 Bytes). View file
 
michelangelo/graphics/primitives/mesh.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import os
4
+ import cv2
5
+ import numpy as np
6
+ import PIL.Image
7
+ from typing import Optional
8
+
9
+ import trimesh
10
+
11
+
12
+ def save_obj(pointnp_px3, facenp_fx3, fname):
13
+ fid = open(fname, "w")
14
+ write_str = ""
15
+ for pidx, p in enumerate(pointnp_px3):
16
+ pp = p
17
+ write_str += "v %f %f %f\n" % (pp[0], pp[1], pp[2])
18
+
19
+ for i, f in enumerate(facenp_fx3):
20
+ f1 = f + 1
21
+ write_str += "f %d %d %d\n" % (f1[0], f1[1], f1[2])
22
+ fid.write(write_str)
23
+ fid.close()
24
+ return
25
+
26
+
27
+ def savemeshtes2(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, tex_map, fname):
28
+ fol, na = os.path.split(fname)
29
+ na, _ = os.path.splitext(na)
30
+
31
+ matname = "%s/%s.mtl" % (fol, na)
32
+ fid = open(matname, "w")
33
+ fid.write("newmtl material_0\n")
34
+ fid.write("Kd 1 1 1\n")
35
+ fid.write("Ka 0 0 0\n")
36
+ fid.write("Ks 0.4 0.4 0.4\n")
37
+ fid.write("Ns 10\n")
38
+ fid.write("illum 2\n")
39
+ fid.write("map_Kd %s.png\n" % na)
40
+ fid.close()
41
+ ####
42
+
43
+ fid = open(fname, "w")
44
+ fid.write("mtllib %s.mtl\n" % na)
45
+
46
+ for pidx, p in enumerate(pointnp_px3):
47
+ pp = p
48
+ fid.write("v %f %f %f\n" % (pp[0], pp[1], pp[2]))
49
+
50
+ for pidx, p in enumerate(tcoords_px2):
51
+ pp = p
52
+ fid.write("vt %f %f\n" % (pp[0], pp[1]))
53
+
54
+ fid.write("usemtl material_0\n")
55
+ for i, f in enumerate(facenp_fx3):
56
+ f1 = f + 1
57
+ f2 = facetex_fx3[i] + 1
58
+ fid.write("f %d/%d %d/%d %d/%d\n" % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2]))
59
+ fid.close()
60
+
61
+ PIL.Image.fromarray(np.ascontiguousarray(tex_map), "RGB").save(
62
+ os.path.join(fol, "%s.png" % na))
63
+
64
+ return
65
+
66
+
67
+ class MeshOutput(object):
68
+
69
+ def __init__(self,
70
+ mesh_v: np.ndarray,
71
+ mesh_f: np.ndarray,
72
+ vertex_colors: Optional[np.ndarray] = None,
73
+ uvs: Optional[np.ndarray] = None,
74
+ mesh_tex_idx: Optional[np.ndarray] = None,
75
+ tex_map: Optional[np.ndarray] = None):
76
+
77
+ self.mesh_v = mesh_v
78
+ self.mesh_f = mesh_f
79
+ self.vertex_colors = vertex_colors
80
+ self.uvs = uvs
81
+ self.mesh_tex_idx = mesh_tex_idx
82
+ self.tex_map = tex_map
83
+
84
+ def contain_uv_texture(self):
85
+ return (self.uvs is not None) and (self.mesh_tex_idx is not None) and (self.tex_map is not None)
86
+
87
+ def contain_vertex_colors(self):
88
+ return self.vertex_colors is not None
89
+
90
+ def export(self, fname):
91
+
92
+ if self.contain_uv_texture():
93
+ savemeshtes2(
94
+ self.mesh_v,
95
+ self.uvs,
96
+ self.mesh_f,
97
+ self.mesh_tex_idx,
98
+ self.tex_map,
99
+ fname
100
+ )
101
+
102
+ elif self.contain_vertex_colors():
103
+ mesh_obj = trimesh.Trimesh(vertices=self.mesh_v, faces=self.mesh_f, vertex_colors=self.vertex_colors)
104
+ mesh_obj.export(fname)
105
+
106
+ else:
107
+ save_obj(
108
+ self.mesh_v,
109
+ self.mesh_f,
110
+ fname
111
+ )
112
+
113
+
114
+
michelangelo/graphics/primitives/volume.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import numpy as np
4
+
5
+
6
+ def generate_dense_grid_points(bbox_min: np.ndarray,
7
+ bbox_max: np.ndarray,
8
+ octree_depth: int,
9
+ indexing: str = "ij"):
10
+ length = bbox_max - bbox_min
11
+ num_cells = np.exp2(octree_depth)
12
+ x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32)
13
+ y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32)
14
+ z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32)
15
+ [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing)
16
+ xyz = np.stack((xs, ys, zs), axis=-1)
17
+ xyz = xyz.reshape(-1, 3)
18
+ grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1]
19
+
20
+ return xyz, grid_size, length
21
+
michelangelo/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
michelangelo/models/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (183 Bytes). View file
 
michelangelo/models/asl_diffusion/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # -*- coding: utf-8 -*-
michelangelo/models/asl_diffusion/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (197 Bytes). View file
 
michelangelo/models/asl_diffusion/__pycache__/asl_udt.cpython-39.pyc ADDED
Binary file (2.64 kB). View file
 
michelangelo/models/asl_diffusion/__pycache__/clip_asl_diffuser_pl_module.cpython-39.pyc ADDED
Binary file (9.87 kB). View file
 
michelangelo/models/asl_diffusion/__pycache__/inference_utils.cpython-39.pyc ADDED
Binary file (1.75 kB). View file
 
michelangelo/models/asl_diffusion/asl_diffuser_pl_module.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from omegaconf import DictConfig
4
+ from typing import List, Tuple, Dict, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.optim import lr_scheduler
10
+ import pytorch_lightning as pl
11
+ from pytorch_lightning.utilities import rank_zero_only
12
+
13
+ from einops import rearrange
14
+
15
+ from diffusers.schedulers import (
16
+ DDPMScheduler,
17
+ DDIMScheduler,
18
+ KarrasVeScheduler,
19
+ DPMSolverMultistepScheduler
20
+ )
21
+
22
+ from michelangelo.utils import instantiate_from_config
23
+ # from michelangelo.models.tsal.tsal_base import ShapeAsLatentPLModule
24
+ from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule
25
+ from michelangelo.models.asl_diffusion.inference_utils import ddim_sample
26
+
27
+ SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
28
+
29
+
30
+ def disabled_train(self, mode=True):
31
+ """Overwrite model.train with this function to make sure train/eval mode
32
+ does not change anymore."""
33
+ return self
34
+
35
+
36
+ class ASLDiffuser(pl.LightningModule):
37
+ first_stage_model: Optional[AlignedShapeAsLatentPLModule]
38
+ # cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
39
+ model: nn.Module
40
+
41
+ def __init__(self, *,
42
+ first_stage_config,
43
+ denoiser_cfg,
44
+ scheduler_cfg,
45
+ optimizer_cfg,
46
+ loss_cfg,
47
+ first_stage_key: str = "surface",
48
+ cond_stage_key: str = "image",
49
+ cond_stage_trainable: bool = True,
50
+ scale_by_std: bool = False,
51
+ z_scale_factor: float = 1.0,
52
+ ckpt_path: Optional[str] = None,
53
+ ignore_keys: Union[Tuple[str], List[str]] = ()):
54
+
55
+ super().__init__()
56
+
57
+ self.first_stage_key = first_stage_key
58
+ self.cond_stage_key = cond_stage_key
59
+ self.cond_stage_trainable = cond_stage_trainable
60
+
61
+ # 1. initialize first stage.
62
+ # Note: the condition model contained in the first stage model.
63
+ self.first_stage_config = first_stage_config
64
+ self.first_stage_model = None
65
+ # self.instantiate_first_stage(first_stage_config)
66
+
67
+ # 2. initialize conditional stage
68
+ # self.instantiate_cond_stage(cond_stage_config)
69
+ self.cond_stage_model = {
70
+ "image": self.encode_image,
71
+ "image_unconditional_embedding": self.empty_img_cond,
72
+ "text": self.encode_text,
73
+ "text_unconditional_embedding": self.empty_text_cond,
74
+ "surface": self.encode_surface,
75
+ "surface_unconditional_embedding": self.empty_surface_cond,
76
+ }
77
+
78
+ # 3. diffusion model
79
+ self.model = instantiate_from_config(
80
+ denoiser_cfg, device=None, dtype=None
81
+ )
82
+
83
+ self.optimizer_cfg = optimizer_cfg
84
+
85
+ # 4. scheduling strategy
86
+ self.scheduler_cfg = scheduler_cfg
87
+
88
+ self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
89
+ self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
90
+
91
+ # 5. loss configures
92
+ self.loss_cfg = loss_cfg
93
+
94
+ self.scale_by_std = scale_by_std
95
+ if scale_by_std:
96
+ self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
97
+ else:
98
+ self.z_scale_factor = z_scale_factor
99
+
100
+ self.ckpt_path = ckpt_path
101
+ if ckpt_path is not None:
102
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
103
+
104
+ def instantiate_first_stage(self, config):
105
+ model = instantiate_from_config(config)
106
+ self.first_stage_model = model.eval()
107
+ self.first_stage_model.train = disabled_train
108
+ for param in self.first_stage_model.parameters():
109
+ param.requires_grad = False
110
+
111
+ self.first_stage_model = self.first_stage_model.to(self.device)
112
+
113
+ # def instantiate_cond_stage(self, config):
114
+ # if not self.cond_stage_trainable:
115
+ # if config == "__is_first_stage__":
116
+ # print("Using first stage also as cond stage.")
117
+ # self.cond_stage_model = self.first_stage_model
118
+ # elif config == "__is_unconditional__":
119
+ # print(f"Training {self.__class__.__name__} as an unconditional model.")
120
+ # self.cond_stage_model = None
121
+ # # self.be_unconditional = True
122
+ # else:
123
+ # model = instantiate_from_config(config)
124
+ # self.cond_stage_model = model.eval()
125
+ # self.cond_stage_model.train = disabled_train
126
+ # for param in self.cond_stage_model.parameters():
127
+ # param.requires_grad = False
128
+ # else:
129
+ # assert config != "__is_first_stage__"
130
+ # assert config != "__is_unconditional__"
131
+ # model = instantiate_from_config(config)
132
+ # self.cond_stage_model = model
133
+
134
+ def init_from_ckpt(self, path, ignore_keys=()):
135
+ state_dict = torch.load(path, map_location="cpu")["state_dict"]
136
+
137
+ keys = list(state_dict.keys())
138
+ for k in keys:
139
+ for ik in ignore_keys:
140
+ if k.startswith(ik):
141
+ print("Deleting key {} from state_dict.".format(k))
142
+ del state_dict[k]
143
+
144
+ missing, unexpected = self.load_state_dict(state_dict, strict=False)
145
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
146
+ if len(missing) > 0:
147
+ print(f"Missing Keys: {missing}")
148
+ print(f"Unexpected Keys: {unexpected}")
149
+
150
+ @property
151
+ def zero_rank(self):
152
+ if self._trainer:
153
+ zero_rank = self.trainer.local_rank == 0
154
+ else:
155
+ zero_rank = True
156
+
157
+ return zero_rank
158
+
159
+ def configure_optimizers(self) -> Tuple[List, List]:
160
+
161
+ lr = self.learning_rate
162
+
163
+ trainable_parameters = list(self.model.parameters())
164
+ # if the conditional encoder is trainable
165
+
166
+ # if self.cond_stage_trainable:
167
+ # conditioner_params = [p for p in self.cond_stage_model.parameters() if p.requires_grad]
168
+ # trainable_parameters += conditioner_params
169
+ # print(f"number of trainable conditional parameters: {len(conditioner_params)}.")
170
+
171
+ if self.optimizer_cfg is None:
172
+ optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
173
+ schedulers = []
174
+ else:
175
+ optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
176
+ scheduler_func = instantiate_from_config(
177
+ self.optimizer_cfg.scheduler,
178
+ max_decay_steps=self.trainer.max_steps,
179
+ lr_max=lr
180
+ )
181
+ scheduler = {
182
+ "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
183
+ "interval": "step",
184
+ "frequency": 1
185
+ }
186
+ optimizers = [optimizer]
187
+ schedulers = [scheduler]
188
+
189
+ return optimizers, schedulers
190
+
191
+ @torch.no_grad()
192
+ def encode_text(self, text):
193
+
194
+ b = text.shape[0]
195
+ text_tokens = rearrange(text, "b t l -> (b t) l")
196
+ text_embed = self.first_stage_model.model.encode_text_embed(text_tokens)
197
+ text_embed = rearrange(text_embed, "(b t) d -> b t d", b=b)
198
+ text_embed = text_embed.mean(dim=1)
199
+ text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
200
+
201
+ return text_embed
202
+
203
+ @torch.no_grad()
204
+ def encode_image(self, img):
205
+
206
+ return self.first_stage_model.model.encode_image_embed(img)
207
+
208
+ @torch.no_grad()
209
+ def encode_surface(self, surface):
210
+
211
+ return self.first_stage_model.model.encode_shape_embed(surface, return_latents=False)
212
+
213
+ @torch.no_grad()
214
+ def empty_text_cond(self, cond):
215
+
216
+ return torch.zeros_like(cond, device=cond.device)
217
+
218
+ @torch.no_grad()
219
+ def empty_img_cond(self, cond):
220
+
221
+ return torch.zeros_like(cond, device=cond.device)
222
+
223
+ @torch.no_grad()
224
+ def empty_surface_cond(self, cond):
225
+
226
+ return torch.zeros_like(cond, device=cond.device)
227
+
228
+ @torch.no_grad()
229
+ def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
230
+
231
+ z_q = self.first_stage_model.encode(surface, sample_posterior)
232
+ z_q = self.z_scale_factor * z_q
233
+
234
+ return z_q
235
+
236
+ @torch.no_grad()
237
+ def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
238
+
239
+ z_q = 1. / self.z_scale_factor * z_q
240
+ latents = self.first_stage_model.decode(z_q, **kwargs)
241
+ return latents
242
+
243
+ @rank_zero_only
244
+ @torch.no_grad()
245
+ def on_train_batch_start(self, batch, batch_idx):
246
+ # only for very first batch
247
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
248
+ and batch_idx == 0 and self.ckpt_path is None:
249
+ # set rescale weight to 1./std of encodings
250
+ print("### USING STD-RESCALING ###")
251
+
252
+ z_q = self.encode_first_stage(batch[self.first_stage_key])
253
+ z = z_q.detach()
254
+
255
+ del self.z_scale_factor
256
+ self.register_buffer("z_scale_factor", 1. / z.flatten().std())
257
+ print(f"setting self.z_scale_factor to {self.z_scale_factor}")
258
+
259
+ print("### USING STD-RESCALING ###")
260
+
261
+ def compute_loss(self, model_outputs, split):
262
+ """
263
+
264
+ Args:
265
+ model_outputs (dict):
266
+ - x_0:
267
+ - noise:
268
+ - noise_prior:
269
+ - noise_pred:
270
+ - noise_pred_prior:
271
+
272
+ split (str):
273
+
274
+ Returns:
275
+
276
+ """
277
+
278
+ pred = model_outputs["pred"]
279
+
280
+ if self.noise_scheduler.prediction_type == "epsilon":
281
+ target = model_outputs["noise"]
282
+ elif self.noise_scheduler.prediction_type == "sample":
283
+ target = model_outputs["x_0"]
284
+ else:
285
+ raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
286
+
287
+ if self.loss_cfg.loss_type == "l1":
288
+ simple = F.l1_loss(pred, target, reduction="mean")
289
+ elif self.loss_cfg.loss_type in ["mse", "l2"]:
290
+ simple = F.mse_loss(pred, target, reduction="mean")
291
+ else:
292
+ raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
293
+
294
+ total_loss = simple
295
+
296
+ loss_dict = {
297
+ f"{split}/total_loss": total_loss.clone().detach(),
298
+ f"{split}/simple": simple.detach(),
299
+ }
300
+
301
+ return total_loss, loss_dict
302
+
303
+ def forward(self, batch):
304
+ """
305
+
306
+ Args:
307
+ batch:
308
+
309
+ Returns:
310
+
311
+ """
312
+
313
+ if self.first_stage_model is None:
314
+ self.instantiate_first_stage(self.first_stage_config)
315
+
316
+ latents = self.encode_first_stage(batch[self.first_stage_key])
317
+
318
+ # conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
319
+
320
+ conditions = self.cond_stage_model[self.cond_stage_key](batch[self.cond_stage_key]).unsqueeze(1)
321
+
322
+ mask = torch.rand((len(conditions), 1, 1), device=conditions.device, dtype=conditions.dtype) >= 0.1
323
+ conditions = conditions * mask.to(conditions)
324
+
325
+ # Sample noise that we"ll add to the latents
326
+ # [batch_size, n_token, latent_dim]
327
+ noise = torch.randn_like(latents)
328
+ bs = latents.shape[0]
329
+ # Sample a random timestep for each motion
330
+ timesteps = torch.randint(
331
+ 0,
332
+ self.noise_scheduler.config.num_train_timesteps,
333
+ (bs,),
334
+ device=latents.device,
335
+ )
336
+ timesteps = timesteps.long()
337
+ # Add noise to the latents according to the noise magnitude at each timestep
338
+ noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
339
+
340
+ # diffusion model forward
341
+ noise_pred = self.model(noisy_z, timesteps, conditions)
342
+
343
+ diffusion_outputs = {
344
+ "x_0": noisy_z,
345
+ "noise": noise,
346
+ "pred": noise_pred
347
+ }
348
+
349
+ return diffusion_outputs
350
+
351
+ def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
352
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
353
+ """
354
+
355
+ Args:
356
+ batch (dict): the batch sample, and it contains:
357
+ - surface (torch.FloatTensor):
358
+ - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
359
+ - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
360
+ - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
361
+ - text (list of str):
362
+
363
+ batch_idx (int):
364
+
365
+ optimizer_idx (int):
366
+
367
+ Returns:
368
+ loss (torch.FloatTensor):
369
+
370
+ """
371
+
372
+ diffusion_outputs = self(batch)
373
+
374
+ loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
375
+ self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
376
+
377
+ return loss
378
+
379
+ def validation_step(self, batch: Dict[str, torch.FloatTensor],
380
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
381
+ """
382
+
383
+ Args:
384
+ batch (dict): the batch sample, and it contains:
385
+ - surface_pc (torch.FloatTensor): [n_pts, 4]
386
+ - surface_feats (torch.FloatTensor): [n_pts, c]
387
+ - text (list of str):
388
+
389
+ batch_idx (int):
390
+
391
+ optimizer_idx (int):
392
+
393
+ Returns:
394
+ loss (torch.FloatTensor):
395
+
396
+ """
397
+
398
+ diffusion_outputs = self(batch)
399
+
400
+ loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
401
+ self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
402
+
403
+ return loss
404
+
405
+ @torch.no_grad()
406
+ def sample(self,
407
+ batch: Dict[str, Union[torch.FloatTensor, List[str]]],
408
+ sample_times: int = 1,
409
+ steps: Optional[int] = None,
410
+ guidance_scale: Optional[float] = None,
411
+ eta: float = 0.0,
412
+ return_intermediates: bool = False, **kwargs):
413
+
414
+ if self.first_stage_model is None:
415
+ self.instantiate_first_stage(self.first_stage_config)
416
+
417
+ if steps is None:
418
+ steps = self.scheduler_cfg.num_inference_steps
419
+
420
+ if guidance_scale is None:
421
+ guidance_scale = self.scheduler_cfg.guidance_scale
422
+ do_classifier_free_guidance = guidance_scale > 0
423
+
424
+ # conditional encode
425
+ xc = batch[self.cond_stage_key]
426
+ # cond = self.cond_stage_model[self.cond_stage_key](xc)
427
+ cond = self.cond_stage_model[self.cond_stage_key](xc).unsqueeze(1)
428
+
429
+ if do_classifier_free_guidance:
430
+ """
431
+ Note: There are two kinds of uncond for text.
432
+ 1: using "" as uncond text; (in SAL diffusion)
433
+ 2: zeros_like(cond) as uncond text; (in MDM)
434
+ """
435
+ # un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
436
+ un_cond = self.cond_stage_model[f"{self.cond_stage_key}_unconditional_embedding"](cond)
437
+ # un_cond = torch.zeros_like(cond, device=cond.device)
438
+ cond = torch.cat([un_cond, cond], dim=0)
439
+
440
+ outputs = []
441
+ latents = None
442
+
443
+ if not return_intermediates:
444
+ for _ in range(sample_times):
445
+ sample_loop = ddim_sample(
446
+ self.denoise_scheduler,
447
+ self.model,
448
+ shape=self.first_stage_model.latent_shape,
449
+ cond=cond,
450
+ steps=steps,
451
+ guidance_scale=guidance_scale,
452
+ do_classifier_free_guidance=do_classifier_free_guidance,
453
+ device=self.device,
454
+ eta=eta,
455
+ disable_prog=not self.zero_rank
456
+ )
457
+ for sample, t in sample_loop:
458
+ latents = sample
459
+ outputs.append(self.decode_first_stage(latents, **kwargs))
460
+ else:
461
+
462
+ sample_loop = ddim_sample(
463
+ self.denoise_scheduler,
464
+ self.model,
465
+ shape=self.first_stage_model.latent_shape,
466
+ cond=cond,
467
+ steps=steps,
468
+ guidance_scale=guidance_scale,
469
+ do_classifier_free_guidance=do_classifier_free_guidance,
470
+ device=self.device,
471
+ eta=eta,
472
+ disable_prog=not self.zero_rank
473
+ )
474
+
475
+ iter_size = steps // sample_times
476
+ i = 0
477
+ for sample, t in sample_loop:
478
+ latents = sample
479
+ if i % iter_size == 0 or i == steps - 1:
480
+ outputs.append(self.decode_first_stage(latents, **kwargs))
481
+ i += 1
482
+
483
+ return outputs
michelangelo/models/asl_diffusion/asl_udt.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Optional
6
+ from diffusers.models.embeddings import Timesteps
7
+ import math
8
+
9
+ from michelangelo.models.modules.transformer_blocks import MLP
10
+ from michelangelo.models.modules.diffusion_transformer import UNetDiffusionTransformer
11
+
12
+
13
+ class ConditionalASLUDTDenoiser(nn.Module):
14
+
15
+ def __init__(self, *,
16
+ device: Optional[torch.device],
17
+ dtype: Optional[torch.dtype],
18
+ input_channels: int,
19
+ output_channels: int,
20
+ n_ctx: int,
21
+ width: int,
22
+ layers: int,
23
+ heads: int,
24
+ context_dim: int,
25
+ context_ln: bool = True,
26
+ skip_ln: bool = False,
27
+ init_scale: float = 0.25,
28
+ flip_sin_to_cos: bool = False,
29
+ use_checkpoint: bool = False):
30
+ super().__init__()
31
+
32
+ self.use_checkpoint = use_checkpoint
33
+
34
+ init_scale = init_scale * math.sqrt(1.0 / width)
35
+
36
+ self.backbone = UNetDiffusionTransformer(
37
+ device=device,
38
+ dtype=dtype,
39
+ n_ctx=n_ctx,
40
+ width=width,
41
+ layers=layers,
42
+ heads=heads,
43
+ skip_ln=skip_ln,
44
+ init_scale=init_scale,
45
+ use_checkpoint=use_checkpoint
46
+ )
47
+ self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
48
+ self.input_proj = nn.Linear(input_channels, width, device=device, dtype=dtype)
49
+ self.output_proj = nn.Linear(width, output_channels, device=device, dtype=dtype)
50
+
51
+ # timestep embedding
52
+ self.time_embed = Timesteps(width, flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=0)
53
+ self.time_proj = MLP(
54
+ device=device, dtype=dtype, width=width, init_scale=init_scale
55
+ )
56
+
57
+ self.context_embed = nn.Sequential(
58
+ nn.LayerNorm(context_dim, device=device, dtype=dtype),
59
+ nn.Linear(context_dim, width, device=device, dtype=dtype),
60
+ )
61
+
62
+ if context_ln:
63
+ self.context_embed = nn.Sequential(
64
+ nn.LayerNorm(context_dim, device=device, dtype=dtype),
65
+ nn.Linear(context_dim, width, device=device, dtype=dtype),
66
+ )
67
+ else:
68
+ self.context_embed = nn.Linear(context_dim, width, device=device, dtype=dtype)
69
+
70
+ def forward(self,
71
+ model_input: torch.FloatTensor,
72
+ timestep: torch.LongTensor,
73
+ context: torch.FloatTensor):
74
+
75
+ r"""
76
+ Args:
77
+ model_input (torch.FloatTensor): [bs, n_data, c]
78
+ timestep (torch.LongTensor): [bs,]
79
+ context (torch.FloatTensor): [bs, context_tokens, c]
80
+
81
+ Returns:
82
+ sample (torch.FloatTensor): [bs, n_data, c]
83
+
84
+ """
85
+
86
+ _, n_data, _ = model_input.shape
87
+
88
+ # 1. time
89
+ t_emb = self.time_proj(self.time_embed(timestep)).unsqueeze(dim=1)
90
+
91
+ # 2. conditions projector
92
+ context = self.context_embed(context)
93
+
94
+ # 3. denoiser
95
+ x = self.input_proj(model_input)
96
+ x = torch.cat([t_emb, context, x], dim=1)
97
+ x = self.backbone(x)
98
+ x = self.ln_post(x)
99
+ x = x[:, -n_data:]
100
+ sample = self.output_proj(x)
101
+
102
+ return sample
103
+
104
+
michelangelo/models/asl_diffusion/base.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class BaseDenoiser(nn.Module):
8
+
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ def forward(self, x, t, context):
13
+ raise NotImplementedError
michelangelo/models/asl_diffusion/clip_asl_diffuser_pl_module.py ADDED
@@ -0,0 +1,393 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from omegaconf import DictConfig
4
+ from typing import List, Tuple, Dict, Optional, Union
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from torch.optim import lr_scheduler
10
+ import pytorch_lightning as pl
11
+ from pytorch_lightning.utilities import rank_zero_only
12
+
13
+ from diffusers.schedulers import (
14
+ DDPMScheduler,
15
+ DDIMScheduler,
16
+ KarrasVeScheduler,
17
+ DPMSolverMultistepScheduler
18
+ )
19
+
20
+ from michelangelo.utils import instantiate_from_config
21
+ from michelangelo.models.tsal.tsal_base import AlignedShapeAsLatentPLModule
22
+ from michelangelo.models.asl_diffusion.inference_utils import ddim_sample
23
+
24
+ SchedulerType = Union[DDIMScheduler, KarrasVeScheduler, DPMSolverMultistepScheduler]
25
+
26
+
27
+ def disabled_train(self, mode=True):
28
+ """Overwrite model.train with this function to make sure train/eval mode
29
+ does not change anymore."""
30
+ return self
31
+
32
+
33
+ class ClipASLDiffuser(pl.LightningModule):
34
+ first_stage_model: Optional[AlignedShapeAsLatentPLModule]
35
+ cond_stage_model: Optional[Union[nn.Module, pl.LightningModule]]
36
+ model: nn.Module
37
+
38
+ def __init__(self, *,
39
+ first_stage_config,
40
+ cond_stage_config,
41
+ denoiser_cfg,
42
+ scheduler_cfg,
43
+ optimizer_cfg,
44
+ loss_cfg,
45
+ first_stage_key: str = "surface",
46
+ cond_stage_key: str = "image",
47
+ scale_by_std: bool = False,
48
+ z_scale_factor: float = 1.0,
49
+ ckpt_path: Optional[str] = None,
50
+ ignore_keys: Union[Tuple[str], List[str]] = ()):
51
+
52
+ super().__init__()
53
+
54
+ self.first_stage_key = first_stage_key
55
+ self.cond_stage_key = cond_stage_key
56
+
57
+ # 1. lazy initialize first stage
58
+ self.instantiate_first_stage(first_stage_config)
59
+
60
+ # 2. initialize conditional stage
61
+ self.instantiate_cond_stage(cond_stage_config)
62
+
63
+ # 3. diffusion model
64
+ self.model = instantiate_from_config(
65
+ denoiser_cfg, device=None, dtype=None
66
+ )
67
+
68
+ self.optimizer_cfg = optimizer_cfg
69
+
70
+ # 4. scheduling strategy
71
+ self.scheduler_cfg = scheduler_cfg
72
+
73
+ self.noise_scheduler: DDPMScheduler = instantiate_from_config(scheduler_cfg.noise)
74
+ self.denoise_scheduler: SchedulerType = instantiate_from_config(scheduler_cfg.denoise)
75
+
76
+ # 5. loss configures
77
+ self.loss_cfg = loss_cfg
78
+
79
+ self.scale_by_std = scale_by_std
80
+ if scale_by_std:
81
+ self.register_buffer("z_scale_factor", torch.tensor(z_scale_factor))
82
+ else:
83
+ self.z_scale_factor = z_scale_factor
84
+
85
+ self.ckpt_path = ckpt_path
86
+ if ckpt_path is not None:
87
+ self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
88
+
89
+ def instantiate_non_trainable_model(self, config):
90
+ model = instantiate_from_config(config)
91
+ model = model.eval()
92
+ model.train = disabled_train
93
+ for param in model.parameters():
94
+ param.requires_grad = False
95
+
96
+ return model
97
+
98
+ def instantiate_first_stage(self, first_stage_config):
99
+ self.first_stage_model = self.instantiate_non_trainable_model(first_stage_config)
100
+ self.first_stage_model.set_shape_model_only()
101
+
102
+ def instantiate_cond_stage(self, cond_stage_config):
103
+ self.cond_stage_model = self.instantiate_non_trainable_model(cond_stage_config)
104
+
105
+ def init_from_ckpt(self, path, ignore_keys=()):
106
+ state_dict = torch.load(path, map_location="cpu")["state_dict"]
107
+
108
+ keys = list(state_dict.keys())
109
+ for k in keys:
110
+ for ik in ignore_keys:
111
+ if k.startswith(ik):
112
+ print("Deleting key {} from state_dict.".format(k))
113
+ del state_dict[k]
114
+
115
+ missing, unexpected = self.load_state_dict(state_dict, strict=False)
116
+ print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
117
+ if len(missing) > 0:
118
+ print(f"Missing Keys: {missing}")
119
+ print(f"Unexpected Keys: {unexpected}")
120
+
121
+ @property
122
+ def zero_rank(self):
123
+ if self._trainer:
124
+ zero_rank = self.trainer.local_rank == 0
125
+ else:
126
+ zero_rank = True
127
+
128
+ return zero_rank
129
+
130
+ def configure_optimizers(self) -> Tuple[List, List]:
131
+
132
+ lr = self.learning_rate
133
+
134
+ trainable_parameters = list(self.model.parameters())
135
+ if self.optimizer_cfg is None:
136
+ optimizers = [torch.optim.AdamW(trainable_parameters, lr=lr, betas=(0.9, 0.99), weight_decay=1e-3)]
137
+ schedulers = []
138
+ else:
139
+ optimizer = instantiate_from_config(self.optimizer_cfg.optimizer, params=trainable_parameters)
140
+ scheduler_func = instantiate_from_config(
141
+ self.optimizer_cfg.scheduler,
142
+ max_decay_steps=self.trainer.max_steps,
143
+ lr_max=lr
144
+ )
145
+ scheduler = {
146
+ "scheduler": lr_scheduler.LambdaLR(optimizer, lr_lambda=scheduler_func.schedule),
147
+ "interval": "step",
148
+ "frequency": 1
149
+ }
150
+ optimizers = [optimizer]
151
+ schedulers = [scheduler]
152
+
153
+ return optimizers, schedulers
154
+
155
+ @torch.no_grad()
156
+ def encode_first_stage(self, surface: torch.FloatTensor, sample_posterior=True):
157
+
158
+ z_q = self.first_stage_model.encode(surface, sample_posterior)
159
+ z_q = self.z_scale_factor * z_q
160
+
161
+ return z_q
162
+
163
+ @torch.no_grad()
164
+ def decode_first_stage(self, z_q: torch.FloatTensor, **kwargs):
165
+
166
+ z_q = 1. / self.z_scale_factor * z_q
167
+ latents = self.first_stage_model.decode(z_q, **kwargs)
168
+ return latents
169
+
170
+ @rank_zero_only
171
+ @torch.no_grad()
172
+ def on_train_batch_start(self, batch, batch_idx):
173
+ # only for very first batch
174
+ if self.scale_by_std and self.current_epoch == 0 and self.global_step == 0 \
175
+ and batch_idx == 0 and self.ckpt_path is None:
176
+ # set rescale weight to 1./std of encodings
177
+ print("### USING STD-RESCALING ###")
178
+
179
+ z_q = self.encode_first_stage(batch[self.first_stage_key])
180
+ z = z_q.detach()
181
+
182
+ del self.z_scale_factor
183
+ self.register_buffer("z_scale_factor", 1. / z.flatten().std())
184
+ print(f"setting self.z_scale_factor to {self.z_scale_factor}")
185
+
186
+ print("### USING STD-RESCALING ###")
187
+
188
+ def compute_loss(self, model_outputs, split):
189
+ """
190
+
191
+ Args:
192
+ model_outputs (dict):
193
+ - x_0:
194
+ - noise:
195
+ - noise_prior:
196
+ - noise_pred:
197
+ - noise_pred_prior:
198
+
199
+ split (str):
200
+
201
+ Returns:
202
+
203
+ """
204
+
205
+ pred = model_outputs["pred"]
206
+
207
+ if self.noise_scheduler.prediction_type == "epsilon":
208
+ target = model_outputs["noise"]
209
+ elif self.noise_scheduler.prediction_type == "sample":
210
+ target = model_outputs["x_0"]
211
+ else:
212
+ raise NotImplementedError(f"Prediction Type: {self.noise_scheduler.prediction_type} not yet supported.")
213
+
214
+ if self.loss_cfg.loss_type == "l1":
215
+ simple = F.l1_loss(pred, target, reduction="mean")
216
+ elif self.loss_cfg.loss_type in ["mse", "l2"]:
217
+ simple = F.mse_loss(pred, target, reduction="mean")
218
+ else:
219
+ raise NotImplementedError(f"Loss Type: {self.loss_cfg.loss_type} not yet supported.")
220
+
221
+ total_loss = simple
222
+
223
+ loss_dict = {
224
+ f"{split}/total_loss": total_loss.clone().detach(),
225
+ f"{split}/simple": simple.detach(),
226
+ }
227
+
228
+ return total_loss, loss_dict
229
+
230
+ def forward(self, batch):
231
+ """
232
+
233
+ Args:
234
+ batch:
235
+
236
+ Returns:
237
+
238
+ """
239
+
240
+ latents = self.encode_first_stage(batch[self.first_stage_key])
241
+ conditions = self.cond_stage_model.encode(batch[self.cond_stage_key])
242
+
243
+ # Sample noise that we"ll add to the latents
244
+ # [batch_size, n_token, latent_dim]
245
+ noise = torch.randn_like(latents)
246
+ bs = latents.shape[0]
247
+ # Sample a random timestep for each motion
248
+ timesteps = torch.randint(
249
+ 0,
250
+ self.noise_scheduler.config.num_train_timesteps,
251
+ (bs,),
252
+ device=latents.device,
253
+ )
254
+ timesteps = timesteps.long()
255
+ # Add noise to the latents according to the noise magnitude at each timestep
256
+ noisy_z = self.noise_scheduler.add_noise(latents, noise, timesteps)
257
+
258
+ # diffusion model forward
259
+ noise_pred = self.model(noisy_z, timesteps, conditions)
260
+
261
+ diffusion_outputs = {
262
+ "x_0": noisy_z,
263
+ "noise": noise,
264
+ "pred": noise_pred
265
+ }
266
+
267
+ return diffusion_outputs
268
+
269
+ def training_step(self, batch: Dict[str, Union[torch.FloatTensor, List[str]]],
270
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
271
+ """
272
+
273
+ Args:
274
+ batch (dict): the batch sample, and it contains:
275
+ - surface (torch.FloatTensor):
276
+ - image (torch.FloatTensor): if provide, [bs, 3, h, w], item range [0, 1]
277
+ - depth (torch.FloatTensor): if provide, [bs, 1, h, w], item range [-1, 1]
278
+ - normal (torch.FloatTensor): if provide, [bs, 3, h, w], item range [-1, 1]
279
+ - text (list of str):
280
+
281
+ batch_idx (int):
282
+
283
+ optimizer_idx (int):
284
+
285
+ Returns:
286
+ loss (torch.FloatTensor):
287
+
288
+ """
289
+
290
+ diffusion_outputs = self(batch)
291
+
292
+ loss, loss_dict = self.compute_loss(diffusion_outputs, "train")
293
+ self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
294
+
295
+ return loss
296
+
297
+ def validation_step(self, batch: Dict[str, torch.FloatTensor],
298
+ batch_idx: int, optimizer_idx: int = 0) -> torch.FloatTensor:
299
+ """
300
+
301
+ Args:
302
+ batch (dict): the batch sample, and it contains:
303
+ - surface_pc (torch.FloatTensor): [n_pts, 4]
304
+ - surface_feats (torch.FloatTensor): [n_pts, c]
305
+ - text (list of str):
306
+
307
+ batch_idx (int):
308
+
309
+ optimizer_idx (int):
310
+
311
+ Returns:
312
+ loss (torch.FloatTensor):
313
+
314
+ """
315
+
316
+ diffusion_outputs = self(batch)
317
+
318
+ loss, loss_dict = self.compute_loss(diffusion_outputs, "val")
319
+ self.log_dict(loss_dict, prog_bar=True, logger=True, sync_dist=False, rank_zero_only=True)
320
+
321
+ return loss
322
+
323
+ @torch.no_grad()
324
+ def sample(self,
325
+ batch: Dict[str, Union[torch.FloatTensor, List[str]]],
326
+ sample_times: int = 1,
327
+ steps: Optional[int] = None,
328
+ guidance_scale: Optional[float] = None,
329
+ eta: float = 0.0,
330
+ return_intermediates: bool = False, **kwargs):
331
+
332
+ if steps is None:
333
+ steps = self.scheduler_cfg.num_inference_steps
334
+
335
+ if guidance_scale is None:
336
+ guidance_scale = self.scheduler_cfg.guidance_scale
337
+ do_classifier_free_guidance = guidance_scale > 0
338
+
339
+ # conditional encode
340
+ xc = batch[self.cond_stage_key]
341
+
342
+ # print(self.first_stage_model.device, self.cond_stage_model.device, self.device)
343
+
344
+ cond = self.cond_stage_model(xc)
345
+
346
+ if do_classifier_free_guidance:
347
+ un_cond = self.cond_stage_model.unconditional_embedding(batch_size=len(xc))
348
+ cond = torch.cat([un_cond, cond], dim=0)
349
+
350
+ outputs = []
351
+ latents = None
352
+
353
+ if not return_intermediates:
354
+ for _ in range(sample_times):
355
+ sample_loop = ddim_sample(
356
+ self.denoise_scheduler,
357
+ self.model,
358
+ shape=self.first_stage_model.latent_shape,
359
+ cond=cond,
360
+ steps=steps,
361
+ guidance_scale=guidance_scale,
362
+ do_classifier_free_guidance=do_classifier_free_guidance,
363
+ device=self.device,
364
+ eta=eta,
365
+ disable_prog=not self.zero_rank
366
+ )
367
+ for sample, t in sample_loop:
368
+ latents = sample
369
+ outputs.append(self.decode_first_stage(latents, **kwargs))
370
+ else:
371
+
372
+ sample_loop = ddim_sample(
373
+ self.denoise_scheduler,
374
+ self.model,
375
+ shape=self.first_stage_model.latent_shape,
376
+ cond=cond,
377
+ steps=steps,
378
+ guidance_scale=guidance_scale,
379
+ do_classifier_free_guidance=do_classifier_free_guidance,
380
+ device=self.device,
381
+ eta=eta,
382
+ disable_prog=not self.zero_rank
383
+ )
384
+
385
+ iter_size = steps // sample_times
386
+ i = 0
387
+ for sample, t in sample_loop:
388
+ latents = sample
389
+ if i % iter_size == 0 or i == steps - 1:
390
+ outputs.append(self.decode_first_stage(latents, **kwargs))
391
+ i += 1
392
+
393
+ return outputs
michelangelo/models/asl_diffusion/inference_utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ import torch
4
+ from tqdm import tqdm
5
+ from typing import Tuple, List, Union, Optional
6
+ from diffusers.schedulers import DDIMScheduler
7
+
8
+
9
+ __all__ = ["ddim_sample"]
10
+
11
+
12
+ def ddim_sample(ddim_scheduler: DDIMScheduler,
13
+ diffusion_model: torch.nn.Module,
14
+ shape: Union[List[int], Tuple[int]],
15
+ cond: torch.FloatTensor,
16
+ steps: int,
17
+ eta: float = 0.0,
18
+ guidance_scale: float = 3.0,
19
+ do_classifier_free_guidance: bool = True,
20
+ generator: Optional[torch.Generator] = None,
21
+ device: torch.device = "cuda:0",
22
+ disable_prog: bool = True):
23
+
24
+ assert steps > 0, f"{steps} must > 0."
25
+
26
+ # init latents
27
+ bsz = cond.shape[0]
28
+ if do_classifier_free_guidance:
29
+ bsz = bsz // 2
30
+
31
+ latents = torch.randn(
32
+ (bsz, *shape),
33
+ generator=generator,
34
+ device=cond.device,
35
+ dtype=cond.dtype,
36
+ )
37
+ # scale the initial noise by the standard deviation required by the scheduler
38
+ latents = latents * ddim_scheduler.init_noise_sigma
39
+ # set timesteps
40
+ ddim_scheduler.set_timesteps(steps)
41
+ timesteps = ddim_scheduler.timesteps.to(device)
42
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
43
+ # eta (η) is only used with the DDIMScheduler, and between [0, 1]
44
+ extra_step_kwargs = {
45
+ "eta": eta,
46
+ "generator": generator
47
+ }
48
+
49
+ # reverse
50
+ for i, t in enumerate(tqdm(timesteps, disable=disable_prog, desc="DDIM Sampling:", leave=False)):
51
+ # expand the latents if we are doing classifier free guidance
52
+ latent_model_input = (
53
+ torch.cat([latents] * 2)
54
+ if do_classifier_free_guidance
55
+ else latents
56
+ )
57
+ # latent_model_input = scheduler.scale_model_input(latent_model_input, t)
58
+ # predict the noise residual
59
+ timestep_tensor = torch.tensor([t], dtype=torch.long, device=device)
60
+ timestep_tensor = timestep_tensor.expand(latent_model_input.shape[0])
61
+ noise_pred = diffusion_model.forward(latent_model_input, timestep_tensor, cond)
62
+
63
+ # perform guidance
64
+ if do_classifier_free_guidance:
65
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
66
+ noise_pred = noise_pred_uncond + guidance_scale * (
67
+ noise_pred_text - noise_pred_uncond
68
+ )
69
+ # text_embeddings_for_guidance = encoder_hidden_states.chunk(
70
+ # 2)[1] if do_classifier_free_guidance else encoder_hidden_states
71
+ # compute the previous noisy sample x_t -> x_t-1
72
+ latents = ddim_scheduler.step(
73
+ noise_pred, t, latents, **extra_step_kwargs
74
+ ).prev_sample
75
+
76
+ yield latents, t
77
+
78
+
79
+ def karra_sample():
80
+ pass
michelangelo/models/conditional_encoders/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .clip import CLIPEncoder