wyysf commited on
Commit
a71c535
β€’
1 Parent(s): 2a77245

Upload 15 files

Browse files
apps/third_party/LGM/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ *.pt
2
+ *.yaml
3
+ **/__pycache__
4
+ *.pyc
5
+
6
+ weights*
7
+ models
8
+ sd-v2*
apps/third_party/LGM/README.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MVDream-diffusers
2
+
3
+ A **unified** diffusers implementation of [MVDream](https://github.com/bytedance/MVDream) and [ImageDream](https://github.com/bytedance/ImageDream).
4
+
5
+ We provide converted `fp16` weights on huggingface:
6
+ * [MVDream](https://huggingface.co/ashawkey/mvdream-sd2.1-diffusers)
7
+ * [ImageDream](https://huggingface.co/ashawkey/imagedream-ipmv-diffusers)
8
+
9
+
10
+ ### Install
11
+ ```bash
12
+ # dependency
13
+ pip install -r requirements.txt
14
+
15
+ # xformers is required! please refer to https://github.com/facebookresearch/xformers
16
+ pip install ninja
17
+ pip install -v -U git+https://github.com/facebookresearch/xformers.git@main#egg=xformers
18
+ ```
19
+
20
+ ### Usage
21
+
22
+ ```bash
23
+ python run_mvdream.py "a cute owl"
24
+ python run_imagedream.py data/anya_rgba.png
25
+ ```
26
+
27
+ ### Convert weights
28
+
29
+ MVDream:
30
+ ```bash
31
+ # download original ckpt (we only support the SD 2.1 version)
32
+ mkdir models
33
+ cd models
34
+ wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt
35
+ wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml
36
+ cd ..
37
+
38
+ # convert
39
+ python convert_mvdream_to_diffusers.py --checkpoint_path models/sd-v2.1-base-4view.pt --dump_path ./weights_mvdream --original_config_file models/sd-v2-base.yaml --half --to_safetensors --test
40
+ ```
41
+
42
+ ImageDream:
43
+ ```bash
44
+ # download original ckpt (we only support the pixel-controller version)
45
+ cd models
46
+ wget https://huggingface.co/Peng-Wang/ImageDream/resolve/main/sd-v2.1-base-4view-ipmv.pt
47
+ wget https://raw.githubusercontent.com/bytedance/ImageDream/main/extern/ImageDream/imagedream/configs/sd_v2_base_ipmv.yaml
48
+ cd ..
49
+
50
+ # convert
51
+ python convert_mvdream_to_diffusers.py --checkpoint_path models/sd-v2.1-base-4view-ipmv.pt --dump_path ./weights_imagedream --original_config_file models/sd_v2_base_ipmv.yaml --half --to_safetensors --test
52
+ ```
53
+
54
+ ### Acknowledgement
55
+
56
+ * The original papers:
57
+ ```bibtex
58
+ @article{shi2023MVDream,
59
+ author = {Shi, Yichun and Wang, Peng and Ye, Jianglong and Mai, Long and Li, Kejie and Yang, Xiao},
60
+ title = {MVDream: Multi-view Diffusion for 3D Generation},
61
+ journal = {arXiv:2308.16512},
62
+ year = {2023},
63
+ }
64
+ @article{wang2023imagedream,
65
+ title={ImageDream: Image-Prompt Multi-view Diffusion for 3D Generation},
66
+ author={Wang, Peng and Shi, Yichun},
67
+ journal={arXiv preprint arXiv:2312.02201},
68
+ year={2023}
69
+ }
70
+ ```
71
+ * This codebase is modified from [mvdream-hf](https://github.com/KokeCacao/mvdream-hf).
apps/third_party/LGM/__pycache__/mv_unet.cpython-310.pyc ADDED
Binary file (23.4 kB). View file
 
apps/third_party/LGM/__pycache__/mv_unet.cpython-38.pyc ADDED
Binary file (23.6 kB). View file
 
apps/third_party/LGM/__pycache__/pipeline_mvdream.cpython-310.pyc ADDED
Binary file (15.7 kB). View file
 
apps/third_party/LGM/__pycache__/pipeline_mvdream.cpython-38.pyc ADDED
Binary file (15.7 kB). View file
 
apps/third_party/LGM/convert_mvdream_to_diffusers.py ADDED
@@ -0,0 +1,597 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/bc691231360a4cbc7d19a58742ebb8ed0f05e027/scripts/convert_original_stable_diffusion_to_diffusers.py
2
+
3
+ import argparse
4
+ import torch
5
+ import sys
6
+
7
+ sys.path.insert(0, ".")
8
+
9
+ from diffusers.models import (
10
+ AutoencoderKL,
11
+ )
12
+ from omegaconf import OmegaConf
13
+ from diffusers.schedulers import DDIMScheduler
14
+ from diffusers.utils import logging
15
+ from typing import Any
16
+ from accelerate import init_empty_weights
17
+ from accelerate.utils import set_module_tensor_to_device
18
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor
19
+
20
+ from mv_unet import MultiViewUNetModel
21
+ from pipeline_mvdream import MVDreamPipeline
22
+ import kiui
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+
27
+ def assign_to_checkpoint(
28
+ paths,
29
+ checkpoint,
30
+ old_checkpoint,
31
+ attention_paths_to_split=None,
32
+ additional_replacements=None,
33
+ config=None,
34
+ ):
35
+ """
36
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
37
+ attention layers, and takes into account additional replacements that may arise.
38
+ Assigns the weights to the new checkpoint.
39
+ """
40
+ assert isinstance(
41
+ paths, list
42
+ ), "Paths should be a list of dicts containing 'old' and 'new' keys."
43
+
44
+ # Splits the attention layers into three variables.
45
+ if attention_paths_to_split is not None:
46
+ for path, path_map in attention_paths_to_split.items():
47
+ old_tensor = old_checkpoint[path]
48
+ channels = old_tensor.shape[0] // 3
49
+
50
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
51
+
52
+ assert config is not None
53
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
54
+
55
+ old_tensor = old_tensor.reshape(
56
+ (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
57
+ )
58
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
59
+
60
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
61
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
62
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
63
+
64
+ for path in paths:
65
+ new_path = path["new"]
66
+
67
+ # These have already been assigned
68
+ if (
69
+ attention_paths_to_split is not None
70
+ and new_path in attention_paths_to_split
71
+ ):
72
+ continue
73
+
74
+ # Global renaming happens here
75
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
76
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
77
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
78
+
79
+ if additional_replacements is not None:
80
+ for replacement in additional_replacements:
81
+ new_path = new_path.replace(replacement["old"], replacement["new"])
82
+
83
+ # proj_attn.weight has to be converted from conv 1D to linear
84
+ is_attn_weight = "proj_attn.weight" in new_path or (
85
+ "attentions" in new_path and "to_" in new_path
86
+ )
87
+ shape = old_checkpoint[path["old"]].shape
88
+ if is_attn_weight and len(shape) == 3:
89
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
90
+ elif is_attn_weight and len(shape) == 4:
91
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
92
+ else:
93
+ checkpoint[new_path] = old_checkpoint[path["old"]]
94
+
95
+
96
+ def shave_segments(path, n_shave_prefix_segments=1):
97
+ """
98
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
99
+ """
100
+ if n_shave_prefix_segments >= 0:
101
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
102
+ else:
103
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
104
+
105
+
106
+ def create_vae_diffusers_config(original_config, image_size):
107
+ """
108
+ Creates a config for the diffusers based on the config of the LDM model.
109
+ """
110
+
111
+
112
+ if 'imagedream' in original_config.model.target:
113
+ vae_params = original_config.model.params.vae_config.params.ddconfig
114
+ _ = original_config.model.params.vae_config.params.embed_dim
115
+ vae_key = "vae_model."
116
+ else:
117
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
118
+ _ = original_config.model.params.first_stage_config.params.embed_dim
119
+ vae_key = "first_stage_model."
120
+
121
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
122
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
123
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
124
+
125
+ config = {
126
+ "sample_size": image_size,
127
+ "in_channels": vae_params.in_channels,
128
+ "out_channels": vae_params.out_ch,
129
+ "down_block_types": tuple(down_block_types),
130
+ "up_block_types": tuple(up_block_types),
131
+ "block_out_channels": tuple(block_out_channels),
132
+ "latent_channels": vae_params.z_channels,
133
+ "layers_per_block": vae_params.num_res_blocks,
134
+ }
135
+ return config, vae_key
136
+
137
+
138
+ def convert_ldm_vae_checkpoint(checkpoint, config, vae_key):
139
+ # extract state dict for VAE
140
+ vae_state_dict = {}
141
+ keys = list(checkpoint.keys())
142
+ for key in keys:
143
+ if key.startswith(vae_key):
144
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
145
+
146
+ new_checkpoint = {}
147
+
148
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
149
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
150
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[
151
+ "encoder.conv_out.weight"
152
+ ]
153
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
154
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[
155
+ "encoder.norm_out.weight"
156
+ ]
157
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[
158
+ "encoder.norm_out.bias"
159
+ ]
160
+
161
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
162
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
163
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[
164
+ "decoder.conv_out.weight"
165
+ ]
166
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
167
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[
168
+ "decoder.norm_out.weight"
169
+ ]
170
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[
171
+ "decoder.norm_out.bias"
172
+ ]
173
+
174
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
175
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
176
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
177
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
178
+
179
+ # Retrieves the keys for the encoder down blocks only
180
+ num_down_blocks = len(
181
+ {
182
+ ".".join(layer.split(".")[:3])
183
+ for layer in vae_state_dict
184
+ if "encoder.down" in layer
185
+ }
186
+ )
187
+ down_blocks = {
188
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key]
189
+ for layer_id in range(num_down_blocks)
190
+ }
191
+
192
+ # Retrieves the keys for the decoder up blocks only
193
+ num_up_blocks = len(
194
+ {
195
+ ".".join(layer.split(".")[:3])
196
+ for layer in vae_state_dict
197
+ if "decoder.up" in layer
198
+ }
199
+ )
200
+ up_blocks = {
201
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key]
202
+ for layer_id in range(num_up_blocks)
203
+ }
204
+
205
+ for i in range(num_down_blocks):
206
+ resnets = [
207
+ key
208
+ for key in down_blocks[i]
209
+ if f"down.{i}" in key and f"down.{i}.downsample" not in key
210
+ ]
211
+
212
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
213
+ new_checkpoint[
214
+ f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"
215
+ ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
216
+ new_checkpoint[
217
+ f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"
218
+ ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
219
+
220
+ paths = renew_vae_resnet_paths(resnets)
221
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
222
+ assign_to_checkpoint(
223
+ paths,
224
+ new_checkpoint,
225
+ vae_state_dict,
226
+ additional_replacements=[meta_path],
227
+ config=config,
228
+ )
229
+
230
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
231
+ num_mid_res_blocks = 2
232
+ for i in range(1, num_mid_res_blocks + 1):
233
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
234
+
235
+ paths = renew_vae_resnet_paths(resnets)
236
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
237
+ assign_to_checkpoint(
238
+ paths,
239
+ new_checkpoint,
240
+ vae_state_dict,
241
+ additional_replacements=[meta_path],
242
+ config=config,
243
+ )
244
+
245
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
246
+ paths = renew_vae_attention_paths(mid_attentions)
247
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
248
+ assign_to_checkpoint(
249
+ paths,
250
+ new_checkpoint,
251
+ vae_state_dict,
252
+ additional_replacements=[meta_path],
253
+ config=config,
254
+ )
255
+ conv_attn_to_linear(new_checkpoint)
256
+
257
+ for i in range(num_up_blocks):
258
+ block_id = num_up_blocks - 1 - i
259
+ resnets = [
260
+ key
261
+ for key in up_blocks[block_id]
262
+ if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
263
+ ]
264
+
265
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
266
+ new_checkpoint[
267
+ f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"
268
+ ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
269
+ new_checkpoint[
270
+ f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"
271
+ ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
272
+
273
+ paths = renew_vae_resnet_paths(resnets)
274
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
275
+ assign_to_checkpoint(
276
+ paths,
277
+ new_checkpoint,
278
+ vae_state_dict,
279
+ additional_replacements=[meta_path],
280
+ config=config,
281
+ )
282
+
283
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
284
+ num_mid_res_blocks = 2
285
+ for i in range(1, num_mid_res_blocks + 1):
286
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
287
+
288
+ paths = renew_vae_resnet_paths(resnets)
289
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
290
+ assign_to_checkpoint(
291
+ paths,
292
+ new_checkpoint,
293
+ vae_state_dict,
294
+ additional_replacements=[meta_path],
295
+ config=config,
296
+ )
297
+
298
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
299
+ paths = renew_vae_attention_paths(mid_attentions)
300
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
301
+ assign_to_checkpoint(
302
+ paths,
303
+ new_checkpoint,
304
+ vae_state_dict,
305
+ additional_replacements=[meta_path],
306
+ config=config,
307
+ )
308
+ conv_attn_to_linear(new_checkpoint)
309
+ return new_checkpoint
310
+
311
+
312
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
313
+ """
314
+ Updates paths inside resnets to the new naming scheme (local renaming)
315
+ """
316
+ mapping = []
317
+ for old_item in old_list:
318
+ new_item = old_item
319
+
320
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
321
+ new_item = shave_segments(
322
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments
323
+ )
324
+
325
+ mapping.append({"old": old_item, "new": new_item})
326
+
327
+ return mapping
328
+
329
+
330
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
331
+ """
332
+ Updates paths inside attentions to the new naming scheme (local renaming)
333
+ """
334
+ mapping = []
335
+ for old_item in old_list:
336
+ new_item = old_item
337
+
338
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
339
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
340
+
341
+ new_item = new_item.replace("q.weight", "to_q.weight")
342
+ new_item = new_item.replace("q.bias", "to_q.bias")
343
+
344
+ new_item = new_item.replace("k.weight", "to_k.weight")
345
+ new_item = new_item.replace("k.bias", "to_k.bias")
346
+
347
+ new_item = new_item.replace("v.weight", "to_v.weight")
348
+ new_item = new_item.replace("v.bias", "to_v.bias")
349
+
350
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
351
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
352
+
353
+ new_item = shave_segments(
354
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments
355
+ )
356
+
357
+ mapping.append({"old": old_item, "new": new_item})
358
+
359
+ return mapping
360
+
361
+
362
+ def conv_attn_to_linear(checkpoint):
363
+ keys = list(checkpoint.keys())
364
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
365
+ for key in keys:
366
+ if ".".join(key.split(".")[-2:]) in attn_keys:
367
+ if checkpoint[key].ndim > 2:
368
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
369
+ elif "proj_attn.weight" in key:
370
+ if checkpoint[key].ndim > 2:
371
+ checkpoint[key] = checkpoint[key][:, :, 0]
372
+
373
+
374
+ def create_unet_config(original_config):
375
+ return OmegaConf.to_container(
376
+ original_config.model.params.unet_config.params, resolve=True
377
+ )
378
+
379
+
380
+ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, device):
381
+ checkpoint = torch.load(checkpoint_path, map_location=device)
382
+ # print(f"Checkpoint: {checkpoint.keys()}")
383
+ torch.cuda.empty_cache()
384
+
385
+ original_config = OmegaConf.load(original_config_file)
386
+ # print(f"Original Config: {original_config}")
387
+ prediction_type = "epsilon"
388
+ image_size = 256
389
+ num_train_timesteps = (
390
+ getattr(original_config.model.params, "timesteps", None) or 1000
391
+ )
392
+ beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
393
+ beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
394
+ scheduler = DDIMScheduler(
395
+ beta_end=beta_end,
396
+ beta_schedule="scaled_linear",
397
+ beta_start=beta_start,
398
+ num_train_timesteps=num_train_timesteps,
399
+ steps_offset=1,
400
+ clip_sample=False,
401
+ set_alpha_to_one=False,
402
+ prediction_type=prediction_type,
403
+ )
404
+ scheduler.register_to_config(clip_sample=False)
405
+
406
+ unet_config = create_unet_config(original_config)
407
+
408
+ # remove unused configs
409
+ unet_config.pop('legacy', None)
410
+ unet_config.pop('use_linear_in_transformer', None)
411
+ unet_config.pop('use_spatial_transformer', None)
412
+
413
+ unet_config.pop('ip_mode', None)
414
+ unet_config.pop('with_ip', None)
415
+
416
+ unet = MultiViewUNetModel(**unet_config)
417
+ unet.register_to_config(**unet_config)
418
+ # print(f"Unet State Dict: {unet.state_dict().keys()}")
419
+ unet.load_state_dict(
420
+ {
421
+ key.replace("model.diffusion_model.", ""): value
422
+ for key, value in checkpoint.items()
423
+ if key.replace("model.diffusion_model.", "") in unet.state_dict()
424
+ }
425
+ )
426
+ for param_name, param in unet.state_dict().items():
427
+ set_module_tensor_to_device(unet, param_name, device=device, value=param)
428
+
429
+ # Convert the VAE model.
430
+ vae_config, vae_key = create_vae_diffusers_config(original_config, image_size=image_size)
431
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config, vae_key)
432
+
433
+ if (
434
+ "model" in original_config
435
+ and "params" in original_config.model
436
+ and "scale_factor" in original_config.model.params
437
+ ):
438
+ vae_scaling_factor = original_config.model.params.scale_factor
439
+ else:
440
+ vae_scaling_factor = 0.18215 # default SD scaling factor
441
+
442
+ vae_config["scaling_factor"] = vae_scaling_factor
443
+
444
+ with init_empty_weights():
445
+ vae = AutoencoderKL(**vae_config)
446
+
447
+ for param_name, param in converted_vae_checkpoint.items():
448
+ set_module_tensor_to_device(vae, param_name, device=device, value=param)
449
+
450
+ # we only supports SD 2.1 based model
451
+ tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="tokenizer")
452
+ text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder").to(device=device) # type: ignore
453
+
454
+ # imagedream variant
455
+ if unet.ip_dim > 0:
456
+ feature_extractor: CLIPImageProcessor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
457
+ image_encoder: CLIPVisionModel = CLIPVisionModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
458
+ else:
459
+ feature_extractor = None
460
+ image_encoder = None
461
+
462
+ pipe = MVDreamPipeline(
463
+ vae=vae,
464
+ unet=unet,
465
+ tokenizer=tokenizer,
466
+ text_encoder=text_encoder,
467
+ scheduler=scheduler,
468
+ feature_extractor=feature_extractor,
469
+ image_encoder=image_encoder,
470
+ )
471
+
472
+ return pipe
473
+
474
+
475
+ if __name__ == "__main__":
476
+ parser = argparse.ArgumentParser()
477
+
478
+ parser.add_argument(
479
+ "--checkpoint_path",
480
+ default=None,
481
+ type=str,
482
+ required=True,
483
+ help="Path to the checkpoint to convert.",
484
+ )
485
+ parser.add_argument(
486
+ "--original_config_file",
487
+ default=None,
488
+ type=str,
489
+ help="The YAML config file corresponding to the original architecture.",
490
+ )
491
+ parser.add_argument(
492
+ "--to_safetensors",
493
+ action="store_true",
494
+ help="Whether to store pipeline in safetensors format or not.",
495
+ )
496
+ parser.add_argument(
497
+ "--half", action="store_true", help="Save weights in half precision."
498
+ )
499
+ parser.add_argument(
500
+ "--test",
501
+ action="store_true",
502
+ help="Whether to test inference after convertion.",
503
+ )
504
+ parser.add_argument(
505
+ "--dump_path",
506
+ default=None,
507
+ type=str,
508
+ required=True,
509
+ help="Path to the output model.",
510
+ )
511
+ parser.add_argument(
512
+ "--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)"
513
+ )
514
+ args = parser.parse_args()
515
+
516
+ args.device = torch.device(
517
+ args.device
518
+ if args.device is not None
519
+ else "cuda"
520
+ if torch.cuda.is_available()
521
+ else "cpu"
522
+ )
523
+
524
+ pipe = convert_from_original_mvdream_ckpt(
525
+ checkpoint_path=args.checkpoint_path,
526
+ original_config_file=args.original_config_file,
527
+ device=args.device,
528
+ )
529
+
530
+ if args.half:
531
+ pipe.to(torch_dtype=torch.float16)
532
+
533
+ print(f"Saving pipeline to {args.dump_path}...")
534
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
535
+
536
+ if args.test:
537
+ try:
538
+ # mvdream
539
+ if pipe.unet.ip_dim == 0:
540
+ print(f"Testing each subcomponent of the pipeline...")
541
+ images = pipe(
542
+ prompt="Head of Hatsune Miku",
543
+ negative_prompt="painting, bad quality, flat",
544
+ output_type="pil",
545
+ guidance_scale=7.5,
546
+ num_inference_steps=50,
547
+ device=args.device,
548
+ )
549
+ for i, image in enumerate(images):
550
+ image.save(f"test_image_{i}.png") # type: ignore
551
+
552
+ print(f"Testing entire pipeline...")
553
+ loaded_pipe = MVDreamPipeline.from_pretrained(args.dump_path) # type: ignore
554
+ images = loaded_pipe(
555
+ prompt="Head of Hatsune Miku",
556
+ negative_prompt="painting, bad quality, flat",
557
+ output_type="pil",
558
+ guidance_scale=7.5,
559
+ num_inference_steps=50,
560
+ device=args.device,
561
+ )
562
+ for i, image in enumerate(images):
563
+ image.save(f"test_image_{i}.png") # type: ignore
564
+ # imagedream
565
+ else:
566
+ input_image = kiui.read_image('data/anya_rgba.png', mode='float')
567
+ print(f"Testing each subcomponent of the pipeline...")
568
+ images = pipe(
569
+ image=input_image,
570
+ prompt="",
571
+ negative_prompt="",
572
+ output_type="pil",
573
+ guidance_scale=5.0,
574
+ num_inference_steps=50,
575
+ device=args.device,
576
+ )
577
+ for i, image in enumerate(images):
578
+ image.save(f"test_image_{i}.png") # type: ignore
579
+
580
+ print(f"Testing entire pipeline...")
581
+ loaded_pipe = MVDreamPipeline.from_pretrained(args.dump_path) # type: ignore
582
+ images = loaded_pipe(
583
+ image=input_image,
584
+ prompt="",
585
+ negative_prompt="",
586
+ output_type="pil",
587
+ guidance_scale=5.0,
588
+ num_inference_steps=50,
589
+ device=args.device,
590
+ )
591
+ for i, image in enumerate(images):
592
+ image.save(f"test_image_{i}.png") # type: ignore
593
+
594
+
595
+ print("Inference test passed!")
596
+ except Exception as e:
597
+ print(f"Failed to test inference: {e}")
apps/third_party/LGM/data/anya_rgba.png ADDED
apps/third_party/LGM/data/corgi.jpg ADDED
apps/third_party/LGM/mv_unet.py ADDED
@@ -0,0 +1,1005 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ from inspect import isfunction
4
+ from typing import Optional, Any, List
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from diffusers.configuration_utils import ConfigMixin
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+
14
+ # require xformers!
15
+ import xformers
16
+ import xformers.ops
17
+
18
+ from kiui.cam import orbit_camera
19
+
20
+ def get_camera(
21
+ num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
22
+ ):
23
+ angle_gap = azimuth_span / num_frames
24
+ cameras = []
25
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
26
+
27
+ pose = orbit_camera(-elevation, azimuth, radius=1) # kiui's elevation is negated, [4, 4]
28
+
29
+ # opengl to blender
30
+ if blender_coord:
31
+ pose[2] *= -1
32
+ pose[[1, 2]] = pose[[2, 1]]
33
+
34
+ cameras.append(pose.flatten())
35
+
36
+ if extra_view:
37
+ cameras.append(np.zeros_like(cameras[0]))
38
+
39
+ return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
40
+
41
+
42
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
43
+ """
44
+ Create sinusoidal timestep embeddings.
45
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
46
+ These may be fractional.
47
+ :param dim: the dimension of the output.
48
+ :param max_period: controls the minimum frequency of the embeddings.
49
+ :return: an [N x dim] Tensor of positional embeddings.
50
+ """
51
+ if not repeat_only:
52
+ half = dim // 2
53
+ freqs = torch.exp(
54
+ -math.log(max_period)
55
+ * torch.arange(start=0, end=half, dtype=torch.float32)
56
+ / half
57
+ ).to(device=timesteps.device)
58
+ args = timesteps[:, None] * freqs[None]
59
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
60
+ if dim % 2:
61
+ embedding = torch.cat(
62
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
63
+ )
64
+ else:
65
+ embedding = repeat(timesteps, "b -> b d", d=dim)
66
+ # import pdb; pdb.set_trace()
67
+ return embedding
68
+
69
+
70
+ def zero_module(module):
71
+ """
72
+ Zero out the parameters of a module and return it.
73
+ """
74
+ for p in module.parameters():
75
+ p.detach().zero_()
76
+ return module
77
+
78
+
79
+ def conv_nd(dims, *args, **kwargs):
80
+ """
81
+ Create a 1D, 2D, or 3D convolution module.
82
+ """
83
+ if dims == 1:
84
+ return nn.Conv1d(*args, **kwargs)
85
+ elif dims == 2:
86
+ return nn.Conv2d(*args, **kwargs)
87
+ elif dims == 3:
88
+ return nn.Conv3d(*args, **kwargs)
89
+ raise ValueError(f"unsupported dimensions: {dims}")
90
+
91
+
92
+ def avg_pool_nd(dims, *args, **kwargs):
93
+ """
94
+ Create a 1D, 2D, or 3D average pooling module.
95
+ """
96
+ if dims == 1:
97
+ return nn.AvgPool1d(*args, **kwargs)
98
+ elif dims == 2:
99
+ return nn.AvgPool2d(*args, **kwargs)
100
+ elif dims == 3:
101
+ return nn.AvgPool3d(*args, **kwargs)
102
+ raise ValueError(f"unsupported dimensions: {dims}")
103
+
104
+
105
+ def default(val, d):
106
+ if val is not None:
107
+ return val
108
+ return d() if isfunction(d) else d
109
+
110
+
111
+ class GEGLU(nn.Module):
112
+ def __init__(self, dim_in, dim_out):
113
+ super().__init__()
114
+ self.proj = nn.Linear(dim_in, dim_out * 2)
115
+
116
+ def forward(self, x):
117
+ x, gate = self.proj(x).chunk(2, dim=-1)
118
+ return x * F.gelu(gate)
119
+
120
+
121
+ class FeedForward(nn.Module):
122
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
123
+ super().__init__()
124
+ inner_dim = int(dim * mult)
125
+ dim_out = default(dim_out, dim)
126
+ project_in = (
127
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
128
+ if not glu
129
+ else GEGLU(dim, inner_dim)
130
+ )
131
+
132
+ self.net = nn.Sequential(
133
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
134
+ )
135
+
136
+ def forward(self, x):
137
+ return self.net(x)
138
+
139
+
140
+ class MemoryEfficientCrossAttention(nn.Module):
141
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
142
+ def __init__(
143
+ self,
144
+ query_dim,
145
+ context_dim=None,
146
+ heads=8,
147
+ dim_head=64,
148
+ dropout=0.0,
149
+ ip_dim=0,
150
+ ip_weight=1,
151
+ ):
152
+ super().__init__()
153
+
154
+ inner_dim = dim_head * heads
155
+ context_dim = default(context_dim, query_dim)
156
+
157
+ self.heads = heads
158
+ self.dim_head = dim_head
159
+
160
+ self.ip_dim = ip_dim
161
+ self.ip_weight = ip_weight
162
+
163
+ if self.ip_dim > 0:
164
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
165
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
166
+
167
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
168
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
169
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
170
+
171
+ self.to_out = nn.Sequential(
172
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
173
+ )
174
+ self.attention_op = None
175
+
176
+ def forward(self, x, context=None):
177
+ q = self.to_q(x)
178
+ context = default(context, x)
179
+
180
+ if self.ip_dim > 0:
181
+ # context: [B, 77 + 16(ip), 1024]
182
+ token_len = context.shape[1]
183
+ context_ip = context[:, -self.ip_dim :, :]
184
+ k_ip = self.to_k_ip(context_ip)
185
+ v_ip = self.to_v_ip(context_ip)
186
+ context = context[:, : (token_len - self.ip_dim), :]
187
+
188
+ k = self.to_k(context)
189
+ v = self.to_v(context)
190
+
191
+ b, _, _ = q.shape
192
+ q, k, v = map(
193
+ lambda t: t.unsqueeze(3)
194
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
195
+ .permute(0, 2, 1, 3)
196
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
197
+ .contiguous(),
198
+ (q, k, v),
199
+ )
200
+
201
+ # actually compute the attention, what we cannot get enough of
202
+ out = xformers.ops.memory_efficient_attention(
203
+ q, k, v, attn_bias=None, op=self.attention_op
204
+ )
205
+
206
+ if self.ip_dim > 0:
207
+ k_ip, v_ip = map(
208
+ lambda t: t.unsqueeze(3)
209
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
210
+ .permute(0, 2, 1, 3)
211
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
212
+ .contiguous(),
213
+ (k_ip, v_ip),
214
+ )
215
+ # actually compute the attention, what we cannot get enough of
216
+ out_ip = xformers.ops.memory_efficient_attention(
217
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
218
+ )
219
+ out = out + self.ip_weight * out_ip
220
+
221
+ out = (
222
+ out.unsqueeze(0)
223
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
224
+ .permute(0, 2, 1, 3)
225
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
226
+ )
227
+ return self.to_out(out)
228
+
229
+
230
+ class BasicTransformerBlock3D(nn.Module):
231
+
232
+ def __init__(
233
+ self,
234
+ dim,
235
+ n_heads,
236
+ d_head,
237
+ context_dim,
238
+ dropout=0.0,
239
+ gated_ff=True,
240
+ ip_dim=0,
241
+ ip_weight=1,
242
+ ):
243
+ super().__init__()
244
+
245
+ self.attn1 = MemoryEfficientCrossAttention(
246
+ query_dim=dim,
247
+ context_dim=None, # self-attention
248
+ heads=n_heads,
249
+ dim_head=d_head,
250
+ dropout=dropout,
251
+ )
252
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
253
+ self.attn2 = MemoryEfficientCrossAttention(
254
+ query_dim=dim,
255
+ context_dim=context_dim,
256
+ heads=n_heads,
257
+ dim_head=d_head,
258
+ dropout=dropout,
259
+ # ip only applies to cross-attention
260
+ ip_dim=ip_dim,
261
+ ip_weight=ip_weight,
262
+ )
263
+ self.norm1 = nn.LayerNorm(dim)
264
+ self.norm2 = nn.LayerNorm(dim)
265
+ self.norm3 = nn.LayerNorm(dim)
266
+
267
+ def forward(self, x, context=None, num_frames=1):
268
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
269
+ x = self.attn1(self.norm1(x), context=None) + x
270
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
271
+ x = self.attn2(self.norm2(x), context=context) + x
272
+ x = self.ff(self.norm3(x)) + x
273
+ return x
274
+
275
+
276
+ class SpatialTransformer3D(nn.Module):
277
+
278
+ def __init__(
279
+ self,
280
+ in_channels,
281
+ n_heads,
282
+ d_head,
283
+ context_dim, # cross attention input dim
284
+ depth=1,
285
+ dropout=0.0,
286
+ ip_dim=0,
287
+ ip_weight=1,
288
+ ):
289
+ super().__init__()
290
+
291
+ if not isinstance(context_dim, list):
292
+ context_dim = [context_dim]
293
+
294
+ self.in_channels = in_channels
295
+
296
+ inner_dim = n_heads * d_head
297
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
298
+ self.proj_in = nn.Linear(in_channels, inner_dim)
299
+
300
+ self.transformer_blocks = nn.ModuleList(
301
+ [
302
+ BasicTransformerBlock3D(
303
+ inner_dim,
304
+ n_heads,
305
+ d_head,
306
+ context_dim=context_dim[d],
307
+ dropout=dropout,
308
+ ip_dim=ip_dim,
309
+ ip_weight=ip_weight,
310
+ )
311
+ for d in range(depth)
312
+ ]
313
+ )
314
+
315
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
316
+
317
+
318
+ def forward(self, x, context=None, num_frames=1):
319
+ # note: if no context is given, cross-attention defaults to self-attention
320
+ if not isinstance(context, list):
321
+ context = [context]
322
+ b, c, h, w = x.shape
323
+ x_in = x
324
+ x = self.norm(x)
325
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
326
+ x = self.proj_in(x)
327
+ for i, block in enumerate(self.transformer_blocks):
328
+ x = block(x, context=context[i], num_frames=num_frames)
329
+ x = self.proj_out(x)
330
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
331
+
332
+ return x + x_in
333
+
334
+
335
+ class PerceiverAttention(nn.Module):
336
+ def __init__(self, *, dim, dim_head=64, heads=8):
337
+ super().__init__()
338
+ self.scale = dim_head ** -0.5
339
+ self.dim_head = dim_head
340
+ self.heads = heads
341
+ inner_dim = dim_head * heads
342
+
343
+ self.norm1 = nn.LayerNorm(dim)
344
+ self.norm2 = nn.LayerNorm(dim)
345
+
346
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
347
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
348
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
349
+
350
+ def forward(self, x, latents):
351
+ """
352
+ Args:
353
+ x (torch.Tensor): image features
354
+ shape (b, n1, D)
355
+ latent (torch.Tensor): latent features
356
+ shape (b, n2, D)
357
+ """
358
+ x = self.norm1(x)
359
+ latents = self.norm2(latents)
360
+
361
+ b, l, _ = latents.shape
362
+
363
+ q = self.to_q(latents)
364
+ kv_input = torch.cat((x, latents), dim=-2)
365
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
366
+
367
+ q, k, v = map(
368
+ lambda t: t.reshape(b, t.shape[1], self.heads, -1)
369
+ .transpose(1, 2)
370
+ .reshape(b, self.heads, t.shape[1], -1)
371
+ .contiguous(),
372
+ (q, k, v),
373
+ )
374
+
375
+ # attention
376
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
377
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
378
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
379
+ out = weight @ v
380
+
381
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
382
+
383
+ return self.to_out(out)
384
+
385
+
386
+ class Resampler(nn.Module):
387
+ def __init__(
388
+ self,
389
+ dim=1024,
390
+ depth=8,
391
+ dim_head=64,
392
+ heads=16,
393
+ num_queries=8,
394
+ embedding_dim=768,
395
+ output_dim=1024,
396
+ ff_mult=4,
397
+ ):
398
+ super().__init__()
399
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
400
+ self.proj_in = nn.Linear(embedding_dim, dim)
401
+ self.proj_out = nn.Linear(dim, output_dim)
402
+ self.norm_out = nn.LayerNorm(output_dim)
403
+
404
+ self.layers = nn.ModuleList([])
405
+ for _ in range(depth):
406
+ self.layers.append(
407
+ nn.ModuleList(
408
+ [
409
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
410
+ nn.Sequential(
411
+ nn.LayerNorm(dim),
412
+ nn.Linear(dim, dim * ff_mult, bias=False),
413
+ nn.GELU(),
414
+ nn.Linear(dim * ff_mult, dim, bias=False),
415
+ )
416
+ ]
417
+ )
418
+ )
419
+
420
+ def forward(self, x):
421
+ latents = self.latents.repeat(x.size(0), 1, 1)
422
+ x = self.proj_in(x)
423
+ for attn, ff in self.layers:
424
+ latents = attn(x, latents) + latents
425
+ latents = ff(latents) + latents
426
+
427
+ latents = self.proj_out(latents)
428
+ return self.norm_out(latents)
429
+
430
+
431
+ class CondSequential(nn.Sequential):
432
+ """
433
+ A sequential module that passes timestep embeddings to the children that
434
+ support it as an extra input.
435
+ """
436
+
437
+ def forward(self, x, emb, context=None, num_frames=1):
438
+ for layer in self:
439
+ if isinstance(layer, ResBlock):
440
+ x = layer(x, emb)
441
+ elif isinstance(layer, SpatialTransformer3D):
442
+ x = layer(x, context, num_frames=num_frames)
443
+ else:
444
+ x = layer(x)
445
+ return x
446
+
447
+
448
+ class Upsample(nn.Module):
449
+ """
450
+ An upsampling layer with an optional convolution.
451
+ :param channels: channels in the inputs and outputs.
452
+ :param use_conv: a bool determining if a convolution is applied.
453
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
454
+ upsampling occurs in the inner-two dimensions.
455
+ """
456
+
457
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
458
+ super().__init__()
459
+ self.channels = channels
460
+ self.out_channels = out_channels or channels
461
+ self.use_conv = use_conv
462
+ self.dims = dims
463
+ if use_conv:
464
+ self.conv = conv_nd(
465
+ dims, self.channels, self.out_channels, 3, padding=padding
466
+ )
467
+
468
+ def forward(self, x):
469
+ assert x.shape[1] == self.channels
470
+ if self.dims == 3:
471
+ x = F.interpolate(
472
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
473
+ )
474
+ else:
475
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
476
+ if self.use_conv:
477
+ x = self.conv(x)
478
+ return x
479
+
480
+
481
+ class Downsample(nn.Module):
482
+ """
483
+ A downsampling layer with an optional convolution.
484
+ :param channels: channels in the inputs and outputs.
485
+ :param use_conv: a bool determining if a convolution is applied.
486
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
487
+ downsampling occurs in the inner-two dimensions.
488
+ """
489
+
490
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
491
+ super().__init__()
492
+ self.channels = channels
493
+ self.out_channels = out_channels or channels
494
+ self.use_conv = use_conv
495
+ self.dims = dims
496
+ stride = 2 if dims != 3 else (1, 2, 2)
497
+ if use_conv:
498
+ self.op = conv_nd(
499
+ dims,
500
+ self.channels,
501
+ self.out_channels,
502
+ 3,
503
+ stride=stride,
504
+ padding=padding,
505
+ )
506
+ else:
507
+ assert self.channels == self.out_channels
508
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
509
+
510
+ def forward(self, x):
511
+ assert x.shape[1] == self.channels
512
+ return self.op(x)
513
+
514
+
515
+ class ResBlock(nn.Module):
516
+ """
517
+ A residual block that can optionally change the number of channels.
518
+ :param channels: the number of input channels.
519
+ :param emb_channels: the number of timestep embedding channels.
520
+ :param dropout: the rate of dropout.
521
+ :param out_channels: if specified, the number of out channels.
522
+ :param use_conv: if True and out_channels is specified, use a spatial
523
+ convolution instead of a smaller 1x1 convolution to change the
524
+ channels in the skip connection.
525
+ :param dims: determines if the signal is 1D, 2D, or 3D.
526
+ :param up: if True, use this block for upsampling.
527
+ :param down: if True, use this block for downsampling.
528
+ """
529
+
530
+ def __init__(
531
+ self,
532
+ channels,
533
+ emb_channels,
534
+ dropout,
535
+ out_channels=None,
536
+ use_conv=False,
537
+ use_scale_shift_norm=False,
538
+ dims=2,
539
+ up=False,
540
+ down=False,
541
+ ):
542
+ super().__init__()
543
+ self.channels = channels
544
+ self.emb_channels = emb_channels
545
+ self.dropout = dropout
546
+ self.out_channels = out_channels or channels
547
+ self.use_conv = use_conv
548
+ self.use_scale_shift_norm = use_scale_shift_norm
549
+
550
+ self.in_layers = nn.Sequential(
551
+ nn.GroupNorm(32, channels),
552
+ nn.SiLU(),
553
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
554
+ )
555
+
556
+ self.updown = up or down
557
+
558
+ if up:
559
+ self.h_upd = Upsample(channels, False, dims)
560
+ self.x_upd = Upsample(channels, False, dims)
561
+ elif down:
562
+ self.h_upd = Downsample(channels, False, dims)
563
+ self.x_upd = Downsample(channels, False, dims)
564
+ else:
565
+ self.h_upd = self.x_upd = nn.Identity()
566
+
567
+ self.emb_layers = nn.Sequential(
568
+ nn.SiLU(),
569
+ nn.Linear(
570
+ emb_channels,
571
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
572
+ ),
573
+ )
574
+ self.out_layers = nn.Sequential(
575
+ nn.GroupNorm(32, self.out_channels),
576
+ nn.SiLU(),
577
+ nn.Dropout(p=dropout),
578
+ zero_module(
579
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
580
+ ),
581
+ )
582
+
583
+ if self.out_channels == channels:
584
+ self.skip_connection = nn.Identity()
585
+ elif use_conv:
586
+ self.skip_connection = conv_nd(
587
+ dims, channels, self.out_channels, 3, padding=1
588
+ )
589
+ else:
590
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
591
+
592
+ def forward(self, x, emb):
593
+ if self.updown:
594
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
595
+ h = in_rest(x)
596
+ h = self.h_upd(h)
597
+ x = self.x_upd(x)
598
+ h = in_conv(h)
599
+ else:
600
+ h = self.in_layers(x)
601
+ emb_out = self.emb_layers(emb).type(h.dtype)
602
+ while len(emb_out.shape) < len(h.shape):
603
+ emb_out = emb_out[..., None]
604
+ if self.use_scale_shift_norm:
605
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
606
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
607
+ h = out_norm(h) * (1 + scale) + shift
608
+ h = out_rest(h)
609
+ else:
610
+ h = h + emb_out
611
+ h = self.out_layers(h)
612
+ return self.skip_connection(x) + h
613
+
614
+
615
+ class MultiViewUNetModel(ModelMixin, ConfigMixin):
616
+ """
617
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
618
+ :param in_channels: channels in the input Tensor.
619
+ :param model_channels: base channel count for the model.
620
+ :param out_channels: channels in the output Tensor.
621
+ :param num_res_blocks: number of residual blocks per downsample.
622
+ :param attention_resolutions: a collection of downsample rates at which
623
+ attention will take place. May be a set, list, or tuple.
624
+ For example, if this contains 4, then at 4x downsampling, attention
625
+ will be used.
626
+ :param dropout: the dropout probability.
627
+ :param channel_mult: channel multiplier for each level of the UNet.
628
+ :param conv_resample: if True, use learned convolutions for upsampling and
629
+ downsampling.
630
+ :param dims: determines if the signal is 1D, 2D, or 3D.
631
+ :param num_classes: if specified (as an int), then this model will be
632
+ class-conditional with `num_classes` classes.
633
+ :param num_heads: the number of attention heads in each attention layer.
634
+ :param num_heads_channels: if specified, ignore num_heads and instead use
635
+ a fixed channel width per attention head.
636
+ :param num_heads_upsample: works with num_heads to set a different number
637
+ of heads for upsampling. Deprecated.
638
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
639
+ :param resblock_updown: use residual blocks for up/downsampling.
640
+ :param use_new_attention_order: use a different attention pattern for potentially
641
+ increased efficiency.
642
+ :param camera_dim: dimensionality of camera input.
643
+ """
644
+
645
+ def __init__(
646
+ self,
647
+ image_size,
648
+ in_channels,
649
+ model_channels,
650
+ out_channels,
651
+ num_res_blocks,
652
+ attention_resolutions,
653
+ dropout=0,
654
+ channel_mult=(1, 2, 4, 8),
655
+ conv_resample=True,
656
+ dims=2,
657
+ num_classes=None,
658
+ num_heads=-1,
659
+ num_head_channels=-1,
660
+ num_heads_upsample=-1,
661
+ use_scale_shift_norm=False,
662
+ resblock_updown=False,
663
+ transformer_depth=1,
664
+ context_dim=None,
665
+ n_embed=None,
666
+ num_attention_blocks=None,
667
+ adm_in_channels=None,
668
+ camera_dim=None,
669
+ ip_dim=0, # imagedream uses ip_dim > 0
670
+ ip_weight=1.0,
671
+ **kwargs,
672
+ ):
673
+ super().__init__()
674
+ assert context_dim is not None
675
+
676
+ if num_heads_upsample == -1:
677
+ num_heads_upsample = num_heads
678
+
679
+ if num_heads == -1:
680
+ assert (
681
+ num_head_channels != -1
682
+ ), "Either num_heads or num_head_channels has to be set"
683
+
684
+ if num_head_channels == -1:
685
+ assert (
686
+ num_heads != -1
687
+ ), "Either num_heads or num_head_channels has to be set"
688
+
689
+ self.image_size = image_size
690
+ self.in_channels = in_channels
691
+ self.model_channels = model_channels
692
+ self.out_channels = out_channels
693
+ if isinstance(num_res_blocks, int):
694
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
695
+ else:
696
+ if len(num_res_blocks) != len(channel_mult):
697
+ raise ValueError(
698
+ "provide num_res_blocks either as an int (globally constant) or "
699
+ "as a list/tuple (per-level) with the same length as channel_mult"
700
+ )
701
+ self.num_res_blocks = num_res_blocks
702
+
703
+ if num_attention_blocks is not None:
704
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
705
+ assert all(
706
+ map(
707
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
708
+ range(len(num_attention_blocks)),
709
+ )
710
+ )
711
+ print(
712
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
713
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
714
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
715
+ f"attention will still not be set."
716
+ )
717
+
718
+ self.attention_resolutions = attention_resolutions
719
+ self.dropout = dropout
720
+ self.channel_mult = channel_mult
721
+ self.conv_resample = conv_resample
722
+ self.num_classes = num_classes
723
+ self.num_heads = num_heads
724
+ self.num_head_channels = num_head_channels
725
+ self.num_heads_upsample = num_heads_upsample
726
+ self.predict_codebook_ids = n_embed is not None
727
+
728
+ self.ip_dim = ip_dim
729
+ self.ip_weight = ip_weight
730
+
731
+ if self.ip_dim > 0:
732
+ self.image_embed = Resampler(
733
+ dim=context_dim,
734
+ depth=4,
735
+ dim_head=64,
736
+ heads=12,
737
+ num_queries=ip_dim, # num token
738
+ embedding_dim=1280,
739
+ output_dim=context_dim,
740
+ ff_mult=4,
741
+ )
742
+
743
+ time_embed_dim = model_channels * 4
744
+ self.time_embed = nn.Sequential(
745
+ nn.Linear(model_channels, time_embed_dim),
746
+ nn.SiLU(),
747
+ nn.Linear(time_embed_dim, time_embed_dim),
748
+ )
749
+
750
+ if camera_dim is not None:
751
+ time_embed_dim = model_channels * 4
752
+ self.camera_embed = nn.Sequential(
753
+ nn.Linear(camera_dim, time_embed_dim),
754
+ nn.SiLU(),
755
+ nn.Linear(time_embed_dim, time_embed_dim),
756
+ )
757
+
758
+ if self.num_classes is not None:
759
+ if isinstance(self.num_classes, int):
760
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
761
+ elif self.num_classes == "continuous":
762
+ # print("setting up linear c_adm embedding layer")
763
+ self.label_emb = nn.Linear(1, time_embed_dim)
764
+ elif self.num_classes == "sequential":
765
+ assert adm_in_channels is not None
766
+ self.label_emb = nn.Sequential(
767
+ nn.Sequential(
768
+ nn.Linear(adm_in_channels, time_embed_dim),
769
+ nn.SiLU(),
770
+ nn.Linear(time_embed_dim, time_embed_dim),
771
+ )
772
+ )
773
+ else:
774
+ raise ValueError()
775
+
776
+ self.input_blocks = nn.ModuleList(
777
+ [
778
+ CondSequential(
779
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
780
+ )
781
+ ]
782
+ )
783
+ self._feature_size = model_channels
784
+ input_block_chans = [model_channels]
785
+ ch = model_channels
786
+ ds = 1
787
+ for level, mult in enumerate(channel_mult):
788
+ for nr in range(self.num_res_blocks[level]):
789
+ layers = [
790
+ ResBlock(
791
+ ch,
792
+ time_embed_dim,
793
+ dropout,
794
+ out_channels=mult * model_channels,
795
+ dims=dims,
796
+ use_scale_shift_norm=use_scale_shift_norm,
797
+ )
798
+ ]
799
+ ch = mult * model_channels
800
+ if ds in attention_resolutions:
801
+ if num_head_channels == -1:
802
+ dim_head = ch // num_heads
803
+ else:
804
+ num_heads = ch // num_head_channels
805
+ dim_head = num_head_channels
806
+
807
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
808
+ layers.append(
809
+ SpatialTransformer3D(
810
+ ch,
811
+ num_heads,
812
+ dim_head,
813
+ context_dim=context_dim,
814
+ depth=transformer_depth,
815
+ ip_dim=self.ip_dim,
816
+ ip_weight=self.ip_weight,
817
+ )
818
+ )
819
+ self.input_blocks.append(CondSequential(*layers))
820
+ self._feature_size += ch
821
+ input_block_chans.append(ch)
822
+ if level != len(channel_mult) - 1:
823
+ out_ch = ch
824
+ self.input_blocks.append(
825
+ CondSequential(
826
+ ResBlock(
827
+ ch,
828
+ time_embed_dim,
829
+ dropout,
830
+ out_channels=out_ch,
831
+ dims=dims,
832
+ use_scale_shift_norm=use_scale_shift_norm,
833
+ down=True,
834
+ )
835
+ if resblock_updown
836
+ else Downsample(
837
+ ch, conv_resample, dims=dims, out_channels=out_ch
838
+ )
839
+ )
840
+ )
841
+ ch = out_ch
842
+ input_block_chans.append(ch)
843
+ ds *= 2
844
+ self._feature_size += ch
845
+
846
+ if num_head_channels == -1:
847
+ dim_head = ch // num_heads
848
+ else:
849
+ num_heads = ch // num_head_channels
850
+ dim_head = num_head_channels
851
+
852
+ self.middle_block = CondSequential(
853
+ ResBlock(
854
+ ch,
855
+ time_embed_dim,
856
+ dropout,
857
+ dims=dims,
858
+ use_scale_shift_norm=use_scale_shift_norm,
859
+ ),
860
+ SpatialTransformer3D(
861
+ ch,
862
+ num_heads,
863
+ dim_head,
864
+ context_dim=context_dim,
865
+ depth=transformer_depth,
866
+ ip_dim=self.ip_dim,
867
+ ip_weight=self.ip_weight,
868
+ ),
869
+ ResBlock(
870
+ ch,
871
+ time_embed_dim,
872
+ dropout,
873
+ dims=dims,
874
+ use_scale_shift_norm=use_scale_shift_norm,
875
+ ),
876
+ )
877
+ self._feature_size += ch
878
+
879
+ self.output_blocks = nn.ModuleList([])
880
+ for level, mult in list(enumerate(channel_mult))[::-1]:
881
+ for i in range(self.num_res_blocks[level] + 1):
882
+ ich = input_block_chans.pop()
883
+ layers = [
884
+ ResBlock(
885
+ ch + ich,
886
+ time_embed_dim,
887
+ dropout,
888
+ out_channels=model_channels * mult,
889
+ dims=dims,
890
+ use_scale_shift_norm=use_scale_shift_norm,
891
+ )
892
+ ]
893
+ ch = model_channels * mult
894
+ if ds in attention_resolutions:
895
+ if num_head_channels == -1:
896
+ dim_head = ch // num_heads
897
+ else:
898
+ num_heads = ch // num_head_channels
899
+ dim_head = num_head_channels
900
+
901
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
902
+ layers.append(
903
+ SpatialTransformer3D(
904
+ ch,
905
+ num_heads,
906
+ dim_head,
907
+ context_dim=context_dim,
908
+ depth=transformer_depth,
909
+ ip_dim=self.ip_dim,
910
+ ip_weight=self.ip_weight,
911
+ )
912
+ )
913
+ if level and i == self.num_res_blocks[level]:
914
+ out_ch = ch
915
+ layers.append(
916
+ ResBlock(
917
+ ch,
918
+ time_embed_dim,
919
+ dropout,
920
+ out_channels=out_ch,
921
+ dims=dims,
922
+ use_scale_shift_norm=use_scale_shift_norm,
923
+ up=True,
924
+ )
925
+ if resblock_updown
926
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
927
+ )
928
+ ds //= 2
929
+ self.output_blocks.append(CondSequential(*layers))
930
+ self._feature_size += ch
931
+
932
+ self.out = nn.Sequential(
933
+ nn.GroupNorm(32, ch),
934
+ nn.SiLU(),
935
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
936
+ )
937
+ if self.predict_codebook_ids:
938
+ self.id_predictor = nn.Sequential(
939
+ nn.GroupNorm(32, ch),
940
+ conv_nd(dims, model_channels, n_embed, 1),
941
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
942
+ )
943
+
944
+ def forward(
945
+ self,
946
+ x,
947
+ timesteps=None,
948
+ context=None,
949
+ y=None,
950
+ camera=None,
951
+ num_frames=1,
952
+ ip=None,
953
+ ip_img=None,
954
+ **kwargs,
955
+ ):
956
+ """
957
+ Apply the model to an input batch.
958
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
959
+ :param timesteps: a 1-D batch of timesteps.
960
+ :param context: conditioning plugged in via crossattn
961
+ :param y: an [N] Tensor of labels, if class-conditional.
962
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
963
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
964
+ """
965
+ assert (
966
+ x.shape[0] % num_frames == 0
967
+ ), "input batch size must be dividable by num_frames!"
968
+ assert (y is not None) == (
969
+ self.num_classes is not None
970
+ ), "must specify y if and only if the model is class-conditional"
971
+
972
+ hs = []
973
+
974
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
975
+
976
+ emb = self.time_embed(t_emb)
977
+
978
+ if self.num_classes is not None:
979
+ assert y is not None
980
+ assert y.shape[0] == x.shape[0]
981
+ emb = emb + self.label_emb(y)
982
+
983
+ # Add camera embeddings
984
+ if camera is not None:
985
+ emb = emb + self.camera_embed(camera)
986
+
987
+ # imagedream variant
988
+ if self.ip_dim > 0:
989
+ x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
990
+ ip_emb = self.image_embed(ip)
991
+ context = torch.cat((context, ip_emb), 1)
992
+
993
+ h = x
994
+ for module in self.input_blocks:
995
+ h = module(h, emb, context, num_frames=num_frames)
996
+ hs.append(h)
997
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
998
+ for module in self.output_blocks:
999
+ h = torch.cat([h, hs.pop()], dim=1)
1000
+ h = module(h, emb, context, num_frames=num_frames)
1001
+ h = h.type(x.dtype)
1002
+ if self.predict_codebook_ids:
1003
+ return self.id_predictor(h)
1004
+ else:
1005
+ return self.out(h)
apps/third_party/LGM/pipeline_mvdream.py ADDED
@@ -0,0 +1,557 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import inspect
4
+ import numpy as np
5
+ from typing import Callable, List, Optional, Union
6
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor
7
+ from diffusers import AutoencoderKL, DiffusionPipeline
8
+ from diffusers.utils import (
9
+ deprecate,
10
+ is_accelerate_available,
11
+ is_accelerate_version,
12
+ logging,
13
+ )
14
+ from diffusers.configuration_utils import FrozenDict
15
+ from diffusers.schedulers import DDIMScheduler
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+
18
+ from apps.third_party.LGM.mv_unet import MultiViewUNetModel, get_camera
19
+
20
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
21
+
22
+
23
+ class MVDreamPipeline(DiffusionPipeline):
24
+
25
+ _optional_components = ["feature_extractor", "image_encoder"]
26
+
27
+ def __init__(
28
+ self,
29
+ vae: AutoencoderKL,
30
+ unet: MultiViewUNetModel,
31
+ tokenizer: CLIPTokenizer,
32
+ text_encoder: CLIPTextModel,
33
+ scheduler: DDIMScheduler,
34
+ # imagedream variant
35
+ feature_extractor: CLIPImageProcessor,
36
+ image_encoder: CLIPVisionModel,
37
+ requires_safety_checker: bool = False,
38
+ ):
39
+ super().__init__()
40
+
41
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore
42
+ deprecation_message = (
43
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
44
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore
45
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
46
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
47
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
48
+ " file"
49
+ )
50
+ deprecate(
51
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
52
+ )
53
+ new_config = dict(scheduler.config)
54
+ new_config["steps_offset"] = 1
55
+ scheduler._internal_dict = FrozenDict(new_config)
56
+
57
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore
58
+ deprecation_message = (
59
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
60
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
61
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
62
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
63
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
64
+ )
65
+ deprecate(
66
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
67
+ )
68
+ new_config = dict(scheduler.config)
69
+ new_config["clip_sample"] = False
70
+ scheduler._internal_dict = FrozenDict(new_config)
71
+
72
+ self.register_modules(
73
+ vae=vae,
74
+ unet=unet,
75
+ scheduler=scheduler,
76
+ tokenizer=tokenizer,
77
+ text_encoder=text_encoder,
78
+ feature_extractor=feature_extractor,
79
+ image_encoder=image_encoder,
80
+ )
81
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
82
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
83
+
84
+ def enable_vae_slicing(self):
85
+ r"""
86
+ Enable sliced VAE decoding.
87
+
88
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
89
+ steps. This is useful to save some memory and allow larger batch sizes.
90
+ """
91
+ self.vae.enable_slicing()
92
+
93
+ def disable_vae_slicing(self):
94
+ r"""
95
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
96
+ computing decoding in one step.
97
+ """
98
+ self.vae.disable_slicing()
99
+
100
+ def enable_vae_tiling(self):
101
+ r"""
102
+ Enable tiled VAE decoding.
103
+
104
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
105
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
106
+ """
107
+ self.vae.enable_tiling()
108
+
109
+ def disable_vae_tiling(self):
110
+ r"""
111
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
112
+ computing decoding in one step.
113
+ """
114
+ self.vae.disable_tiling()
115
+
116
+ def enable_sequential_cpu_offload(self, gpu_id=0):
117
+ r"""
118
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
119
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
120
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
121
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
122
+ `enable_model_cpu_offload`, but performance is lower.
123
+ """
124
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
125
+ from accelerate import cpu_offload
126
+ else:
127
+ raise ImportError(
128
+ "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher"
129
+ )
130
+
131
+ device = torch.device(f"cuda:{gpu_id}")
132
+
133
+ if self.device.type != "cpu":
134
+ self.to("cpu", silence_dtype_warnings=True)
135
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
136
+
137
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
138
+ cpu_offload(cpu_offloaded_model, device)
139
+
140
+ def enable_model_cpu_offload(self, gpu_id=0):
141
+ r"""
142
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
143
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
144
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
145
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
146
+ """
147
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
148
+ from accelerate import cpu_offload_with_hook
149
+ else:
150
+ raise ImportError(
151
+ "`enable_model_offload` requires `accelerate v0.17.0` or higher."
152
+ )
153
+
154
+ device = torch.device(f"cuda:{gpu_id}")
155
+
156
+ if self.device.type != "cpu":
157
+ self.to("cpu", silence_dtype_warnings=True)
158
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
159
+
160
+ hook = None
161
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
162
+ _, hook = cpu_offload_with_hook(
163
+ cpu_offloaded_model, device, prev_module_hook=hook
164
+ )
165
+
166
+ # We'll offload the last model manually.
167
+ self.final_offload_hook = hook
168
+
169
+ @property
170
+ def _execution_device(self):
171
+ r"""
172
+ Returns the device on which the pipeline's models will be executed. After calling
173
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
174
+ hooks.
175
+ """
176
+ if not hasattr(self.unet, "_hf_hook"):
177
+ return self.device
178
+ for module in self.unet.modules():
179
+ if (
180
+ hasattr(module, "_hf_hook")
181
+ and hasattr(module._hf_hook, "execution_device")
182
+ and module._hf_hook.execution_device is not None
183
+ ):
184
+ return torch.device(module._hf_hook.execution_device)
185
+ return self.device
186
+
187
+ def _encode_prompt(
188
+ self,
189
+ prompt,
190
+ device,
191
+ num_images_per_prompt,
192
+ do_classifier_free_guidance: bool,
193
+ negative_prompt=None,
194
+ ):
195
+ r"""
196
+ Encodes the prompt into text encoder hidden states.
197
+
198
+ Args:
199
+ prompt (`str` or `List[str]`, *optional*):
200
+ prompt to be encoded
201
+ device: (`torch.device`):
202
+ torch device
203
+ num_images_per_prompt (`int`):
204
+ number of images that should be generated per prompt
205
+ do_classifier_free_guidance (`bool`):
206
+ whether to use classifier free guidance or not
207
+ negative_prompt (`str` or `List[str]`, *optional*):
208
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
209
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
210
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
211
+ prompt_embeds (`torch.FloatTensor`, *optional*):
212
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
213
+ provided, text embeddings will be generated from `prompt` input argument.
214
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
215
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
216
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
217
+ argument.
218
+ """
219
+ if prompt is not None and isinstance(prompt, str):
220
+ batch_size = 1
221
+ elif prompt is not None and isinstance(prompt, list):
222
+ batch_size = len(prompt)
223
+ else:
224
+ raise ValueError(
225
+ f"`prompt` should be either a string or a list of strings, but got {type(prompt)}."
226
+ )
227
+
228
+ text_inputs = self.tokenizer(
229
+ prompt,
230
+ padding="max_length",
231
+ max_length=self.tokenizer.model_max_length,
232
+ truncation=True,
233
+ return_tensors="pt",
234
+ )
235
+ text_input_ids = text_inputs.input_ids
236
+ untruncated_ids = self.tokenizer(
237
+ prompt, padding="longest", return_tensors="pt"
238
+ ).input_ids
239
+
240
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
241
+ text_input_ids, untruncated_ids
242
+ ):
243
+ removed_text = self.tokenizer.batch_decode(
244
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
245
+ )
246
+ logger.warning(
247
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
248
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
249
+ )
250
+
251
+ if (
252
+ hasattr(self.text_encoder.config, "use_attention_mask")
253
+ and self.text_encoder.config.use_attention_mask
254
+ ):
255
+ attention_mask = text_inputs.attention_mask.to(device)
256
+ else:
257
+ attention_mask = None
258
+
259
+ prompt_embeds = self.text_encoder(
260
+ text_input_ids.to(device),
261
+ attention_mask=attention_mask,
262
+ )
263
+ prompt_embeds = prompt_embeds[0]
264
+
265
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
266
+
267
+ bs_embed, seq_len, _ = prompt_embeds.shape
268
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
269
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
270
+ prompt_embeds = prompt_embeds.view(
271
+ bs_embed * num_images_per_prompt, seq_len, -1
272
+ )
273
+
274
+ # get unconditional embeddings for classifier free guidance
275
+ if do_classifier_free_guidance:
276
+ uncond_tokens: List[str]
277
+ if negative_prompt is None:
278
+ uncond_tokens = [""] * batch_size
279
+ elif type(prompt) is not type(negative_prompt):
280
+ raise TypeError(
281
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
282
+ f" {type(prompt)}."
283
+ )
284
+ elif isinstance(negative_prompt, str):
285
+ uncond_tokens = [negative_prompt]
286
+ elif batch_size != len(negative_prompt):
287
+ raise ValueError(
288
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
289
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
290
+ " the batch size of `prompt`."
291
+ )
292
+ else:
293
+ uncond_tokens = negative_prompt
294
+
295
+ max_length = prompt_embeds.shape[1]
296
+ uncond_input = self.tokenizer(
297
+ uncond_tokens,
298
+ padding="max_length",
299
+ max_length=max_length,
300
+ truncation=True,
301
+ return_tensors="pt",
302
+ )
303
+
304
+ if (
305
+ hasattr(self.text_encoder.config, "use_attention_mask")
306
+ and self.text_encoder.config.use_attention_mask
307
+ ):
308
+ attention_mask = uncond_input.attention_mask.to(device)
309
+ else:
310
+ attention_mask = None
311
+
312
+ negative_prompt_embeds = self.text_encoder(
313
+ uncond_input.input_ids.to(device),
314
+ attention_mask=attention_mask,
315
+ )
316
+ negative_prompt_embeds = negative_prompt_embeds[0]
317
+
318
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
319
+ seq_len = negative_prompt_embeds.shape[1]
320
+
321
+ negative_prompt_embeds = negative_prompt_embeds.to(
322
+ dtype=self.text_encoder.dtype, device=device
323
+ )
324
+
325
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
326
+ 1, num_images_per_prompt, 1
327
+ )
328
+ negative_prompt_embeds = negative_prompt_embeds.view(
329
+ batch_size * num_images_per_prompt, seq_len, -1
330
+ )
331
+
332
+ # For classifier free guidance, we need to do two forward passes.
333
+ # Here we concatenate the unconditional and text embeddings into a single batch
334
+ # to avoid doing two forward passes
335
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
336
+
337
+ return prompt_embeds
338
+
339
+ def decode_latents(self, latents):
340
+ latents = 1 / self.vae.config.scaling_factor * latents
341
+ image = self.vae.decode(latents).sample
342
+ image = (image / 2 + 0.5).clamp(0, 1)
343
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
344
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
345
+ return image
346
+
347
+ def prepare_extra_step_kwargs(self, generator, eta):
348
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
349
+ # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
350
+ # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
351
+ # and should be between [0, 1]
352
+
353
+ accepts_eta = "eta" in set(
354
+ inspect.signature(self.scheduler.step).parameters.keys()
355
+ )
356
+ extra_step_kwargs = {}
357
+ if accepts_eta:
358
+ extra_step_kwargs["eta"] = eta
359
+
360
+ # check if the scheduler accepts generator
361
+ accepts_generator = "generator" in set(
362
+ inspect.signature(self.scheduler.step).parameters.keys()
363
+ )
364
+ if accepts_generator:
365
+ extra_step_kwargs["generator"] = generator
366
+ return extra_step_kwargs
367
+
368
+ def prepare_latents(
369
+ self,
370
+ batch_size,
371
+ num_channels_latents,
372
+ height,
373
+ width,
374
+ dtype,
375
+ device,
376
+ generator,
377
+ latents=None,
378
+ ):
379
+ shape = (
380
+ batch_size,
381
+ num_channels_latents,
382
+ height // self.vae_scale_factor,
383
+ width // self.vae_scale_factor,
384
+ )
385
+ if isinstance(generator, list) and len(generator) != batch_size:
386
+ raise ValueError(
387
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
388
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
389
+ )
390
+
391
+ if latents is None:
392
+ latents = randn_tensor(
393
+ shape, generator=generator, device=device, dtype=dtype
394
+ )
395
+ else:
396
+ latents = latents.to(device)
397
+
398
+ # scale the initial noise by the standard deviation required by the scheduler
399
+ latents = latents * self.scheduler.init_noise_sigma
400
+ return latents
401
+
402
+ def encode_image(self, image, device, num_images_per_prompt):
403
+ dtype = next(self.image_encoder.parameters()).dtype
404
+
405
+ if image.dtype == np.float32:
406
+ image = (image * 255).astype(np.uint8)
407
+
408
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
409
+ image = image.to(device=device, dtype=dtype)
410
+
411
+ image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
412
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
413
+
414
+ return torch.zeros_like(image_embeds), image_embeds
415
+
416
+ def encode_image_latents(self, image, device, num_images_per_prompt):
417
+
418
+ dtype = next(self.image_encoder.parameters()).dtype
419
+
420
+ image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device) # [1, 3, H, W]
421
+ image = 2 * image - 1
422
+ image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)
423
+ image = image.to(dtype=dtype)
424
+
425
+ posterior = self.vae.encode(image).latent_dist
426
+ latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
427
+ latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
428
+
429
+ return torch.zeros_like(latents), latents
430
+
431
+ @torch.no_grad()
432
+ def __call__(
433
+ self,
434
+ prompt: str = "",
435
+ image: Optional[np.ndarray] = None,
436
+ height: int = 256,
437
+ width: int = 256,
438
+ elevation: float = 0,
439
+ num_inference_steps: int = 50,
440
+ guidance_scale: float = 7.0,
441
+ negative_prompt: str = "",
442
+ num_images_per_prompt: int = 1,
443
+ eta: float = 0.0,
444
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
445
+ output_type: Optional[str] = "numpy", # pil, numpy, latents
446
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
447
+ callback_steps: int = 1,
448
+ num_frames: int = 4,
449
+ device=torch.device("cuda:0"),
450
+ ):
451
+ self.unet = self.unet.to(device=device)
452
+ self.vae = self.vae.to(device=device)
453
+ self.text_encoder = self.text_encoder.to(device=device)
454
+
455
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
456
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
457
+ # corresponds to doing no classifier free guidance.
458
+ do_classifier_free_guidance = guidance_scale > 1.0
459
+
460
+ # Prepare timesteps
461
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
462
+ timesteps = self.scheduler.timesteps
463
+
464
+ # imagedream variant
465
+ if image is not None:
466
+ assert isinstance(image, np.ndarray) and image.dtype == np.float32
467
+ self.image_encoder = self.image_encoder.to(device=device)
468
+ image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt)
469
+ image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt)
470
+
471
+ _prompt_embeds = self._encode_prompt(
472
+ prompt=prompt,
473
+ device=device,
474
+ num_images_per_prompt=num_images_per_prompt,
475
+ do_classifier_free_guidance=do_classifier_free_guidance,
476
+ negative_prompt=negative_prompt,
477
+ ) # type: ignore
478
+ prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
479
+
480
+ # Prepare latent variables
481
+ actual_num_frames = num_frames if image is None else num_frames + 1
482
+ latents: torch.Tensor = self.prepare_latents(
483
+ actual_num_frames * num_images_per_prompt,
484
+ 4,
485
+ height,
486
+ width,
487
+ prompt_embeds_pos.dtype,
488
+ device,
489
+ generator,
490
+ None,
491
+ )
492
+
493
+ # Get camera
494
+ camera = get_camera(num_frames, elevation=elevation, extra_view=(image is not None)).to(dtype=latents.dtype, device=device)
495
+ camera = camera.repeat_interleave(num_images_per_prompt, dim=0)
496
+
497
+ # Prepare extra step kwargs.
498
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
499
+
500
+ # Denoising loop
501
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
502
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
503
+ for i, t in enumerate(timesteps):
504
+ # expand the latents if we are doing classifier free guidance
505
+ multiplier = 2 if do_classifier_free_guidance else 1
506
+ latent_model_input = torch.cat([latents] * multiplier)
507
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
508
+
509
+ unet_inputs = {
510
+ 'x': latent_model_input,
511
+ 'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device),
512
+ 'context': torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames),
513
+ 'num_frames': actual_num_frames,
514
+ 'camera': torch.cat([camera] * multiplier),
515
+ }
516
+
517
+ if image is not None:
518
+ unet_inputs['ip'] = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames)
519
+ unet_inputs['ip_img'] = torch.cat([image_latents_neg] + [image_latents_pos]) # no repeat
520
+
521
+ # predict the noise residual
522
+ noise_pred = self.unet.forward(**unet_inputs)
523
+
524
+ # perform guidance
525
+ if do_classifier_free_guidance:
526
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
527
+ noise_pred = noise_pred_uncond + guidance_scale * (
528
+ noise_pred_text - noise_pred_uncond
529
+ )
530
+
531
+ # compute the previous noisy sample x_t -> x_t-1
532
+ latents: torch.Tensor = self.scheduler.step(
533
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
534
+ )[0]
535
+
536
+ # call the callback, if provided
537
+ if i == len(timesteps) - 1 or (
538
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
539
+ ):
540
+ progress_bar.update()
541
+ if callback is not None and i % callback_steps == 0:
542
+ callback(i, t, latents) # type: ignore
543
+
544
+ # Post-processing
545
+ if output_type == "latent":
546
+ image = latents
547
+ elif output_type == "pil":
548
+ image = self.decode_latents(latents)
549
+ image = self.numpy_to_pil(image)
550
+ else: # numpy
551
+ image = self.decode_latents(latents)
552
+
553
+ # Offload last model to CPU
554
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
555
+ self.final_offload_hook.offload()
556
+
557
+ return image
apps/third_party/LGM/requirements.lock.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ omegaconf == 2.3.0
2
+ diffusers == 0.23.1
3
+ safetensors == 0.4.1
4
+ huggingface_hub == 0.19.4
5
+ transformers == 4.35.2
6
+ accelerate == 0.25.0.dev0
7
+ kiui == 0.2.0
apps/third_party/LGM/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ omegaconf
2
+ diffusers
3
+ safetensors
4
+ huggingface_hub
5
+ transformers
6
+ accelerate
7
+ kiui
8
+ einops
9
+ rich
apps/third_party/LGM/run_imagedream.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import kiui
3
+ import numpy as np
4
+ import argparse
5
+ from pipeline_mvdream import MVDreamPipeline
6
+ import ipdb
7
+ pipe = MVDreamPipeline.from_pretrained(
8
+ # "./weights_imagedream", # local weights
9
+ "/mnt/cfs/home/liweiyu/codes/3DNativeGeneration/ckpts/pretrained_weights/huggingface/hub/models--ashawkey--imagedream-ipmv-diffusers/snapshots/73a034178e748421506492e91790cc62d6aefef5", # remote weights
10
+ torch_dtype=torch.float16,
11
+ trust_remote_code=True,
12
+ )
13
+ pipe = pipe.to("cuda")
14
+
15
+
16
+ parser = argparse.ArgumentParser(description="ImageDream")
17
+ parser.add_argument("image", type=str, default='data/anya_rgba.png')
18
+ parser.add_argument("--prompt", type=str, default="")
19
+ args = parser.parse_args()
20
+
21
+ for i in range(5):
22
+ input_image = kiui.read_image(args.image, mode='float')
23
+ image = pipe(args.prompt, input_image, guidance_scale=5, num_inference_steps=30, elevation=0)
24
+ ipdb.set_trace()
25
+ # print(image)
26
+ grid = np.concatenate(
27
+ [
28
+ np.concatenate([image[0], image[2]], axis=0),
29
+ np.concatenate([image[1], image[3]], axis=0),
30
+ ],
31
+ axis=1,
32
+ )
33
+ # kiui.vis.plot_image(grid)
34
+ kiui.write_image(f'test_imagedream_{i}.jpg', grid)
apps/third_party/LGM/run_mvdream.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import kiui
3
+ import numpy as np
4
+ import argparse
5
+ from pipeline_mvdream import MVDreamPipeline
6
+
7
+ import ipdb
8
+ pipe = MVDreamPipeline.from_pretrained(
9
+ # "./weights_mvdream", # local weights
10
+ '/mnt/cfs/home/liweiyu/codes/3DNativeGeneration/ckpts/pretrained_weights/huggingface/hub/models--ashawkey--mvdream-sd2.1-diffusers/snapshots/503bb19fc2b2bc542c2afdb7d73ac87a7cbc2253', # remote weights
11
+ torch_dtype=torch.float16,
12
+ # trust_remote_code=True,
13
+ )
14
+
15
+ pipe = pipe.to("cuda")
16
+
17
+
18
+ parser = argparse.ArgumentParser(description="MVDream")
19
+ parser.add_argument("prompt", type=str, default="a cute owl 3d model")
20
+ args = parser.parse_args()
21
+
22
+ for i in range(5):
23
+ image = pipe(args.prompt, guidance_scale=5, num_inference_steps=30, elevation=0)
24
+ ipdb.set_trace()
25
+ grid = np.concatenate(
26
+ [
27
+ np.concatenate([image[0], image[2]], axis=0),
28
+ np.concatenate([image[1], image[3]], axis=0),
29
+ ],
30
+ axis=1,
31
+ )
32
+ # kiui.vis.plot_image(grid)
33
+ kiui.write_image(f'test_mvdream_{i}.jpg', grid)