Koke_Cacao commited on
Commit
5b08d3b
·
1 Parent(s): 57e6edd

:sparkles: finish inference

Browse files
.gitignore CHANGED
@@ -2,3 +2,4 @@
2
  *.yaml
3
  converted
4
  __pycache__
 
 
2
  *.yaml
3
  converted
4
  __pycache__
5
+ *.png
scripts/README.md CHANGED
@@ -14,6 +14,5 @@ wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd
14
 
15
  Hugging Face diffusers weights are converted by script:
16
  ```bash
17
- mkdir converted
18
  python ./scripts/convert_mvdream_to_diffusers.py --checkpoint_path ./sd-v1.5-4view.pt --dump_path ./converted --original_config_file ./sd-v1.yaml
19
  ```
 
14
 
15
  Hugging Face diffusers weights are converted by script:
16
  ```bash
 
17
  python ./scripts/convert_mvdream_to_diffusers.py --checkpoint_path ./sd-v1.5-4view.pt --dump_path ./converted --original_config_file ./sd-v1.yaml
18
  ```
scripts/convert_mvdream_to_diffusers.py CHANGED
@@ -20,7 +20,7 @@ from diffusers.utils import logging
20
  from accelerate import init_empty_weights
21
  from accelerate.utils import set_module_tensor_to_device
22
  from rich import print, print_json
23
- from models import MultiViewUNetModel
24
  from pipeline_mvdream import MVDreamStableDiffusionPipeline
25
  from transformers import CLIPTokenizer, CLIPTextModel
26
 
@@ -659,7 +659,6 @@ def conv_attn_to_linear(checkpoint):
659
  if checkpoint[key].ndim > 2:
660
  checkpoint[key] = checkpoint[key][:, :, 0]
661
 
662
-
663
  def convert_from_original_mvdream_ckpt(
664
  checkpoint_path,
665
  original_config_file,
@@ -667,13 +666,13 @@ def convert_from_original_mvdream_ckpt(
667
  device
668
  ):
669
  checkpoint = torch.load(checkpoint_path, map_location=device)
670
- print(f"Checkpoint: {checkpoint.keys()}")
671
  torch.cuda.empty_cache()
672
 
673
  from omegaconf import OmegaConf
674
 
675
  original_config = OmegaConf.load(original_config_file)
676
- print(f"Original Config: {original_config}")
677
  prediction_type = "epsilon"
678
  image_size = 256
679
  num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
@@ -700,10 +699,11 @@ def convert_from_original_mvdream_ckpt(
700
  # converted_unet_checkpoint = convert_ldm_unet_checkpoint(
701
  # checkpoint, unet_config, path=None, extract_ema=extract_ema
702
  # )
703
- print(f"Unet Config: {original_config.model.params.unet_config.params}")
704
- unet: MultiViewUNetModel = MultiViewUNetModel(**original_config.model.params.unet_config.params)
 
705
  unet.load_state_dict({
706
- key.replace("model.diffusion_model.", ""): value for key, value in checkpoint.items() if key.replace("model.diffusion_model.", "") in unet.state_dict()
707
  })
708
  for param_name, param in unet.state_dict().items():
709
  set_module_tensor_to_device(unet, param_name, "cuda:0", value=param)
@@ -738,9 +738,6 @@ def convert_from_original_mvdream_ckpt(
738
  tokenizer=tokenizer,
739
  text_encoder=text_encoder,
740
  scheduler=scheduler,
741
- safety_checker=None,
742
- feature_extractor=None,
743
- requires_safety_checker=False
744
  )
745
 
746
  return pipe
@@ -787,8 +784,15 @@ if __name__ == "__main__":
787
  if args.half:
788
  pipe.to(torch_dtype=torch.float16)
789
 
790
- out = pipe()
 
 
 
 
 
 
 
 
 
791
 
792
- assert False
793
-
794
  pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
 
20
  from accelerate import init_empty_weights
21
  from accelerate.utils import set_module_tensor_to_device
22
  from rich import print, print_json
23
+ from models import MultiViewUNetModel, MultiViewUNetWrapperModel
24
  from pipeline_mvdream import MVDreamStableDiffusionPipeline
25
  from transformers import CLIPTokenizer, CLIPTextModel
26
 
 
659
  if checkpoint[key].ndim > 2:
660
  checkpoint[key] = checkpoint[key][:, :, 0]
661
 
 
662
  def convert_from_original_mvdream_ckpt(
663
  checkpoint_path,
664
  original_config_file,
 
666
  device
667
  ):
668
  checkpoint = torch.load(checkpoint_path, map_location=device)
669
+ # print(f"Checkpoint: {checkpoint.keys()}")
670
  torch.cuda.empty_cache()
671
 
672
  from omegaconf import OmegaConf
673
 
674
  original_config = OmegaConf.load(original_config_file)
675
+ # print(f"Original Config: {original_config}")
676
  prediction_type = "epsilon"
677
  image_size = 256
678
  num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
 
699
  # converted_unet_checkpoint = convert_ldm_unet_checkpoint(
700
  # checkpoint, unet_config, path=None, extract_ema=extract_ema
701
  # )
702
+ # print(f"Unet Config: {original_config.model.params.unet_config.params}")
703
+ unet: MultiViewUNetWrapperModel = MultiViewUNetWrapperModel(**original_config.model.params.unet_config.params)
704
+ # print(f"Unet State Dict: {unet.state_dict().keys()}")
705
  unet.load_state_dict({
706
+ key.replace("model.diffusion_model.", "unet."): value for key, value in checkpoint.items() if key.replace("model.diffusion_model.", "unet.") in unet.state_dict()
707
  })
708
  for param_name, param in unet.state_dict().items():
709
  set_module_tensor_to_device(unet, param_name, "cuda:0", value=param)
 
738
  tokenizer=tokenizer,
739
  text_encoder=text_encoder,
740
  scheduler=scheduler,
 
 
 
741
  )
742
 
743
  return pipe
 
784
  if args.half:
785
  pipe.to(torch_dtype=torch.float16)
786
 
787
+ images = pipe(
788
+ prompt="Head of Hatsune Miku",
789
+ negative_prompt="painting, bad quality, flat",
790
+ output_type="pil",
791
+ return_dict=False,
792
+ guidance_scale=7.5,
793
+ num_inference_steps=50,
794
+ )
795
+ for i, image in enumerate(images):
796
+ image.save(f"image_{i}.png")
797
 
 
 
798
  pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
scripts/models.py CHANGED
@@ -1,5 +1,6 @@
1
  from abc import abstractmethod
2
  import math
 
3
 
4
  import numpy as np
5
  import torch as th
@@ -18,6 +19,16 @@ from util import (
18
  from attention import SpatialTransformer, SpatialTransformer3D, exists
19
 
20
 
 
 
 
 
 
 
 
 
 
 
21
  # dummy replace
22
  def convert_module_to_f16(x):
23
  pass
 
1
  from abc import abstractmethod
2
  import math
3
+ from typing import Any, Mapping
4
 
5
  import numpy as np
6
  import torch as th
 
19
  from attention import SpatialTransformer, SpatialTransformer3D, exists
20
 
21
 
22
+ from diffusers.configuration_utils import ConfigMixin
23
+ from diffusers.models.modeling_utils import ModelMixin
24
+ class MultiViewUNetWrapperModel(ModelMixin, ConfigMixin):
25
+ def __init__(self, *args, **kwargs):
26
+ super().__init__()
27
+ self.unet: MultiViewUNetModel = MultiViewUNetModel(*args, **kwargs)
28
+
29
+ def forward(self, *args, **kwargs):
30
+ return self.unet(*args, **kwargs)
31
+
32
  # dummy replace
33
  def convert_module_to_f16(x):
34
  pass
scripts/pipeline_mvdream.py CHANGED
@@ -2,11 +2,10 @@ import inspect
2
  from typing import Any, Callable, Dict, List, Optional, Union
3
 
4
  import torch
5
- from packaging import version
6
- from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
7
 
8
- from diffusers import AutoencoderKL, UNet2DConditionModel, DiffusionPipeline
9
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
10
  from diffusers.utils import (
11
  deprecate,
12
  is_accelerate_available,
@@ -16,26 +15,21 @@ from diffusers.utils import (
16
  )
17
 
18
  try:
19
- from diffusers import randn_tensor # old import
20
  except ImportError:
21
- from diffusers.utils.torch_utils import randn_tensor # new import
22
 
23
  from diffusers.configuration_utils import FrozenDict
24
- import PIL
25
  import numpy as np
26
- import kornia
27
- from diffusers.configuration_utils import ConfigMixin
28
- from diffusers.models.modeling_utils import ModelMixin
29
-
30
- from models import MultiViewUNetModel
31
  from diffusers.schedulers import DDIMScheduler
 
32
 
33
  EXAMPLE_DOC_STRING = ""
34
 
35
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
-
37
 
38
  import numpy as np
 
39
  def create_camera_to_world_matrix(elevation, azimuth):
40
  elevation = np.radians(elevation)
41
  azimuth = np.radians(azimuth)
@@ -43,12 +37,12 @@ def create_camera_to_world_matrix(elevation, azimuth):
43
  x = np.cos(elevation) * np.sin(azimuth)
44
  y = np.sin(elevation)
45
  z = np.cos(elevation) * np.cos(azimuth)
46
-
47
  # Calculate camera position, target, and up vectors
48
  camera_pos = np.array([x, y, z])
49
  target = np.array([0, 0, 0])
50
  up = np.array([0, 1, 0])
51
-
52
  # Construct view matrix
53
  forward = target - camera_pos
54
  forward /= np.linalg.norm(forward)
@@ -61,90 +55,96 @@ def create_camera_to_world_matrix(elevation, azimuth):
61
  cam2world[:3, 3] = camera_pos
62
  return cam2world
63
 
 
64
  def convert_opengl_to_blender(camera_matrix):
65
  if isinstance(camera_matrix, np.ndarray):
66
  # Construct transformation matrix to convert from OpenGL space to Blender space
67
- flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
 
68
  camera_matrix_blender = np.dot(flip_yz, camera_matrix)
69
  else:
70
  # Construct transformation matrix to convert from OpenGL space to Blender space
71
- flip_yz = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
 
72
  if camera_matrix.ndim == 3:
73
  flip_yz = flip_yz.unsqueeze(0)
74
- camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix)
 
75
  return camera_matrix_blender
76
 
77
- def get_camera(num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True):
 
 
 
 
 
78
  angle_gap = azimuth_span / num_frames
79
  cameras = []
80
- for azimuth in np.arange(azimuth_start, azimuth_span+azimuth_start, angle_gap):
 
81
  camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
82
  if blender_coord:
83
  camera_matrix = convert_opengl_to_blender(camera_matrix)
84
  cameras.append(camera_matrix.flatten())
85
  return torch.tensor(np.stack(cameras, 0)).float()
86
 
 
87
  class MVDreamStableDiffusionPipeline(DiffusionPipeline):
 
88
  def __init__(
89
  self,
90
  vae: AutoencoderKL,
91
- unet: MultiViewUNetModel,
92
  tokenizer: CLIPTokenizer,
93
  text_encoder: CLIPTextModel,
94
  scheduler: DDIMScheduler,
95
- safety_checker: Optional[StableDiffusionSafetyChecker],
96
- feature_extractor: Optional[CLIPFeatureExtractor],
97
- requires_safety_checker: bool = True,
98
  ):
99
  super().__init__()
100
 
101
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
102
- deprecation_message = (f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
103
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
104
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
105
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
106
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
107
- " file")
108
- deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
 
 
 
 
 
109
  new_config = dict(scheduler.config)
110
  new_config["steps_offset"] = 1
111
  scheduler._internal_dict = FrozenDict(new_config)
112
 
113
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
114
- deprecation_message = (f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
115
- " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
116
- " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
117
- " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
118
- " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file")
119
- deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
 
 
 
 
 
 
120
  new_config = dict(scheduler.config)
121
  new_config["clip_sample"] = False
122
  scheduler._internal_dict = FrozenDict(new_config)
123
 
124
- if safety_checker is None and requires_safety_checker:
125
- logger.warning(f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
126
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
127
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
128
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
129
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
130
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 .")
131
-
132
- if safety_checker is not None and feature_extractor is None:
133
- raise ValueError("Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
134
- " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.")
135
-
136
  self.register_modules(
137
  vae=vae,
138
  unet=unet,
139
  scheduler=scheduler,
140
  tokenizer=tokenizer,
141
  text_encoder=text_encoder,
142
- safety_checker=safety_checker,
143
- feature_extractor=feature_extractor,
144
  )
145
- self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels) - 1)
146
- self.register_to_config(requires_safety_checker=requires_safety_checker)
147
- # self.model_mode = None
148
 
149
  def enable_vae_slicing(self):
150
  r"""
@@ -189,20 +189,20 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
189
  if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
190
  from accelerate import cpu_offload
191
  else:
192
- raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
 
 
193
 
194
  device = torch.device(f"cuda:{gpu_id}")
195
 
196
  if self.device.type != "cpu":
197
  self.to("cpu", silence_dtype_warnings=True)
198
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
 
199
 
200
  for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
201
  cpu_offload(cpu_offloaded_model, device)
202
 
203
- if self.safety_checker is not None:
204
- cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
205
-
206
  def enable_model_cpu_offload(self, gpu_id=0):
207
  r"""
208
  Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
@@ -210,23 +210,26 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
210
  method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
211
  `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
212
  """
213
- if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
 
214
  from accelerate import cpu_offload_with_hook
215
  else:
216
- raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
 
 
217
 
218
  device = torch.device(f"cuda:{gpu_id}")
219
 
220
  if self.device.type != "cpu":
221
  self.to("cpu", silence_dtype_warnings=True)
222
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
 
223
 
224
  hook = None
225
  for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
226
- _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
227
-
228
- if self.safety_checker is not None:
229
- _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
230
 
231
  # We'll offload the last model manually.
232
  self.final_offload_hook = hook
@@ -241,7 +244,9 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
241
  if not hasattr(self.unet, "_hf_hook"):
242
  return self.device
243
  for module in self.unet.modules():
244
- if (hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None):
 
 
245
  return torch.device(module._hf_hook.execution_device)
246
  return self.device
247
 
@@ -295,14 +300,21 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
295
  return_tensors="pt",
296
  )
297
  text_input_ids = text_inputs.input_ids
298
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
299
-
300
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
301
- removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1:-1])
302
- logger.warning("The following part of your input was truncated because CLIP can only handle sequences up to"
303
- f" {self.tokenizer.model_max_length} tokens: {removed_text}")
304
-
305
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
 
 
 
 
 
 
 
306
  attention_mask = text_inputs.attention_mask.to(device)
307
  else:
308
  attention_mask = None
@@ -313,12 +325,14 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
313
  )
314
  prompt_embeds = prompt_embeds[0]
315
 
316
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
 
317
 
318
  bs_embed, seq_len, _ = prompt_embeds.shape
319
  # duplicate text embeddings for each generation per prompt, using mps friendly method
320
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
321
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
 
322
 
323
  # get unconditional embeddings for classifier free guidance
324
  if do_classifier_free_guidance and negative_prompt_embeds is None:
@@ -326,14 +340,16 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
326
  if negative_prompt is None:
327
  uncond_tokens = [""] * batch_size
328
  elif type(prompt) is not type(negative_prompt):
329
- raise TypeError(f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
330
- f" {type(prompt)}.")
 
331
  elif isinstance(negative_prompt, str):
332
  uncond_tokens = [negative_prompt]
333
  elif batch_size != len(negative_prompt):
334
- raise ValueError(f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
335
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
336
- " the batch size of `prompt`.")
 
337
  else:
338
  uncond_tokens = negative_prompt
339
 
@@ -346,7 +362,8 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
346
  return_tensors="pt",
347
  )
348
 
349
- if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
 
350
  attention_mask = uncond_input.attention_mask.to(device)
351
  else:
352
  attention_mask = None
@@ -361,10 +378,13 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
361
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
362
  seq_len = negative_prompt_embeds.shape[1]
363
 
364
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
 
365
 
366
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
367
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
 
 
368
 
369
  # For classifier free guidance, we need to do two forward passes.
370
  # Here we concatenate the unconditional and text embeddings into a single batch
@@ -373,14 +393,6 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
373
 
374
  return prompt_embeds
375
 
376
- def run_safety_checker(self, image, device, dtype):
377
- if self.safety_checker is not None:
378
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
379
- image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
380
- else:
381
- has_nsfw_concept = None
382
- return image, has_nsfw_concept
383
-
384
  def decode_latents(self, latents):
385
  latents = 1 / self.vae.config.scaling_factor * latents
386
  image = self.vae.decode(latents).sample
@@ -395,25 +407,42 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
395
  # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
396
  # and should be between [0, 1]
397
 
398
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
399
  extra_step_kwargs = {}
400
  if accepts_eta:
401
  extra_step_kwargs["eta"] = eta
402
 
403
  # check if the scheduler accepts generator
404
- accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
 
405
  if accepts_generator:
406
  extra_step_kwargs["generator"] = generator
407
  return extra_step_kwargs
408
 
409
- def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
410
- shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
 
 
 
 
 
 
 
 
 
 
411
  if isinstance(generator, list) and len(generator) != batch_size:
412
- raise ValueError(f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
413
- f" size of {batch_size}. Make sure the batch size matches the length of the generators.")
 
 
414
 
415
  if latents is None:
416
- latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
 
 
 
417
  else:
418
  latents = latents.to(device)
419
 
@@ -433,10 +462,12 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
433
  negative_prompt: str = "bad quality",
434
  num_images_per_prompt: int = 1,
435
  eta: float = 0.0,
436
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
 
437
  output_type: Optional[str] = "pil",
438
  return_dict: bool = True,
439
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
 
440
  callback_steps: int = 1,
441
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
442
  controlnet_conditioning_scale: float = 1.0,
@@ -514,9 +545,9 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
514
  # 0. Default height and width to unet
515
  batch_size = 4
516
  device = torch.device("cuda:0")
517
-
518
- camera = get_camera(4).to(device=device)
519
-
520
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
521
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
522
  # corresponds to doing no classifier free guidance.
@@ -525,14 +556,15 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
525
  # 4. Prepare timesteps
526
  self.scheduler.set_timesteps(num_inference_steps, device=device)
527
  timesteps = self.scheduler.timesteps
528
-
529
- prompt_embeds: torch.Tensor = self._encode_prompt(
530
  prompt=prompt,
531
  device=device,
532
  num_images_per_prompt=num_images_per_prompt,
533
- do_classifier_free_guidance=True,
534
  negative_prompt=negative_prompt,
535
- ) # type: ignore
 
536
 
537
  # 5. Prepare latent variables
538
  latents: torch.Tensor = self.prepare_latents(
@@ -540,44 +572,65 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
540
  4,
541
  height,
542
  width,
543
- prompt_embeds.dtype,
544
  device,
545
  generator,
546
  None,
547
  )
548
-
549
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
550
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
551
 
552
  # 7. Denoising loop
553
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
 
554
  with self.progress_bar(total=num_inference_steps) as progress_bar:
555
  for i, t in enumerate(timesteps):
556
  # expand the latents if we are doing classifier free guidance
557
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
558
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
 
559
 
560
  # predict the noise residual
561
- prompt_embeds = torch.cat([prompt_embeds] * 4)
562
- print(f"shape of latent_model_input: {latent_model_input.shape}") # [2*4, 4, 32, 32]
563
- print(f"shape of prompt_embeds: {prompt_embeds.shape}") # [2*4, 77, 768]
564
- print(f"shape of camera: {camera.shape}") # [4, 16]
565
- noise_pred = self.unet.forward(x=latent_model_input, timesteps=torch.tensor([t], device=device), context=prompt_embeds, num_frames=4)
 
 
 
 
 
 
 
 
 
 
566
 
567
  # perform guidance
568
  if do_classifier_free_guidance:
569
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
570
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
571
 
572
  # compute the previous noisy sample x_t -> x_t-1
573
  # latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype)
574
- latents: torch.Tensor = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
 
 
 
 
 
575
 
576
  # call the callback, if provided
577
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
 
 
578
  progress_bar.update()
579
  if callback is not None and i % callback_steps == 0:
580
- callback(i, t, latents) # type: ignore
581
 
582
  # 8. Post-processing
583
  if output_type == "latent":
@@ -592,10 +645,13 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
592
  image = self.decode_latents(latents)
593
 
594
  # Offload last model to CPU
595
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
 
 
596
  self.final_offload_hook.offload()
597
 
598
  if not return_dict:
599
- return (image, None)
600
 
601
- return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
 
 
2
  from typing import Any, Callable, Dict, List, Optional, Union
3
 
4
  import torch
5
+ from transformers import CLIPTextModel, CLIPTokenizer
 
6
 
7
+ from diffusers import AutoencoderKL, DiffusionPipeline
8
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
9
  from diffusers.utils import (
10
  deprecate,
11
  is_accelerate_available,
 
15
  )
16
 
17
  try:
18
+ from diffusers import randn_tensor # old import
19
  except ImportError:
20
+ from diffusers.utils.torch_utils import randn_tensor # new import
21
 
22
  from diffusers.configuration_utils import FrozenDict
 
23
  import numpy as np
 
 
 
 
 
24
  from diffusers.schedulers import DDIMScheduler
25
+ from models import MultiViewUNetModel, MultiViewUNetWrapperModel
26
 
27
  EXAMPLE_DOC_STRING = ""
28
 
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
 
30
 
31
  import numpy as np
32
+
33
  def create_camera_to_world_matrix(elevation, azimuth):
34
  elevation = np.radians(elevation)
35
  azimuth = np.radians(azimuth)
 
37
  x = np.cos(elevation) * np.sin(azimuth)
38
  y = np.sin(elevation)
39
  z = np.cos(elevation) * np.cos(azimuth)
40
+
41
  # Calculate camera position, target, and up vectors
42
  camera_pos = np.array([x, y, z])
43
  target = np.array([0, 0, 0])
44
  up = np.array([0, 1, 0])
45
+
46
  # Construct view matrix
47
  forward = target - camera_pos
48
  forward /= np.linalg.norm(forward)
 
55
  cam2world[:3, 3] = camera_pos
56
  return cam2world
57
 
58
+
59
  def convert_opengl_to_blender(camera_matrix):
60
  if isinstance(camera_matrix, np.ndarray):
61
  # Construct transformation matrix to convert from OpenGL space to Blender space
62
+ flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0],
63
+ [0, 0, 0, 1]])
64
  camera_matrix_blender = np.dot(flip_yz, camera_matrix)
65
  else:
66
  # Construct transformation matrix to convert from OpenGL space to Blender space
67
+ flip_yz = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0],
68
+ [0, 0, 0, 1]])
69
  if camera_matrix.ndim == 3:
70
  flip_yz = flip_yz.unsqueeze(0)
71
+ camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix),
72
+ camera_matrix)
73
  return camera_matrix_blender
74
 
75
+
76
+ def get_camera(num_frames,
77
+ elevation=15,
78
+ azimuth_start=0,
79
+ azimuth_span=360,
80
+ blender_coord=True):
81
  angle_gap = azimuth_span / num_frames
82
  cameras = []
83
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start,
84
+ angle_gap):
85
  camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
86
  if blender_coord:
87
  camera_matrix = convert_opengl_to_blender(camera_matrix)
88
  cameras.append(camera_matrix.flatten())
89
  return torch.tensor(np.stack(cameras, 0)).float()
90
 
91
+
92
  class MVDreamStableDiffusionPipeline(DiffusionPipeline):
93
+
94
  def __init__(
95
  self,
96
  vae: AutoencoderKL,
97
+ unet: MultiViewUNetWrapperModel,
98
  tokenizer: CLIPTokenizer,
99
  text_encoder: CLIPTextModel,
100
  scheduler: DDIMScheduler,
 
 
 
101
  ):
102
  super().__init__()
103
 
104
+ if hasattr(scheduler.config,
105
+ "steps_offset") and scheduler.config.steps_offset != 1:
106
+ deprecation_message = (
107
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
108
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
109
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
110
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
111
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
112
+ " file")
113
+ deprecate("steps_offset!=1",
114
+ "1.0.0",
115
+ deprecation_message,
116
+ standard_warn=False)
117
  new_config = dict(scheduler.config)
118
  new_config["steps_offset"] = 1
119
  scheduler._internal_dict = FrozenDict(new_config)
120
 
121
+ if hasattr(scheduler.config,
122
+ "clip_sample") and scheduler.config.clip_sample is True:
123
+ deprecation_message = (
124
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
125
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
126
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
127
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
128
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
129
+ )
130
+ deprecate("clip_sample not set",
131
+ "1.0.0",
132
+ deprecation_message,
133
+ standard_warn=False)
134
  new_config = dict(scheduler.config)
135
  new_config["clip_sample"] = False
136
  scheduler._internal_dict = FrozenDict(new_config)
137
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  self.register_modules(
139
  vae=vae,
140
  unet=unet,
141
  scheduler=scheduler,
142
  tokenizer=tokenizer,
143
  text_encoder=text_encoder,
 
 
144
  )
145
+ self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels) -
146
+ 1)
147
+ self.register_to_config(requires_safety_checker=False)
148
 
149
  def enable_vae_slicing(self):
150
  r"""
 
189
  if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
190
  from accelerate import cpu_offload
191
  else:
192
+ raise ImportError(
193
+ "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher"
194
+ )
195
 
196
  device = torch.device(f"cuda:{gpu_id}")
197
 
198
  if self.device.type != "cpu":
199
  self.to("cpu", silence_dtype_warnings=True)
200
+ torch.cuda.empty_cache(
201
+ ) # otherwise we don't see the memory savings (but they probably exist)
202
 
203
  for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
204
  cpu_offload(cpu_offloaded_model, device)
205
 
 
 
 
206
  def enable_model_cpu_offload(self, gpu_id=0):
207
  r"""
208
  Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
 
210
  method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
211
  `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
212
  """
213
+ if is_accelerate_available() and is_accelerate_version(
214
+ ">=", "0.17.0.dev0"):
215
  from accelerate import cpu_offload_with_hook
216
  else:
217
+ raise ImportError(
218
+ "`enable_model_offload` requires `accelerate v0.17.0` or higher."
219
+ )
220
 
221
  device = torch.device(f"cuda:{gpu_id}")
222
 
223
  if self.device.type != "cpu":
224
  self.to("cpu", silence_dtype_warnings=True)
225
+ torch.cuda.empty_cache(
226
+ ) # otherwise we don't see the memory savings (but they probably exist)
227
 
228
  hook = None
229
  for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
230
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model,
231
+ device,
232
+ prev_module_hook=hook)
 
233
 
234
  # We'll offload the last model manually.
235
  self.final_offload_hook = hook
 
244
  if not hasattr(self.unet, "_hf_hook"):
245
  return self.device
246
  for module in self.unet.modules():
247
+ if (hasattr(module, "_hf_hook")
248
+ and hasattr(module._hf_hook, "execution_device")
249
+ and module._hf_hook.execution_device is not None):
250
  return torch.device(module._hf_hook.execution_device)
251
  return self.device
252
 
 
300
  return_tensors="pt",
301
  )
302
  text_input_ids = text_inputs.input_ids
303
+ untruncated_ids = self.tokenizer(prompt,
304
+ padding="longest",
305
+ return_tensors="pt").input_ids
306
+
307
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[
308
+ -1] and not torch.equal(text_input_ids, untruncated_ids):
309
+ removed_text = self.tokenizer.batch_decode(
310
+ untruncated_ids[:, self.tokenizer.model_max_length - 1:-1])
311
+ logger.warning(
312
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
313
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
314
+ )
315
+
316
+ if hasattr(self.text_encoder.config, "use_attention_mask"
317
+ ) and self.text_encoder.config.use_attention_mask:
318
  attention_mask = text_inputs.attention_mask.to(device)
319
  else:
320
  attention_mask = None
 
325
  )
326
  prompt_embeds = prompt_embeds[0]
327
 
328
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype,
329
+ device=device)
330
 
331
  bs_embed, seq_len, _ = prompt_embeds.shape
332
  # duplicate text embeddings for each generation per prompt, using mps friendly method
333
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
334
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt,
335
+ seq_len, -1)
336
 
337
  # get unconditional embeddings for classifier free guidance
338
  if do_classifier_free_guidance and negative_prompt_embeds is None:
 
340
  if negative_prompt is None:
341
  uncond_tokens = [""] * batch_size
342
  elif type(prompt) is not type(negative_prompt):
343
+ raise TypeError(
344
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
345
+ f" {type(prompt)}.")
346
  elif isinstance(negative_prompt, str):
347
  uncond_tokens = [negative_prompt]
348
  elif batch_size != len(negative_prompt):
349
+ raise ValueError(
350
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
351
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
352
+ " the batch size of `prompt`.")
353
  else:
354
  uncond_tokens = negative_prompt
355
 
 
362
  return_tensors="pt",
363
  )
364
 
365
+ if hasattr(self.text_encoder.config, "use_attention_mask"
366
+ ) and self.text_encoder.config.use_attention_mask:
367
  attention_mask = uncond_input.attention_mask.to(device)
368
  else:
369
  attention_mask = None
 
378
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
379
  seq_len = negative_prompt_embeds.shape[1]
380
 
381
+ negative_prompt_embeds = negative_prompt_embeds.to(
382
+ dtype=self.text_encoder.dtype, device=device)
383
 
384
+ negative_prompt_embeds = negative_prompt_embeds.repeat(
385
+ 1, num_images_per_prompt, 1)
386
+ negative_prompt_embeds = negative_prompt_embeds.view(
387
+ batch_size * num_images_per_prompt, seq_len, -1)
388
 
389
  # For classifier free guidance, we need to do two forward passes.
390
  # Here we concatenate the unconditional and text embeddings into a single batch
 
393
 
394
  return prompt_embeds
395
 
 
 
 
 
 
 
 
 
396
  def decode_latents(self, latents):
397
  latents = 1 / self.vae.config.scaling_factor * latents
398
  image = self.vae.decode(latents).sample
 
407
  # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
408
  # and should be between [0, 1]
409
 
410
+ accepts_eta = "eta" in set(
411
+ inspect.signature(self.scheduler.step).parameters.keys())
412
  extra_step_kwargs = {}
413
  if accepts_eta:
414
  extra_step_kwargs["eta"] = eta
415
 
416
  # check if the scheduler accepts generator
417
+ accepts_generator = "generator" in set(
418
+ inspect.signature(self.scheduler.step).parameters.keys())
419
  if accepts_generator:
420
  extra_step_kwargs["generator"] = generator
421
  return extra_step_kwargs
422
 
423
+ def prepare_latents(self,
424
+ batch_size,
425
+ num_channels_latents,
426
+ height,
427
+ width,
428
+ dtype,
429
+ device,
430
+ generator,
431
+ latents=None):
432
+ shape = (batch_size, num_channels_latents,
433
+ height // self.vae_scale_factor,
434
+ width // self.vae_scale_factor)
435
  if isinstance(generator, list) and len(generator) != batch_size:
436
+ raise ValueError(
437
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
438
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
439
+ )
440
 
441
  if latents is None:
442
+ latents = randn_tensor(shape,
443
+ generator=generator,
444
+ device=device,
445
+ dtype=dtype)
446
  else:
447
  latents = latents.to(device)
448
 
 
462
  negative_prompt: str = "bad quality",
463
  num_images_per_prompt: int = 1,
464
  eta: float = 0.0,
465
+ generator: Optional[Union[torch.Generator,
466
+ List[torch.Generator]]] = None,
467
  output_type: Optional[str] = "pil",
468
  return_dict: bool = True,
469
+ callback: Optional[Callable[[int, int, torch.FloatTensor],
470
+ None]] = None,
471
  callback_steps: int = 1,
472
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
473
  controlnet_conditioning_scale: float = 1.0,
 
545
  # 0. Default height and width to unet
546
  batch_size = 4
547
  device = torch.device("cuda:0")
548
+
549
+ camera = get_camera(batch_size).to(device=device)
550
+
551
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
552
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
553
  # corresponds to doing no classifier free guidance.
 
556
  # 4. Prepare timesteps
557
  self.scheduler.set_timesteps(num_inference_steps, device=device)
558
  timesteps = self.scheduler.timesteps
559
+
560
+ _: torch.Tensor = self._encode_prompt(
561
  prompt=prompt,
562
  device=device,
563
  num_images_per_prompt=num_images_per_prompt,
564
+ do_classifier_free_guidance=do_classifier_free_guidance,
565
  negative_prompt=negative_prompt,
566
+ ) # type: ignore
567
+ prompt_embeds_neg, prompt_embeds_pos = _.chunk(2)
568
 
569
  # 5. Prepare latent variables
570
  latents: torch.Tensor = self.prepare_latents(
 
572
  4,
573
  height,
574
  width,
575
+ prompt_embeds_pos.dtype,
576
  device,
577
  generator,
578
  None,
579
  )
580
+
581
  # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
582
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
583
 
584
  # 7. Denoising loop
585
+ num_warmup_steps = len(
586
+ timesteps) - num_inference_steps * self.scheduler.order
587
  with self.progress_bar(total=num_inference_steps) as progress_bar:
588
  for i, t in enumerate(timesteps):
589
  # expand the latents if we are doing classifier free guidance
590
+ multiplier = 2 if do_classifier_free_guidance else 1
591
+ latent_model_input = torch.cat([latents] * multiplier)
592
+ latent_model_input = self.scheduler.scale_model_input(
593
+ latent_model_input, t)
594
 
595
  # predict the noise residual
596
+ # print(
597
+ # f"shape of latent_model_input: {latent_model_input.shape}"
598
+ # ) # [2*4, 4, 32, 32]
599
+ # print(f"shape of prompt_embeds: {prompt_embeds.shape}"
600
+ # ) # [2*4, 77, 768]
601
+ # print(f"shape of camera: {camera.shape}") # [4, 16]
602
+ noise_pred = self.unet.forward(
603
+ x=latent_model_input,
604
+ timesteps=torch.tensor([t] * 4 * multiplier,
605
+ device=device),
606
+ context=torch.cat([prompt_embeds_neg] * 4 +
607
+ [prompt_embeds_pos] * 4),
608
+ num_frames=4,
609
+ camera=torch.cat([camera] * multiplier),
610
+ )
611
 
612
  # perform guidance
613
  if do_classifier_free_guidance:
614
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
615
+ noise_pred = noise_pred_uncond + guidance_scale * (
616
+ noise_pred_text - noise_pred_uncond)
617
 
618
  # compute the previous noisy sample x_t -> x_t-1
619
  # latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype)
620
+ latents: torch.Tensor = self.scheduler.step(
621
+ noise_pred,
622
+ t,
623
+ latents,
624
+ **extra_step_kwargs,
625
+ return_dict=False)[0]
626
 
627
  # call the callback, if provided
628
+ if i == len(timesteps) - 1 or (
629
+ (i + 1) > num_warmup_steps and
630
+ (i + 1) % self.scheduler.order == 0):
631
  progress_bar.update()
632
  if callback is not None and i % callback_steps == 0:
633
+ callback(i, t, latents) # type: ignore
634
 
635
  # 8. Post-processing
636
  if output_type == "latent":
 
645
  image = self.decode_latents(latents)
646
 
647
  # Offload last model to CPU
648
+ if hasattr(
649
+ self,
650
+ "final_offload_hook") and self.final_offload_hook is not None:
651
  self.final_offload_hook.offload()
652
 
653
  if not return_dict:
654
+ return image
655
 
656
+ return StableDiffusionPipelineOutput(images=image,
657
+ nsfw_content_detected=None)