ashawkey commited on
Commit
af95a32
1 Parent(s): 8241d5f

draft imagedream, before merge

Browse files
README.md CHANGED
@@ -8,11 +8,22 @@ modified from https://github.com/KokeCacao/mvdream-hf.
8
  pip install -U omegaconf diffusers safetensors huggingface_hub transformers accelerate
9
 
10
  # download original ckpt
 
11
  wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt
12
  wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml
 
13
 
14
  # convert
15
- python convert_mvdream_to_diffusers.py --checkpoint_path ./sd-v2.1-base-4view.pt --dump_path ./weights --original_config_file ./sd-v2-base.yaml --half --to_safetensors --test
 
 
 
 
 
 
 
 
 
16
  ```
17
 
18
  ### usage
 
8
  pip install -U omegaconf diffusers safetensors huggingface_hub transformers accelerate
9
 
10
  # download original ckpt
11
+ cd models
12
  wget https://huggingface.co/MVDream/MVDream/resolve/main/sd-v2.1-base-4view.pt
13
  wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd-v2-base.yaml
14
+ cd ..
15
 
16
  # convert
17
+ python convert_mvdream_to_diffusers.py --checkpoint_path models/sd-v2.1-base-4view.pt --dump_path ./weights_mvdream --original_config_file models/sd-v2-base.yaml --half --to_safetensors --test
18
+ ```
19
+
20
+ ```bash
21
+ # download original ckpt
22
+ wget https://huggingface.co/Peng-Wang/ImageDream/resolve/main/sd-v2.1-base-4view-ipmv-local.pt
23
+ wget https://raw.githubusercontent.com/bytedance/ImageDream/main/extern/ImageDream/imagedream/configs/sd_v2_base_ipmv_local.yaml
24
+
25
+ # convert
26
+ python convert_imagedream_to_diffusers.py --checkpoint_path models/sd-v2.1-base-4view-ipmv-local.pt --dump_path ./weights_imagedream --original_config_file models/sd-v2-base_ipmv_local.yaml --half --to_safetensors --test
27
  ```
28
 
29
  ### usage
convert_imagedream_to_diffusers.py ADDED
@@ -0,0 +1,561 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from https://github.com/huggingface/diffusers/blob/bc691231360a4cbc7d19a58742ebb8ed0f05e027/scripts/convert_original_stable_diffusion_to_diffusers.py
2
+
3
+ import argparse
4
+ import torch
5
+ import sys
6
+
7
+ sys.path.insert(0, ".")
8
+
9
+ from diffusers.models import (
10
+ AutoencoderKL,
11
+ )
12
+ from omegaconf import OmegaConf
13
+ from diffusers.schedulers import DDIMScheduler
14
+ from diffusers.utils import logging
15
+ from typing import Any
16
+ from accelerate import init_empty_weights
17
+ from accelerate.utils import set_module_tensor_to_device
18
+ from imagedream.models import MultiViewUNetModel
19
+ from imagedream.pipeline_imagedream import ImageDreamPipeline
20
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPFeatureExtractor
21
+
22
+ logger = logging.get_logger(__name__)
23
+
24
+
25
+ def assign_to_checkpoint(
26
+ paths,
27
+ checkpoint,
28
+ old_checkpoint,
29
+ attention_paths_to_split=None,
30
+ additional_replacements=None,
31
+ config=None,
32
+ ):
33
+ """
34
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
35
+ attention layers, and takes into account additional replacements that may arise.
36
+ Assigns the weights to the new checkpoint.
37
+ """
38
+ assert isinstance(
39
+ paths, list
40
+ ), "Paths should be a list of dicts containing 'old' and 'new' keys."
41
+
42
+ # Splits the attention layers into three variables.
43
+ if attention_paths_to_split is not None:
44
+ for path, path_map in attention_paths_to_split.items():
45
+ old_tensor = old_checkpoint[path]
46
+ channels = old_tensor.shape[0] // 3
47
+
48
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
49
+
50
+ assert config is not None
51
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
52
+
53
+ old_tensor = old_tensor.reshape(
54
+ (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
55
+ )
56
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
57
+
58
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
59
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
60
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
61
+
62
+ for path in paths:
63
+ new_path = path["new"]
64
+
65
+ # These have already been assigned
66
+ if (
67
+ attention_paths_to_split is not None
68
+ and new_path in attention_paths_to_split
69
+ ):
70
+ continue
71
+
72
+ # Global renaming happens here
73
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
74
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
75
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
76
+
77
+ if additional_replacements is not None:
78
+ for replacement in additional_replacements:
79
+ new_path = new_path.replace(replacement["old"], replacement["new"])
80
+
81
+ # proj_attn.weight has to be converted from conv 1D to linear
82
+ is_attn_weight = "proj_attn.weight" in new_path or (
83
+ "attentions" in new_path and "to_" in new_path
84
+ )
85
+ shape = old_checkpoint[path["old"]].shape
86
+ if is_attn_weight and len(shape) == 3:
87
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
88
+ elif is_attn_weight and len(shape) == 4:
89
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
90
+ else:
91
+ checkpoint[new_path] = old_checkpoint[path["old"]]
92
+
93
+
94
+ def shave_segments(path, n_shave_prefix_segments=1):
95
+ """
96
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
97
+ """
98
+ if n_shave_prefix_segments >= 0:
99
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
100
+ else:
101
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
102
+
103
+
104
+ def create_vae_diffusers_config(original_config, image_size: int):
105
+ """
106
+ Creates a config for the diffusers based on the config of the LDM model.
107
+ """
108
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
109
+ _ = original_config.model.params.first_stage_config.params.embed_dim
110
+
111
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
112
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
113
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
114
+
115
+ config = {
116
+ "sample_size": image_size,
117
+ "in_channels": vae_params.in_channels,
118
+ "out_channels": vae_params.out_ch,
119
+ "down_block_types": tuple(down_block_types),
120
+ "up_block_types": tuple(up_block_types),
121
+ "block_out_channels": tuple(block_out_channels),
122
+ "latent_channels": vae_params.z_channels,
123
+ "layers_per_block": vae_params.num_res_blocks,
124
+ }
125
+ return config
126
+
127
+
128
+ def convert_ldm_vae_checkpoint(checkpoint, config):
129
+ # extract state dict for VAE
130
+ vae_state_dict = {}
131
+ vae_key = "first_stage_model."
132
+ keys = list(checkpoint.keys())
133
+ for key in keys:
134
+ if key.startswith(vae_key):
135
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
136
+
137
+ new_checkpoint = {}
138
+
139
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
140
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
141
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[
142
+ "encoder.conv_out.weight"
143
+ ]
144
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
145
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[
146
+ "encoder.norm_out.weight"
147
+ ]
148
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[
149
+ "encoder.norm_out.bias"
150
+ ]
151
+
152
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
153
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
154
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[
155
+ "decoder.conv_out.weight"
156
+ ]
157
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
158
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[
159
+ "decoder.norm_out.weight"
160
+ ]
161
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[
162
+ "decoder.norm_out.bias"
163
+ ]
164
+
165
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
166
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
167
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
168
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
169
+
170
+ # Retrieves the keys for the encoder down blocks only
171
+ num_down_blocks = len(
172
+ {
173
+ ".".join(layer.split(".")[:3])
174
+ for layer in vae_state_dict
175
+ if "encoder.down" in layer
176
+ }
177
+ )
178
+ down_blocks = {
179
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key]
180
+ for layer_id in range(num_down_blocks)
181
+ }
182
+
183
+ # Retrieves the keys for the decoder up blocks only
184
+ num_up_blocks = len(
185
+ {
186
+ ".".join(layer.split(".")[:3])
187
+ for layer in vae_state_dict
188
+ if "decoder.up" in layer
189
+ }
190
+ )
191
+ up_blocks = {
192
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key]
193
+ for layer_id in range(num_up_blocks)
194
+ }
195
+
196
+ for i in range(num_down_blocks):
197
+ resnets = [
198
+ key
199
+ for key in down_blocks[i]
200
+ if f"down.{i}" in key and f"down.{i}.downsample" not in key
201
+ ]
202
+
203
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
204
+ new_checkpoint[
205
+ f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"
206
+ ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
207
+ new_checkpoint[
208
+ f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"
209
+ ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
210
+
211
+ paths = renew_vae_resnet_paths(resnets)
212
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
213
+ assign_to_checkpoint(
214
+ paths,
215
+ new_checkpoint,
216
+ vae_state_dict,
217
+ additional_replacements=[meta_path],
218
+ config=config,
219
+ )
220
+
221
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
222
+ num_mid_res_blocks = 2
223
+ for i in range(1, num_mid_res_blocks + 1):
224
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
225
+
226
+ paths = renew_vae_resnet_paths(resnets)
227
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
228
+ assign_to_checkpoint(
229
+ paths,
230
+ new_checkpoint,
231
+ vae_state_dict,
232
+ additional_replacements=[meta_path],
233
+ config=config,
234
+ )
235
+
236
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
237
+ paths = renew_vae_attention_paths(mid_attentions)
238
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
239
+ assign_to_checkpoint(
240
+ paths,
241
+ new_checkpoint,
242
+ vae_state_dict,
243
+ additional_replacements=[meta_path],
244
+ config=config,
245
+ )
246
+ conv_attn_to_linear(new_checkpoint)
247
+
248
+ for i in range(num_up_blocks):
249
+ block_id = num_up_blocks - 1 - i
250
+ resnets = [
251
+ key
252
+ for key in up_blocks[block_id]
253
+ if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
254
+ ]
255
+
256
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
257
+ new_checkpoint[
258
+ f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"
259
+ ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
260
+ new_checkpoint[
261
+ f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"
262
+ ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
263
+
264
+ paths = renew_vae_resnet_paths(resnets)
265
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
266
+ assign_to_checkpoint(
267
+ paths,
268
+ new_checkpoint,
269
+ vae_state_dict,
270
+ additional_replacements=[meta_path],
271
+ config=config,
272
+ )
273
+
274
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
275
+ num_mid_res_blocks = 2
276
+ for i in range(1, num_mid_res_blocks + 1):
277
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
278
+
279
+ paths = renew_vae_resnet_paths(resnets)
280
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
281
+ assign_to_checkpoint(
282
+ paths,
283
+ new_checkpoint,
284
+ vae_state_dict,
285
+ additional_replacements=[meta_path],
286
+ config=config,
287
+ )
288
+
289
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
290
+ paths = renew_vae_attention_paths(mid_attentions)
291
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
292
+ assign_to_checkpoint(
293
+ paths,
294
+ new_checkpoint,
295
+ vae_state_dict,
296
+ additional_replacements=[meta_path],
297
+ config=config,
298
+ )
299
+ conv_attn_to_linear(new_checkpoint)
300
+ return new_checkpoint
301
+
302
+
303
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
304
+ """
305
+ Updates paths inside resnets to the new naming scheme (local renaming)
306
+ """
307
+ mapping = []
308
+ for old_item in old_list:
309
+ new_item = old_item
310
+
311
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
312
+ new_item = shave_segments(
313
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments
314
+ )
315
+
316
+ mapping.append({"old": old_item, "new": new_item})
317
+
318
+ return mapping
319
+
320
+
321
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
322
+ """
323
+ Updates paths inside attentions to the new naming scheme (local renaming)
324
+ """
325
+ mapping = []
326
+ for old_item in old_list:
327
+ new_item = old_item
328
+
329
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
330
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
331
+
332
+ new_item = new_item.replace("q.weight", "to_q.weight")
333
+ new_item = new_item.replace("q.bias", "to_q.bias")
334
+
335
+ new_item = new_item.replace("k.weight", "to_k.weight")
336
+ new_item = new_item.replace("k.bias", "to_k.bias")
337
+
338
+ new_item = new_item.replace("v.weight", "to_v.weight")
339
+ new_item = new_item.replace("v.bias", "to_v.bias")
340
+
341
+ new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
342
+ new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
343
+
344
+ new_item = shave_segments(
345
+ new_item, n_shave_prefix_segments=n_shave_prefix_segments
346
+ )
347
+
348
+ mapping.append({"old": old_item, "new": new_item})
349
+
350
+ return mapping
351
+
352
+
353
+ def conv_attn_to_linear(checkpoint):
354
+ keys = list(checkpoint.keys())
355
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
356
+ for key in keys:
357
+ if ".".join(key.split(".")[-2:]) in attn_keys:
358
+ if checkpoint[key].ndim > 2:
359
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
360
+ elif "proj_attn.weight" in key:
361
+ if checkpoint[key].ndim > 2:
362
+ checkpoint[key] = checkpoint[key][:, :, 0]
363
+
364
+
365
+ def create_unet_config(original_config) -> Any:
366
+ return OmegaConf.to_container(
367
+ original_config.model.params.unet_config.params, resolve=True
368
+ )
369
+
370
+
371
+ def convert_from_original_imagedream_ckpt(checkpoint_path, original_config_file, device):
372
+ checkpoint = torch.load(checkpoint_path, map_location=device)
373
+ # print(f"Checkpoint: {checkpoint.keys()}")
374
+ torch.cuda.empty_cache()
375
+
376
+ original_config = OmegaConf.load(original_config_file)
377
+ # print(f"Original Config: {original_config}")
378
+ prediction_type = "epsilon"
379
+ image_size = 256
380
+ num_train_timesteps = (
381
+ getattr(original_config.model.params, "timesteps", None) or 1000
382
+ )
383
+ beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
384
+ beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
385
+ scheduler = DDIMScheduler(
386
+ beta_end=beta_end,
387
+ beta_schedule="scaled_linear",
388
+ beta_start=beta_start,
389
+ num_train_timesteps=num_train_timesteps,
390
+ steps_offset=1,
391
+ clip_sample=False,
392
+ set_alpha_to_one=False,
393
+ prediction_type=prediction_type,
394
+ )
395
+ scheduler.register_to_config(clip_sample=False)
396
+
397
+ # Convert the UNet2DConditionModel model.
398
+ # upcast_attention = None
399
+ # unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
400
+ # unet_config["upcast_attention"] = upcast_attention
401
+ # with init_empty_weights():
402
+ # unet = UNet2DConditionModel(**unet_config)
403
+ # converted_unet_checkpoint = convert_ldm_unet_checkpoint(
404
+ # checkpoint, unet_config, path=None, extract_ema=extract_ema
405
+ # )
406
+ # print(f"Unet Config: {original_config.model.params.unet_config.params}")
407
+ unet_config = create_unet_config(original_config)
408
+
409
+ # remove unused configs
410
+ del unet_config['legacy']
411
+ del unet_config['use_linear_in_transformer']
412
+ del unet_config['use_spatial_transformer']
413
+ del unet_config['ip_mode']
414
+
415
+ unet = MultiViewUNetModel(**unet_config)
416
+ unet.register_to_config(**unet_config)
417
+ # print(f"Unet State Dict: {unet.state_dict().keys()}")
418
+ unet.load_state_dict(
419
+ {
420
+ key.replace("model.diffusion_model.", ""): value
421
+ for key, value in checkpoint.items()
422
+ if key.replace("model.diffusion_model.", "") in unet.state_dict()
423
+ }
424
+ )
425
+ for param_name, param in unet.state_dict().items():
426
+ set_module_tensor_to_device(unet, param_name, device=device, value=param)
427
+
428
+ # Convert the VAE model.
429
+ vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
430
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
431
+
432
+ if (
433
+ "model" in original_config
434
+ and "params" in original_config.model
435
+ and "scale_factor" in original_config.model.params
436
+ ):
437
+ vae_scaling_factor = original_config.model.params.scale_factor
438
+ else:
439
+ vae_scaling_factor = 0.18215 # default SD scaling factor
440
+
441
+ vae_config["scaling_factor"] = vae_scaling_factor
442
+
443
+ with init_empty_weights():
444
+ vae = AutoencoderKL(**vae_config)
445
+
446
+ for param_name, param in converted_vae_checkpoint.items():
447
+ set_module_tensor_to_device(vae, param_name, device=device, value=param)
448
+
449
+
450
+ tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="tokenizer")
451
+ text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder").to(device=device) # type: ignore
452
+
453
+ # this is the clip used by sd2.1
454
+ feature_extractor: CLIPFeatureExtractor = CLIPFeatureExtractor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
455
+ image_encoder: CLIPVisionModel = CLIPVisionModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
456
+
457
+ pipe = ImageDreamPipeline(
458
+ vae=vae,
459
+ unet=unet,
460
+ tokenizer=tokenizer,
461
+ text_encoder=text_encoder,
462
+ scheduler=scheduler,
463
+ feature_extractor=feature_extractor,
464
+ image_encoder=image_encoder,
465
+ )
466
+
467
+ return pipe
468
+
469
+
470
+ if __name__ == "__main__":
471
+ parser = argparse.ArgumentParser()
472
+
473
+ parser.add_argument(
474
+ "--checkpoint_path",
475
+ default=None,
476
+ type=str,
477
+ required=True,
478
+ help="Path to the checkpoint to convert.",
479
+ )
480
+ parser.add_argument(
481
+ "--original_config_file",
482
+ default=None,
483
+ type=str,
484
+ help="The YAML config file corresponding to the original architecture.",
485
+ )
486
+ parser.add_argument(
487
+ "--to_safetensors",
488
+ action="store_true",
489
+ help="Whether to store pipeline in safetensors format or not.",
490
+ )
491
+ parser.add_argument(
492
+ "--half", action="store_true", help="Save weights in half precision."
493
+ )
494
+ parser.add_argument(
495
+ "--test",
496
+ action="store_true",
497
+ help="Whether to test inference after convertion.",
498
+ )
499
+ parser.add_argument(
500
+ "--dump_path",
501
+ default=None,
502
+ type=str,
503
+ required=True,
504
+ help="Path to the output model.",
505
+ )
506
+ parser.add_argument(
507
+ "--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)"
508
+ )
509
+ args = parser.parse_args()
510
+
511
+ args.device = torch.device(
512
+ args.device
513
+ if args.device is not None
514
+ else "cuda"
515
+ if torch.cuda.is_available()
516
+ else "cpu"
517
+ )
518
+
519
+ pipe = convert_from_original_imagedream_ckpt(
520
+ checkpoint_path=args.checkpoint_path,
521
+ original_config_file=args.original_config_file,
522
+ device=args.device,
523
+ )
524
+
525
+ if args.half:
526
+ pipe.to(torch_dtype=torch.float16)
527
+
528
+ print(f"Saving pipeline to {args.dump_path}...")
529
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
530
+
531
+ # TODO: input image...
532
+ if args.test:
533
+ try:
534
+ print(f"Testing each subcomponent of the pipeline...")
535
+ images = pipe(
536
+ prompt="Head of Hatsune Miku",
537
+ negative_prompt="painting, bad quality, flat",
538
+ output_type="pil",
539
+ guidance_scale=7.5,
540
+ num_inference_steps=50,
541
+ device=args.device,
542
+ )
543
+ for i, image in enumerate(images):
544
+ image.save(f"image_{i}.png") # type: ignore
545
+
546
+ print(f"Testing entire pipeline...")
547
+ loaded_pipe = ImageDreamPipeline.from_pretrained(args.dump_path, safe_serialization=args.to_safetensors) # type: ignore
548
+ images = loaded_pipe(
549
+ prompt="Head of Hatsune Miku",
550
+ negative_prompt="painting, bad quality, flat",
551
+ output_type="pil",
552
+ guidance_scale=7.5,
553
+ num_inference_steps=50,
554
+ device=args.device,
555
+ )
556
+ for i, image in enumerate(images):
557
+ image.save(f"image_{i}.png") # type: ignore
558
+ except Exception as e:
559
+ print(f"Failed to test inference: {e}")
560
+ raise e from e
561
+ print("Inference test passed!")
imagedream/adaptor.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ # FFN
6
+ def FeedForward(dim, mult=4):
7
+ inner_dim = int(dim * mult)
8
+ return nn.Sequential(
9
+ nn.LayerNorm(dim),
10
+ nn.Linear(dim, inner_dim, bias=False),
11
+ nn.GELU(),
12
+ nn.Linear(inner_dim, dim, bias=False),
13
+ )
14
+
15
+
16
+ def reshape_tensor(x, heads):
17
+ bs, length, width = x.shape
18
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
19
+ x = x.view(bs, length, heads, -1)
20
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
21
+ x = x.transpose(1, 2)
22
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
23
+ x = x.reshape(bs, heads, length, -1)
24
+ return x
25
+
26
+
27
+ class PerceiverAttention(nn.Module):
28
+ def __init__(self, *, dim, dim_head=64, heads=8):
29
+ super().__init__()
30
+ self.scale = dim_head ** -0.5
31
+ self.dim_head = dim_head
32
+ self.heads = heads
33
+ inner_dim = dim_head * heads
34
+
35
+ self.norm1 = nn.LayerNorm(dim)
36
+ self.norm2 = nn.LayerNorm(dim)
37
+
38
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
39
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
40
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
41
+
42
+ def forward(self, x, latents):
43
+ """
44
+ Args:
45
+ x (torch.Tensor): image features
46
+ shape (b, n1, D)
47
+ latent (torch.Tensor): latent features
48
+ shape (b, n2, D)
49
+ """
50
+ x = self.norm1(x)
51
+ latents = self.norm2(latents)
52
+
53
+ b, l, _ = latents.shape
54
+
55
+ q = self.to_q(latents)
56
+ kv_input = torch.cat((x, latents), dim=-2)
57
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
58
+
59
+ q = reshape_tensor(q, self.heads)
60
+ k = reshape_tensor(k, self.heads)
61
+ v = reshape_tensor(v, self.heads)
62
+
63
+ # attention
64
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
65
+ weight = (q * scale) @ (k * scale).transpose(
66
+ -2, -1
67
+ ) # More stable with f16 than dividing afterwards
68
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
69
+ out = weight @ v
70
+
71
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
72
+
73
+ return self.to_out(out)
74
+
75
+
76
+ class ImageProjModel(torch.nn.Module):
77
+ """Projection Model"""
78
+
79
+ def __init__(
80
+ self,
81
+ cross_attention_dim=1024,
82
+ clip_embeddings_dim=1024,
83
+ clip_extra_context_tokens=4,
84
+ ):
85
+ super().__init__()
86
+ self.cross_attention_dim = cross_attention_dim
87
+ self.clip_extra_context_tokens = clip_extra_context_tokens
88
+
89
+ # from 1024 -> 4 * 1024
90
+ self.proj = torch.nn.Linear(
91
+ clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
92
+ )
93
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
94
+
95
+ def forward(self, image_embeds):
96
+ embeds = image_embeds
97
+ clip_extra_context_tokens = self.proj(embeds).reshape(
98
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
99
+ )
100
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
101
+ return clip_extra_context_tokens
102
+
103
+
104
+ class Resampler(nn.Module):
105
+ def __init__(
106
+ self,
107
+ dim=1024,
108
+ depth=8,
109
+ dim_head=64,
110
+ heads=16,
111
+ num_queries=8,
112
+ embedding_dim=768,
113
+ output_dim=1024,
114
+ ff_mult=4,
115
+ ):
116
+ super().__init__()
117
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
118
+ self.proj_in = nn.Linear(embedding_dim, dim)
119
+ self.proj_out = nn.Linear(dim, output_dim)
120
+ self.norm_out = nn.LayerNorm(output_dim)
121
+
122
+ self.layers = nn.ModuleList([])
123
+ for _ in range(depth):
124
+ self.layers.append(
125
+ nn.ModuleList(
126
+ [
127
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
128
+ FeedForward(dim=dim, mult=ff_mult),
129
+ ]
130
+ )
131
+ )
132
+
133
+ def forward(self, x):
134
+ latents = self.latents.repeat(x.size(0), 1, 1)
135
+ x = self.proj_in(x)
136
+ for attn, ff in self.layers:
137
+ latents = attn(x, latents) + latents
138
+ latents = ff(latents) + latents
139
+
140
+ latents = self.proj_out(latents)
141
+ return self.norm_out(latents)
imagedream/attention.py CHANGED
@@ -1,26 +1,16 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
- from torch.amp.autocast_mode import autocast
5
 
6
  from inspect import isfunction
7
  from einops import rearrange, repeat
8
  from typing import Optional, Any
9
- from .util import checkpoint, zero_module
10
-
11
- try:
12
- import xformers # type: ignore
13
- import xformers.ops # type: ignore
14
- XFORMERS_IS_AVAILBLE = True
15
- except:
16
- print(f'[WARN] xformers is unavailable!')
17
- XFORMERS_IS_AVAILBLE = False
18
 
19
- # CrossAttn precision handling
20
- import os
21
-
22
- _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
23
 
 
24
 
25
  def default(val, d):
26
  if val is not None:
@@ -57,68 +47,35 @@ class FeedForward(nn.Module):
57
  return self.net(x)
58
 
59
 
60
- class CrossAttention(nn.Module):
61
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
62
- super().__init__()
63
- inner_dim = dim_head * heads
64
- context_dim = default(context_dim, query_dim)
65
-
66
- self.scale = dim_head**-0.5
67
- self.heads = heads
68
-
69
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
70
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
71
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
72
-
73
- self.to_out = nn.Sequential(
74
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
75
- )
76
-
77
- def forward(self, x, context=None, mask=None):
78
- h = self.heads
79
-
80
- q = self.to_q(x)
81
- context = default(context, x)
82
- k = self.to_k(context)
83
- v = self.to_v(context)
84
-
85
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
86
-
87
- # force cast to fp32 to avoid overflowing
88
- if _ATTN_PRECISION == "fp32":
89
- with autocast(enabled=False, device_type="cuda"):
90
- q, k = q.float(), k.float()
91
- sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
92
- else:
93
- sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
94
-
95
- del q, k
96
-
97
- if mask is not None:
98
- mask = rearrange(mask, "b ... -> b (...)")
99
- max_neg_value = -torch.finfo(sim.dtype).max
100
- mask = repeat(mask, "b j -> (b h) () j", h=h)
101
- sim.masked_fill_(~mask, max_neg_value)
102
-
103
- # attention, what we cannot get enough of
104
- sim = sim.softmax(dim=-1)
105
-
106
- out = torch.einsum("b i j, b j d -> b i d", sim, v)
107
- out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
108
- return self.to_out(out)
109
-
110
-
111
  class MemoryEfficientCrossAttention(nn.Module):
112
  # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
113
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
 
 
 
 
 
 
 
 
 
 
114
  super().__init__()
115
- # print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using {heads} heads.")
116
  inner_dim = dim_head * heads
117
  context_dim = default(context_dim, query_dim)
118
 
119
  self.heads = heads
120
  self.dim_head = dim_head
121
 
 
 
 
 
 
 
 
 
122
  self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
123
  self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
124
  self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
@@ -128,9 +85,18 @@ class MemoryEfficientCrossAttention(nn.Module):
128
  )
129
  self.attention_op: Optional[Any] = None
130
 
131
- def forward(self, x, context=None, mask=None):
132
  q = self.to_q(x)
133
  context = default(context, x)
 
 
 
 
 
 
 
 
 
134
  k = self.to_k(context)
135
  v = self.to_v(context)
136
 
@@ -149,8 +115,21 @@ class MemoryEfficientCrossAttention(nn.Module):
149
  q, k, v, attn_bias=None, op=self.attention_op
150
  )
151
 
152
- if mask is not None:
153
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  out = (
155
  out.unsqueeze(0)
156
  .reshape(b, self.heads, out.shape[1], self.dim_head)
@@ -160,148 +139,47 @@ class MemoryEfficientCrossAttention(nn.Module):
160
  return self.to_out(out)
161
 
162
 
163
- class BasicTransformerBlock(nn.Module):
164
- ATTENTION_MODES = {
165
- "softmax": CrossAttention,
166
- "softmax-xformers": MemoryEfficientCrossAttention,
167
- } # vanilla attention
168
-
169
  def __init__(
170
  self,
171
  dim,
 
172
  n_heads,
173
  d_head,
174
  dropout=0.0,
175
- context_dim=None,
176
  gated_ff=True,
177
  checkpoint=True,
178
- disable_self_attn=False,
 
 
179
  ):
180
  super().__init__()
181
- attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
182
- assert attn_mode in self.ATTENTION_MODES
183
- attn_cls = self.ATTENTION_MODES[attn_mode]
184
- self.disable_self_attn = disable_self_attn
185
- self.attn1 = attn_cls(
186
  query_dim=dim,
 
187
  heads=n_heads,
188
  dim_head=d_head,
189
  dropout=dropout,
190
- context_dim=context_dim if self.disable_self_attn else None,
191
- ) # is a self-attention if not self.disable_self_attn
192
  self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
193
- self.attn2 = attn_cls(
194
  query_dim=dim,
195
  context_dim=context_dim,
196
  heads=n_heads,
197
  dim_head=d_head,
198
  dropout=dropout,
199
- ) # is self-attn if context is none
 
 
 
 
200
  self.norm1 = nn.LayerNorm(dim)
201
  self.norm2 = nn.LayerNorm(dim)
202
  self.norm3 = nn.LayerNorm(dim)
203
  self.checkpoint = checkpoint
204
 
205
- def forward(self, x, context=None):
206
- return checkpoint(
207
- self._forward, (x, context), self.parameters(), self.checkpoint
208
- )
209
-
210
- def _forward(self, x, context=None):
211
- x = (
212
- self.attn1(
213
- self.norm1(x), context=context if self.disable_self_attn else None
214
- )
215
- + x
216
- )
217
- x = self.attn2(self.norm2(x), context=context) + x
218
- x = self.ff(self.norm3(x)) + x
219
- return x
220
-
221
-
222
- class SpatialTransformer(nn.Module):
223
- """
224
- Transformer block for image-like data.
225
- First, project the input (aka embedding)
226
- and reshape to b, t, d.
227
- Then apply standard transformer action.
228
- Finally, reshape to image
229
- NEW: use_linear for more efficiency instead of the 1x1 convs
230
- """
231
-
232
- def __init__(
233
- self,
234
- in_channels,
235
- n_heads,
236
- d_head,
237
- depth=1,
238
- dropout=0.0,
239
- context_dim=None,
240
- disable_self_attn=False,
241
- use_linear=False,
242
- use_checkpoint=True,
243
- ):
244
- super().__init__()
245
- assert context_dim is not None
246
- if not isinstance(context_dim, list):
247
- context_dim = [context_dim]
248
- self.in_channels = in_channels
249
- inner_dim = n_heads * d_head
250
- self.norm = nn.GroupNorm(
251
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
252
- )
253
- if not use_linear:
254
- self.proj_in = nn.Conv2d(
255
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0
256
- )
257
- else:
258
- self.proj_in = nn.Linear(in_channels, inner_dim)
259
-
260
- self.transformer_blocks = nn.ModuleList(
261
- [
262
- BasicTransformerBlock(
263
- inner_dim,
264
- n_heads,
265
- d_head,
266
- dropout=dropout,
267
- context_dim=context_dim[d],
268
- disable_self_attn=disable_self_attn,
269
- checkpoint=use_checkpoint,
270
- )
271
- for d in range(depth)
272
- ]
273
- )
274
- if not use_linear:
275
- self.proj_out = zero_module(
276
- nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
277
- )
278
- else:
279
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
280
- self.use_linear = use_linear
281
-
282
- def forward(self, x, context=None):
283
- # note: if no context is given, cross-attention defaults to self-attention
284
- if not isinstance(context, list):
285
- context = [context]
286
- b, c, h, w = x.shape
287
- x_in = x
288
- x = self.norm(x)
289
- if not self.use_linear:
290
- x = self.proj_in(x)
291
- x = rearrange(x, "b c h w -> b (h w) c").contiguous()
292
- if self.use_linear:
293
- x = self.proj_in(x)
294
- for i, block in enumerate(self.transformer_blocks):
295
- x = block(x, context=context[i])
296
- if self.use_linear:
297
- x = self.proj_out(x)
298
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
299
- if not self.use_linear:
300
- x = self.proj_out(x)
301
- return x + x_in
302
-
303
-
304
- class BasicTransformerBlock3D(BasicTransformerBlock):
305
  def forward(self, x, context=None, num_frames=1):
306
  return checkpoint(
307
  self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
@@ -309,12 +187,7 @@ class BasicTransformerBlock3D(BasicTransformerBlock):
309
 
310
  def _forward(self, x, context=None, num_frames=1):
311
  x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
312
- x = (
313
- self.attn1(
314
- self.norm1(x), context=context if self.disable_self_attn else None
315
- )
316
- + x
317
- )
318
  x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
319
  x = self.attn2(self.norm2(x), context=context) + x
320
  x = self.ff(self.norm3(x)) + x
@@ -322,35 +195,32 @@ class BasicTransformerBlock3D(BasicTransformerBlock):
322
 
323
 
324
  class SpatialTransformer3D(nn.Module):
325
- """3D self-attention"""
326
 
327
  def __init__(
328
  self,
329
  in_channels,
330
  n_heads,
331
  d_head,
 
332
  depth=1,
333
  dropout=0.0,
334
- context_dim=None,
335
- disable_self_attn=False,
336
- use_linear=True,
337
  use_checkpoint=True,
338
  ):
339
  super().__init__()
340
- assert context_dim is not None
341
  if not isinstance(context_dim, list):
342
  context_dim = [context_dim]
 
343
  self.in_channels = in_channels
 
344
  inner_dim = n_heads * d_head
345
  self.norm = nn.GroupNorm(
346
  num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
347
  )
348
- if not use_linear:
349
- self.proj_in = nn.Conv2d(
350
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0
351
- )
352
- else:
353
- self.proj_in = nn.Linear(in_channels, inner_dim)
354
 
355
  self.transformer_blocks = nn.ModuleList(
356
  [
@@ -358,21 +228,19 @@ class SpatialTransformer3D(nn.Module):
358
  inner_dim,
359
  n_heads,
360
  d_head,
361
- dropout=dropout,
362
  context_dim=context_dim[d],
363
- disable_self_attn=disable_self_attn,
364
  checkpoint=use_checkpoint,
 
 
 
365
  )
366
  for d in range(depth)
367
  ]
368
  )
369
- if not use_linear:
370
- self.proj_out = zero_module(
371
- nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
372
- )
373
- else:
374
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
375
- self.use_linear = use_linear
376
 
377
  def forward(self, x, context=None, num_frames=1):
378
  # note: if no context is given, cross-attention defaults to self-attention
@@ -381,16 +249,11 @@ class SpatialTransformer3D(nn.Module):
381
  b, c, h, w = x.shape
382
  x_in = x
383
  x = self.norm(x)
384
- if not self.use_linear:
385
- x = self.proj_in(x)
386
  x = rearrange(x, "b c h w -> b (h w) c").contiguous()
387
- if self.use_linear:
388
- x = self.proj_in(x)
389
  for i, block in enumerate(self.transformer_blocks):
390
  x = block(x, context=context[i], num_frames=num_frames)
391
- if self.use_linear:
392
- x = self.proj_out(x)
393
  x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
394
- if not self.use_linear:
395
- x = self.proj_out(x)
396
  return x + x_in
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
4
 
5
  from inspect import isfunction
6
  from einops import rearrange, repeat
7
  from typing import Optional, Any
 
 
 
 
 
 
 
 
 
8
 
9
+ # require xformers
10
+ import xformers # type: ignore
11
+ import xformers.ops # type: ignore
 
12
 
13
+ from .util import checkpoint, zero_module
14
 
15
  def default(val, d):
16
  if val is not None:
 
47
  return self.net(x)
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  class MemoryEfficientCrossAttention(nn.Module):
51
  # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
52
+ def __init__(
53
+ self,
54
+ query_dim,
55
+ context_dim=None,
56
+ heads=8,
57
+ dim_head=64,
58
+ dropout=0.0,
59
+ with_ip=False,
60
+ ip_dim=16,
61
+ ip_weight=1,
62
+ ):
63
  super().__init__()
64
+
65
  inner_dim = dim_head * heads
66
  context_dim = default(context_dim, query_dim)
67
 
68
  self.heads = heads
69
  self.dim_head = dim_head
70
 
71
+ self.with_ip = with_ip and (context_dim is not None)
72
+ self.ip_dim = ip_dim
73
+ self.ip_weight = ip_weight
74
+
75
+ if self.with_ip:
76
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
77
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
78
+
79
  self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
80
  self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
81
  self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
 
85
  )
86
  self.attention_op: Optional[Any] = None
87
 
88
+ def forward(self, x, context=None):
89
  q = self.to_q(x)
90
  context = default(context, x)
91
+
92
+ if self.with_ip:
93
+ # context dim [(b frame_num), (77 + img_token), 1024]
94
+ token_len = context.shape[1]
95
+ context_ip = context[:, -self.ip_dim :, :]
96
+ k_ip = self.to_k_ip(context_ip)
97
+ v_ip = self.to_v_ip(context_ip)
98
+ context = context[:, : (token_len - self.ip_dim), :]
99
+
100
  k = self.to_k(context)
101
  v = self.to_v(context)
102
 
 
115
  q, k, v, attn_bias=None, op=self.attention_op
116
  )
117
 
118
+ if self.with_ip:
119
+ k_ip, v_ip = map(
120
+ lambda t: t.unsqueeze(3)
121
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
122
+ .permute(0, 2, 1, 3)
123
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
124
+ .contiguous(),
125
+ (k_ip, v_ip),
126
+ )
127
+ # actually compute the attention, what we cannot get enough of
128
+ out_ip = xformers.ops.memory_efficient_attention(
129
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
130
+ )
131
+ out = out + self.ip_weight * out_ip
132
+
133
  out = (
134
  out.unsqueeze(0)
135
  .reshape(b, self.heads, out.shape[1], self.dim_head)
 
139
  return self.to_out(out)
140
 
141
 
142
+ class BasicTransformerBlock3D(nn.Module):
143
+
 
 
 
 
144
  def __init__(
145
  self,
146
  dim,
147
+ context_dim,
148
  n_heads,
149
  d_head,
150
  dropout=0.0,
 
151
  gated_ff=True,
152
  checkpoint=True,
153
+ with_ip=False,
154
+ ip_dim=16,
155
+ ip_weight=1,
156
  ):
157
  super().__init__()
158
+
159
+ self.attn1 = MemoryEfficientCrossAttention(
 
 
 
160
  query_dim=dim,
161
+ context_dim=None, # self-attention
162
  heads=n_heads,
163
  dim_head=d_head,
164
  dropout=dropout,
165
+ )
 
166
  self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
167
+ self.attn2 = MemoryEfficientCrossAttention(
168
  query_dim=dim,
169
  context_dim=context_dim,
170
  heads=n_heads,
171
  dim_head=d_head,
172
  dropout=dropout,
173
+ # ip only applies to cross-attention
174
+ with_ip=with_ip,
175
+ ip_dim=ip_dim,
176
+ ip_weight=ip_weight,
177
+ )
178
  self.norm1 = nn.LayerNorm(dim)
179
  self.norm2 = nn.LayerNorm(dim)
180
  self.norm3 = nn.LayerNorm(dim)
181
  self.checkpoint = checkpoint
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  def forward(self, x, context=None, num_frames=1):
184
  return checkpoint(
185
  self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
 
187
 
188
  def _forward(self, x, context=None, num_frames=1):
189
  x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
190
+ x = self.attn1(self.norm1(x), context=None) + x
 
 
 
 
 
191
  x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
192
  x = self.attn2(self.norm2(x), context=context) + x
193
  x = self.ff(self.norm3(x)) + x
 
195
 
196
 
197
  class SpatialTransformer3D(nn.Module):
 
198
 
199
  def __init__(
200
  self,
201
  in_channels,
202
  n_heads,
203
  d_head,
204
+ context_dim, # cross attention input dim
205
  depth=1,
206
  dropout=0.0,
207
+ with_ip=False,
208
+ ip_dim=16,
209
+ ip_weight=1,
210
  use_checkpoint=True,
211
  ):
212
  super().__init__()
213
+
214
  if not isinstance(context_dim, list):
215
  context_dim = [context_dim]
216
+
217
  self.in_channels = in_channels
218
+
219
  inner_dim = n_heads * d_head
220
  self.norm = nn.GroupNorm(
221
  num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
222
  )
223
+ self.proj_in = nn.Linear(in_channels, inner_dim)
 
 
 
 
 
224
 
225
  self.transformer_blocks = nn.ModuleList(
226
  [
 
228
  inner_dim,
229
  n_heads,
230
  d_head,
 
231
  context_dim=context_dim[d],
232
+ dropout=dropout,
233
  checkpoint=use_checkpoint,
234
+ with_ip=with_ip,
235
+ ip_dim=ip_dim,
236
+ ip_weight=ip_weight,
237
  )
238
  for d in range(depth)
239
  ]
240
  )
241
+
242
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
243
+
 
 
 
 
244
 
245
  def forward(self, x, context=None, num_frames=1):
246
  # note: if no context is given, cross-attention defaults to self-attention
 
249
  b, c, h, w = x.shape
250
  x_in = x
251
  x = self.norm(x)
 
 
252
  x = rearrange(x, "b c h w -> b (h w) c").contiguous()
253
+ x = self.proj_in(x)
 
254
  for i, block in enumerate(self.transformer_blocks):
255
  x = block(x, context=context[i], num_frames=num_frames)
256
+ x = self.proj_out(x)
 
257
  x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
258
+
 
259
  return x + x_in
imagedream/models.py CHANGED
@@ -13,8 +13,8 @@ from .util import (
13
  zero_module,
14
  timestep_embedding,
15
  )
16
- from .attention import SpatialTransformer, SpatialTransformer3D
17
-
18
 
19
  class CondSequential(nn.Sequential):
20
  """
@@ -28,8 +28,6 @@ class CondSequential(nn.Sequential):
28
  x = layer(x, emb)
29
  elif isinstance(layer, SpatialTransformer3D):
30
  x = layer(x, context, num_frames=num_frames)
31
- elif isinstance(layer, SpatialTransformer):
32
- x = layer(x, context)
33
  else:
34
  x = layer(x)
35
  return x
@@ -274,6 +272,9 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
274
  disable_middle_self_attn=False,
275
  adm_in_channels=None,
276
  camera_dim=None,
 
 
 
277
  **kwargs,
278
  ):
279
  super().__init__()
@@ -305,9 +306,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
305
  "as a list/tuple (per-level) with the same length as channel_mult"
306
  )
307
  self.num_res_blocks = num_res_blocks
308
- if disable_self_attentions is not None:
309
- # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
310
- assert len(disable_self_attentions) == len(channel_mult)
311
  if num_attention_blocks is not None:
312
  assert len(num_attention_blocks) == len(self.num_res_blocks)
313
  assert all(
@@ -334,6 +333,22 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
334
  self.num_heads_upsample = num_heads_upsample
335
  self.predict_codebook_ids = n_embed is not None
336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  time_embed_dim = model_channels * 4
338
  self.time_embed = nn.Sequential(
339
  nn.Linear(model_channels, time_embed_dim),
@@ -398,11 +413,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
398
  else:
399
  num_heads = ch // num_head_channels
400
  dim_head = num_head_channels
401
-
402
- if disable_self_attentions is not None:
403
- disabled_sa = disable_self_attentions[level]
404
- else:
405
- disabled_sa = False
406
 
407
  if num_attention_blocks is None or nr < num_attention_blocks[level]:
408
  layers.append(
@@ -410,10 +420,12 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
410
  ch,
411
  num_heads,
412
  dim_head,
413
- depth=transformer_depth,
414
  context_dim=context_dim,
415
- disable_self_attn=disabled_sa,
416
  use_checkpoint=use_checkpoint,
 
 
 
417
  )
418
  )
419
  self.input_blocks.append(CondSequential(*layers))
@@ -463,10 +475,12 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
463
  ch,
464
  num_heads,
465
  dim_head,
466
- depth=transformer_depth,
467
  context_dim=context_dim,
468
- disable_self_attn=disable_middle_self_attn,
469
  use_checkpoint=use_checkpoint,
 
 
 
470
  ),
471
  ResBlock(
472
  ch,
@@ -501,11 +515,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
501
  else:
502
  num_heads = ch // num_head_channels
503
  dim_head = num_head_channels
504
-
505
- if disable_self_attentions is not None:
506
- disabled_sa = disable_self_attentions[level]
507
- else:
508
- disabled_sa = False
509
 
510
  if num_attention_blocks is None or i < num_attention_blocks[level]:
511
  layers.append(
@@ -513,10 +522,12 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
513
  ch,
514
  num_heads,
515
  dim_head,
516
- depth=transformer_depth,
517
  context_dim=context_dim,
518
- disable_self_attn=disabled_sa,
519
  use_checkpoint=use_checkpoint,
 
 
 
520
  )
521
  )
522
  if level and i == self.num_res_blocks[level]:
@@ -559,6 +570,9 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
559
  y: Optional[Tensor] = None,
560
  camera=None,
561
  num_frames=1,
 
 
 
562
  **kwargs,
563
  ):
564
  """
@@ -592,6 +606,11 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
592
  if camera is not None:
593
  assert camera.shape[0] == emb.shape[0]
594
  emb = emb + self.camera_embed(camera)
 
 
 
 
 
595
 
596
  h = x
597
  for module in self.input_blocks:
 
13
  zero_module,
14
  timestep_embedding,
15
  )
16
+ from .attention import SpatialTransformer3D
17
+ from .adaptor import Resampler, ImageProjModel
18
 
19
  class CondSequential(nn.Sequential):
20
  """
 
28
  x = layer(x, emb)
29
  elif isinstance(layer, SpatialTransformer3D):
30
  x = layer(x, context, num_frames=num_frames)
 
 
31
  else:
32
  x = layer(x)
33
  return x
 
272
  disable_middle_self_attn=False,
273
  adm_in_channels=None,
274
  camera_dim=None,
275
+ with_ip=True,
276
+ ip_dim=16,
277
+ ip_weight=1.0,
278
  **kwargs,
279
  ):
280
  super().__init__()
 
306
  "as a list/tuple (per-level) with the same length as channel_mult"
307
  )
308
  self.num_res_blocks = num_res_blocks
309
+
 
 
310
  if num_attention_blocks is not None:
311
  assert len(num_attention_blocks) == len(self.num_res_blocks)
312
  assert all(
 
333
  self.num_heads_upsample = num_heads_upsample
334
  self.predict_codebook_ids = n_embed is not None
335
 
336
+ self.with_ip = with_ip
337
+ self.ip_dim = ip_dim
338
+ self.ip_weight = ip_weight
339
+
340
+ if self.with_ip and self.ip_dim > 0:
341
+ self.image_embed = Resampler(
342
+ dim=context_dim,
343
+ depth=4,
344
+ dim_head=64,
345
+ heads=12,
346
+ num_queries=ip_dim, # num token
347
+ embedding_dim=1280,
348
+ output_dim=context_dim,
349
+ ff_mult=4,
350
+ )
351
+
352
  time_embed_dim = model_channels * 4
353
  self.time_embed = nn.Sequential(
354
  nn.Linear(model_channels, time_embed_dim),
 
413
  else:
414
  num_heads = ch // num_head_channels
415
  dim_head = num_head_channels
 
 
 
 
 
416
 
417
  if num_attention_blocks is None or nr < num_attention_blocks[level]:
418
  layers.append(
 
420
  ch,
421
  num_heads,
422
  dim_head,
 
423
  context_dim=context_dim,
424
+ depth=transformer_depth,
425
  use_checkpoint=use_checkpoint,
426
+ with_ip=self.with_ip,
427
+ ip_dim=self.ip_dim,
428
+ ip_weight=self.ip_weight,
429
  )
430
  )
431
  self.input_blocks.append(CondSequential(*layers))
 
475
  ch,
476
  num_heads,
477
  dim_head,
 
478
  context_dim=context_dim,
479
+ depth=transformer_depth,
480
  use_checkpoint=use_checkpoint,
481
+ with_ip=self.with_ip,
482
+ ip_dim=self.ip_dim,
483
+ ip_weight=self.ip_weight,
484
  ),
485
  ResBlock(
486
  ch,
 
515
  else:
516
  num_heads = ch // num_head_channels
517
  dim_head = num_head_channels
 
 
 
 
 
518
 
519
  if num_attention_blocks is None or i < num_attention_blocks[level]:
520
  layers.append(
 
522
  ch,
523
  num_heads,
524
  dim_head,
 
525
  context_dim=context_dim,
526
+ depth=transformer_depth,
527
  use_checkpoint=use_checkpoint,
528
+ with_ip=self.with_ip,
529
+ ip_dim=self.ip_dim,
530
+ ip_weight=self.ip_weight,
531
  )
532
  )
533
  if level and i == self.num_res_blocks[level]:
 
570
  y: Optional[Tensor] = None,
571
  camera=None,
572
  num_frames=1,
573
+ # should be provided if with_ip
574
+ ip = None,
575
+ ip_img = None,
576
  **kwargs,
577
  ):
578
  """
 
606
  if camera is not None:
607
  assert camera.shape[0] == emb.shape[0]
608
  emb = emb + self.camera_embed(camera)
609
+
610
+ if self.with_ip:
611
+ x[(num_frames - 1) :: num_frames, :, :, :] = ip_img
612
+ ip_emb = self.image_embed(ip)
613
+ context = torch.cat((context, ip_emb), 1)
614
 
615
  h = x
616
  for module in self.input_blocks:
imagedream/pipeline_imagedream.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import inspect
3
  import numpy as np
4
  from typing import Callable, List, Optional, Union
5
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor
6
  from diffusers import AutoencoderKL, DiffusionPipeline
7
  from diffusers.utils import (
8
  deprecate,
@@ -16,6 +16,8 @@ from diffusers.utils.torch_utils import randn_tensor
16
 
17
  from .models import MultiViewUNetModel
18
 
 
 
19
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
 
21
 
@@ -62,7 +64,7 @@ def convert_opengl_to_blender(camera_matrix):
62
 
63
 
64
  def get_camera(
65
- num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True
66
  ):
67
  angle_gap = azimuth_span / num_frames
68
  cameras = []
@@ -71,6 +73,9 @@ def get_camera(
71
  if blender_coord:
72
  camera_matrix = convert_opengl_to_blender(camera_matrix)
73
  cameras.append(camera_matrix.flatten())
 
 
 
74
  return torch.tensor(np.stack(cameras, 0)).float()
75
 
76
 
@@ -82,8 +87,8 @@ class ImageDreamPipeline(DiffusionPipeline):
82
  tokenizer: CLIPTokenizer,
83
  text_encoder: CLIPTextModel,
84
  scheduler: DDIMScheduler,
85
- feature_extractor: CLIPImageProcessor,
86
- image_encoder: CLIPVisionModel,
87
  requires_safety_checker: bool = False,
88
  ):
89
  super().__init__()
@@ -449,10 +454,36 @@ class ImageDreamPipeline(DiffusionPipeline):
449
  latents = latents * self.scheduler.init_noise_sigma
450
  return latents
451
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
452
  @torch.no_grad()
453
  def __call__(
454
  self,
455
- image, # input image (TODO: pil?)
456
  prompt: str = "a car",
457
  height: int = 256,
458
  width: int = 256,
@@ -465,7 +496,7 @@ class ImageDreamPipeline(DiffusionPipeline):
465
  output_type: Optional[str] = "image",
466
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
467
  callback_steps: int = 1,
468
- batch_size: int = 4,
469
  device=torch.device("cuda:0"),
470
  ):
471
  self.unet = self.unet.to(device=device)
@@ -482,7 +513,18 @@ class ImageDreamPipeline(DiffusionPipeline):
482
  self.scheduler.set_timesteps(num_inference_steps, device=device)
483
  timesteps = self.scheduler.timesteps
484
 
485
- _prompt_embeds: torch.Tensor = self._encode_prompt(
 
 
 
 
 
 
 
 
 
 
 
486
  prompt=prompt,
487
  device=device,
488
  num_images_per_prompt=num_images_per_prompt,
@@ -493,8 +535,8 @@ class ImageDreamPipeline(DiffusionPipeline):
493
 
494
  # Prepare latent variables
495
  latents: torch.Tensor = self.prepare_latents(
496
- batch_size * num_images_per_prompt,
497
- 4,
498
  height,
499
  width,
500
  prompt_embeds_pos.dtype,
@@ -503,9 +545,10 @@ class ImageDreamPipeline(DiffusionPipeline):
503
  None,
504
  )
505
 
506
- camera = get_camera(batch_size).to(dtype=latents.dtype, device=device)
 
507
 
508
- # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
509
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
510
 
511
  # Denoising loop
@@ -523,15 +566,22 @@ class ImageDreamPipeline(DiffusionPipeline):
523
  noise_pred = self.unet.forward(
524
  x=latent_model_input,
525
  timesteps=torch.tensor(
526
- [t] * 4 * multiplier,
527
  dtype=latent_model_input.dtype,
528
  device=device,
529
  ),
530
  context=torch.cat(
531
- [prompt_embeds_neg] * 4 + [prompt_embeds_pos] * 4
532
  ),
533
- num_frames=4,
534
  camera=torch.cat([camera] * multiplier),
 
 
 
 
 
 
 
535
  )
536
 
537
  # perform guidance
@@ -542,7 +592,6 @@ class ImageDreamPipeline(DiffusionPipeline):
542
  )
543
 
544
  # compute the previous noisy sample x_t -> x_t-1
545
- # latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype)
546
  latents: torch.Tensor = self.scheduler.step(
547
  noise_pred, t, latents, **extra_step_kwargs, return_dict=False
548
  )[0]
 
2
  import inspect
3
  import numpy as np
4
  from typing import Callable, List, Optional, Union
5
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPFeatureExtractor
6
  from diffusers import AutoencoderKL, DiffusionPipeline
7
  from diffusers.utils import (
8
  deprecate,
 
16
 
17
  from .models import MultiViewUNetModel
18
 
19
+ import kiui
20
+
21
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
 
23
 
 
64
 
65
 
66
  def get_camera(
67
+ num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
68
  ):
69
  angle_gap = azimuth_span / num_frames
70
  cameras = []
 
73
  if blender_coord:
74
  camera_matrix = convert_opengl_to_blender(camera_matrix)
75
  cameras.append(camera_matrix.flatten())
76
+ if extra_view:
77
+ dim = len(cameras[0])
78
+ cameras.append(np.zeros(dim))
79
  return torch.tensor(np.stack(cameras, 0)).float()
80
 
81
 
 
87
  tokenizer: CLIPTokenizer,
88
  text_encoder: CLIPTextModel,
89
  scheduler: DDIMScheduler,
90
+ feature_extractor: CLIPFeatureExtractor = None,
91
+ image_encoder: CLIPVisionModel = None,
92
  requires_safety_checker: bool = False,
93
  ):
94
  super().__init__()
 
454
  latents = latents * self.scheduler.init_noise_sigma
455
  return latents
456
 
457
+ def encode_image(self, image, device, num_images_per_prompt):
458
+ dtype = next(self.image_encoder.parameters()).dtype
459
+
460
+ image = (image * 255).astype(np.uint8)
461
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
462
+
463
+ image = image.to(device=device, dtype=dtype)
464
+
465
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
466
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
467
+
468
+ # imagedream directly use zero as uncond image embeddings
469
+ uncond_image_enc_hidden_states = torch.zeros_like(image_enc_hidden_states)
470
+
471
+ return uncond_image_enc_hidden_states, image_enc_hidden_states
472
+
473
+ def encode_image_latents(self, image, device, num_images_per_prompt):
474
+
475
+ image = torch.from_numpy(image).to(device)
476
+ posterior = self.vae.encode(image).latent_dist
477
+
478
+ latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
479
+ latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
480
+
481
+ return torch.zeros_like(latents), latents
482
+
483
  @torch.no_grad()
484
  def __call__(
485
  self,
486
+ image, # input image, np.ndarray float32!
487
  prompt: str = "a car",
488
  height: int = 256,
489
  width: int = 256,
 
496
  output_type: Optional[str] = "image",
497
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
498
  callback_steps: int = 1,
499
+ num_frames: int = 4,
500
  device=torch.device("cuda:0"),
501
  ):
502
  self.unet = self.unet.to(device=device)
 
513
  self.scheduler.set_timesteps(num_inference_steps, device=device)
514
  timesteps = self.scheduler.timesteps
515
 
516
+ # encode image
517
+ assert isinstance(image, np.ndarray) and image.dtype == np.float32
518
+
519
+ self.image_encoder = self.image_encoder.to(device=device)
520
+ image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt)
521
+ kiui.lo(image_embeds_pos) # should be [1, 257, 1280]?
522
+
523
+ image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt)
524
+ kiui.lo(image_latents_pos)
525
+
526
+ # encode text
527
+ _prompt_embeds = self._encode_prompt(
528
  prompt=prompt,
529
  device=device,
530
  num_images_per_prompt=num_images_per_prompt,
 
535
 
536
  # Prepare latent variables
537
  latents: torch.Tensor = self.prepare_latents(
538
+ (num_frames + 1) * num_images_per_prompt,
539
+ 4, # channel
540
  height,
541
  width,
542
  prompt_embeds_pos.dtype,
 
545
  None,
546
  )
547
 
548
+ camera = get_camera(num_frames, extra_view=True).to(dtype=latents.dtype, device=device)
549
+ camera = camera.repeat(num_images_per_prompt, 1).to(self.device)
550
 
551
+ # Prepare extra step kwargs.
552
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
553
 
554
  # Denoising loop
 
566
  noise_pred = self.unet.forward(
567
  x=latent_model_input,
568
  timesteps=torch.tensor(
569
+ [t] * (num_frames + 1) * multiplier,
570
  dtype=latent_model_input.dtype,
571
  device=device,
572
  ),
573
  context=torch.cat(
574
+ [prompt_embeds_neg] * (num_frames + 1) + [prompt_embeds_pos] * (num_frames + 1)
575
  ),
576
+ num_frames=num_frames + 1,
577
  camera=torch.cat([camera] * multiplier),
578
+ # for with_ip
579
+ ip=torch.cat(
580
+ [image_embeds_neg] * (num_frames + 1) + [image_embeds_pos] * (num_frames + 1)
581
+ ),
582
+ ip_img=torch.cat(
583
+ [image_latents_neg] * (num_frames + 1) + [image_latents_pos] * (num_frames + 1)
584
+ ),
585
  )
586
 
587
  # perform guidance
 
592
  )
593
 
594
  # compute the previous noisy sample x_t -> x_t-1
 
595
  latents: torch.Tensor = self.scheduler.step(
596
  noise_pred, t, latents, **extra_step_kwargs, return_dict=False
597
  )[0]