ashawkey commited on
Commit
35925cc
1 Parent(s): 34f2f42

Upload folder using huggingface_hub

Browse files
feature_extractor/preprocessor_config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "crop_size": {
3
+ "height": 224,
4
+ "width": 224
5
+ },
6
+ "do_center_crop": true,
7
+ "do_convert_rgb": true,
8
+ "do_normalize": true,
9
+ "do_rescale": true,
10
+ "do_resize": true,
11
+ "feature_extractor_type": "CLIPFeatureExtractor",
12
+ "image_mean": [
13
+ 0.48145466,
14
+ 0.4578275,
15
+ 0.40821073
16
+ ],
17
+ "image_processor_type": "CLIPImageProcessor",
18
+ "image_std": [
19
+ 0.26862954,
20
+ 0.26130258,
21
+ 0.27577711
22
+ ],
23
+ "resample": 3,
24
+ "rescale_factor": 0.00392156862745098,
25
+ "size": {
26
+ "shortest_edge": 224
27
+ },
28
+ "use_square_size": false
29
+ }
image_encoder/config.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
3
+ "architectures": [
4
+ "CLIPVisionModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "dropout": 0.0,
8
+ "hidden_act": "gelu",
9
+ "hidden_size": 1280,
10
+ "image_size": 224,
11
+ "initializer_factor": 1.0,
12
+ "initializer_range": 0.02,
13
+ "intermediate_size": 5120,
14
+ "layer_norm_eps": 1e-05,
15
+ "model_type": "clip_vision_model",
16
+ "num_attention_heads": 16,
17
+ "num_channels": 3,
18
+ "num_hidden_layers": 32,
19
+ "patch_size": 14,
20
+ "projection_dim": 1024,
21
+ "torch_dtype": "float16",
22
+ "transformers_version": "4.35.2"
23
+ }
image_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2a56cfd4ffcf40be097c430324ec184cc37187f6dafef128ef9225438a3c03c4
3
+ size 1261595704
model_index.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MVDreamPipeline",
3
+ "_diffusers_version": "0.25.0",
4
+ "feature_extractor": [
5
+ "transformers",
6
+ "CLIPImageProcessor"
7
+ ],
8
+ "image_encoder": [
9
+ "transformers",
10
+ "CLIPVisionModel"
11
+ ],
12
+ "requires_safety_checker": false,
13
+ "scheduler": [
14
+ "diffusers",
15
+ "DDIMScheduler"
16
+ ],
17
+ "text_encoder": [
18
+ "transformers",
19
+ "CLIPTextModel"
20
+ ],
21
+ "tokenizer": [
22
+ "transformers",
23
+ "CLIPTokenizer"
24
+ ],
25
+ "unet": [
26
+ "mv_unet",
27
+ "MultiViewUNetModel"
28
+ ],
29
+ "vae": [
30
+ "diffusers",
31
+ "AutoencoderKL"
32
+ ]
33
+ }
pipeline_mvdream.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import inspect
4
+ import numpy as np
5
+ from typing import Callable, List, Optional, Union
6
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor
7
+ from diffusers import AutoencoderKL, DiffusionPipeline
8
+ from diffusers.utils import (
9
+ deprecate,
10
+ is_accelerate_available,
11
+ is_accelerate_version,
12
+ logging,
13
+ )
14
+ from diffusers.configuration_utils import FrozenDict
15
+ from diffusers.schedulers import DDIMScheduler
16
+ from diffusers.utils.torch_utils import randn_tensor
17
+
18
+ from mv_unet import MultiViewUNetModel, get_camera
19
+
20
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
21
+
22
+
23
+ class MVDreamPipeline(DiffusionPipeline):
24
+
25
+ _optional_components = ["feature_extractor", "image_encoder"]
26
+
27
+ def __init__(
28
+ self,
29
+ vae: AutoencoderKL,
30
+ unet: MultiViewUNetModel,
31
+ tokenizer: CLIPTokenizer,
32
+ text_encoder: CLIPTextModel,
33
+ scheduler: DDIMScheduler,
34
+ # imagedream variant
35
+ feature_extractor: CLIPImageProcessor,
36
+ image_encoder: CLIPVisionModel,
37
+ requires_safety_checker: bool = False,
38
+ ):
39
+ super().__init__()
40
+
41
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore
42
+ deprecation_message = (
43
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
44
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore
45
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
46
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
47
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
48
+ " file"
49
+ )
50
+ deprecate(
51
+ "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
52
+ )
53
+ new_config = dict(scheduler.config)
54
+ new_config["steps_offset"] = 1
55
+ scheduler._internal_dict = FrozenDict(new_config)
56
+
57
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore
58
+ deprecation_message = (
59
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
60
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
61
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
62
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
63
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
64
+ )
65
+ deprecate(
66
+ "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
67
+ )
68
+ new_config = dict(scheduler.config)
69
+ new_config["clip_sample"] = False
70
+ scheduler._internal_dict = FrozenDict(new_config)
71
+
72
+ self.register_modules(
73
+ vae=vae,
74
+ unet=unet,
75
+ scheduler=scheduler,
76
+ tokenizer=tokenizer,
77
+ text_encoder=text_encoder,
78
+ feature_extractor=feature_extractor,
79
+ image_encoder=image_encoder,
80
+ )
81
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
82
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
83
+
84
+ def enable_vae_slicing(self):
85
+ r"""
86
+ Enable sliced VAE decoding.
87
+
88
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
89
+ steps. This is useful to save some memory and allow larger batch sizes.
90
+ """
91
+ self.vae.enable_slicing()
92
+
93
+ def disable_vae_slicing(self):
94
+ r"""
95
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
96
+ computing decoding in one step.
97
+ """
98
+ self.vae.disable_slicing()
99
+
100
+ def enable_vae_tiling(self):
101
+ r"""
102
+ Enable tiled VAE decoding.
103
+
104
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
105
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
106
+ """
107
+ self.vae.enable_tiling()
108
+
109
+ def disable_vae_tiling(self):
110
+ r"""
111
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
112
+ computing decoding in one step.
113
+ """
114
+ self.vae.disable_tiling()
115
+
116
+ def enable_sequential_cpu_offload(self, gpu_id=0):
117
+ r"""
118
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
119
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
120
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
121
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
122
+ `enable_model_cpu_offload`, but performance is lower.
123
+ """
124
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
125
+ from accelerate import cpu_offload
126
+ else:
127
+ raise ImportError(
128
+ "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher"
129
+ )
130
+
131
+ device = torch.device(f"cuda:{gpu_id}")
132
+
133
+ if self.device.type != "cpu":
134
+ self.to("cpu", silence_dtype_warnings=True)
135
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
136
+
137
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
138
+ cpu_offload(cpu_offloaded_model, device)
139
+
140
+ def enable_model_cpu_offload(self, gpu_id=0):
141
+ r"""
142
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
143
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
144
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
145
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
146
+ """
147
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
148
+ from accelerate import cpu_offload_with_hook
149
+ else:
150
+ raise ImportError(
151
+ "`enable_model_offload` requires `accelerate v0.17.0` or higher."
152
+ )
153
+
154
+ device = torch.device(f"cuda:{gpu_id}")
155
+
156
+ if self.device.type != "cpu":
157
+ self.to("cpu", silence_dtype_warnings=True)
158
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
159
+
160
+ hook = None
161
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
162
+ _, hook = cpu_offload_with_hook(
163
+ cpu_offloaded_model, device, prev_module_hook=hook
164
+ )
165
+
166
+ # We'll offload the last model manually.
167
+ self.final_offload_hook = hook
168
+
169
+ @property
170
+ def _execution_device(self):
171
+ r"""
172
+ Returns the device on which the pipeline's models will be executed. After calling
173
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
174
+ hooks.
175
+ """
176
+ if not hasattr(self.unet, "_hf_hook"):
177
+ return self.device
178
+ for module in self.unet.modules():
179
+ if (
180
+ hasattr(module, "_hf_hook")
181
+ and hasattr(module._hf_hook, "execution_device")
182
+ and module._hf_hook.execution_device is not None
183
+ ):
184
+ return torch.device(module._hf_hook.execution_device)
185
+ return self.device
186
+
187
+ def _encode_prompt(
188
+ self,
189
+ prompt,
190
+ device,
191
+ num_images_per_prompt,
192
+ do_classifier_free_guidance: bool,
193
+ negative_prompt=None,
194
+ ):
195
+ r"""
196
+ Encodes the prompt into text encoder hidden states.
197
+
198
+ Args:
199
+ prompt (`str` or `List[str]`, *optional*):
200
+ prompt to be encoded
201
+ device: (`torch.device`):
202
+ torch device
203
+ num_images_per_prompt (`int`):
204
+ number of images that should be generated per prompt
205
+ do_classifier_free_guidance (`bool`):
206
+ whether to use classifier free guidance or not
207
+ negative_prompt (`str` or `List[str]`, *optional*):
208
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
209
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
210
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
211
+ prompt_embeds (`torch.FloatTensor`, *optional*):
212
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
213
+ provided, text embeddings will be generated from `prompt` input argument.
214
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
215
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
216
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
217
+ argument.
218
+ """
219
+ if prompt is not None and isinstance(prompt, str):
220
+ batch_size = 1
221
+ elif prompt is not None and isinstance(prompt, list):
222
+ batch_size = len(prompt)
223
+ else:
224
+ raise ValueError(
225
+ f"`prompt` should be either a string or a list of strings, but got {type(prompt)}."
226
+ )
227
+
228
+ text_inputs = self.tokenizer(
229
+ prompt,
230
+ padding="max_length",
231
+ max_length=self.tokenizer.model_max_length,
232
+ truncation=True,
233
+ return_tensors="pt",
234
+ )
235
+ text_input_ids = text_inputs.input_ids
236
+ untruncated_ids = self.tokenizer(
237
+ prompt, padding="longest", return_tensors="pt"
238
+ ).input_ids
239
+
240
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
241
+ text_input_ids, untruncated_ids
242
+ ):
243
+ removed_text = self.tokenizer.batch_decode(
244
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
245
+ )
246
+ logger.warning(
247
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
248
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
249
+ )
250
+
251
+ if (
252
+ hasattr(self.text_encoder.config, "use_attention_mask")
253
+ and self.text_encoder.config.use_attention_mask
254
+ ):
255
+ attention_mask = text_inputs.attention_mask.to(device)
256
+ else:
257
+ attention_mask = None
258
+
259
+ prompt_embeds = self.text_encoder(
260
+ text_input_ids.to(device),
261
+ attention_mask=attention_mask,
262
+ )
263
+ prompt_embeds = prompt_embeds[0]
264
+
265
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
266
+
267
+ bs_embed, seq_len, _ = prompt_embeds.shape
268
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
269
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
270
+ prompt_embeds = prompt_embeds.view(
271
+ bs_embed * num_images_per_prompt, seq_len, -1
272
+ )
273
+
274
+ # get unconditional embeddings for classifier free guidance
275
+ if do_classifier_free_guidance:
276
+ uncond_tokens: List[str]
277
+ if negative_prompt is None:
278
+ uncond_tokens = [""] * batch_size
279
+ elif type(prompt) is not type(negative_prompt):
280
+ raise TypeError(
281
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
282
+ f" {type(prompt)}."
283
+ )
284
+ elif isinstance(negative_prompt, str):
285
+ uncond_tokens = [negative_prompt]
286
+ elif batch_size != len(negative_prompt):
287
+ raise ValueError(
288
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
289
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
290
+ " the batch size of `prompt`."
291
+ )
292
+ else:
293
+ uncond_tokens = negative_prompt
294
+
295
+ max_length = prompt_embeds.shape[1]
296
+ uncond_input = self.tokenizer(
297
+ uncond_tokens,
298
+ padding="max_length",
299
+ max_length=max_length,
300
+ truncation=True,
301
+ return_tensors="pt",
302
+ )
303
+
304
+ if (
305
+ hasattr(self.text_encoder.config, "use_attention_mask")
306
+ and self.text_encoder.config.use_attention_mask
307
+ ):
308
+ attention_mask = uncond_input.attention_mask.to(device)
309
+ else:
310
+ attention_mask = None
311
+
312
+ negative_prompt_embeds = self.text_encoder(
313
+ uncond_input.input_ids.to(device),
314
+ attention_mask=attention_mask,
315
+ )
316
+ negative_prompt_embeds = negative_prompt_embeds[0]
317
+
318
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
319
+ seq_len = negative_prompt_embeds.shape[1]
320
+
321
+ negative_prompt_embeds = negative_prompt_embeds.to(
322
+ dtype=self.text_encoder.dtype, device=device
323
+ )
324
+
325
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
326
+ 1, num_images_per_prompt, 1
327
+ )
328
+ negative_prompt_embeds = negative_prompt_embeds.view(
329
+ batch_size * num_images_per_prompt, seq_len, -1
330
+ )
331
+
332
+ # For classifier free guidance, we need to do two forward passes.
333
+ # Here we concatenate the unconditional and text embeddings into a single batch
334
+ # to avoid doing two forward passes
335
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
336
+
337
+ return prompt_embeds
338
+
339
+ def decode_latents(self, latents):
340
+ latents = 1 / self.vae.config.scaling_factor * latents
341
+ image = self.vae.decode(latents).sample
342
+ image = (image / 2 + 0.5).clamp(0, 1)
343
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
344
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
345
+ return image
346
+
347
+ def prepare_extra_step_kwargs(self, generator, eta):
348
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
349
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
350
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
351
+ # and should be between [0, 1]
352
+
353
+ accepts_eta = "eta" in set(
354
+ inspect.signature(self.scheduler.step).parameters.keys()
355
+ )
356
+ extra_step_kwargs = {}
357
+ if accepts_eta:
358
+ extra_step_kwargs["eta"] = eta
359
+
360
+ # check if the scheduler accepts generator
361
+ accepts_generator = "generator" in set(
362
+ inspect.signature(self.scheduler.step).parameters.keys()
363
+ )
364
+ if accepts_generator:
365
+ extra_step_kwargs["generator"] = generator
366
+ return extra_step_kwargs
367
+
368
+ def prepare_latents(
369
+ self,
370
+ batch_size,
371
+ num_channels_latents,
372
+ height,
373
+ width,
374
+ dtype,
375
+ device,
376
+ generator,
377
+ latents=None,
378
+ ):
379
+ shape = (
380
+ batch_size,
381
+ num_channels_latents,
382
+ height // self.vae_scale_factor,
383
+ width // self.vae_scale_factor,
384
+ )
385
+ if isinstance(generator, list) and len(generator) != batch_size:
386
+ raise ValueError(
387
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
388
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
389
+ )
390
+
391
+ if latents is None:
392
+ latents = randn_tensor(
393
+ shape, generator=generator, device=device, dtype=dtype
394
+ )
395
+ else:
396
+ latents = latents.to(device)
397
+
398
+ # scale the initial noise by the standard deviation required by the scheduler
399
+ latents = latents * self.scheduler.init_noise_sigma
400
+ return latents
401
+
402
+ def encode_image(self, image, device, num_images_per_prompt):
403
+ dtype = next(self.image_encoder.parameters()).dtype
404
+
405
+ if image.dtype == np.float32:
406
+ image = (image * 255).astype(np.uint8)
407
+
408
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
409
+ image = image.to(device=device, dtype=dtype)
410
+
411
+ image_embeds = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
412
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
413
+
414
+ return torch.zeros_like(image_embeds), image_embeds
415
+
416
+ def encode_image_latents(self, image, device, num_images_per_prompt):
417
+
418
+ dtype = next(self.image_encoder.parameters()).dtype
419
+
420
+ image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2).to(device=device) # [1, 3, H, W]
421
+ image = 2 * image - 1
422
+ image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)
423
+ image = image.to(dtype=dtype)
424
+
425
+ posterior = self.vae.encode(image).latent_dist
426
+ latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
427
+ latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
428
+
429
+ return torch.zeros_like(latents), latents
430
+
431
+ @torch.no_grad()
432
+ def __call__(
433
+ self,
434
+ prompt: str = "",
435
+ image: Optional[np.ndarray] = None,
436
+ height: int = 256,
437
+ width: int = 256,
438
+ num_inference_steps: int = 50,
439
+ guidance_scale: float = 7.0,
440
+ negative_prompt: str = "",
441
+ num_images_per_prompt: int = 1,
442
+ eta: float = 0.0,
443
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
444
+ output_type: Optional[str] = "numpy", # pil, numpy, latents
445
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
446
+ callback_steps: int = 1,
447
+ num_frames: int = 4,
448
+ device=torch.device("cuda:0"),
449
+ ):
450
+ self.unet = self.unet.to(device=device)
451
+ self.vae = self.vae.to(device=device)
452
+ self.text_encoder = self.text_encoder.to(device=device)
453
+
454
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
455
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
456
+ # corresponds to doing no classifier free guidance.
457
+ do_classifier_free_guidance = guidance_scale > 1.0
458
+
459
+ # Prepare timesteps
460
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
461
+ timesteps = self.scheduler.timesteps
462
+
463
+ # imagedream variant
464
+ if image is not None:
465
+ assert isinstance(image, np.ndarray) and image.dtype == np.float32
466
+ self.image_encoder = self.image_encoder.to(device=device)
467
+ image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt)
468
+ image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt)
469
+
470
+ _prompt_embeds = self._encode_prompt(
471
+ prompt=prompt,
472
+ device=device,
473
+ num_images_per_prompt=num_images_per_prompt,
474
+ do_classifier_free_guidance=do_classifier_free_guidance,
475
+ negative_prompt=negative_prompt,
476
+ ) # type: ignore
477
+ prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
478
+
479
+ # Prepare latent variables
480
+ actual_num_frames = num_frames if image is None else num_frames + 1
481
+ latents: torch.Tensor = self.prepare_latents(
482
+ actual_num_frames * num_images_per_prompt,
483
+ 4,
484
+ height,
485
+ width,
486
+ prompt_embeds_pos.dtype,
487
+ device,
488
+ generator,
489
+ None,
490
+ )
491
+
492
+ if image is not None:
493
+ camera = get_camera(num_frames, elevation=5, extra_view=True).to(dtype=latents.dtype, device=device)
494
+ else:
495
+ camera = get_camera(num_frames, elevation=15, extra_view=False).to(dtype=latents.dtype, device=device)
496
+ camera = camera.repeat_interleave(num_images_per_prompt, dim=0)
497
+
498
+ # Prepare extra step kwargs.
499
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
500
+
501
+ # Denoising loop
502
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
503
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
504
+ for i, t in enumerate(timesteps):
505
+ # expand the latents if we are doing classifier free guidance
506
+ multiplier = 2 if do_classifier_free_guidance else 1
507
+ latent_model_input = torch.cat([latents] * multiplier)
508
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
509
+
510
+ unet_inputs = {
511
+ 'x': latent_model_input,
512
+ 'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device),
513
+ 'context': torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames),
514
+ 'num_frames': actual_num_frames,
515
+ 'camera': torch.cat([camera] * multiplier),
516
+ }
517
+
518
+ if image is not None:
519
+ unet_inputs['ip'] = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames)
520
+ unet_inputs['ip_img'] = torch.cat([image_latents_neg] + [image_latents_pos]) # no repeat
521
+
522
+ # predict the noise residual
523
+ noise_pred = self.unet.forward(**unet_inputs)
524
+
525
+ # perform guidance
526
+ if do_classifier_free_guidance:
527
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
528
+ noise_pred = noise_pred_uncond + guidance_scale * (
529
+ noise_pred_text - noise_pred_uncond
530
+ )
531
+
532
+ # compute the previous noisy sample x_t -> x_t-1
533
+ latents: torch.Tensor = self.scheduler.step(
534
+ noise_pred, t, latents, **extra_step_kwargs, return_dict=False
535
+ )[0]
536
+
537
+ # call the callback, if provided
538
+ if i == len(timesteps) - 1 or (
539
+ (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
540
+ ):
541
+ progress_bar.update()
542
+ if callback is not None and i % callback_steps == 0:
543
+ callback(i, t, latents) # type: ignore
544
+
545
+ # Post-processing
546
+ if output_type == "latent":
547
+ image = latents
548
+ elif output_type == "pil":
549
+ image = self.decode_latents(latents)
550
+ image = self.numpy_to_pil(image)
551
+ else: # numpy
552
+ image = self.decode_latents(latents)
553
+
554
+ # Offload last model to CPU
555
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
556
+ self.final_offload_hook.offload()
557
+
558
+ return image
scheduler/scheduler_config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "DDIMScheduler",
3
+ "_diffusers_version": "0.25.0",
4
+ "beta_end": 0.012,
5
+ "beta_schedule": "scaled_linear",
6
+ "beta_start": 0.00085,
7
+ "clip_sample": false,
8
+ "clip_sample_range": 1.0,
9
+ "dynamic_thresholding_ratio": 0.995,
10
+ "num_train_timesteps": 1000,
11
+ "prediction_type": "epsilon",
12
+ "rescale_betas_zero_snr": false,
13
+ "sample_max_value": 1.0,
14
+ "set_alpha_to_one": false,
15
+ "steps_offset": 1,
16
+ "thresholding": false,
17
+ "timestep_spacing": "leading",
18
+ "trained_betas": null
19
+ }
text_encoder/config.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "stabilityai/stable-diffusion-2-1",
3
+ "architectures": [
4
+ "CLIPTextModel"
5
+ ],
6
+ "attention_dropout": 0.0,
7
+ "bos_token_id": 0,
8
+ "dropout": 0.0,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_size": 1024,
12
+ "initializer_factor": 1.0,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4096,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 77,
17
+ "model_type": "clip_text_model",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 23,
20
+ "pad_token_id": 1,
21
+ "projection_dim": 512,
22
+ "torch_dtype": "float16",
23
+ "transformers_version": "4.35.2",
24
+ "vocab_size": 49408
25
+ }
text_encoder/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bc1827c465450322616f06dea41596eac7d493f4e95904dcb51f0fc745c4e13f
3
+ size 680820392
tokenizer/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|startoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "!",
17
+ "unk_token": {
18
+ "content": "<|endoftext|>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "0": {
5
+ "content": "!",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "49406": {
13
+ "content": "<|startoftext|>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "49407": {
21
+ "content": "<|endoftext|>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "bos_token": "<|startoftext|>",
30
+ "clean_up_tokenization_spaces": true,
31
+ "do_lower_case": true,
32
+ "eos_token": "<|endoftext|>",
33
+ "errors": "replace",
34
+ "model_max_length": 77,
35
+ "pad_token": "!",
36
+ "tokenizer_class": "CLIPTokenizer",
37
+ "unk_token": "<|endoftext|>"
38
+ }
tokenizer/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
unet/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MultiViewUNetModel",
3
+ "_diffusers_version": "0.25.0",
4
+ "attention_resolutions": [
5
+ 4,
6
+ 2,
7
+ 1
8
+ ],
9
+ "camera_dim": 16,
10
+ "channel_mult": [
11
+ 1,
12
+ 2,
13
+ 4,
14
+ 4
15
+ ],
16
+ "context_dim": 1024,
17
+ "image_size": 32,
18
+ "in_channels": 4,
19
+ "ip_dim": 16,
20
+ "model_channels": 320,
21
+ "num_head_channels": 64,
22
+ "num_res_blocks": 2,
23
+ "out_channels": 4,
24
+ "transformer_depth": 1,
25
+ "use_checkpoint": false
26
+ }
unet/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28d8b241a54125fa0a041c1818a5dcdb717e6f5270eea1268172acd3ab0238e0
3
+ size 1883435904
unet/mv_unet.py ADDED
@@ -0,0 +1,1089 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ from inspect import isfunction
4
+ from typing import Optional, Any, List
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange, repeat
10
+
11
+ from diffusers.configuration_utils import ConfigMixin
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+
14
+ # require xformers!
15
+ import xformers
16
+ import xformers.ops
17
+
18
+ from kiui.cam import orbit_camera
19
+
20
+ def get_camera(
21
+ num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
22
+ ):
23
+ angle_gap = azimuth_span / num_frames
24
+ cameras = []
25
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
26
+
27
+ pose = orbit_camera(-elevation, azimuth, radius=1) # kiui's elevation is negated, [4, 4]
28
+
29
+ # opengl to blender
30
+ if blender_coord:
31
+ pose[2] *= -1
32
+ pose[[1, 2]] = pose[[2, 1]]
33
+
34
+ cameras.append(pose.flatten())
35
+
36
+ if extra_view:
37
+ cameras.append(np.zeros_like(cameras[0]))
38
+
39
+ return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
40
+
41
+
42
+ def checkpoint(func, inputs, params, flag):
43
+ """
44
+ Evaluate a function without caching intermediate activations, allowing for
45
+ reduced memory at the expense of extra compute in the backward pass.
46
+ :param func: the function to evaluate.
47
+ :param inputs: the argument sequence to pass to `func`.
48
+ :param params: a sequence of parameters `func` depends on but does not
49
+ explicitly take as arguments.
50
+ :param flag: if False, disable gradient checkpointing.
51
+ """
52
+ if flag:
53
+ args = tuple(inputs) + tuple(params)
54
+ return CheckpointFunction.apply(func, len(inputs), *args)
55
+ else:
56
+ return func(*inputs)
57
+
58
+
59
+ class CheckpointFunction(torch.autograd.Function):
60
+ @staticmethod
61
+ def forward(ctx, run_function, length, *args):
62
+ ctx.run_function = run_function
63
+ ctx.input_tensors = list(args[:length])
64
+ ctx.input_params = list(args[length:])
65
+
66
+ with torch.no_grad():
67
+ output_tensors = ctx.run_function(*ctx.input_tensors)
68
+ return output_tensors
69
+
70
+ @staticmethod
71
+ def backward(ctx, *output_grads):
72
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
73
+ with torch.enable_grad():
74
+ # Fixes a bug where the first op in run_function modifies the
75
+ # Tensor storage in place, which is not allowed for detach()'d
76
+ # Tensors.
77
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
78
+ output_tensors = ctx.run_function(*shallow_copies)
79
+ input_grads = torch.autograd.grad(
80
+ output_tensors,
81
+ ctx.input_tensors + ctx.input_params,
82
+ output_grads,
83
+ allow_unused=True,
84
+ )
85
+ del ctx.input_tensors
86
+ del ctx.input_params
87
+ del output_tensors
88
+ return (None, None) + input_grads
89
+
90
+
91
+ def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
92
+ """
93
+ Create sinusoidal timestep embeddings.
94
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
95
+ These may be fractional.
96
+ :param dim: the dimension of the output.
97
+ :param max_period: controls the minimum frequency of the embeddings.
98
+ :return: an [N x dim] Tensor of positional embeddings.
99
+ """
100
+ if not repeat_only:
101
+ half = dim // 2
102
+ freqs = torch.exp(
103
+ -math.log(max_period)
104
+ * torch.arange(start=0, end=half, dtype=torch.float32)
105
+ / half
106
+ ).to(device=timesteps.device)
107
+ args = timesteps[:, None] * freqs[None]
108
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
109
+ if dim % 2:
110
+ embedding = torch.cat(
111
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
112
+ )
113
+ else:
114
+ embedding = repeat(timesteps, "b -> b d", d=dim)
115
+ # import pdb; pdb.set_trace()
116
+ return embedding
117
+
118
+
119
+ def zero_module(module):
120
+ """
121
+ Zero out the parameters of a module and return it.
122
+ """
123
+ for p in module.parameters():
124
+ p.detach().zero_()
125
+ return module
126
+
127
+
128
+ def conv_nd(dims, *args, **kwargs):
129
+ """
130
+ Create a 1D, 2D, or 3D convolution module.
131
+ """
132
+ if dims == 1:
133
+ return nn.Conv1d(*args, **kwargs)
134
+ elif dims == 2:
135
+ return nn.Conv2d(*args, **kwargs)
136
+ elif dims == 3:
137
+ return nn.Conv3d(*args, **kwargs)
138
+ raise ValueError(f"unsupported dimensions: {dims}")
139
+
140
+
141
+ def avg_pool_nd(dims, *args, **kwargs):
142
+ """
143
+ Create a 1D, 2D, or 3D average pooling module.
144
+ """
145
+ if dims == 1:
146
+ return nn.AvgPool1d(*args, **kwargs)
147
+ elif dims == 2:
148
+ return nn.AvgPool2d(*args, **kwargs)
149
+ elif dims == 3:
150
+ return nn.AvgPool3d(*args, **kwargs)
151
+ raise ValueError(f"unsupported dimensions: {dims}")
152
+
153
+
154
+ def default(val, d):
155
+ if val is not None:
156
+ return val
157
+ return d() if isfunction(d) else d
158
+
159
+
160
+ class GEGLU(nn.Module):
161
+ def __init__(self, dim_in, dim_out):
162
+ super().__init__()
163
+ self.proj = nn.Linear(dim_in, dim_out * 2)
164
+
165
+ def forward(self, x):
166
+ x, gate = self.proj(x).chunk(2, dim=-1)
167
+ return x * F.gelu(gate)
168
+
169
+
170
+ class FeedForward(nn.Module):
171
+ def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
172
+ super().__init__()
173
+ inner_dim = int(dim * mult)
174
+ dim_out = default(dim_out, dim)
175
+ project_in = (
176
+ nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
177
+ if not glu
178
+ else GEGLU(dim, inner_dim)
179
+ )
180
+
181
+ self.net = nn.Sequential(
182
+ project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
183
+ )
184
+
185
+ def forward(self, x):
186
+ return self.net(x)
187
+
188
+
189
+ class MemoryEfficientCrossAttention(nn.Module):
190
+ # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
191
+ def __init__(
192
+ self,
193
+ query_dim,
194
+ context_dim=None,
195
+ heads=8,
196
+ dim_head=64,
197
+ dropout=0.0,
198
+ ip_dim=0,
199
+ ip_weight=1,
200
+ ):
201
+ super().__init__()
202
+
203
+ inner_dim = dim_head * heads
204
+ context_dim = default(context_dim, query_dim)
205
+
206
+ self.heads = heads
207
+ self.dim_head = dim_head
208
+
209
+ self.ip_dim = ip_dim
210
+ self.ip_weight = ip_weight
211
+
212
+ if self.ip_dim > 0:
213
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
214
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
215
+
216
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
217
+ self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
218
+ self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
219
+
220
+ self.to_out = nn.Sequential(
221
+ nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
222
+ )
223
+ self.attention_op: Optional[Any] = None
224
+
225
+ def forward(self, x, context=None):
226
+ q = self.to_q(x)
227
+ context = default(context, x)
228
+
229
+ if self.ip_dim > 0:
230
+ # context: [B, 77 + 16(ip), 1024]
231
+ token_len = context.shape[1]
232
+ context_ip = context[:, -self.ip_dim :, :]
233
+ k_ip = self.to_k_ip(context_ip)
234
+ v_ip = self.to_v_ip(context_ip)
235
+ context = context[:, : (token_len - self.ip_dim), :]
236
+
237
+ k = self.to_k(context)
238
+ v = self.to_v(context)
239
+
240
+ b, _, _ = q.shape
241
+ q, k, v = map(
242
+ lambda t: t.unsqueeze(3)
243
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
244
+ .permute(0, 2, 1, 3)
245
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
246
+ .contiguous(),
247
+ (q, k, v),
248
+ )
249
+
250
+ # actually compute the attention, what we cannot get enough of
251
+ out = xformers.ops.memory_efficient_attention(
252
+ q, k, v, attn_bias=None, op=self.attention_op
253
+ )
254
+
255
+ if self.ip_dim > 0:
256
+ k_ip, v_ip = map(
257
+ lambda t: t.unsqueeze(3)
258
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
259
+ .permute(0, 2, 1, 3)
260
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
261
+ .contiguous(),
262
+ (k_ip, v_ip),
263
+ )
264
+ # actually compute the attention, what we cannot get enough of
265
+ out_ip = xformers.ops.memory_efficient_attention(
266
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
267
+ )
268
+ out = out + self.ip_weight * out_ip
269
+
270
+ out = (
271
+ out.unsqueeze(0)
272
+ .reshape(b, self.heads, out.shape[1], self.dim_head)
273
+ .permute(0, 2, 1, 3)
274
+ .reshape(b, out.shape[1], self.heads * self.dim_head)
275
+ )
276
+ return self.to_out(out)
277
+
278
+
279
+ class BasicTransformerBlock3D(nn.Module):
280
+
281
+ def __init__(
282
+ self,
283
+ dim,
284
+ n_heads,
285
+ d_head,
286
+ context_dim,
287
+ dropout=0.0,
288
+ gated_ff=True,
289
+ checkpoint=True,
290
+ ip_dim=0,
291
+ ip_weight=1,
292
+ ):
293
+ super().__init__()
294
+
295
+ self.attn1 = MemoryEfficientCrossAttention(
296
+ query_dim=dim,
297
+ context_dim=None, # self-attention
298
+ heads=n_heads,
299
+ dim_head=d_head,
300
+ dropout=dropout,
301
+ )
302
+ self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
303
+ self.attn2 = MemoryEfficientCrossAttention(
304
+ query_dim=dim,
305
+ context_dim=context_dim,
306
+ heads=n_heads,
307
+ dim_head=d_head,
308
+ dropout=dropout,
309
+ # ip only applies to cross-attention
310
+ ip_dim=ip_dim,
311
+ ip_weight=ip_weight,
312
+ )
313
+ self.norm1 = nn.LayerNorm(dim)
314
+ self.norm2 = nn.LayerNorm(dim)
315
+ self.norm3 = nn.LayerNorm(dim)
316
+ self.checkpoint = checkpoint
317
+
318
+ def forward(self, x, context=None, num_frames=1):
319
+ return checkpoint(
320
+ self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
321
+ )
322
+
323
+ def _forward(self, x, context=None, num_frames=1):
324
+ x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
325
+ x = self.attn1(self.norm1(x), context=None) + x
326
+ x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
327
+ x = self.attn2(self.norm2(x), context=context) + x
328
+ x = self.ff(self.norm3(x)) + x
329
+ return x
330
+
331
+
332
+ class SpatialTransformer3D(nn.Module):
333
+
334
+ def __init__(
335
+ self,
336
+ in_channels,
337
+ n_heads,
338
+ d_head,
339
+ context_dim, # cross attention input dim
340
+ depth=1,
341
+ dropout=0.0,
342
+ ip_dim=0,
343
+ ip_weight=1,
344
+ use_checkpoint=True,
345
+ ):
346
+ super().__init__()
347
+
348
+ if not isinstance(context_dim, list):
349
+ context_dim = [context_dim]
350
+
351
+ self.in_channels = in_channels
352
+
353
+ inner_dim = n_heads * d_head
354
+ self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
355
+ self.proj_in = nn.Linear(in_channels, inner_dim)
356
+
357
+ self.transformer_blocks = nn.ModuleList(
358
+ [
359
+ BasicTransformerBlock3D(
360
+ inner_dim,
361
+ n_heads,
362
+ d_head,
363
+ context_dim=context_dim[d],
364
+ dropout=dropout,
365
+ checkpoint=use_checkpoint,
366
+ ip_dim=ip_dim,
367
+ ip_weight=ip_weight,
368
+ )
369
+ for d in range(depth)
370
+ ]
371
+ )
372
+
373
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
374
+
375
+
376
+ def forward(self, x, context=None, num_frames=1):
377
+ # note: if no context is given, cross-attention defaults to self-attention
378
+ if not isinstance(context, list):
379
+ context = [context]
380
+ b, c, h, w = x.shape
381
+ x_in = x
382
+ x = self.norm(x)
383
+ x = rearrange(x, "b c h w -> b (h w) c").contiguous()
384
+ x = self.proj_in(x)
385
+ for i, block in enumerate(self.transformer_blocks):
386
+ x = block(x, context=context[i], num_frames=num_frames)
387
+ x = self.proj_out(x)
388
+ x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
389
+
390
+ return x + x_in
391
+
392
+
393
+ class PerceiverAttention(nn.Module):
394
+ def __init__(self, *, dim, dim_head=64, heads=8):
395
+ super().__init__()
396
+ self.scale = dim_head ** -0.5
397
+ self.dim_head = dim_head
398
+ self.heads = heads
399
+ inner_dim = dim_head * heads
400
+
401
+ self.norm1 = nn.LayerNorm(dim)
402
+ self.norm2 = nn.LayerNorm(dim)
403
+
404
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
405
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
406
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
407
+
408
+ def forward(self, x, latents):
409
+ """
410
+ Args:
411
+ x (torch.Tensor): image features
412
+ shape (b, n1, D)
413
+ latent (torch.Tensor): latent features
414
+ shape (b, n2, D)
415
+ """
416
+ x = self.norm1(x)
417
+ latents = self.norm2(latents)
418
+
419
+ b, l, _ = latents.shape
420
+
421
+ q = self.to_q(latents)
422
+ kv_input = torch.cat((x, latents), dim=-2)
423
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
424
+
425
+ q, k, v = map(
426
+ lambda t: t.reshape(b, t.shape[1], self.heads, -1)
427
+ .transpose(1, 2)
428
+ .reshape(b, self.heads, t.shape[1], -1)
429
+ .contiguous(),
430
+ (q, k, v),
431
+ )
432
+
433
+ # attention
434
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
435
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
436
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
437
+ out = weight @ v
438
+
439
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
440
+
441
+ return self.to_out(out)
442
+
443
+
444
+ class Resampler(nn.Module):
445
+ def __init__(
446
+ self,
447
+ dim=1024,
448
+ depth=8,
449
+ dim_head=64,
450
+ heads=16,
451
+ num_queries=8,
452
+ embedding_dim=768,
453
+ output_dim=1024,
454
+ ff_mult=4,
455
+ ):
456
+ super().__init__()
457
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim ** 0.5)
458
+ self.proj_in = nn.Linear(embedding_dim, dim)
459
+ self.proj_out = nn.Linear(dim, output_dim)
460
+ self.norm_out = nn.LayerNorm(output_dim)
461
+
462
+ self.layers = nn.ModuleList([])
463
+ for _ in range(depth):
464
+ self.layers.append(
465
+ nn.ModuleList(
466
+ [
467
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
468
+ nn.Sequential(
469
+ nn.LayerNorm(dim),
470
+ nn.Linear(dim, dim * ff_mult, bias=False),
471
+ nn.GELU(),
472
+ nn.Linear(dim * ff_mult, dim, bias=False),
473
+ )
474
+ ]
475
+ )
476
+ )
477
+
478
+ def forward(self, x):
479
+ latents = self.latents.repeat(x.size(0), 1, 1)
480
+ x = self.proj_in(x)
481
+ for attn, ff in self.layers:
482
+ latents = attn(x, latents) + latents
483
+ latents = ff(latents) + latents
484
+
485
+ latents = self.proj_out(latents)
486
+ return self.norm_out(latents)
487
+
488
+
489
+ class CondSequential(nn.Sequential):
490
+ """
491
+ A sequential module that passes timestep embeddings to the children that
492
+ support it as an extra input.
493
+ """
494
+
495
+ def forward(self, x, emb, context=None, num_frames=1):
496
+ for layer in self:
497
+ if isinstance(layer, ResBlock):
498
+ x = layer(x, emb)
499
+ elif isinstance(layer, SpatialTransformer3D):
500
+ x = layer(x, context, num_frames=num_frames)
501
+ else:
502
+ x = layer(x)
503
+ return x
504
+
505
+
506
+ class Upsample(nn.Module):
507
+ """
508
+ An upsampling layer with an optional convolution.
509
+ :param channels: channels in the inputs and outputs.
510
+ :param use_conv: a bool determining if a convolution is applied.
511
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
512
+ upsampling occurs in the inner-two dimensions.
513
+ """
514
+
515
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
516
+ super().__init__()
517
+ self.channels = channels
518
+ self.out_channels = out_channels or channels
519
+ self.use_conv = use_conv
520
+ self.dims = dims
521
+ if use_conv:
522
+ self.conv = conv_nd(
523
+ dims, self.channels, self.out_channels, 3, padding=padding
524
+ )
525
+
526
+ def forward(self, x):
527
+ assert x.shape[1] == self.channels
528
+ if self.dims == 3:
529
+ x = F.interpolate(
530
+ x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
531
+ )
532
+ else:
533
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
534
+ if self.use_conv:
535
+ x = self.conv(x)
536
+ return x
537
+
538
+
539
+ class Downsample(nn.Module):
540
+ """
541
+ A downsampling layer with an optional convolution.
542
+ :param channels: channels in the inputs and outputs.
543
+ :param use_conv: a bool determining if a convolution is applied.
544
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
545
+ downsampling occurs in the inner-two dimensions.
546
+ """
547
+
548
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
549
+ super().__init__()
550
+ self.channels = channels
551
+ self.out_channels = out_channels or channels
552
+ self.use_conv = use_conv
553
+ self.dims = dims
554
+ stride = 2 if dims != 3 else (1, 2, 2)
555
+ if use_conv:
556
+ self.op = conv_nd(
557
+ dims,
558
+ self.channels,
559
+ self.out_channels,
560
+ 3,
561
+ stride=stride,
562
+ padding=padding,
563
+ )
564
+ else:
565
+ assert self.channels == self.out_channels
566
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
567
+
568
+ def forward(self, x):
569
+ assert x.shape[1] == self.channels
570
+ return self.op(x)
571
+
572
+
573
+ class ResBlock(nn.Module):
574
+ """
575
+ A residual block that can optionally change the number of channels.
576
+ :param channels: the number of input channels.
577
+ :param emb_channels: the number of timestep embedding channels.
578
+ :param dropout: the rate of dropout.
579
+ :param out_channels: if specified, the number of out channels.
580
+ :param use_conv: if True and out_channels is specified, use a spatial
581
+ convolution instead of a smaller 1x1 convolution to change the
582
+ channels in the skip connection.
583
+ :param dims: determines if the signal is 1D, 2D, or 3D.
584
+ :param use_checkpoint: if True, use gradient checkpointing on this module.
585
+ :param up: if True, use this block for upsampling.
586
+ :param down: if True, use this block for downsampling.
587
+ """
588
+
589
+ def __init__(
590
+ self,
591
+ channels,
592
+ emb_channels,
593
+ dropout,
594
+ out_channels=None,
595
+ use_conv=False,
596
+ use_scale_shift_norm=False,
597
+ dims=2,
598
+ use_checkpoint=False,
599
+ up=False,
600
+ down=False,
601
+ ):
602
+ super().__init__()
603
+ self.channels = channels
604
+ self.emb_channels = emb_channels
605
+ self.dropout = dropout
606
+ self.out_channels = out_channels or channels
607
+ self.use_conv = use_conv
608
+ self.use_checkpoint = use_checkpoint
609
+ self.use_scale_shift_norm = use_scale_shift_norm
610
+
611
+ self.in_layers = nn.Sequential(
612
+ nn.GroupNorm(32, channels),
613
+ nn.SiLU(),
614
+ conv_nd(dims, channels, self.out_channels, 3, padding=1),
615
+ )
616
+
617
+ self.updown = up or down
618
+
619
+ if up:
620
+ self.h_upd = Upsample(channels, False, dims)
621
+ self.x_upd = Upsample(channels, False, dims)
622
+ elif down:
623
+ self.h_upd = Downsample(channels, False, dims)
624
+ self.x_upd = Downsample(channels, False, dims)
625
+ else:
626
+ self.h_upd = self.x_upd = nn.Identity()
627
+
628
+ self.emb_layers = nn.Sequential(
629
+ nn.SiLU(),
630
+ nn.Linear(
631
+ emb_channels,
632
+ 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
633
+ ),
634
+ )
635
+ self.out_layers = nn.Sequential(
636
+ nn.GroupNorm(32, self.out_channels),
637
+ nn.SiLU(),
638
+ nn.Dropout(p=dropout),
639
+ zero_module(
640
+ conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
641
+ ),
642
+ )
643
+
644
+ if self.out_channels == channels:
645
+ self.skip_connection = nn.Identity()
646
+ elif use_conv:
647
+ self.skip_connection = conv_nd(
648
+ dims, channels, self.out_channels, 3, padding=1
649
+ )
650
+ else:
651
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
652
+
653
+ def forward(self, x, emb):
654
+ """
655
+ Apply the block to a Tensor, conditioned on a timestep embedding.
656
+ :param x: an [N x C x ...] Tensor of features.
657
+ :param emb: an [N x emb_channels] Tensor of timestep embeddings.
658
+ :return: an [N x C x ...] Tensor of outputs.
659
+ """
660
+ return checkpoint(
661
+ self._forward, (x, emb), self.parameters(), self.use_checkpoint
662
+ )
663
+
664
+ def _forward(self, x, emb):
665
+ if self.updown:
666
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
667
+ h = in_rest(x)
668
+ h = self.h_upd(h)
669
+ x = self.x_upd(x)
670
+ h = in_conv(h)
671
+ else:
672
+ h = self.in_layers(x)
673
+ emb_out = self.emb_layers(emb).type(h.dtype)
674
+ while len(emb_out.shape) < len(h.shape):
675
+ emb_out = emb_out[..., None]
676
+ if self.use_scale_shift_norm:
677
+ out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
678
+ scale, shift = torch.chunk(emb_out, 2, dim=1)
679
+ h = out_norm(h) * (1 + scale) + shift
680
+ h = out_rest(h)
681
+ else:
682
+ h = h + emb_out
683
+ h = self.out_layers(h)
684
+ return self.skip_connection(x) + h
685
+
686
+
687
+ class MultiViewUNetModel(ModelMixin, ConfigMixin):
688
+ """
689
+ The full multi-view UNet model with attention, timestep embedding and camera embedding.
690
+ :param in_channels: channels in the input Tensor.
691
+ :param model_channels: base channel count for the model.
692
+ :param out_channels: channels in the output Tensor.
693
+ :param num_res_blocks: number of residual blocks per downsample.
694
+ :param attention_resolutions: a collection of downsample rates at which
695
+ attention will take place. May be a set, list, or tuple.
696
+ For example, if this contains 4, then at 4x downsampling, attention
697
+ will be used.
698
+ :param dropout: the dropout probability.
699
+ :param channel_mult: channel multiplier for each level of the UNet.
700
+ :param conv_resample: if True, use learned convolutions for upsampling and
701
+ downsampling.
702
+ :param dims: determines if the signal is 1D, 2D, or 3D.
703
+ :param num_classes: if specified (as an int), then this model will be
704
+ class-conditional with `num_classes` classes.
705
+ :param use_checkpoint: use gradient checkpointing to reduce memory usage.
706
+ :param num_heads: the number of attention heads in each attention layer.
707
+ :param num_heads_channels: if specified, ignore num_heads and instead use
708
+ a fixed channel width per attention head.
709
+ :param num_heads_upsample: works with num_heads to set a different number
710
+ of heads for upsampling. Deprecated.
711
+ :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
712
+ :param resblock_updown: use residual blocks for up/downsampling.
713
+ :param use_new_attention_order: use a different attention pattern for potentially
714
+ increased efficiency.
715
+ :param camera_dim: dimensionality of camera input.
716
+ """
717
+
718
+ def __init__(
719
+ self,
720
+ image_size,
721
+ in_channels,
722
+ model_channels,
723
+ out_channels,
724
+ num_res_blocks,
725
+ attention_resolutions,
726
+ dropout=0,
727
+ channel_mult=(1, 2, 4, 8),
728
+ conv_resample=True,
729
+ dims=2,
730
+ num_classes=None,
731
+ use_checkpoint=False,
732
+ num_heads=-1,
733
+ num_head_channels=-1,
734
+ num_heads_upsample=-1,
735
+ use_scale_shift_norm=False,
736
+ resblock_updown=False,
737
+ transformer_depth=1,
738
+ context_dim=None,
739
+ n_embed=None,
740
+ num_attention_blocks=None,
741
+ adm_in_channels=None,
742
+ camera_dim=None,
743
+ ip_dim=0, # imagedream uses ip_dim > 0
744
+ ip_weight=1.0,
745
+ **kwargs,
746
+ ):
747
+ super().__init__()
748
+ assert context_dim is not None
749
+
750
+ if num_heads_upsample == -1:
751
+ num_heads_upsample = num_heads
752
+
753
+ if num_heads == -1:
754
+ assert (
755
+ num_head_channels != -1
756
+ ), "Either num_heads or num_head_channels has to be set"
757
+
758
+ if num_head_channels == -1:
759
+ assert (
760
+ num_heads != -1
761
+ ), "Either num_heads or num_head_channels has to be set"
762
+
763
+ self.image_size = image_size
764
+ self.in_channels = in_channels
765
+ self.model_channels = model_channels
766
+ self.out_channels = out_channels
767
+ if isinstance(num_res_blocks, int):
768
+ self.num_res_blocks = len(channel_mult) * [num_res_blocks]
769
+ else:
770
+ if len(num_res_blocks) != len(channel_mult):
771
+ raise ValueError(
772
+ "provide num_res_blocks either as an int (globally constant) or "
773
+ "as a list/tuple (per-level) with the same length as channel_mult"
774
+ )
775
+ self.num_res_blocks = num_res_blocks
776
+
777
+ if num_attention_blocks is not None:
778
+ assert len(num_attention_blocks) == len(self.num_res_blocks)
779
+ assert all(
780
+ map(
781
+ lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
782
+ range(len(num_attention_blocks)),
783
+ )
784
+ )
785
+ print(
786
+ f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
787
+ f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
788
+ f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
789
+ f"attention will still not be set."
790
+ )
791
+
792
+ self.attention_resolutions = attention_resolutions
793
+ self.dropout = dropout
794
+ self.channel_mult = channel_mult
795
+ self.conv_resample = conv_resample
796
+ self.num_classes = num_classes
797
+ self.use_checkpoint = use_checkpoint
798
+ self.num_heads = num_heads
799
+ self.num_head_channels = num_head_channels
800
+ self.num_heads_upsample = num_heads_upsample
801
+ self.predict_codebook_ids = n_embed is not None
802
+
803
+ self.ip_dim = ip_dim
804
+ self.ip_weight = ip_weight
805
+
806
+ if self.ip_dim > 0:
807
+ self.image_embed = Resampler(
808
+ dim=context_dim,
809
+ depth=4,
810
+ dim_head=64,
811
+ heads=12,
812
+ num_queries=ip_dim, # num token
813
+ embedding_dim=1280,
814
+ output_dim=context_dim,
815
+ ff_mult=4,
816
+ )
817
+
818
+ time_embed_dim = model_channels * 4
819
+ self.time_embed = nn.Sequential(
820
+ nn.Linear(model_channels, time_embed_dim),
821
+ nn.SiLU(),
822
+ nn.Linear(time_embed_dim, time_embed_dim),
823
+ )
824
+
825
+ if camera_dim is not None:
826
+ time_embed_dim = model_channels * 4
827
+ self.camera_embed = nn.Sequential(
828
+ nn.Linear(camera_dim, time_embed_dim),
829
+ nn.SiLU(),
830
+ nn.Linear(time_embed_dim, time_embed_dim),
831
+ )
832
+
833
+ if self.num_classes is not None:
834
+ if isinstance(self.num_classes, int):
835
+ self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
836
+ elif self.num_classes == "continuous":
837
+ # print("setting up linear c_adm embedding layer")
838
+ self.label_emb = nn.Linear(1, time_embed_dim)
839
+ elif self.num_classes == "sequential":
840
+ assert adm_in_channels is not None
841
+ self.label_emb = nn.Sequential(
842
+ nn.Sequential(
843
+ nn.Linear(adm_in_channels, time_embed_dim),
844
+ nn.SiLU(),
845
+ nn.Linear(time_embed_dim, time_embed_dim),
846
+ )
847
+ )
848
+ else:
849
+ raise ValueError()
850
+
851
+ self.input_blocks = nn.ModuleList(
852
+ [
853
+ CondSequential(
854
+ conv_nd(dims, in_channels, model_channels, 3, padding=1)
855
+ )
856
+ ]
857
+ )
858
+ self._feature_size = model_channels
859
+ input_block_chans = [model_channels]
860
+ ch = model_channels
861
+ ds = 1
862
+ for level, mult in enumerate(channel_mult):
863
+ for nr in range(self.num_res_blocks[level]):
864
+ layers: List[Any] = [
865
+ ResBlock(
866
+ ch,
867
+ time_embed_dim,
868
+ dropout,
869
+ out_channels=mult * model_channels,
870
+ dims=dims,
871
+ use_checkpoint=use_checkpoint,
872
+ use_scale_shift_norm=use_scale_shift_norm,
873
+ )
874
+ ]
875
+ ch = mult * model_channels
876
+ if ds in attention_resolutions:
877
+ if num_head_channels == -1:
878
+ dim_head = ch // num_heads
879
+ else:
880
+ num_heads = ch // num_head_channels
881
+ dim_head = num_head_channels
882
+
883
+ if num_attention_blocks is None or nr < num_attention_blocks[level]:
884
+ layers.append(
885
+ SpatialTransformer3D(
886
+ ch,
887
+ num_heads,
888
+ dim_head,
889
+ context_dim=context_dim,
890
+ depth=transformer_depth,
891
+ use_checkpoint=use_checkpoint,
892
+ ip_dim=self.ip_dim,
893
+ ip_weight=self.ip_weight,
894
+ )
895
+ )
896
+ self.input_blocks.append(CondSequential(*layers))
897
+ self._feature_size += ch
898
+ input_block_chans.append(ch)
899
+ if level != len(channel_mult) - 1:
900
+ out_ch = ch
901
+ self.input_blocks.append(
902
+ CondSequential(
903
+ ResBlock(
904
+ ch,
905
+ time_embed_dim,
906
+ dropout,
907
+ out_channels=out_ch,
908
+ dims=dims,
909
+ use_checkpoint=use_checkpoint,
910
+ use_scale_shift_norm=use_scale_shift_norm,
911
+ down=True,
912
+ )
913
+ if resblock_updown
914
+ else Downsample(
915
+ ch, conv_resample, dims=dims, out_channels=out_ch
916
+ )
917
+ )
918
+ )
919
+ ch = out_ch
920
+ input_block_chans.append(ch)
921
+ ds *= 2
922
+ self._feature_size += ch
923
+
924
+ if num_head_channels == -1:
925
+ dim_head = ch // num_heads
926
+ else:
927
+ num_heads = ch // num_head_channels
928
+ dim_head = num_head_channels
929
+
930
+ self.middle_block = CondSequential(
931
+ ResBlock(
932
+ ch,
933
+ time_embed_dim,
934
+ dropout,
935
+ dims=dims,
936
+ use_checkpoint=use_checkpoint,
937
+ use_scale_shift_norm=use_scale_shift_norm,
938
+ ),
939
+ SpatialTransformer3D(
940
+ ch,
941
+ num_heads,
942
+ dim_head,
943
+ context_dim=context_dim,
944
+ depth=transformer_depth,
945
+ use_checkpoint=use_checkpoint,
946
+ ip_dim=self.ip_dim,
947
+ ip_weight=self.ip_weight,
948
+ ),
949
+ ResBlock(
950
+ ch,
951
+ time_embed_dim,
952
+ dropout,
953
+ dims=dims,
954
+ use_checkpoint=use_checkpoint,
955
+ use_scale_shift_norm=use_scale_shift_norm,
956
+ ),
957
+ )
958
+ self._feature_size += ch
959
+
960
+ self.output_blocks = nn.ModuleList([])
961
+ for level, mult in list(enumerate(channel_mult))[::-1]:
962
+ for i in range(self.num_res_blocks[level] + 1):
963
+ ich = input_block_chans.pop()
964
+ layers = [
965
+ ResBlock(
966
+ ch + ich,
967
+ time_embed_dim,
968
+ dropout,
969
+ out_channels=model_channels * mult,
970
+ dims=dims,
971
+ use_checkpoint=use_checkpoint,
972
+ use_scale_shift_norm=use_scale_shift_norm,
973
+ )
974
+ ]
975
+ ch = model_channels * mult
976
+ if ds in attention_resolutions:
977
+ if num_head_channels == -1:
978
+ dim_head = ch // num_heads
979
+ else:
980
+ num_heads = ch // num_head_channels
981
+ dim_head = num_head_channels
982
+
983
+ if num_attention_blocks is None or i < num_attention_blocks[level]:
984
+ layers.append(
985
+ SpatialTransformer3D(
986
+ ch,
987
+ num_heads,
988
+ dim_head,
989
+ context_dim=context_dim,
990
+ depth=transformer_depth,
991
+ use_checkpoint=use_checkpoint,
992
+ ip_dim=self.ip_dim,
993
+ ip_weight=self.ip_weight,
994
+ )
995
+ )
996
+ if level and i == self.num_res_blocks[level]:
997
+ out_ch = ch
998
+ layers.append(
999
+ ResBlock(
1000
+ ch,
1001
+ time_embed_dim,
1002
+ dropout,
1003
+ out_channels=out_ch,
1004
+ dims=dims,
1005
+ use_checkpoint=use_checkpoint,
1006
+ use_scale_shift_norm=use_scale_shift_norm,
1007
+ up=True,
1008
+ )
1009
+ if resblock_updown
1010
+ else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
1011
+ )
1012
+ ds //= 2
1013
+ self.output_blocks.append(CondSequential(*layers))
1014
+ self._feature_size += ch
1015
+
1016
+ self.out = nn.Sequential(
1017
+ nn.GroupNorm(32, ch),
1018
+ nn.SiLU(),
1019
+ zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
1020
+ )
1021
+ if self.predict_codebook_ids:
1022
+ self.id_predictor = nn.Sequential(
1023
+ nn.GroupNorm(32, ch),
1024
+ conv_nd(dims, model_channels, n_embed, 1),
1025
+ # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
1026
+ )
1027
+
1028
+ def forward(
1029
+ self,
1030
+ x,
1031
+ timesteps=None,
1032
+ context=None,
1033
+ y=None,
1034
+ camera=None,
1035
+ num_frames=1,
1036
+ ip=None,
1037
+ ip_img=None,
1038
+ **kwargs,
1039
+ ):
1040
+ """
1041
+ Apply the model to an input batch.
1042
+ :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
1043
+ :param timesteps: a 1-D batch of timesteps.
1044
+ :param context: conditioning plugged in via crossattn
1045
+ :param y: an [N] Tensor of labels, if class-conditional.
1046
+ :param num_frames: a integer indicating number of frames for tensor reshaping.
1047
+ :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
1048
+ """
1049
+ assert (
1050
+ x.shape[0] % num_frames == 0
1051
+ ), "input batch size must be dividable by num_frames!"
1052
+ assert (y is not None) == (
1053
+ self.num_classes is not None
1054
+ ), "must specify y if and only if the model is class-conditional"
1055
+
1056
+ hs = []
1057
+
1058
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
1059
+
1060
+ emb = self.time_embed(t_emb)
1061
+
1062
+ if self.num_classes is not None:
1063
+ assert y is not None
1064
+ assert y.shape[0] == x.shape[0]
1065
+ emb = emb + self.label_emb(y)
1066
+
1067
+ # Add camera embeddings
1068
+ if camera is not None:
1069
+ emb = emb + self.camera_embed(camera)
1070
+
1071
+ # imagedream variant
1072
+ if self.ip_dim > 0:
1073
+ x[(num_frames - 1) :: num_frames, :, :, :] = ip_img # place at [4, 9]
1074
+ ip_emb = self.image_embed(ip)
1075
+ context = torch.cat((context, ip_emb), 1)
1076
+
1077
+ h = x
1078
+ for module in self.input_blocks:
1079
+ h = module(h, emb, context, num_frames=num_frames)
1080
+ hs.append(h)
1081
+ h = self.middle_block(h, emb, context, num_frames=num_frames)
1082
+ for module in self.output_blocks:
1083
+ h = torch.cat([h, hs.pop()], dim=1)
1084
+ h = module(h, emb, context, num_frames=num_frames)
1085
+ h = h.type(x.dtype)
1086
+ if self.predict_codebook_ids:
1087
+ return self.id_predictor(h)
1088
+ else:
1089
+ return self.out(h)
vae/config.json ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AutoencoderKL",
3
+ "_diffusers_version": "0.25.0",
4
+ "act_fn": "silu",
5
+ "block_out_channels": [
6
+ 128,
7
+ 256,
8
+ 512,
9
+ 512
10
+ ],
11
+ "down_block_types": [
12
+ "DownEncoderBlock2D",
13
+ "DownEncoderBlock2D",
14
+ "DownEncoderBlock2D",
15
+ "DownEncoderBlock2D"
16
+ ],
17
+ "force_upcast": true,
18
+ "in_channels": 3,
19
+ "latent_channels": 4,
20
+ "layers_per_block": 2,
21
+ "norm_num_groups": 32,
22
+ "out_channels": 3,
23
+ "sample_size": 256,
24
+ "scaling_factor": 0.18215,
25
+ "up_block_types": [
26
+ "UpDecoderBlock2D",
27
+ "UpDecoderBlock2D",
28
+ "UpDecoderBlock2D",
29
+ "UpDecoderBlock2D"
30
+ ]
31
+ }
vae/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e4c08995484ee61270175e9e7a072b66a6e4eeb5f0c266667fe1f45b90daf9a
3
+ size 167335342