Koke_Cacao
commited on
Commit
·
5b08d3b
1
Parent(s):
57e6edd
:sparkles: finish inference
Browse files- .gitignore +1 -0
- scripts/README.md +0 -1
- scripts/convert_mvdream_to_diffusers.py +17 -13
- scripts/models.py +11 -0
- scripts/pipeline_mvdream.py +186 -130
.gitignore
CHANGED
@@ -2,3 +2,4 @@
|
|
2 |
*.yaml
|
3 |
converted
|
4 |
__pycache__
|
|
|
|
2 |
*.yaml
|
3 |
converted
|
4 |
__pycache__
|
5 |
+
*.png
|
scripts/README.md
CHANGED
@@ -14,6 +14,5 @@ wget https://raw.githubusercontent.com/bytedance/MVDream/main/mvdream/configs/sd
|
|
14 |
|
15 |
Hugging Face diffusers weights are converted by script:
|
16 |
```bash
|
17 |
-
mkdir converted
|
18 |
python ./scripts/convert_mvdream_to_diffusers.py --checkpoint_path ./sd-v1.5-4view.pt --dump_path ./converted --original_config_file ./sd-v1.yaml
|
19 |
```
|
|
|
14 |
|
15 |
Hugging Face diffusers weights are converted by script:
|
16 |
```bash
|
|
|
17 |
python ./scripts/convert_mvdream_to_diffusers.py --checkpoint_path ./sd-v1.5-4view.pt --dump_path ./converted --original_config_file ./sd-v1.yaml
|
18 |
```
|
scripts/convert_mvdream_to_diffusers.py
CHANGED
@@ -20,7 +20,7 @@ from diffusers.utils import logging
|
|
20 |
from accelerate import init_empty_weights
|
21 |
from accelerate.utils import set_module_tensor_to_device
|
22 |
from rich import print, print_json
|
23 |
-
from models import MultiViewUNetModel
|
24 |
from pipeline_mvdream import MVDreamStableDiffusionPipeline
|
25 |
from transformers import CLIPTokenizer, CLIPTextModel
|
26 |
|
@@ -659,7 +659,6 @@ def conv_attn_to_linear(checkpoint):
|
|
659 |
if checkpoint[key].ndim > 2:
|
660 |
checkpoint[key] = checkpoint[key][:, :, 0]
|
661 |
|
662 |
-
|
663 |
def convert_from_original_mvdream_ckpt(
|
664 |
checkpoint_path,
|
665 |
original_config_file,
|
@@ -667,13 +666,13 @@ def convert_from_original_mvdream_ckpt(
|
|
667 |
device
|
668 |
):
|
669 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
670 |
-
print(f"Checkpoint: {checkpoint.keys()}")
|
671 |
torch.cuda.empty_cache()
|
672 |
|
673 |
from omegaconf import OmegaConf
|
674 |
|
675 |
original_config = OmegaConf.load(original_config_file)
|
676 |
-
print(f"Original Config: {original_config}")
|
677 |
prediction_type = "epsilon"
|
678 |
image_size = 256
|
679 |
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
|
@@ -700,10 +699,11 @@ def convert_from_original_mvdream_ckpt(
|
|
700 |
# converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
701 |
# checkpoint, unet_config, path=None, extract_ema=extract_ema
|
702 |
# )
|
703 |
-
print(f"Unet Config: {original_config.model.params.unet_config.params}")
|
704 |
-
unet:
|
|
|
705 |
unet.load_state_dict({
|
706 |
-
key.replace("model.diffusion_model.", ""): value for key, value in checkpoint.items() if key.replace("model.diffusion_model.", "") in unet.state_dict()
|
707 |
})
|
708 |
for param_name, param in unet.state_dict().items():
|
709 |
set_module_tensor_to_device(unet, param_name, "cuda:0", value=param)
|
@@ -738,9 +738,6 @@ def convert_from_original_mvdream_ckpt(
|
|
738 |
tokenizer=tokenizer,
|
739 |
text_encoder=text_encoder,
|
740 |
scheduler=scheduler,
|
741 |
-
safety_checker=None,
|
742 |
-
feature_extractor=None,
|
743 |
-
requires_safety_checker=False
|
744 |
)
|
745 |
|
746 |
return pipe
|
@@ -787,8 +784,15 @@ if __name__ == "__main__":
|
|
787 |
if args.half:
|
788 |
pipe.to(torch_dtype=torch.float16)
|
789 |
|
790 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
791 |
|
792 |
-
assert False
|
793 |
-
|
794 |
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
|
|
20 |
from accelerate import init_empty_weights
|
21 |
from accelerate.utils import set_module_tensor_to_device
|
22 |
from rich import print, print_json
|
23 |
+
from models import MultiViewUNetModel, MultiViewUNetWrapperModel
|
24 |
from pipeline_mvdream import MVDreamStableDiffusionPipeline
|
25 |
from transformers import CLIPTokenizer, CLIPTextModel
|
26 |
|
|
|
659 |
if checkpoint[key].ndim > 2:
|
660 |
checkpoint[key] = checkpoint[key][:, :, 0]
|
661 |
|
|
|
662 |
def convert_from_original_mvdream_ckpt(
|
663 |
checkpoint_path,
|
664 |
original_config_file,
|
|
|
666 |
device
|
667 |
):
|
668 |
checkpoint = torch.load(checkpoint_path, map_location=device)
|
669 |
+
# print(f"Checkpoint: {checkpoint.keys()}")
|
670 |
torch.cuda.empty_cache()
|
671 |
|
672 |
from omegaconf import OmegaConf
|
673 |
|
674 |
original_config = OmegaConf.load(original_config_file)
|
675 |
+
# print(f"Original Config: {original_config}")
|
676 |
prediction_type = "epsilon"
|
677 |
image_size = 256
|
678 |
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
|
|
|
699 |
# converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
700 |
# checkpoint, unet_config, path=None, extract_ema=extract_ema
|
701 |
# )
|
702 |
+
# print(f"Unet Config: {original_config.model.params.unet_config.params}")
|
703 |
+
unet: MultiViewUNetWrapperModel = MultiViewUNetWrapperModel(**original_config.model.params.unet_config.params)
|
704 |
+
# print(f"Unet State Dict: {unet.state_dict().keys()}")
|
705 |
unet.load_state_dict({
|
706 |
+
key.replace("model.diffusion_model.", "unet."): value for key, value in checkpoint.items() if key.replace("model.diffusion_model.", "unet.") in unet.state_dict()
|
707 |
})
|
708 |
for param_name, param in unet.state_dict().items():
|
709 |
set_module_tensor_to_device(unet, param_name, "cuda:0", value=param)
|
|
|
738 |
tokenizer=tokenizer,
|
739 |
text_encoder=text_encoder,
|
740 |
scheduler=scheduler,
|
|
|
|
|
|
|
741 |
)
|
742 |
|
743 |
return pipe
|
|
|
784 |
if args.half:
|
785 |
pipe.to(torch_dtype=torch.float16)
|
786 |
|
787 |
+
images = pipe(
|
788 |
+
prompt="Head of Hatsune Miku",
|
789 |
+
negative_prompt="painting, bad quality, flat",
|
790 |
+
output_type="pil",
|
791 |
+
return_dict=False,
|
792 |
+
guidance_scale=7.5,
|
793 |
+
num_inference_steps=50,
|
794 |
+
)
|
795 |
+
for i, image in enumerate(images):
|
796 |
+
image.save(f"image_{i}.png")
|
797 |
|
|
|
|
|
798 |
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
scripts/models.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from abc import abstractmethod
|
2 |
import math
|
|
|
3 |
|
4 |
import numpy as np
|
5 |
import torch as th
|
@@ -18,6 +19,16 @@ from util import (
|
|
18 |
from attention import SpatialTransformer, SpatialTransformer3D, exists
|
19 |
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
# dummy replace
|
22 |
def convert_module_to_f16(x):
|
23 |
pass
|
|
|
1 |
from abc import abstractmethod
|
2 |
import math
|
3 |
+
from typing import Any, Mapping
|
4 |
|
5 |
import numpy as np
|
6 |
import torch as th
|
|
|
19 |
from attention import SpatialTransformer, SpatialTransformer3D, exists
|
20 |
|
21 |
|
22 |
+
from diffusers.configuration_utils import ConfigMixin
|
23 |
+
from diffusers.models.modeling_utils import ModelMixin
|
24 |
+
class MultiViewUNetWrapperModel(ModelMixin, ConfigMixin):
|
25 |
+
def __init__(self, *args, **kwargs):
|
26 |
+
super().__init__()
|
27 |
+
self.unet: MultiViewUNetModel = MultiViewUNetModel(*args, **kwargs)
|
28 |
+
|
29 |
+
def forward(self, *args, **kwargs):
|
30 |
+
return self.unet(*args, **kwargs)
|
31 |
+
|
32 |
# dummy replace
|
33 |
def convert_module_to_f16(x):
|
34 |
pass
|
scripts/pipeline_mvdream.py
CHANGED
@@ -2,11 +2,10 @@ import inspect
|
|
2 |
from typing import Any, Callable, Dict, List, Optional, Union
|
3 |
|
4 |
import torch
|
5 |
-
from
|
6 |
-
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
7 |
|
8 |
-
from diffusers import AutoencoderKL,
|
9 |
-
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
10 |
from diffusers.utils import (
|
11 |
deprecate,
|
12 |
is_accelerate_available,
|
@@ -16,26 +15,21 @@ from diffusers.utils import (
|
|
16 |
)
|
17 |
|
18 |
try:
|
19 |
-
from diffusers import randn_tensor
|
20 |
except ImportError:
|
21 |
-
from diffusers.utils.torch_utils import randn_tensor
|
22 |
|
23 |
from diffusers.configuration_utils import FrozenDict
|
24 |
-
import PIL
|
25 |
import numpy as np
|
26 |
-
import kornia
|
27 |
-
from diffusers.configuration_utils import ConfigMixin
|
28 |
-
from diffusers.models.modeling_utils import ModelMixin
|
29 |
-
|
30 |
-
from models import MultiViewUNetModel
|
31 |
from diffusers.schedulers import DDIMScheduler
|
|
|
32 |
|
33 |
EXAMPLE_DOC_STRING = ""
|
34 |
|
35 |
-
logger = logging.get_logger(__name__)
|
36 |
-
|
37 |
|
38 |
import numpy as np
|
|
|
39 |
def create_camera_to_world_matrix(elevation, azimuth):
|
40 |
elevation = np.radians(elevation)
|
41 |
azimuth = np.radians(azimuth)
|
@@ -43,12 +37,12 @@ def create_camera_to_world_matrix(elevation, azimuth):
|
|
43 |
x = np.cos(elevation) * np.sin(azimuth)
|
44 |
y = np.sin(elevation)
|
45 |
z = np.cos(elevation) * np.cos(azimuth)
|
46 |
-
|
47 |
# Calculate camera position, target, and up vectors
|
48 |
camera_pos = np.array([x, y, z])
|
49 |
target = np.array([0, 0, 0])
|
50 |
up = np.array([0, 1, 0])
|
51 |
-
|
52 |
# Construct view matrix
|
53 |
forward = target - camera_pos
|
54 |
forward /= np.linalg.norm(forward)
|
@@ -61,90 +55,96 @@ def create_camera_to_world_matrix(elevation, azimuth):
|
|
61 |
cam2world[:3, 3] = camera_pos
|
62 |
return cam2world
|
63 |
|
|
|
64 |
def convert_opengl_to_blender(camera_matrix):
|
65 |
if isinstance(camera_matrix, np.ndarray):
|
66 |
# Construct transformation matrix to convert from OpenGL space to Blender space
|
67 |
-
flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0],
|
|
|
68 |
camera_matrix_blender = np.dot(flip_yz, camera_matrix)
|
69 |
else:
|
70 |
# Construct transformation matrix to convert from OpenGL space to Blender space
|
71 |
-
flip_yz = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0],
|
|
|
72 |
if camera_matrix.ndim == 3:
|
73 |
flip_yz = flip_yz.unsqueeze(0)
|
74 |
-
camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix),
|
|
|
75 |
return camera_matrix_blender
|
76 |
|
77 |
-
|
|
|
|
|
|
|
|
|
|
|
78 |
angle_gap = azimuth_span / num_frames
|
79 |
cameras = []
|
80 |
-
for azimuth in np.arange(azimuth_start, azimuth_span+azimuth_start,
|
|
|
81 |
camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
|
82 |
if blender_coord:
|
83 |
camera_matrix = convert_opengl_to_blender(camera_matrix)
|
84 |
cameras.append(camera_matrix.flatten())
|
85 |
return torch.tensor(np.stack(cameras, 0)).float()
|
86 |
|
|
|
87 |
class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
|
88 |
def __init__(
|
89 |
self,
|
90 |
vae: AutoencoderKL,
|
91 |
-
unet:
|
92 |
tokenizer: CLIPTokenizer,
|
93 |
text_encoder: CLIPTextModel,
|
94 |
scheduler: DDIMScheduler,
|
95 |
-
safety_checker: Optional[StableDiffusionSafetyChecker],
|
96 |
-
feature_extractor: Optional[CLIPFeatureExtractor],
|
97 |
-
requires_safety_checker: bool = True,
|
98 |
):
|
99 |
super().__init__()
|
100 |
|
101 |
-
if hasattr(scheduler.config,
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
|
|
|
|
|
|
|
|
109 |
new_config = dict(scheduler.config)
|
110 |
new_config["steps_offset"] = 1
|
111 |
scheduler._internal_dict = FrozenDict(new_config)
|
112 |
|
113 |
-
if hasattr(scheduler.config,
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
new_config = dict(scheduler.config)
|
121 |
new_config["clip_sample"] = False
|
122 |
scheduler._internal_dict = FrozenDict(new_config)
|
123 |
|
124 |
-
if safety_checker is None and requires_safety_checker:
|
125 |
-
logger.warning(f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
|
126 |
-
" that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
|
127 |
-
" results in services or applications open to the public. Both the diffusers team and Hugging Face"
|
128 |
-
" strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
|
129 |
-
" it only for use-cases that involve analyzing network behavior or auditing its results. For more"
|
130 |
-
" information, please have a look at https://github.com/huggingface/diffusers/pull/254 .")
|
131 |
-
|
132 |
-
if safety_checker is not None and feature_extractor is None:
|
133 |
-
raise ValueError("Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
|
134 |
-
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead.")
|
135 |
-
|
136 |
self.register_modules(
|
137 |
vae=vae,
|
138 |
unet=unet,
|
139 |
scheduler=scheduler,
|
140 |
tokenizer=tokenizer,
|
141 |
text_encoder=text_encoder,
|
142 |
-
safety_checker=safety_checker,
|
143 |
-
feature_extractor=feature_extractor,
|
144 |
)
|
145 |
-
self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels) -
|
146 |
-
|
147 |
-
|
148 |
|
149 |
def enable_vae_slicing(self):
|
150 |
r"""
|
@@ -189,20 +189,20 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
189 |
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
|
190 |
from accelerate import cpu_offload
|
191 |
else:
|
192 |
-
raise ImportError(
|
|
|
|
|
193 |
|
194 |
device = torch.device(f"cuda:{gpu_id}")
|
195 |
|
196 |
if self.device.type != "cpu":
|
197 |
self.to("cpu", silence_dtype_warnings=True)
|
198 |
-
torch.cuda.empty_cache(
|
|
|
199 |
|
200 |
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
201 |
cpu_offload(cpu_offloaded_model, device)
|
202 |
|
203 |
-
if self.safety_checker is not None:
|
204 |
-
cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
|
205 |
-
|
206 |
def enable_model_cpu_offload(self, gpu_id=0):
|
207 |
r"""
|
208 |
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
@@ -210,23 +210,26 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
210 |
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
211 |
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
212 |
"""
|
213 |
-
if is_accelerate_available() and is_accelerate_version(
|
|
|
214 |
from accelerate import cpu_offload_with_hook
|
215 |
else:
|
216 |
-
raise ImportError(
|
|
|
|
|
217 |
|
218 |
device = torch.device(f"cuda:{gpu_id}")
|
219 |
|
220 |
if self.device.type != "cpu":
|
221 |
self.to("cpu", silence_dtype_warnings=True)
|
222 |
-
torch.cuda.empty_cache(
|
|
|
223 |
|
224 |
hook = None
|
225 |
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
226 |
-
_, hook = cpu_offload_with_hook(cpu_offloaded_model,
|
227 |
-
|
228 |
-
|
229 |
-
_, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
|
230 |
|
231 |
# We'll offload the last model manually.
|
232 |
self.final_offload_hook = hook
|
@@ -241,7 +244,9 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
241 |
if not hasattr(self.unet, "_hf_hook"):
|
242 |
return self.device
|
243 |
for module in self.unet.modules():
|
244 |
-
if (hasattr(module, "_hf_hook")
|
|
|
|
|
245 |
return torch.device(module._hf_hook.execution_device)
|
246 |
return self.device
|
247 |
|
@@ -295,14 +300,21 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
295 |
return_tensors="pt",
|
296 |
)
|
297 |
text_input_ids = text_inputs.input_ids
|
298 |
-
untruncated_ids = self.tokenizer(prompt,
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
attention_mask = text_inputs.attention_mask.to(device)
|
307 |
else:
|
308 |
attention_mask = None
|
@@ -313,12 +325,14 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
313 |
)
|
314 |
prompt_embeds = prompt_embeds[0]
|
315 |
|
316 |
-
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype,
|
|
|
317 |
|
318 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
319 |
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
320 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
321 |
-
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt,
|
|
|
322 |
|
323 |
# get unconditional embeddings for classifier free guidance
|
324 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
@@ -326,14 +340,16 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
326 |
if negative_prompt is None:
|
327 |
uncond_tokens = [""] * batch_size
|
328 |
elif type(prompt) is not type(negative_prompt):
|
329 |
-
raise TypeError(
|
330 |
-
|
|
|
331 |
elif isinstance(negative_prompt, str):
|
332 |
uncond_tokens = [negative_prompt]
|
333 |
elif batch_size != len(negative_prompt):
|
334 |
-
raise ValueError(
|
335 |
-
|
336 |
-
|
|
|
337 |
else:
|
338 |
uncond_tokens = negative_prompt
|
339 |
|
@@ -346,7 +362,8 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
346 |
return_tensors="pt",
|
347 |
)
|
348 |
|
349 |
-
if hasattr(self.text_encoder.config, "use_attention_mask"
|
|
|
350 |
attention_mask = uncond_input.attention_mask.to(device)
|
351 |
else:
|
352 |
attention_mask = None
|
@@ -361,10 +378,13 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
361 |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
362 |
seq_len = negative_prompt_embeds.shape[1]
|
363 |
|
364 |
-
negative_prompt_embeds = negative_prompt_embeds.to(
|
|
|
365 |
|
366 |
-
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
367 |
-
|
|
|
|
|
368 |
|
369 |
# For classifier free guidance, we need to do two forward passes.
|
370 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
@@ -373,14 +393,6 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
373 |
|
374 |
return prompt_embeds
|
375 |
|
376 |
-
def run_safety_checker(self, image, device, dtype):
|
377 |
-
if self.safety_checker is not None:
|
378 |
-
safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
|
379 |
-
image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values.to(dtype))
|
380 |
-
else:
|
381 |
-
has_nsfw_concept = None
|
382 |
-
return image, has_nsfw_concept
|
383 |
-
|
384 |
def decode_latents(self, latents):
|
385 |
latents = 1 / self.vae.config.scaling_factor * latents
|
386 |
image = self.vae.decode(latents).sample
|
@@ -395,25 +407,42 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
395 |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
396 |
# and should be between [0, 1]
|
397 |
|
398 |
-
accepts_eta = "eta" in set(
|
|
|
399 |
extra_step_kwargs = {}
|
400 |
if accepts_eta:
|
401 |
extra_step_kwargs["eta"] = eta
|
402 |
|
403 |
# check if the scheduler accepts generator
|
404 |
-
accepts_generator = "generator" in set(
|
|
|
405 |
if accepts_generator:
|
406 |
extra_step_kwargs["generator"] = generator
|
407 |
return extra_step_kwargs
|
408 |
|
409 |
-
def prepare_latents(self,
|
410 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
411 |
if isinstance(generator, list) and len(generator) != batch_size:
|
412 |
-
raise ValueError(
|
413 |
-
|
|
|
|
|
414 |
|
415 |
if latents is None:
|
416 |
-
latents = randn_tensor(shape,
|
|
|
|
|
|
|
417 |
else:
|
418 |
latents = latents.to(device)
|
419 |
|
@@ -433,10 +462,12 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
433 |
negative_prompt: str = "bad quality",
|
434 |
num_images_per_prompt: int = 1,
|
435 |
eta: float = 0.0,
|
436 |
-
generator: Optional[Union[torch.Generator,
|
|
|
437 |
output_type: Optional[str] = "pil",
|
438 |
return_dict: bool = True,
|
439 |
-
callback: Optional[Callable[[int, int, torch.FloatTensor],
|
|
|
440 |
callback_steps: int = 1,
|
441 |
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
442 |
controlnet_conditioning_scale: float = 1.0,
|
@@ -514,9 +545,9 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
514 |
# 0. Default height and width to unet
|
515 |
batch_size = 4
|
516 |
device = torch.device("cuda:0")
|
517 |
-
|
518 |
-
camera = get_camera(
|
519 |
-
|
520 |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
521 |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
522 |
# corresponds to doing no classifier free guidance.
|
@@ -525,14 +556,15 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
525 |
# 4. Prepare timesteps
|
526 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
527 |
timesteps = self.scheduler.timesteps
|
528 |
-
|
529 |
-
|
530 |
prompt=prompt,
|
531 |
device=device,
|
532 |
num_images_per_prompt=num_images_per_prompt,
|
533 |
-
do_classifier_free_guidance=
|
534 |
negative_prompt=negative_prompt,
|
535 |
-
)
|
|
|
536 |
|
537 |
# 5. Prepare latent variables
|
538 |
latents: torch.Tensor = self.prepare_latents(
|
@@ -540,44 +572,65 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
540 |
4,
|
541 |
height,
|
542 |
width,
|
543 |
-
|
544 |
device,
|
545 |
generator,
|
546 |
None,
|
547 |
)
|
548 |
-
|
549 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
550 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
551 |
|
552 |
# 7. Denoising loop
|
553 |
-
num_warmup_steps = len(
|
|
|
554 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
555 |
for i, t in enumerate(timesteps):
|
556 |
# expand the latents if we are doing classifier free guidance
|
557 |
-
|
558 |
-
latent_model_input =
|
|
|
|
|
559 |
|
560 |
# predict the noise residual
|
561 |
-
|
562 |
-
|
563 |
-
|
564 |
-
print(f"shape of
|
565 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
566 |
|
567 |
# perform guidance
|
568 |
if do_classifier_free_guidance:
|
569 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
570 |
-
noise_pred = noise_pred_uncond + guidance_scale * (
|
|
|
571 |
|
572 |
# compute the previous noisy sample x_t -> x_t-1
|
573 |
# latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype)
|
574 |
-
latents: torch.Tensor = self.scheduler.step(
|
|
|
|
|
|
|
|
|
|
|
575 |
|
576 |
# call the callback, if provided
|
577 |
-
if i == len(timesteps) - 1 or (
|
|
|
|
|
578 |
progress_bar.update()
|
579 |
if callback is not None and i % callback_steps == 0:
|
580 |
-
callback(i, t, latents)
|
581 |
|
582 |
# 8. Post-processing
|
583 |
if output_type == "latent":
|
@@ -592,10 +645,13 @@ class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
|
592 |
image = self.decode_latents(latents)
|
593 |
|
594 |
# Offload last model to CPU
|
595 |
-
if hasattr(
|
|
|
|
|
596 |
self.final_offload_hook.offload()
|
597 |
|
598 |
if not return_dict:
|
599 |
-
return
|
600 |
|
601 |
-
return StableDiffusionPipelineOutput(images=image,
|
|
|
|
2 |
from typing import Any, Callable, Dict, List, Optional, Union
|
3 |
|
4 |
import torch
|
5 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
|
|
6 |
|
7 |
+
from diffusers import AutoencoderKL, DiffusionPipeline
|
8 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
9 |
from diffusers.utils import (
|
10 |
deprecate,
|
11 |
is_accelerate_available,
|
|
|
15 |
)
|
16 |
|
17 |
try:
|
18 |
+
from diffusers import randn_tensor # old import
|
19 |
except ImportError:
|
20 |
+
from diffusers.utils.torch_utils import randn_tensor # new import
|
21 |
|
22 |
from diffusers.configuration_utils import FrozenDict
|
|
|
23 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
24 |
from diffusers.schedulers import DDIMScheduler
|
25 |
+
from models import MultiViewUNetModel, MultiViewUNetWrapperModel
|
26 |
|
27 |
EXAMPLE_DOC_STRING = ""
|
28 |
|
29 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
|
30 |
|
31 |
import numpy as np
|
32 |
+
|
33 |
def create_camera_to_world_matrix(elevation, azimuth):
|
34 |
elevation = np.radians(elevation)
|
35 |
azimuth = np.radians(azimuth)
|
|
|
37 |
x = np.cos(elevation) * np.sin(azimuth)
|
38 |
y = np.sin(elevation)
|
39 |
z = np.cos(elevation) * np.cos(azimuth)
|
40 |
+
|
41 |
# Calculate camera position, target, and up vectors
|
42 |
camera_pos = np.array([x, y, z])
|
43 |
target = np.array([0, 0, 0])
|
44 |
up = np.array([0, 1, 0])
|
45 |
+
|
46 |
# Construct view matrix
|
47 |
forward = target - camera_pos
|
48 |
forward /= np.linalg.norm(forward)
|
|
|
55 |
cam2world[:3, 3] = camera_pos
|
56 |
return cam2world
|
57 |
|
58 |
+
|
59 |
def convert_opengl_to_blender(camera_matrix):
|
60 |
if isinstance(camera_matrix, np.ndarray):
|
61 |
# Construct transformation matrix to convert from OpenGL space to Blender space
|
62 |
+
flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0],
|
63 |
+
[0, 0, 0, 1]])
|
64 |
camera_matrix_blender = np.dot(flip_yz, camera_matrix)
|
65 |
else:
|
66 |
# Construct transformation matrix to convert from OpenGL space to Blender space
|
67 |
+
flip_yz = torch.tensor([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0],
|
68 |
+
[0, 0, 0, 1]])
|
69 |
if camera_matrix.ndim == 3:
|
70 |
flip_yz = flip_yz.unsqueeze(0)
|
71 |
+
camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix),
|
72 |
+
camera_matrix)
|
73 |
return camera_matrix_blender
|
74 |
|
75 |
+
|
76 |
+
def get_camera(num_frames,
|
77 |
+
elevation=15,
|
78 |
+
azimuth_start=0,
|
79 |
+
azimuth_span=360,
|
80 |
+
blender_coord=True):
|
81 |
angle_gap = azimuth_span / num_frames
|
82 |
cameras = []
|
83 |
+
for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start,
|
84 |
+
angle_gap):
|
85 |
camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
|
86 |
if blender_coord:
|
87 |
camera_matrix = convert_opengl_to_blender(camera_matrix)
|
88 |
cameras.append(camera_matrix.flatten())
|
89 |
return torch.tensor(np.stack(cameras, 0)).float()
|
90 |
|
91 |
+
|
92 |
class MVDreamStableDiffusionPipeline(DiffusionPipeline):
|
93 |
+
|
94 |
def __init__(
|
95 |
self,
|
96 |
vae: AutoencoderKL,
|
97 |
+
unet: MultiViewUNetWrapperModel,
|
98 |
tokenizer: CLIPTokenizer,
|
99 |
text_encoder: CLIPTextModel,
|
100 |
scheduler: DDIMScheduler,
|
|
|
|
|
|
|
101 |
):
|
102 |
super().__init__()
|
103 |
|
104 |
+
if hasattr(scheduler.config,
|
105 |
+
"steps_offset") and scheduler.config.steps_offset != 1:
|
106 |
+
deprecation_message = (
|
107 |
+
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
|
108 |
+
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
|
109 |
+
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
|
110 |
+
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
|
111 |
+
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
|
112 |
+
" file")
|
113 |
+
deprecate("steps_offset!=1",
|
114 |
+
"1.0.0",
|
115 |
+
deprecation_message,
|
116 |
+
standard_warn=False)
|
117 |
new_config = dict(scheduler.config)
|
118 |
new_config["steps_offset"] = 1
|
119 |
scheduler._internal_dict = FrozenDict(new_config)
|
120 |
|
121 |
+
if hasattr(scheduler.config,
|
122 |
+
"clip_sample") and scheduler.config.clip_sample is True:
|
123 |
+
deprecation_message = (
|
124 |
+
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
|
125 |
+
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
|
126 |
+
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
|
127 |
+
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
|
128 |
+
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
|
129 |
+
)
|
130 |
+
deprecate("clip_sample not set",
|
131 |
+
"1.0.0",
|
132 |
+
deprecation_message,
|
133 |
+
standard_warn=False)
|
134 |
new_config = dict(scheduler.config)
|
135 |
new_config["clip_sample"] = False
|
136 |
scheduler._internal_dict = FrozenDict(new_config)
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
self.register_modules(
|
139 |
vae=vae,
|
140 |
unet=unet,
|
141 |
scheduler=scheduler,
|
142 |
tokenizer=tokenizer,
|
143 |
text_encoder=text_encoder,
|
|
|
|
|
144 |
)
|
145 |
+
self.vae_scale_factor = 2**(len(self.vae.config.block_out_channels) -
|
146 |
+
1)
|
147 |
+
self.register_to_config(requires_safety_checker=False)
|
148 |
|
149 |
def enable_vae_slicing(self):
|
150 |
r"""
|
|
|
189 |
if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
|
190 |
from accelerate import cpu_offload
|
191 |
else:
|
192 |
+
raise ImportError(
|
193 |
+
"`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher"
|
194 |
+
)
|
195 |
|
196 |
device = torch.device(f"cuda:{gpu_id}")
|
197 |
|
198 |
if self.device.type != "cpu":
|
199 |
self.to("cpu", silence_dtype_warnings=True)
|
200 |
+
torch.cuda.empty_cache(
|
201 |
+
) # otherwise we don't see the memory savings (but they probably exist)
|
202 |
|
203 |
for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
|
204 |
cpu_offload(cpu_offloaded_model, device)
|
205 |
|
|
|
|
|
|
|
206 |
def enable_model_cpu_offload(self, gpu_id=0):
|
207 |
r"""
|
208 |
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
|
|
210 |
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
211 |
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
212 |
"""
|
213 |
+
if is_accelerate_available() and is_accelerate_version(
|
214 |
+
">=", "0.17.0.dev0"):
|
215 |
from accelerate import cpu_offload_with_hook
|
216 |
else:
|
217 |
+
raise ImportError(
|
218 |
+
"`enable_model_offload` requires `accelerate v0.17.0` or higher."
|
219 |
+
)
|
220 |
|
221 |
device = torch.device(f"cuda:{gpu_id}")
|
222 |
|
223 |
if self.device.type != "cpu":
|
224 |
self.to("cpu", silence_dtype_warnings=True)
|
225 |
+
torch.cuda.empty_cache(
|
226 |
+
) # otherwise we don't see the memory savings (but they probably exist)
|
227 |
|
228 |
hook = None
|
229 |
for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
|
230 |
+
_, hook = cpu_offload_with_hook(cpu_offloaded_model,
|
231 |
+
device,
|
232 |
+
prev_module_hook=hook)
|
|
|
233 |
|
234 |
# We'll offload the last model manually.
|
235 |
self.final_offload_hook = hook
|
|
|
244 |
if not hasattr(self.unet, "_hf_hook"):
|
245 |
return self.device
|
246 |
for module in self.unet.modules():
|
247 |
+
if (hasattr(module, "_hf_hook")
|
248 |
+
and hasattr(module._hf_hook, "execution_device")
|
249 |
+
and module._hf_hook.execution_device is not None):
|
250 |
return torch.device(module._hf_hook.execution_device)
|
251 |
return self.device
|
252 |
|
|
|
300 |
return_tensors="pt",
|
301 |
)
|
302 |
text_input_ids = text_inputs.input_ids
|
303 |
+
untruncated_ids = self.tokenizer(prompt,
|
304 |
+
padding="longest",
|
305 |
+
return_tensors="pt").input_ids
|
306 |
+
|
307 |
+
if untruncated_ids.shape[-1] >= text_input_ids.shape[
|
308 |
+
-1] and not torch.equal(text_input_ids, untruncated_ids):
|
309 |
+
removed_text = self.tokenizer.batch_decode(
|
310 |
+
untruncated_ids[:, self.tokenizer.model_max_length - 1:-1])
|
311 |
+
logger.warning(
|
312 |
+
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
313 |
+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
314 |
+
)
|
315 |
+
|
316 |
+
if hasattr(self.text_encoder.config, "use_attention_mask"
|
317 |
+
) and self.text_encoder.config.use_attention_mask:
|
318 |
attention_mask = text_inputs.attention_mask.to(device)
|
319 |
else:
|
320 |
attention_mask = None
|
|
|
325 |
)
|
326 |
prompt_embeds = prompt_embeds[0]
|
327 |
|
328 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype,
|
329 |
+
device=device)
|
330 |
|
331 |
bs_embed, seq_len, _ = prompt_embeds.shape
|
332 |
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
333 |
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
334 |
+
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt,
|
335 |
+
seq_len, -1)
|
336 |
|
337 |
# get unconditional embeddings for classifier free guidance
|
338 |
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
|
|
340 |
if negative_prompt is None:
|
341 |
uncond_tokens = [""] * batch_size
|
342 |
elif type(prompt) is not type(negative_prompt):
|
343 |
+
raise TypeError(
|
344 |
+
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
345 |
+
f" {type(prompt)}.")
|
346 |
elif isinstance(negative_prompt, str):
|
347 |
uncond_tokens = [negative_prompt]
|
348 |
elif batch_size != len(negative_prompt):
|
349 |
+
raise ValueError(
|
350 |
+
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
351 |
+
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
352 |
+
" the batch size of `prompt`.")
|
353 |
else:
|
354 |
uncond_tokens = negative_prompt
|
355 |
|
|
|
362 |
return_tensors="pt",
|
363 |
)
|
364 |
|
365 |
+
if hasattr(self.text_encoder.config, "use_attention_mask"
|
366 |
+
) and self.text_encoder.config.use_attention_mask:
|
367 |
attention_mask = uncond_input.attention_mask.to(device)
|
368 |
else:
|
369 |
attention_mask = None
|
|
|
378 |
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
379 |
seq_len = negative_prompt_embeds.shape[1]
|
380 |
|
381 |
+
negative_prompt_embeds = negative_prompt_embeds.to(
|
382 |
+
dtype=self.text_encoder.dtype, device=device)
|
383 |
|
384 |
+
negative_prompt_embeds = negative_prompt_embeds.repeat(
|
385 |
+
1, num_images_per_prompt, 1)
|
386 |
+
negative_prompt_embeds = negative_prompt_embeds.view(
|
387 |
+
batch_size * num_images_per_prompt, seq_len, -1)
|
388 |
|
389 |
# For classifier free guidance, we need to do two forward passes.
|
390 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
|
|
393 |
|
394 |
return prompt_embeds
|
395 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
def decode_latents(self, latents):
|
397 |
latents = 1 / self.vae.config.scaling_factor * latents
|
398 |
image = self.vae.decode(latents).sample
|
|
|
407 |
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
|
408 |
# and should be between [0, 1]
|
409 |
|
410 |
+
accepts_eta = "eta" in set(
|
411 |
+
inspect.signature(self.scheduler.step).parameters.keys())
|
412 |
extra_step_kwargs = {}
|
413 |
if accepts_eta:
|
414 |
extra_step_kwargs["eta"] = eta
|
415 |
|
416 |
# check if the scheduler accepts generator
|
417 |
+
accepts_generator = "generator" in set(
|
418 |
+
inspect.signature(self.scheduler.step).parameters.keys())
|
419 |
if accepts_generator:
|
420 |
extra_step_kwargs["generator"] = generator
|
421 |
return extra_step_kwargs
|
422 |
|
423 |
+
def prepare_latents(self,
|
424 |
+
batch_size,
|
425 |
+
num_channels_latents,
|
426 |
+
height,
|
427 |
+
width,
|
428 |
+
dtype,
|
429 |
+
device,
|
430 |
+
generator,
|
431 |
+
latents=None):
|
432 |
+
shape = (batch_size, num_channels_latents,
|
433 |
+
height // self.vae_scale_factor,
|
434 |
+
width // self.vae_scale_factor)
|
435 |
if isinstance(generator, list) and len(generator) != batch_size:
|
436 |
+
raise ValueError(
|
437 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
438 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
439 |
+
)
|
440 |
|
441 |
if latents is None:
|
442 |
+
latents = randn_tensor(shape,
|
443 |
+
generator=generator,
|
444 |
+
device=device,
|
445 |
+
dtype=dtype)
|
446 |
else:
|
447 |
latents = latents.to(device)
|
448 |
|
|
|
462 |
negative_prompt: str = "bad quality",
|
463 |
num_images_per_prompt: int = 1,
|
464 |
eta: float = 0.0,
|
465 |
+
generator: Optional[Union[torch.Generator,
|
466 |
+
List[torch.Generator]]] = None,
|
467 |
output_type: Optional[str] = "pil",
|
468 |
return_dict: bool = True,
|
469 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor],
|
470 |
+
None]] = None,
|
471 |
callback_steps: int = 1,
|
472 |
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
473 |
controlnet_conditioning_scale: float = 1.0,
|
|
|
545 |
# 0. Default height and width to unet
|
546 |
batch_size = 4
|
547 |
device = torch.device("cuda:0")
|
548 |
+
|
549 |
+
camera = get_camera(batch_size).to(device=device)
|
550 |
+
|
551 |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
552 |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
553 |
# corresponds to doing no classifier free guidance.
|
|
|
556 |
# 4. Prepare timesteps
|
557 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
558 |
timesteps = self.scheduler.timesteps
|
559 |
+
|
560 |
+
_: torch.Tensor = self._encode_prompt(
|
561 |
prompt=prompt,
|
562 |
device=device,
|
563 |
num_images_per_prompt=num_images_per_prompt,
|
564 |
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
565 |
negative_prompt=negative_prompt,
|
566 |
+
) # type: ignore
|
567 |
+
prompt_embeds_neg, prompt_embeds_pos = _.chunk(2)
|
568 |
|
569 |
# 5. Prepare latent variables
|
570 |
latents: torch.Tensor = self.prepare_latents(
|
|
|
572 |
4,
|
573 |
height,
|
574 |
width,
|
575 |
+
prompt_embeds_pos.dtype,
|
576 |
device,
|
577 |
generator,
|
578 |
None,
|
579 |
)
|
580 |
+
|
581 |
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
582 |
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
583 |
|
584 |
# 7. Denoising loop
|
585 |
+
num_warmup_steps = len(
|
586 |
+
timesteps) - num_inference_steps * self.scheduler.order
|
587 |
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
588 |
for i, t in enumerate(timesteps):
|
589 |
# expand the latents if we are doing classifier free guidance
|
590 |
+
multiplier = 2 if do_classifier_free_guidance else 1
|
591 |
+
latent_model_input = torch.cat([latents] * multiplier)
|
592 |
+
latent_model_input = self.scheduler.scale_model_input(
|
593 |
+
latent_model_input, t)
|
594 |
|
595 |
# predict the noise residual
|
596 |
+
# print(
|
597 |
+
# f"shape of latent_model_input: {latent_model_input.shape}"
|
598 |
+
# ) # [2*4, 4, 32, 32]
|
599 |
+
# print(f"shape of prompt_embeds: {prompt_embeds.shape}"
|
600 |
+
# ) # [2*4, 77, 768]
|
601 |
+
# print(f"shape of camera: {camera.shape}") # [4, 16]
|
602 |
+
noise_pred = self.unet.forward(
|
603 |
+
x=latent_model_input,
|
604 |
+
timesteps=torch.tensor([t] * 4 * multiplier,
|
605 |
+
device=device),
|
606 |
+
context=torch.cat([prompt_embeds_neg] * 4 +
|
607 |
+
[prompt_embeds_pos] * 4),
|
608 |
+
num_frames=4,
|
609 |
+
camera=torch.cat([camera] * multiplier),
|
610 |
+
)
|
611 |
|
612 |
# perform guidance
|
613 |
if do_classifier_free_guidance:
|
614 |
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
615 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
616 |
+
noise_pred_text - noise_pred_uncond)
|
617 |
|
618 |
# compute the previous noisy sample x_t -> x_t-1
|
619 |
# latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype)
|
620 |
+
latents: torch.Tensor = self.scheduler.step(
|
621 |
+
noise_pred,
|
622 |
+
t,
|
623 |
+
latents,
|
624 |
+
**extra_step_kwargs,
|
625 |
+
return_dict=False)[0]
|
626 |
|
627 |
# call the callback, if provided
|
628 |
+
if i == len(timesteps) - 1 or (
|
629 |
+
(i + 1) > num_warmup_steps and
|
630 |
+
(i + 1) % self.scheduler.order == 0):
|
631 |
progress_bar.update()
|
632 |
if callback is not None and i % callback_steps == 0:
|
633 |
+
callback(i, t, latents) # type: ignore
|
634 |
|
635 |
# 8. Post-processing
|
636 |
if output_type == "latent":
|
|
|
645 |
image = self.decode_latents(latents)
|
646 |
|
647 |
# Offload last model to CPU
|
648 |
+
if hasattr(
|
649 |
+
self,
|
650 |
+
"final_offload_hook") and self.final_offload_hook is not None:
|
651 |
self.final_offload_hook.offload()
|
652 |
|
653 |
if not return_dict:
|
654 |
+
return image
|
655 |
|
656 |
+
return StableDiffusionPipelineOutput(images=image,
|
657 |
+
nsfw_content_detected=None)
|