voidDescriptor commited on
Commit
dd4ae0d
1 Parent(s): 53c6380

Upload 2 files

Browse files
hotshot_xl_controlnet_pipeline.py ADDED
@@ -0,0 +1,1389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modifications:
16
+ # Copyright 2023 Natural Synthetics Inc. All rights reserved.
17
+ # - Adapted the SDXL Controlnet Pipeline to work temporally
18
+
19
+ import inspect
20
+ import os
21
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
22
+
23
+ import numpy as np
24
+ import PIL.Image
25
+ import torch
26
+ import torch.nn.functional as F
27
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
28
+
29
+ from hotshot_xl import HotshotPipelineXLOutput
30
+
31
+ from diffusers.image_processor import VaeImageProcessor
32
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
33
+ from diffusers.models import AutoencoderKL, ControlNetModel
34
+ from diffusers.models.attention_processor import (
35
+ AttnProcessor2_0,
36
+ LoRAAttnProcessor2_0,
37
+ LoRAXFormersAttnProcessor,
38
+ XFormersAttnProcessor,
39
+ )
40
+ from diffusers.schedulers import KarrasDiffusionSchedulers
41
+ from diffusers.utils import (
42
+ is_accelerate_available,
43
+ is_accelerate_version,
44
+ logging,
45
+ replace_example_docstring,
46
+ )
47
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
48
+ from diffusers.utils.torch_utils import randn_tensor, is_compiled_module
49
+
50
+ from ..models.unet import UNet3DConditionModel
51
+
52
+ from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
53
+ from einops import rearrange
54
+ from tqdm import tqdm
55
+
56
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
57
+
58
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
59
+ """
60
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
61
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
62
+ """
63
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
64
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
65
+ # rescale the results from guidance (fixes overexposure)
66
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
67
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
68
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
69
+ return noise_cfg
70
+
71
+ EXAMPLE_DOC_STRING = """
72
+ Examples:
73
+ ```py
74
+ >>> import torch
75
+ >>> from hotshot_xl import HotshotPipelineXL
76
+ >>> from diffusers import ControlNetModel
77
+
78
+ >>> pipe = HotshotXLPipeline.from_pretrained(
79
+ ... "hotshotco/Hotshot-XL",
80
+ ... controlnet=ControlNetModel.from_pretrained("diffusers/controlnet-canny-sdxl-1.0")
81
+ ... )
82
+
83
+ >>> def canny(image):
84
+ >>> image = cv2.Canny(image, 100, 200)
85
+ >>> image = image[:, :, None]
86
+ >>> image = np.concatenate([image, image, image], axis=2)
87
+ >>> return Image.fromarray(image)
88
+
89
+ >>> # assuming you have 8 keyframes in current directory...
90
+
91
+ >>> keyframes = [f"image_{i}.jpg" for i in range(8)]
92
+ >>> control_images = [canny(Image.open(fp)) for fp in keyframes]
93
+
94
+ >>> pipe = pipe.to("cuda")
95
+
96
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
97
+ >>> video = pipe(prompt,
98
+ ... width=672, height=384,
99
+ ... original_size=(1920, 1080),
100
+ ... target_size=(512, 512),
101
+ ... output_type="tensor",
102
+ ... controlnet_conditioning_scale=0.7,
103
+ ... control_images=control_images
104
+ ).video
105
+ ```
106
+ """
107
+ class HotshotXLControlNetPipeline(
108
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
109
+ ):
110
+ r"""
111
+ Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance.
112
+
113
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
114
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
115
+
116
+ The pipeline also inherits the following loading methods:
117
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
118
+ - [`loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
119
+ - [`loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
120
+
121
+ Args:
122
+ vae ([`AutoencoderKL`]):
123
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
124
+ text_encoder ([`~transformers.CLIPTextModel`]):
125
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
126
+ text_encoder_2 ([`~transformers.CLIPTextModelWithProjection`]):
127
+ Second frozen text-encoder
128
+ ([laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)).
129
+ tokenizer ([`~transformers.CLIPTokenizer`]):
130
+ A `CLIPTokenizer` to tokenize text.
131
+ tokenizer_2 ([`~transformers.CLIPTokenizer`]):
132
+ A `CLIPTokenizer` to tokenize text.
133
+ unet ([`UNet3DConditionModel`]):
134
+ A `UNet3DConditionModel` to denoise the encoded image latents.
135
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
136
+ Provides additional conditioning to the `unet` during the denoising process. If you set multiple
137
+ ControlNets as a list, the outputs from each ControlNet are added together to create one combined
138
+ additional conditioning.
139
+ scheduler ([`SchedulerMixin`]):
140
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
141
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
142
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
143
+ Whether the negative prompt embeddings should always be set to 0. Also see the config of
144
+ `stabilityai/stable-diffusion-xl-base-1-0`.
145
+ add_watermarker (`bool`, *optional*):
146
+ Whether to use the [invisible_watermark](https://github.com/ShieldMnt/invisible-watermark/) library to
147
+ watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no
148
+ watermarker is used.
149
+ """
150
+
151
+ def __init__(
152
+ self,
153
+ vae: AutoencoderKL,
154
+ text_encoder: CLIPTextModel,
155
+ text_encoder_2: CLIPTextModelWithProjection,
156
+ tokenizer: CLIPTokenizer,
157
+ tokenizer_2: CLIPTokenizer,
158
+ unet: UNet3DConditionModel,
159
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
160
+ scheduler: KarrasDiffusionSchedulers,
161
+ force_zeros_for_empty_prompt: bool = True,
162
+ add_watermarker: Optional[bool] = None,
163
+ ):
164
+ super().__init__()
165
+
166
+ if isinstance(controlnet, (list, tuple)):
167
+ controlnet = MultiControlNetModel(controlnet)
168
+
169
+ self.register_modules(
170
+ vae=vae,
171
+ text_encoder=text_encoder,
172
+ text_encoder_2=text_encoder_2,
173
+ tokenizer=tokenizer,
174
+ tokenizer_2=tokenizer_2,
175
+ unet=unet,
176
+ controlnet=controlnet,
177
+ scheduler=scheduler,
178
+ )
179
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
180
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
181
+ self.control_image_processor = VaeImageProcessor(
182
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
183
+ )
184
+
185
+ self.watermark = None
186
+
187
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
188
+
189
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
190
+ def enable_vae_slicing(self):
191
+ r"""
192
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
193
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
194
+ """
195
+ self.vae.enable_slicing()
196
+
197
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
198
+ def disable_vae_slicing(self):
199
+ r"""
200
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
201
+ computing decoding in one step.
202
+ """
203
+ self.vae.disable_slicing()
204
+
205
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
206
+ def enable_vae_tiling(self):
207
+ r"""
208
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
209
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
210
+ processing larger images.
211
+ """
212
+ self.vae.enable_tiling()
213
+
214
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
215
+ def disable_vae_tiling(self):
216
+ r"""
217
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
218
+ computing decoding in one step.
219
+ """
220
+ self.vae.disable_tiling()
221
+
222
+ def enable_model_cpu_offload(self, gpu_id=0):
223
+ r"""
224
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
225
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
226
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
227
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
228
+ """
229
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
230
+ from accelerate import cpu_offload_with_hook
231
+ else:
232
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
233
+
234
+ device = torch.device(f"cuda:{gpu_id}")
235
+
236
+ if self.device.type != "cpu":
237
+ self.to("cpu", silence_dtype_warnings=True)
238
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
239
+
240
+ model_sequence = (
241
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
242
+ )
243
+ model_sequence.extend([self.unet, self.vae])
244
+
245
+ hook = None
246
+ for cpu_offloaded_model in model_sequence:
247
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
248
+
249
+ cpu_offload_with_hook(self.controlnet, device)
250
+
251
+ # We'll offload the last model manually.
252
+ self.final_offload_hook = hook
253
+
254
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
255
+ def encode_prompt(
256
+ self,
257
+ prompt: str,
258
+ prompt_2: Optional[str] = None,
259
+ device: Optional[torch.device] = None,
260
+ num_images_per_prompt: int = 1,
261
+ do_classifier_free_guidance: bool = True,
262
+ negative_prompt: Optional[str] = None,
263
+ negative_prompt_2: Optional[str] = None,
264
+ prompt_embeds: Optional[torch.FloatTensor] = None,
265
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
266
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
267
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
268
+ lora_scale: Optional[float] = None,
269
+ ):
270
+ r"""
271
+ Encodes the prompt into text encoder hidden states.
272
+
273
+ Args:
274
+ prompt (`str` or `List[str]`, *optional*):
275
+ prompt to be encoded
276
+ prompt_2 (`str` or `List[str]`, *optional*):
277
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
278
+ used in both text-encoders
279
+ device: (`torch.device`):
280
+ torch device
281
+ num_images_per_prompt (`int`):
282
+ number of images that should be generated per prompt
283
+ do_classifier_free_guidance (`bool`):
284
+ whether to use classifier free guidance or not
285
+ negative_prompt (`str` or `List[str]`, *optional*):
286
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
287
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
288
+ less than `1`).
289
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
290
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
291
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
292
+ prompt_embeds (`torch.FloatTensor`, *optional*):
293
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
294
+ provided, text embeddings will be generated from `prompt` input argument.
295
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
296
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
297
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
298
+ argument.
299
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
300
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
301
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
302
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
303
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
304
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
305
+ input argument.
306
+ lora_scale (`float`, *optional*):
307
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
308
+ """
309
+ device = device or self._execution_device
310
+
311
+ # set lora scale so that monkey patched LoRA
312
+ # function of text encoder can correctly access it
313
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
314
+ self._lora_scale = lora_scale
315
+
316
+ if prompt is not None and isinstance(prompt, str):
317
+ batch_size = 1
318
+ elif prompt is not None and isinstance(prompt, list):
319
+ batch_size = len(prompt)
320
+ else:
321
+ batch_size = prompt_embeds.shape[0]
322
+
323
+ # Define tokenizers and text encoders
324
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
325
+ text_encoders = (
326
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
327
+ )
328
+
329
+ if prompt_embeds is None:
330
+ prompt_2 = prompt_2 or prompt
331
+ # textual inversion: procecss multi-vector tokens if necessary
332
+ prompt_embeds_list = []
333
+ prompts = [prompt, prompt_2]
334
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
335
+ if isinstance(self, TextualInversionLoaderMixin):
336
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
337
+
338
+ text_inputs = tokenizer(
339
+ prompt,
340
+ padding="max_length",
341
+ max_length=tokenizer.model_max_length,
342
+ truncation=True,
343
+ return_tensors="pt",
344
+ )
345
+
346
+ text_input_ids = text_inputs.input_ids
347
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
348
+
349
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
350
+ text_input_ids, untruncated_ids
351
+ ):
352
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
353
+ logger.warning(
354
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
355
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
356
+ )
357
+
358
+ prompt_embeds = text_encoder(
359
+ text_input_ids.to(device),
360
+ output_hidden_states=True,
361
+ )
362
+
363
+ # We are only ALWAYS interested in the pooled output of the final text encoder
364
+ pooled_prompt_embeds = prompt_embeds[0]
365
+ prompt_embeds = prompt_embeds.hidden_states[-2]
366
+
367
+ prompt_embeds_list.append(prompt_embeds)
368
+
369
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
370
+
371
+ # get unconditional embeddings for classifier free guidance
372
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
373
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
374
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
375
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
376
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
377
+ negative_prompt = negative_prompt or ""
378
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
379
+
380
+ uncond_tokens: List[str]
381
+ if prompt is not None and type(prompt) is not type(negative_prompt):
382
+ raise TypeError(
383
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
384
+ f" {type(prompt)}."
385
+ )
386
+ elif isinstance(negative_prompt, str):
387
+ uncond_tokens = [negative_prompt, negative_prompt_2]
388
+ elif batch_size != len(negative_prompt):
389
+ raise ValueError(
390
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
391
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
392
+ " the batch size of `prompt`."
393
+ )
394
+ else:
395
+ uncond_tokens = [negative_prompt, negative_prompt_2]
396
+
397
+ negative_prompt_embeds_list = []
398
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
399
+ if isinstance(self, TextualInversionLoaderMixin):
400
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
401
+
402
+ max_length = prompt_embeds.shape[1]
403
+ uncond_input = tokenizer(
404
+ negative_prompt,
405
+ padding="max_length",
406
+ max_length=max_length,
407
+ truncation=True,
408
+ return_tensors="pt",
409
+ )
410
+
411
+ negative_prompt_embeds = text_encoder(
412
+ uncond_input.input_ids.to(device),
413
+ output_hidden_states=True,
414
+ )
415
+ # We are only ALWAYS interested in the pooled output of the final text encoder
416
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
417
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
418
+
419
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
420
+
421
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
422
+
423
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
424
+ bs_embed, seq_len, _ = prompt_embeds.shape
425
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
426
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
427
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
428
+
429
+ if do_classifier_free_guidance:
430
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
431
+ seq_len = negative_prompt_embeds.shape[1]
432
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
433
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
434
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
435
+
436
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
437
+ bs_embed * num_images_per_prompt, -1
438
+ )
439
+ if do_classifier_free_guidance:
440
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
441
+ bs_embed * num_images_per_prompt, -1
442
+ )
443
+
444
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
445
+
446
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
447
+ def prepare_extra_step_kwargs(self, generator, eta):
448
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
449
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
450
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
451
+ # and should be between [0, 1]
452
+
453
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
454
+ extra_step_kwargs = {}
455
+ if accepts_eta:
456
+ extra_step_kwargs["eta"] = eta
457
+
458
+ # check if the scheduler accepts generator
459
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
460
+ if accepts_generator:
461
+ extra_step_kwargs["generator"] = generator
462
+ return extra_step_kwargs
463
+
464
+ def check_inputs(
465
+ self,
466
+ prompt,
467
+ prompt_2,
468
+ control_images,
469
+ video_length,
470
+ callback_steps,
471
+ negative_prompt=None,
472
+ negative_prompt_2=None,
473
+ prompt_embeds=None,
474
+ negative_prompt_embeds=None,
475
+ pooled_prompt_embeds=None,
476
+ negative_pooled_prompt_embeds=None,
477
+ controlnet_conditioning_scale=1.0,
478
+ control_guidance_start=0.0,
479
+ control_guidance_end=1.0,
480
+ ):
481
+ if (callback_steps is None) or (
482
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
483
+ ):
484
+ raise ValueError(
485
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
486
+ f" {type(callback_steps)}."
487
+ )
488
+
489
+ if prompt is not None and prompt_embeds is not None:
490
+ raise ValueError(
491
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
492
+ " only forward one of the two."
493
+ )
494
+ elif prompt_2 is not None and prompt_embeds is not None:
495
+ raise ValueError(
496
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
497
+ " only forward one of the two."
498
+ )
499
+ elif prompt is None and prompt_embeds is None:
500
+ raise ValueError(
501
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
502
+ )
503
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
504
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
505
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
506
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
507
+
508
+ if negative_prompt is not None and negative_prompt_embeds is not None:
509
+ raise ValueError(
510
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
511
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
512
+ )
513
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
514
+ raise ValueError(
515
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
516
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
517
+ )
518
+
519
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
520
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
521
+ raise ValueError(
522
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
523
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
524
+ f" {negative_prompt_embeds.shape}."
525
+ )
526
+
527
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
528
+ raise ValueError(
529
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
530
+ )
531
+
532
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
533
+ raise ValueError(
534
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
535
+ )
536
+
537
+ # `prompt` needs more sophisticated handling when there are multiple
538
+ # conditionings.
539
+ if isinstance(self.controlnet, MultiControlNetModel):
540
+ if isinstance(prompt, list):
541
+ logger.warning(
542
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
543
+ " prompts. The conditionings will be fixed across the prompts."
544
+ )
545
+
546
+ # Check `image`
547
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
548
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
549
+ )
550
+ if (
551
+ isinstance(self.controlnet, ControlNetModel)
552
+ or is_compiled
553
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
554
+ ):
555
+
556
+ assert len(control_images) == video_length
557
+ # for image in control_images:
558
+ # self.check_image(image, prompt, prompt_embeds)
559
+ elif (
560
+ isinstance(self.controlnet, MultiControlNetModel)
561
+ or is_compiled
562
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
563
+ ):
564
+ ...
565
+ # todo
566
+ #
567
+ # if not isinstance(image, list):
568
+ # raise TypeError("For multiple controlnets: `image` must be type `list`")
569
+ #
570
+ # # When `image` is a nested list:
571
+ # # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
572
+ # elif any(isinstance(i, list) for i in image):
573
+ # raise ValueError("A single batch of multiple conditionings are supported at the moment.")
574
+ # elif len(image) != len(self.controlnet.nets):
575
+ # raise ValueError(
576
+ # f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets."
577
+ # )
578
+ #
579
+ # for image_ in image:
580
+ # self.check_image(image_, prompt, prompt_embeds)
581
+ else:
582
+ assert False
583
+
584
+ # Check `controlnet_conditioning_scale`
585
+ if (
586
+ isinstance(self.controlnet, ControlNetModel)
587
+ or is_compiled
588
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
589
+ ):
590
+ if not isinstance(controlnet_conditioning_scale, float):
591
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
592
+ elif (
593
+ isinstance(self.controlnet, MultiControlNetModel)
594
+ or is_compiled
595
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
596
+ ):
597
+ if isinstance(controlnet_conditioning_scale, list):
598
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
599
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
600
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
601
+ self.controlnet.nets
602
+ ):
603
+ raise ValueError(
604
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
605
+ " the same length as the number of controlnets"
606
+ )
607
+ else:
608
+ assert False
609
+
610
+ if not isinstance(control_guidance_start, (tuple, list)):
611
+ control_guidance_start = [control_guidance_start]
612
+
613
+ if not isinstance(control_guidance_end, (tuple, list)):
614
+ control_guidance_end = [control_guidance_end]
615
+
616
+ if len(control_guidance_start) != len(control_guidance_end):
617
+ raise ValueError(
618
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
619
+ )
620
+
621
+ if isinstance(self.controlnet, MultiControlNetModel):
622
+ if len(control_guidance_start) != len(self.controlnet.nets):
623
+ raise ValueError(
624
+ f"`control_guidance_start`: {control_guidance_start} has {len(control_guidance_start)} elements but there are {len(self.controlnet.nets)} controlnets available. Make sure to provide {len(self.controlnet.nets)}."
625
+ )
626
+
627
+ for start, end in zip(control_guidance_start, control_guidance_end):
628
+ if start >= end:
629
+ raise ValueError(
630
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
631
+ )
632
+ if start < 0.0:
633
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
634
+ if end > 1.0:
635
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
636
+
637
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image
638
+ def check_image(self, image, prompt, prompt_embeds):
639
+ image_is_pil = isinstance(image, PIL.Image.Image)
640
+ image_is_tensor = isinstance(image, torch.Tensor)
641
+ image_is_np = isinstance(image, np.ndarray)
642
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
643
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
644
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
645
+
646
+ if (
647
+ not image_is_pil
648
+ and not image_is_tensor
649
+ and not image_is_np
650
+ and not image_is_pil_list
651
+ and not image_is_tensor_list
652
+ and not image_is_np_list
653
+ ):
654
+ raise TypeError(
655
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
656
+ )
657
+
658
+ if image_is_pil:
659
+ image_batch_size = 1
660
+ else:
661
+ image_batch_size = len(image)
662
+
663
+ if prompt is not None and isinstance(prompt, str):
664
+ prompt_batch_size = 1
665
+ elif prompt is not None and isinstance(prompt, list):
666
+ prompt_batch_size = len(prompt)
667
+ elif prompt_embeds is not None:
668
+ prompt_batch_size = prompt_embeds.shape[0]
669
+
670
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
671
+ raise ValueError(
672
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
673
+ )
674
+
675
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
676
+ def prepare_images(
677
+ self,
678
+ images,
679
+ width,
680
+ height,
681
+ batch_size,
682
+ num_images_per_prompt,
683
+ device,
684
+ dtype,
685
+ do_classifier_free_guidance=False,
686
+ guess_mode=False,
687
+ ):
688
+ images_pre_processed = [self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) for image in images]
689
+
690
+ images_pre_processed = torch.cat(images_pre_processed, dim=0)
691
+
692
+ repeat_factor = [1] * len(images_pre_processed.shape)
693
+ repeat_factor[0] = batch_size * num_images_per_prompt
694
+ images_pre_processed = images_pre_processed.repeat(*repeat_factor)
695
+
696
+ images = images_pre_processed.unsqueeze(0)
697
+
698
+ # image_batch_size = image.shape[0]
699
+ #
700
+ # if image_batch_size == 1:
701
+ # repeat_by = batch_size
702
+ # else:
703
+ # # image batch size is the same as prompt batch size
704
+ # repeat_by = num_images_per_prompt
705
+
706
+ #image = image.repeat_interleave(repeat_by, dim=0)
707
+
708
+ images = images.to(device=device, dtype=dtype)
709
+
710
+ if do_classifier_free_guidance and not guess_mode:
711
+ repeat_factor = [1] * len(images.shape)
712
+ repeat_factor[0] = 2
713
+ images = images.repeat(*repeat_factor)
714
+
715
+ return images
716
+
717
+ # def prepare_images(self,
718
+ # images: list,
719
+ # width,
720
+ # height,
721
+ # batch_size,
722
+ # num_images_per_prompt,
723
+ # device,
724
+ # dtype,
725
+ # do_classifier_free_guidance=False,
726
+ # guess_mode=False):
727
+ #
728
+ # images = [self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) for image in images]
729
+ #
730
+ # image_batch_size = image.shape[0]
731
+ #
732
+ # if image_batch_size == 1:
733
+ # repeat_by = batch_size
734
+ # else:
735
+ # # image batch size is the same as prompt batch size
736
+ # repeat_by = num_images_per_prompt
737
+ #
738
+ # image = image.repeat_interleave(repeat_by, dim=0)
739
+ #
740
+ # image = image.to(device=device, dtype=dtype)
741
+ #
742
+ # if do_classifier_free_guidance and not guess_mode:
743
+ # image = torch.cat([image] * 2)
744
+ #
745
+ # return image
746
+
747
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
748
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
749
+ #shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
750
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
751
+ if isinstance(generator, list) and len(generator) != batch_size:
752
+ raise ValueError(
753
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
754
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
755
+ )
756
+
757
+ if latents is None:
758
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
759
+ else:
760
+ latents = latents.to(device)
761
+
762
+ # scale the initial noise by the standard deviation required by the scheduler
763
+ latents = latents * self.scheduler.init_noise_sigma
764
+ return latents
765
+
766
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids
767
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
768
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
769
+
770
+ passed_add_embed_dim = (
771
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
772
+ )
773
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
774
+
775
+ if expected_add_embed_dim != passed_add_embed_dim:
776
+ raise ValueError(
777
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
778
+ )
779
+
780
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
781
+ return add_time_ids
782
+
783
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
784
+ def upcast_vae(self):
785
+ dtype = self.vae.dtype
786
+ self.vae.to(dtype=torch.float32)
787
+ use_torch_2_0_or_xformers = isinstance(
788
+ self.vae.decoder.mid_block.attentions[0].processor,
789
+ (
790
+ AttnProcessor2_0,
791
+ XFormersAttnProcessor,
792
+ LoRAXFormersAttnProcessor,
793
+ LoRAAttnProcessor2_0,
794
+ ),
795
+ )
796
+ # if xformers or torch_2_0 is used attention block does not need
797
+ # to be in float32 which can save lots of memory
798
+ if use_torch_2_0_or_xformers:
799
+ self.vae.post_quant_conv.to(dtype)
800
+ self.vae.decoder.conv_in.to(dtype)
801
+ self.vae.decoder.mid_block.to(dtype)
802
+
803
+ @torch.no_grad()
804
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
805
+ def __call__(
806
+ self,
807
+ prompt: Union[str, List[str]] = None,
808
+ prompt_2: Optional[Union[str, List[str]]] = None,
809
+ video_length: Optional[int] = 8,
810
+ control_images: List[PIL.Image.Image] = None,
811
+ height: Optional[int] = None,
812
+ width: Optional[int] = None,
813
+ num_inference_steps: int = 50,
814
+ guidance_scale: float = 5.0,
815
+ negative_prompt: Optional[Union[str, List[str]]] = None,
816
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
817
+ num_images_per_prompt: Optional[int] = 1,
818
+ eta: float = 0.0,
819
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
820
+ latents: Optional[torch.FloatTensor] = None,
821
+ prompt_embeds: Optional[torch.FloatTensor] = None,
822
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
823
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
824
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
825
+ output_type: Optional[str] = "pil",
826
+ return_dict: bool = True,
827
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
828
+ callback_steps: int = 1,
829
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
830
+ guidance_rescale: float = 0.0,
831
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
832
+ guess_mode: bool = False,
833
+ control_guidance_start: Union[float, List[float]] = 0.0,
834
+ control_guidance_end: Union[float, List[float]] = 1.0,
835
+ original_size: Tuple[int, int] = None,
836
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
837
+ target_size: Tuple[int, int] = None,
838
+ negative_original_size: Optional[Tuple[int, int]] = None,
839
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
840
+ negative_target_size: Optional[Tuple[int, int]] = None,
841
+ ):
842
+ r"""
843
+ The call function to the pipeline for generation.
844
+
845
+ Args:
846
+ prompt (`str` or `List[str]`, *optional*):
847
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
848
+ prompt_2 (`str` or `List[str]`, *optional*):
849
+ The prompt or prompts to be sent to `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
850
+ used in both text-encoders.
851
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
852
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
853
+ The ControlNet input condition to provide guidance to the `unet` for generation. If the type is
854
+ specified as `torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can also be
855
+ accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If height
856
+ and/or width are passed, `image` is resized accordingly. If multiple ControlNets are specified in
857
+ `init`, images must be passed as a list such that each element of the list can be correctly batched for
858
+ input to a single ControlNet.
859
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
860
+ The height in pixels of the generated image.
861
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
862
+ The width in pixels of the generated image.
863
+ num_inference_steps (`int`, *optional*, defaults to 50):
864
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
865
+ expense of slower inference.
866
+ guidance_scale (`float`, *optional*, defaults to 5.0):
867
+ A higher guidance scale value encourages the model to generate images closely linked to the text
868
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
869
+ negative_prompt (`str` or `List[str]`, *optional*):
870
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
871
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
872
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
873
+ The prompt or prompts to guide what to not include in image generation. This is sent to `tokenizer_2`
874
+ and `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders.
875
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
876
+ The number of images to generate per prompt.
877
+ eta (`float`, *optional*, defaults to 0.0):
878
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
879
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
880
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
881
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
882
+ generation deterministic.
883
+ latents (`torch.FloatTensor`, *optional*):
884
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
885
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
886
+ tensor is generated by sampling using the supplied random `generator`.
887
+ prompt_embeds (`torch.FloatTensor`, *optional*):
888
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
889
+ provided, text embeddings are generated from the `prompt` input argument.
890
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
891
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
892
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
893
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
894
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
895
+ not provided, pooled text embeddings are generated from `prompt` input argument.
896
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
897
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs (prompt
898
+ weighting). If not provided, pooled `negative_prompt_embeds` are generated from `negative_prompt` input
899
+ argument.
900
+ output_type (`str`, *optional*, defaults to `"pil"`):
901
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
902
+ return_dict (`bool`, *optional*, defaults to `True`):
903
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
904
+ plain tuple.
905
+ callback (`Callable`, *optional*):
906
+ A function that calls every `callback_steps` steps during inference. The function is called with the
907
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
908
+ callback_steps (`int`, *optional*, defaults to 1):
909
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
910
+ every step.
911
+ cross_attention_kwargs (`dict`, *optional*):
912
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
913
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
914
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
915
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
916
+ to the residual in the original `unet`. If multiple ControlNets are specified in `init`, you can set
917
+ the corresponding scale as a list.
918
+ guess_mode (`bool`, *optional*, defaults to `False`):
919
+ The ControlNet encoder tries to recognize the content of the input image even if you remove all
920
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
921
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
922
+ The percentage of total steps at which the ControlNet starts applying.
923
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
924
+ The percentage of total steps at which the ControlNet stops applying.
925
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
926
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
927
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
928
+ explained in section 2.2 of
929
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
930
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
931
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
932
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
933
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
934
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
935
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
936
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
937
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
938
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
939
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
940
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
941
+ micro-conditioning as explained in section 2.2 of
942
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
943
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
944
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
945
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
946
+ micro-conditioning as explained in section 2.2 of
947
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
948
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
949
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
950
+ To negatively condition the generation process based on a target image resolution. It should be as same
951
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
952
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
953
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
954
+
955
+ Examples:
956
+
957
+ Returns:
958
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
959
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
960
+ otherwise a `tuple` is returned containing the output images.
961
+ """
962
+
963
+
964
+ if video_length > 1 and num_images_per_prompt > 1:
965
+ print(f"Warning - setting num_images_per_prompt = 1 because video_length = {video_length}")
966
+ num_images_per_prompt = 1
967
+
968
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
969
+
970
+ # align format for control guidance
971
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
972
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
973
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
974
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
975
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
976
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
977
+ control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [
978
+ control_guidance_end
979
+ ]
980
+
981
+ # 1. Check inputs. Raise error if not correct
982
+ self.check_inputs(
983
+ prompt,
984
+ prompt_2,
985
+ control_images,
986
+ video_length,
987
+ callback_steps,
988
+ negative_prompt,
989
+ negative_prompt_2,
990
+ prompt_embeds,
991
+ negative_prompt_embeds,
992
+ pooled_prompt_embeds,
993
+ negative_pooled_prompt_embeds,
994
+ controlnet_conditioning_scale,
995
+ control_guidance_start,
996
+ control_guidance_end,
997
+ )
998
+
999
+ # 2. Define call parameters
1000
+ if prompt is not None and isinstance(prompt, str):
1001
+ batch_size = 1
1002
+ elif prompt is not None and isinstance(prompt, list):
1003
+ batch_size = len(prompt)
1004
+ else:
1005
+ batch_size = prompt_embeds.shape[0]
1006
+
1007
+ device = self._execution_device
1008
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1009
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1010
+ # corresponds to doing no classifier free guidance.
1011
+ do_classifier_free_guidance = guidance_scale > 1.0
1012
+
1013
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1014
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1015
+
1016
+ global_pool_conditions = (
1017
+ controlnet.config.global_pool_conditions
1018
+ if isinstance(controlnet, ControlNetModel)
1019
+ else controlnet.nets[0].config.global_pool_conditions
1020
+ )
1021
+ guess_mode = guess_mode or global_pool_conditions
1022
+
1023
+ # 3. Encode input prompt
1024
+ text_encoder_lora_scale = (
1025
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
1026
+ )
1027
+ (
1028
+ prompt_embeds,
1029
+ negative_prompt_embeds,
1030
+ pooled_prompt_embeds,
1031
+ negative_pooled_prompt_embeds,
1032
+ ) = self.encode_prompt(
1033
+ prompt,
1034
+ prompt_2,
1035
+ device,
1036
+ num_images_per_prompt,
1037
+ do_classifier_free_guidance,
1038
+ negative_prompt,
1039
+ negative_prompt_2,
1040
+ prompt_embeds=prompt_embeds,
1041
+ negative_prompt_embeds=negative_prompt_embeds,
1042
+ pooled_prompt_embeds=pooled_prompt_embeds,
1043
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1044
+ lora_scale=text_encoder_lora_scale,
1045
+ )
1046
+
1047
+
1048
+ # 4. Prepare image
1049
+ if isinstance(controlnet, ControlNetModel):
1050
+
1051
+ assert len(control_images) == video_length * batch_size
1052
+
1053
+ images = self.prepare_images(
1054
+ images=control_images,
1055
+ width=width,
1056
+ height=height,
1057
+ batch_size=batch_size * num_images_per_prompt,
1058
+ num_images_per_prompt=num_images_per_prompt,
1059
+ device=device,
1060
+ dtype=controlnet.dtype,
1061
+ do_classifier_free_guidance=do_classifier_free_guidance,
1062
+ guess_mode=guess_mode,
1063
+ )
1064
+
1065
+ height, width = images.shape[-2:]
1066
+ elif isinstance(controlnet, MultiControlNetModel):
1067
+
1068
+ raise Exception("not supported yet")
1069
+
1070
+ # images = []
1071
+ #
1072
+ # for image_ in control_images:
1073
+ # image_ = self.prepare_image(
1074
+ # image=image_,
1075
+ # width=width,
1076
+ # height=height,
1077
+ # batch_size=batch_size * num_images_per_prompt,
1078
+ # num_images_per_prompt=num_images_per_prompt,
1079
+ # device=device,
1080
+ # dtype=controlnet.dtype,
1081
+ # do_classifier_free_guidance=do_classifier_free_guidance,
1082
+ # guess_mode=guess_mode,
1083
+ # )
1084
+ #
1085
+ # images.append(image_)
1086
+ #
1087
+ # image = images
1088
+ # height, width = image[0].shape[-2:]
1089
+ else:
1090
+ assert False
1091
+
1092
+ # 5. Prepare timesteps
1093
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1094
+ timesteps = self.scheduler.timesteps
1095
+
1096
+ # 6. Prepare latent variables
1097
+ num_channels_latents = self.unet.config.in_channels
1098
+ latents = self.prepare_latents(
1099
+ batch_size * num_images_per_prompt,
1100
+ num_channels_latents,
1101
+ video_length,
1102
+ height,
1103
+ width,
1104
+ prompt_embeds.dtype,
1105
+ device,
1106
+ generator,
1107
+ latents,
1108
+ )
1109
+
1110
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1111
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1112
+
1113
+ # 7.1 Create tensor stating which controlnets to keep
1114
+ controlnet_keep = []
1115
+ for i in range(len(timesteps)):
1116
+ keeps = [
1117
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1118
+ for s, e in zip(control_guidance_start, control_guidance_end)
1119
+ ]
1120
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1121
+
1122
+ # 7.2 Prepare added time ids & embeddings
1123
+ # if isinstance(image, list):
1124
+ # original_size = original_size or image[0].shape[-2:]
1125
+ # else:
1126
+ original_size = original_size or images.shape[-2:]
1127
+ target_size = target_size or (height, width)
1128
+
1129
+ add_text_embeds = pooled_prompt_embeds
1130
+ add_time_ids = self._get_add_time_ids(
1131
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
1132
+ )
1133
+
1134
+ if negative_original_size is not None and negative_target_size is not None:
1135
+ negative_add_time_ids = self._get_add_time_ids(
1136
+ negative_original_size,
1137
+ negative_crops_coords_top_left,
1138
+ negative_target_size,
1139
+ dtype=prompt_embeds.dtype,
1140
+ )
1141
+ else:
1142
+ negative_add_time_ids = add_time_ids
1143
+
1144
+ if do_classifier_free_guidance:
1145
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1146
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1147
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1148
+
1149
+ prompt_embeds = prompt_embeds.to(device)
1150
+ add_text_embeds = add_text_embeds.to(device)
1151
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1152
+
1153
+ # 8. Denoising loop
1154
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1155
+
1156
+ images = rearrange(images, "b f c h w -> (b f) c h w")
1157
+
1158
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1159
+ for i, t in enumerate(timesteps):
1160
+ # expand the latents if we are doing classifier free guidance
1161
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1162
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1163
+
1164
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1165
+
1166
+ # controlnet(s) inference
1167
+ if guess_mode and do_classifier_free_guidance:
1168
+ # Infer ControlNet only for the conditional batch.
1169
+ control_model_input = latents
1170
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1171
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1172
+ controlnet_added_cond_kwargs = {
1173
+ "text_embeds": add_text_embeds.chunk(2)[1],
1174
+ "time_ids": add_time_ids.chunk(2)[1],
1175
+ }
1176
+ else:
1177
+ control_model_input = latent_model_input
1178
+ controlnet_prompt_embeds = prompt_embeds
1179
+ controlnet_added_cond_kwargs = added_cond_kwargs
1180
+
1181
+ if isinstance(controlnet_keep[i], list):
1182
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1183
+ else:
1184
+ controlnet_cond_scale = controlnet_conditioning_scale
1185
+ if isinstance(controlnet_cond_scale, list):
1186
+ controlnet_cond_scale = controlnet_cond_scale[0]
1187
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1188
+
1189
+
1190
+ # this will be non interlaced when arranged!
1191
+ control_model_input = rearrange(control_model_input, "b c f h w -> (b f) c h w")
1192
+ # if we chunked this by 2 - the top 8 frames will be positive for cfg
1193
+ # the bottom half will be negative for cfg...
1194
+
1195
+ if video_length > 1:
1196
+ # use repeat_interleave as we need to match the rearrangement above.
1197
+
1198
+ controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(video_length, dim=0)
1199
+ controlnet_added_cond_kwargs = {
1200
+ "text_embeds": controlnet_added_cond_kwargs['text_embeds'].repeat_interleave(video_length, dim=0),
1201
+ "time_ids": controlnet_added_cond_kwargs['time_ids'].repeat_interleave(video_length, dim=0)
1202
+ }
1203
+
1204
+ # if type(image) is list:
1205
+ # image = torch.cat(image, dim=0)
1206
+
1207
+ # todo - check if video_length > 1 this needs to produce num_frames * batch_size samples...
1208
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1209
+ control_model_input,
1210
+ t,
1211
+ encoder_hidden_states=controlnet_prompt_embeds,
1212
+ controlnet_cond=images,
1213
+ conditioning_scale=cond_scale,
1214
+ guess_mode=guess_mode,
1215
+ added_cond_kwargs=controlnet_added_cond_kwargs,
1216
+ return_dict=False,
1217
+ )
1218
+
1219
+ for j, sample in enumerate(down_block_res_samples):
1220
+ down_block_res_samples[j] = rearrange(sample, "(b f) c h w -> b c f h w", f=video_length)
1221
+
1222
+ mid_block_res_sample = rearrange(mid_block_res_sample, "(b f) c h w -> b c f h w", f=video_length)
1223
+
1224
+ if guess_mode and do_classifier_free_guidance:
1225
+ # Infered ControlNet only for the conditional batch.
1226
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1227
+ # add 0 to the unconditional batch to keep it unchanged.
1228
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1229
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1230
+
1231
+ # predict the noise residual
1232
+ noise_pred = self.unet(
1233
+ latent_model_input,
1234
+ t,
1235
+ encoder_hidden_states=prompt_embeds,
1236
+ cross_attention_kwargs=cross_attention_kwargs,
1237
+ down_block_additional_residuals=down_block_res_samples,
1238
+ mid_block_additional_residual=mid_block_res_sample,
1239
+ added_cond_kwargs=added_cond_kwargs,
1240
+ return_dict=False,
1241
+ enable_temporal_attentions=video_length > 1
1242
+ )[0]
1243
+
1244
+ # perform guidance
1245
+ if do_classifier_free_guidance:
1246
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1247
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1248
+
1249
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
1250
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1251
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1252
+
1253
+ # compute the previous noisy sample x_t -> x_t-1
1254
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1255
+
1256
+ # call the callback, if provided
1257
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1258
+ progress_bar.update()
1259
+ if callback is not None and i % callback_steps == 0:
1260
+ callback(i, t, latents)
1261
+
1262
+ # make sure the VAE is in float32 mode, as it overflows in float16
1263
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
1264
+ self.upcast_vae()
1265
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1266
+
1267
+ # If we do sequential model offloading, let's offload unet and controlnet
1268
+ # manually for max memory savings
1269
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1270
+ self.unet.to("cpu")
1271
+ self.controlnet.to("cpu")
1272
+ torch.cuda.empty_cache()
1273
+
1274
+ # if not output_type == "latent":
1275
+ # # make sure the VAE is in float32 mode, as it overflows in float16
1276
+ # needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1277
+ #
1278
+ # if needs_upcasting:
1279
+ # self.upcast_vae()
1280
+ # latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1281
+ #
1282
+ # image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1283
+ #
1284
+ # # cast back to fp16 if needed
1285
+ # if needs_upcasting:
1286
+ # self.vae.to(dtype=torch.float16)
1287
+ # else:
1288
+ # image = latents
1289
+ # return StableDiffusionXLPipelineOutput(images=image)
1290
+
1291
+ video = self.decode_latents(latents)
1292
+
1293
+ # Convert to tensor
1294
+ if output_type == "tensor":
1295
+ video = torch.from_numpy(video)
1296
+
1297
+ if not return_dict:
1298
+ return video
1299
+
1300
+ return HotshotPipelineXLOutput(videos=video)
1301
+
1302
+ def decode_latents(self, latents):
1303
+ video_length = latents.shape[2]
1304
+ latents = 1 / self.vae.config.scaling_factor * latents
1305
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
1306
+ # video = self.vae.decode(latents).sample
1307
+ video = []
1308
+ for frame_idx in tqdm(range(latents.shape[0])):
1309
+ video.append(self.vae.decode(
1310
+ latents[frame_idx:frame_idx+1]).sample)
1311
+ video = torch.cat(video)
1312
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
1313
+ video = (video / 2.0 + 0.5).clamp(0, 1)
1314
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
1315
+ video = video.cpu().float().numpy()
1316
+ return video
1317
+
1318
+ # Overrride to properly handle the loading and unloading of the additional text encoder.
1319
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
1320
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
1321
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
1322
+ # it here explicitly to be able to tell that it's coming from an SDXL
1323
+ # pipeline.
1324
+ state_dict, network_alphas = self.lora_state_dict(
1325
+ pretrained_model_name_or_path_or_dict,
1326
+ unet_config=self.unet.config,
1327
+ **kwargs,
1328
+ )
1329
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
1330
+
1331
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1332
+ if len(text_encoder_state_dict) > 0:
1333
+ self.load_lora_into_text_encoder(
1334
+ text_encoder_state_dict,
1335
+ network_alphas=network_alphas,
1336
+ text_encoder=self.text_encoder,
1337
+ prefix="text_encoder",
1338
+ lora_scale=self.lora_scale,
1339
+ )
1340
+
1341
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
1342
+ if len(text_encoder_2_state_dict) > 0:
1343
+ self.load_lora_into_text_encoder(
1344
+ text_encoder_2_state_dict,
1345
+ network_alphas=network_alphas,
1346
+ text_encoder=self.text_encoder_2,
1347
+ prefix="text_encoder_2",
1348
+ lora_scale=self.lora_scale,
1349
+ )
1350
+
1351
+ @classmethod
1352
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
1353
+ def save_lora_weights(
1354
+ self,
1355
+ save_directory: Union[str, os.PathLike],
1356
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1357
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1358
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1359
+ is_main_process: bool = True,
1360
+ weight_name: str = None,
1361
+ save_function: Callable = None,
1362
+ safe_serialization: bool = True,
1363
+ ):
1364
+ state_dict = {}
1365
+
1366
+ def pack_weights(layers, prefix):
1367
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
1368
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
1369
+ return layers_state_dict
1370
+
1371
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
1372
+
1373
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
1374
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
1375
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1376
+
1377
+ self.write_lora_layers(
1378
+ state_dict=state_dict,
1379
+ save_directory=save_directory,
1380
+ is_main_process=is_main_process,
1381
+ weight_name=weight_name,
1382
+ save_function=save_function,
1383
+ safe_serialization=safe_serialization,
1384
+ )
1385
+
1386
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
1387
+ def _remove_text_encoder_monkey_patch(self):
1388
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
1389
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
hotshot_xl_pipeline.py ADDED
@@ -0,0 +1,996 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modifications:
16
+ # Copyright 2023 Natural Synthetics Inc. All rights reserved.
17
+ # - Adapted the SDXL Pipeline to work temporally
18
+
19
+
20
+ import os
21
+ import inspect
22
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
26
+ from hotshot_xl import HotshotPipelineXLOutput
27
+
28
+ from diffusers.image_processor import VaeImageProcessor
29
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
30
+ from diffusers.models import AutoencoderKL
31
+ from hotshot_xl.models.unet import UNet3DConditionModel
32
+ from diffusers.models.attention_processor import (
33
+ AttnProcessor2_0,
34
+ LoRAAttnProcessor2_0,
35
+ LoRAXFormersAttnProcessor,
36
+ XFormersAttnProcessor,
37
+ )
38
+ from diffusers.schedulers import KarrasDiffusionSchedulers
39
+ from diffusers.utils import (
40
+ is_accelerate_available,
41
+ is_accelerate_version,
42
+ logging,
43
+ replace_example_docstring,
44
+ )
45
+ from diffusers.utils.torch_utils import randn_tensor
46
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
47
+ from tqdm import tqdm
48
+ from einops import repeat, rearrange
49
+ from diffusers.utils import deprecate, logging
50
+ import gc
51
+
52
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
53
+
54
+ EXAMPLE_DOC_STRING = """
55
+ Examples:
56
+ ```py
57
+ >>> import torch
58
+ >>> from hotshot_xl import HotshotPipelineXL
59
+
60
+ >>> pipe = HotshotXLPipeline.from_pretrained(
61
+ ... "hotshotco/Hotshot-XL"
62
+ ... )
63
+ >>> pipe = pipe.to("cuda")
64
+
65
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
66
+ >>> video = pipe(prompt,
67
+ ... width=672, height=384,
68
+ ... original_size=(1920, 1080),
69
+ ... target_size=(512, 512),
70
+ ... output_type="tensor"
71
+ ).video
72
+ ```
73
+ """
74
+
75
+
76
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
77
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
78
+ """
79
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
80
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
81
+ """
82
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
83
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
84
+ # rescale the results from guidance (fixes overexposure)
85
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
86
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
87
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
88
+ return noise_cfg
89
+
90
+
91
+
92
+
93
+ class HotshotXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
94
+ r"""
95
+ Pipeline for text-to-image generation using Stable Diffusion XL.
96
+
97
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
98
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
99
+
100
+ In addition the pipeline inherits the following loading methods:
101
+ - *LoRA*: [`HotshotPipelineXL.load_lora_weights`]
102
+ - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
103
+
104
+ as well as the following saving methods:
105
+ - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
106
+
107
+ Args:
108
+ vae ([`AutoencoderKL`]):
109
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
110
+ text_encoder ([`CLIPTextModel`]):
111
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
112
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
113
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
114
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
115
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
116
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
117
+ specifically the
118
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
119
+ variant.
120
+ tokenizer (`CLIPTokenizer`):
121
+ Tokenizer of class
122
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
123
+ tokenizer_2 (`CLIPTokenizer`):
124
+ Second Tokenizer of class
125
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
126
+ unet ([`UNet3DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
127
+ scheduler ([`SchedulerMixin`]):
128
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
129
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
130
+ """
131
+
132
+ def __init__(
133
+ self,
134
+ vae: AutoencoderKL,
135
+ text_encoder: CLIPTextModel,
136
+ text_encoder_2: CLIPTextModelWithProjection,
137
+ tokenizer: CLIPTokenizer,
138
+ tokenizer_2: CLIPTokenizer,
139
+ unet: UNet3DConditionModel,
140
+ scheduler: KarrasDiffusionSchedulers,
141
+ force_zeros_for_empty_prompt: bool = True,
142
+ add_watermarker: Optional[bool] = None,
143
+ ):
144
+ super().__init__()
145
+
146
+ self.register_modules(
147
+ vae=vae,
148
+ text_encoder=text_encoder,
149
+ text_encoder_2=text_encoder_2,
150
+ tokenizer=tokenizer,
151
+ tokenizer_2=tokenizer_2,
152
+ unet=unet,
153
+ scheduler=scheduler,
154
+ )
155
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
156
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
157
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
158
+ self.default_sample_size = self.unet.config.sample_size
159
+ self.watermark = None
160
+
161
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
162
+ def enable_vae_slicing(self):
163
+ r"""
164
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
165
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
166
+ """
167
+ self.vae.enable_slicing()
168
+
169
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
170
+ def disable_vae_slicing(self):
171
+ r"""
172
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
173
+ computing decoding in one step.
174
+ """
175
+ self.vae.disable_slicing()
176
+
177
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
178
+ def enable_vae_tiling(self):
179
+ r"""
180
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
181
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
182
+ processing larger images.
183
+ """
184
+ self.vae.enable_tiling()
185
+
186
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
187
+ def disable_vae_tiling(self):
188
+ r"""
189
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
190
+ computing decoding in one step.
191
+ """
192
+ self.vae.disable_tiling()
193
+
194
+ def enable_model_cpu_offload(self, gpu_id=0):
195
+ r"""
196
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
197
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
198
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
199
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
200
+ """
201
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
202
+ from accelerate import cpu_offload_with_hook
203
+ else:
204
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
205
+
206
+ device = torch.device(f"cuda:{gpu_id}")
207
+
208
+ if self.device.type != "cpu":
209
+ self.to("cpu", silence_dtype_warnings=True)
210
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
211
+
212
+ model_sequence = (
213
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
214
+ )
215
+ model_sequence.extend([self.unet, self.vae])
216
+
217
+ hook = None
218
+ for cpu_offloaded_model in model_sequence:
219
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
220
+
221
+ # We'll offload the last model manually.
222
+ self.final_offload_hook = hook
223
+
224
+ def encode_prompt(
225
+ self,
226
+ prompt: str,
227
+ prompt_2: Optional[str] = None,
228
+ device: Optional[torch.device] = None,
229
+ num_images_per_prompt: int = 1,
230
+ do_classifier_free_guidance: bool = True,
231
+ negative_prompt: Optional[str] = None,
232
+ negative_prompt_2: Optional[str] = None,
233
+ prompt_embeds: Optional[torch.FloatTensor] = None,
234
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
235
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
236
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
237
+ lora_scale: Optional[float] = None,
238
+ ):
239
+ r"""
240
+ Encodes the prompt into text encoder hidden states.
241
+
242
+ Args:
243
+ prompt (`str` or `List[str]`, *optional*):
244
+ prompt to be encoded
245
+ prompt_2 (`str` or `List[str]`, *optional*):
246
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
247
+ used in both text-encoders
248
+ device: (`torch.device`):
249
+ torch device
250
+ num_images_per_prompt (`int`):
251
+ number of images that should be generated per prompt
252
+ do_classifier_free_guidance (`bool`):
253
+ whether to use classifier free guidance or not
254
+ negative_prompt (`str` or `List[str]`, *optional*):
255
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
256
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
257
+ less than `1`).
258
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
259
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
260
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
261
+ prompt_embeds (`torch.FloatTensor`, *optional*):
262
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
263
+ provided, text embeddings will be generated from `prompt` input argument.
264
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
265
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
266
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
267
+ argument.
268
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
269
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
270
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
271
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
272
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
273
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
274
+ input argument.
275
+ lora_scale (`float`, *optional*):
276
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
277
+ """
278
+ device = device or self._execution_device
279
+
280
+ # set lora scale so that monkey patched LoRA
281
+ # function of text encoder can correctly access it
282
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
283
+ self._lora_scale = lora_scale
284
+
285
+ if prompt is not None and isinstance(prompt, str):
286
+ batch_size = 1
287
+ elif prompt is not None and isinstance(prompt, list):
288
+ batch_size = len(prompt)
289
+ else:
290
+ batch_size = prompt_embeds.shape[0]
291
+
292
+ # Define tokenizers and text encoders
293
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
294
+ text_encoders = (
295
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
296
+ )
297
+
298
+ if prompt_embeds is None:
299
+ prompt_2 = prompt_2 or prompt
300
+ # textual inversion: procecss multi-vector tokens if necessary
301
+ prompt_embeds_list = []
302
+ prompts = [prompt, prompt_2]
303
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
304
+ if isinstance(self, TextualInversionLoaderMixin):
305
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
306
+
307
+ text_inputs = tokenizer(
308
+ prompt,
309
+ padding="max_length",
310
+ max_length=tokenizer.model_max_length,
311
+ truncation=True,
312
+ return_tensors="pt",
313
+ )
314
+
315
+ text_input_ids = text_inputs.input_ids
316
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
317
+
318
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
319
+ text_input_ids, untruncated_ids
320
+ ):
321
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
322
+ logger.warning(
323
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
324
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
325
+ )
326
+
327
+ prompt_embeds = text_encoder(
328
+ text_input_ids.to(device),
329
+ output_hidden_states=True,
330
+ )
331
+
332
+ # We are only ALWAYS interested in the pooled output of the final text encoder
333
+ pooled_prompt_embeds = prompt_embeds[0]
334
+ prompt_embeds = prompt_embeds.hidden_states[-2]
335
+
336
+ prompt_embeds_list.append(prompt_embeds)
337
+
338
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
339
+
340
+ # get unconditional embeddings for classifier free guidance
341
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
342
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
343
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
344
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
345
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
346
+ negative_prompt = negative_prompt or ""
347
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
348
+
349
+ uncond_tokens: List[str]
350
+ if prompt is not None and type(prompt) is not type(negative_prompt):
351
+ raise TypeError(
352
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
353
+ f" {type(prompt)}."
354
+ )
355
+ elif isinstance(negative_prompt, str):
356
+ uncond_tokens = [negative_prompt, negative_prompt_2]
357
+ elif batch_size != len(negative_prompt):
358
+ raise ValueError(
359
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
360
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
361
+ " the batch size of `prompt`."
362
+ )
363
+ else:
364
+ uncond_tokens = [negative_prompt, negative_prompt_2]
365
+
366
+ negative_prompt_embeds_list = []
367
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
368
+ if isinstance(self, TextualInversionLoaderMixin):
369
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
370
+
371
+ max_length = prompt_embeds.shape[1]
372
+ uncond_input = tokenizer(
373
+ negative_prompt,
374
+ padding="max_length",
375
+ max_length=max_length,
376
+ truncation=True,
377
+ return_tensors="pt",
378
+ )
379
+
380
+ negative_prompt_embeds = text_encoder(
381
+ uncond_input.input_ids.to(device),
382
+ output_hidden_states=True,
383
+ )
384
+ # We are only ALWAYS interested in the pooled output of the final text encoder
385
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
386
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
387
+
388
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
389
+
390
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
391
+
392
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
393
+ bs_embed, seq_len, _ = prompt_embeds.shape
394
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
395
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
396
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
397
+
398
+ if do_classifier_free_guidance:
399
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
400
+ seq_len = negative_prompt_embeds.shape[1]
401
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
402
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
403
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
404
+
405
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
406
+ bs_embed * num_images_per_prompt, -1
407
+ )
408
+ if do_classifier_free_guidance:
409
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
410
+ bs_embed * num_images_per_prompt, -1
411
+ )
412
+
413
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
414
+
415
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
416
+ def prepare_extra_step_kwargs(self, generator, eta):
417
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
418
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
419
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
420
+ # and should be between [0, 1]
421
+
422
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
423
+ extra_step_kwargs = {}
424
+ if accepts_eta:
425
+ extra_step_kwargs["eta"] = eta
426
+
427
+ # check if the scheduler accepts generator
428
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
429
+ if accepts_generator:
430
+ extra_step_kwargs["generator"] = generator
431
+ return extra_step_kwargs
432
+
433
+ def check_inputs(
434
+ self,
435
+ prompt,
436
+ prompt_2,
437
+ height,
438
+ width,
439
+ callback_steps,
440
+ negative_prompt=None,
441
+ negative_prompt_2=None,
442
+ prompt_embeds=None,
443
+ negative_prompt_embeds=None,
444
+ pooled_prompt_embeds=None,
445
+ negative_pooled_prompt_embeds=None,
446
+ ):
447
+ if height % 8 != 0 or width % 8 != 0:
448
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
449
+
450
+ if (callback_steps is None) or (
451
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
452
+ ):
453
+ raise ValueError(
454
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
455
+ f" {type(callback_steps)}."
456
+ )
457
+
458
+ if prompt is not None and prompt_embeds is not None:
459
+ raise ValueError(
460
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
461
+ " only forward one of the two."
462
+ )
463
+ elif prompt_2 is not None and prompt_embeds is not None:
464
+ raise ValueError(
465
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
466
+ " only forward one of the two."
467
+ )
468
+ elif prompt is None and prompt_embeds is None:
469
+ raise ValueError(
470
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
471
+ )
472
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
473
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
474
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
475
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
476
+
477
+ if negative_prompt is not None and negative_prompt_embeds is not None:
478
+ raise ValueError(
479
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
480
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
481
+ )
482
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
483
+ raise ValueError(
484
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
485
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
486
+ )
487
+
488
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
489
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
490
+ raise ValueError(
491
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
492
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
493
+ f" {negative_prompt_embeds.shape}."
494
+ )
495
+
496
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
497
+ raise ValueError(
498
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
499
+ )
500
+
501
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
502
+ raise ValueError(
503
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
504
+ )
505
+
506
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
507
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
508
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
509
+ if isinstance(generator, list) and len(generator) != batch_size:
510
+ raise ValueError(
511
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
512
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
513
+ )
514
+
515
+ if latents is None:
516
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
517
+ else:
518
+ latents = latents.to(device)
519
+
520
+ # scale the initial noise by the standard deviation required by the scheduler
521
+ latents = latents * self.scheduler.init_noise_sigma
522
+ return latents
523
+
524
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
525
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
526
+
527
+ passed_add_embed_dim = (
528
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
529
+ )
530
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
531
+
532
+ if expected_add_embed_dim != passed_add_embed_dim:
533
+ raise ValueError(
534
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
535
+ )
536
+
537
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
538
+ return add_time_ids
539
+
540
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
541
+ def upcast_vae(self):
542
+ dtype = self.vae.dtype
543
+ self.vae.to(dtype=torch.float32)
544
+ use_torch_2_0_or_xformers = isinstance(
545
+ self.vae.decoder.mid_block.attentions[0].processor,
546
+ (
547
+ AttnProcessor2_0,
548
+ XFormersAttnProcessor,
549
+ LoRAXFormersAttnProcessor,
550
+ LoRAAttnProcessor2_0,
551
+ ),
552
+ )
553
+ # if xformers or torch_2_0 is used attention block does not need
554
+ # to be in float32 which can save lots of memory
555
+ if use_torch_2_0_or_xformers:
556
+ self.vae.post_quant_conv.to(dtype)
557
+ self.vae.decoder.conv_in.to(dtype)
558
+ self.vae.decoder.mid_block.to(dtype)
559
+
560
+ @torch.no_grad()
561
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
562
+ def __call__(
563
+ self,
564
+ prompt: Union[str, List[str]] = None,
565
+ prompt_2: Optional[Union[str, List[str]]] = None,
566
+ video_length: Optional[int] = 8,
567
+ num_images_per_prompt: Optional[int] = 1,
568
+ height: Optional[int] = None,
569
+ width: Optional[int] = None,
570
+ num_inference_steps: int = 50,
571
+ denoising_end: Optional[float] = None,
572
+ guidance_scale: float = 5.0,
573
+ negative_prompt: Optional[Union[str, List[str]]] = None,
574
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
575
+ eta: float = 0.0,
576
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
577
+ latents: Optional[torch.FloatTensor] = None,
578
+ prompt_embeds: Optional[torch.FloatTensor] = None,
579
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
580
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
581
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
582
+ output_type: Optional[str] = "pil",
583
+ return_dict: bool = True,
584
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
585
+ callback_steps: int = 1,
586
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
587
+ guidance_rescale: float = 0.0,
588
+ original_size: Optional[Tuple[int, int]] = None,
589
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
590
+ target_size: Optional[Tuple[int, int]] = None,
591
+ low_vram_mode: Optional[bool] = False
592
+ ):
593
+ r"""
594
+ Function invoked when calling the pipeline for generation.
595
+
596
+ Args:
597
+ prompt (`str` or `List[str]`, *optional*):
598
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
599
+ instead.
600
+ prompt_2 (`str` or `List[str]`, *optional*):
601
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
602
+ used in both text-encoders
603
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
604
+ The height in pixels of the generated image.
605
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
606
+ The width in pixels of the generated image.
607
+ num_inference_steps (`int`, *optional*, defaults to 50):
608
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
609
+ expense of slower inference.
610
+ denoising_end (`float`, *optional*):
611
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
612
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
613
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
614
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
615
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
616
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
617
+ guidance_scale (`float`, *optional*, defaults to 5.0):
618
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
619
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
620
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
621
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
622
+ usually at the expense of lower image quality.
623
+ negative_prompt (`str` or `List[str]`, *optional*):
624
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
625
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
626
+ less than `1`).
627
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
628
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
629
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
630
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
631
+ The number of images to generate per prompt.
632
+ eta (`float`, *optional*, defaults to 0.0):
633
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
634
+ [`schedulers.DDIMScheduler`], will be ignored for others.
635
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
636
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
637
+ to make generation deterministic.
638
+ latents (`torch.FloatTensor`, *optional*):
639
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
640
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
641
+ tensor will ge generated by sampling using the supplied random `generator`.
642
+ prompt_embeds (`torch.FloatTensor`, *optional*):
643
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
644
+ provided, text embeddings will be generated from `prompt` input argument.
645
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
646
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
647
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
648
+ argument.
649
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
650
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
651
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
652
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
653
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
654
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
655
+ input argument.
656
+ output_type (`str`, *optional*, defaults to `"pil"`):
657
+ The output format of the generate image. Choose between
658
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
659
+ return_dict (`bool`, *optional*, defaults to `True`):
660
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
661
+ of a plain tuple.
662
+ callback (`Callable`, *optional*):
663
+ A function that will be called every `callback_steps` steps during inference. The function will be
664
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
665
+ callback_steps (`int`, *optional*, defaults to 1):
666
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
667
+ called at every step.
668
+ cross_attention_kwargs (`dict`, *optional*):
669
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
670
+ `self.processor` in
671
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
672
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
673
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
674
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
675
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
676
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
677
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
678
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
679
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
680
+ explained in section 2.2 of
681
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
682
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
683
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
684
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
685
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
686
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
687
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
688
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
689
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
690
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
691
+
692
+ Examples:
693
+
694
+ Returns:
695
+ [`~hotshot_xl.HotshotPipelineXLOutput`] or `tuple`:
696
+ [`~hotshot_xl.HotshotPipelineXLOutput`] if `return_dict` is True, otherwise a
697
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
698
+ """
699
+ self.low_vram_mode = low_vram_mode
700
+
701
+ if video_length > 1:
702
+ print(f"Warning - setting num_images_per_prompt = 1 because video_length = {video_length}")
703
+ num_images_per_prompt = 1
704
+
705
+ # 0. Default height and width to unet
706
+ height = height or self.default_sample_size * self.vae_scale_factor
707
+ width = width or self.default_sample_size * self.vae_scale_factor
708
+
709
+ original_size = original_size or (height, width)
710
+ target_size = target_size or (height, width)
711
+
712
+ # 1. Check inputs. Raise error if not correct
713
+ self.check_inputs(
714
+ prompt,
715
+ prompt_2,
716
+ height,
717
+ width,
718
+ callback_steps,
719
+ negative_prompt,
720
+ negative_prompt_2,
721
+ prompt_embeds,
722
+ negative_prompt_embeds,
723
+ pooled_prompt_embeds,
724
+ negative_pooled_prompt_embeds,
725
+ )
726
+
727
+ # 2. Define call parameters
728
+ if prompt is not None and isinstance(prompt, str):
729
+ batch_size = 1
730
+ elif prompt is not None and isinstance(prompt, list):
731
+ batch_size = len(prompt)
732
+ else:
733
+ batch_size = prompt_embeds.shape[0]
734
+
735
+ device = self._execution_device
736
+
737
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
738
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
739
+ # corresponds to doing no classifier free guidance.
740
+ do_classifier_free_guidance = guidance_scale > 1.0
741
+
742
+ if self.low_vram_mode:
743
+ self.text_encoder.to(device)
744
+ self.text_encoder_2.to(device)
745
+
746
+ # 3. Encode input prompt
747
+ text_encoder_lora_scale = (
748
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
749
+ )
750
+ (
751
+ prompt_embeds,
752
+ negative_prompt_embeds,
753
+ pooled_prompt_embeds,
754
+ negative_pooled_prompt_embeds,
755
+ ) = self.encode_prompt(
756
+ prompt=prompt,
757
+ prompt_2=prompt_2,
758
+ device=device,
759
+ num_images_per_prompt=num_images_per_prompt,
760
+ do_classifier_free_guidance=do_classifier_free_guidance,
761
+ negative_prompt=negative_prompt,
762
+ negative_prompt_2=negative_prompt_2,
763
+ prompt_embeds=prompt_embeds,
764
+ negative_prompt_embeds=negative_prompt_embeds,
765
+ pooled_prompt_embeds=pooled_prompt_embeds,
766
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
767
+ lora_scale=text_encoder_lora_scale,
768
+ )
769
+
770
+ if self.low_vram_mode:
771
+ self.text_encoder.to(torch.device("cpu"))
772
+ self.text_encoder_2.to(torch.device("cpu"))
773
+ self.vae.to(torch.device("cpu"))
774
+ torch.cuda.empty_cache()
775
+ torch.cuda.synchronize()
776
+ gc.collect()
777
+
778
+ # 4. Prepare timesteps
779
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
780
+
781
+ timesteps = self.scheduler.timesteps
782
+
783
+ # 5. Prepare latent variables
784
+ num_channels_latents = self.unet.config.in_channels
785
+ latents = self.prepare_latents(
786
+ batch_size * num_images_per_prompt,
787
+ num_channels_latents,
788
+ video_length,
789
+ height,
790
+ width,
791
+ prompt_embeds.dtype,
792
+ device,
793
+ generator,
794
+ latents,
795
+ )
796
+
797
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
798
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
799
+
800
+ # 7. Prepare added time ids & embeddings
801
+ add_text_embeds = pooled_prompt_embeds
802
+ add_time_ids = self._get_add_time_ids(
803
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
804
+ )
805
+
806
+ # todo - negative_original_size from latest diffusers for cfg
807
+
808
+ if do_classifier_free_guidance:
809
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
810
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
811
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
812
+
813
+ prompt_embeds = prompt_embeds.to(device)
814
+ add_text_embeds = add_text_embeds.to(device)
815
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
816
+
817
+ # 8. Denoising loop
818
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
819
+
820
+ # 7.1 Apply denoising_end
821
+ if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
822
+ discrete_timestep_cutoff = int(
823
+ round(
824
+ self.scheduler.config.num_train_timesteps
825
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
826
+ )
827
+ )
828
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
829
+ timesteps = timesteps[:num_inference_steps]
830
+
831
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
832
+ for i, t in enumerate(timesteps):
833
+ # expand the latents if we are doing classifier free guidance
834
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
835
+
836
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
837
+
838
+ # predict the noise residual
839
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
840
+ noise_pred = self.unet(
841
+ latent_model_input,
842
+ t,
843
+ encoder_hidden_states=prompt_embeds,
844
+ cross_attention_kwargs=cross_attention_kwargs,
845
+ added_cond_kwargs=added_cond_kwargs,
846
+ return_dict=False,
847
+ enable_temporal_attentions= video_length > 1
848
+ )[0]
849
+
850
+ # perform guidance
851
+ if do_classifier_free_guidance:
852
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
853
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
854
+
855
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
856
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
857
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
858
+
859
+ # compute the previous noisy sample x_t -> x_t-1
860
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
861
+
862
+ # call the callback, if provided
863
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
864
+ progress_bar.update()
865
+ if callback is not None and i % callback_steps == 0:
866
+ callback(i, t, latents)
867
+
868
+ # make sure the VAE is in float32 mode, as it overflows in float16
869
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
870
+ self.upcast_vae()
871
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
872
+
873
+ # if not output_type == "latent":
874
+ # image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
875
+ # else:
876
+ # image = latents
877
+ # return StableDiffusionXLPipelineOutput(images=image)
878
+
879
+ # apply watermark if available
880
+ # if self.watermark is not None:
881
+ # image = self.watermark.apply_watermark(image)
882
+
883
+ #image = self.image_processor.postprocess(image, output_type=output_type)
884
+
885
+ if self.low_vram_mode:
886
+ self.vae.to(device)
887
+ torch.cuda.empty_cache()
888
+ torch.cuda.synchronize()
889
+ gc.collect()
890
+
891
+ video = self.decode_latents(latents)
892
+
893
+ # Convert to tensor
894
+ if output_type == "tensor":
895
+ video = torch.from_numpy(video)
896
+
897
+ if not return_dict:
898
+ return video
899
+
900
+ return HotshotPipelineXLOutput(videos=video)
901
+
902
+ #
903
+ # # Offload last model to CPU
904
+ # if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
905
+ # self.final_offload_hook.offload()
906
+ #
907
+ # if not return_dict:
908
+ # return (image,)
909
+ #
910
+ # return StableDiffusionXLPipelineOutput(images=image)
911
+
912
+ # Overrride to properly handle the loading and unloading of the additional text encoder.
913
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
914
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
915
+ # it here explicitly to be able to tell that it's coming from an SDXL
916
+ # pipeline.
917
+ state_dict, network_alphas = self.lora_state_dict(
918
+ pretrained_model_name_or_path_or_dict,
919
+ unet_config=self.unet.config,
920
+ **kwargs,
921
+ )
922
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
923
+
924
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
925
+ if len(text_encoder_state_dict) > 0:
926
+ self.load_lora_into_text_encoder(
927
+ text_encoder_state_dict,
928
+ network_alphas=network_alphas,
929
+ text_encoder=self.text_encoder,
930
+ prefix="text_encoder",
931
+ lora_scale=self.lora_scale,
932
+ )
933
+
934
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
935
+ if len(text_encoder_2_state_dict) > 0:
936
+ self.load_lora_into_text_encoder(
937
+ text_encoder_2_state_dict,
938
+ network_alphas=network_alphas,
939
+ text_encoder=self.text_encoder_2,
940
+ prefix="text_encoder_2",
941
+ lora_scale=self.lora_scale,
942
+ )
943
+
944
+ @classmethod
945
+ def save_lora_weights(
946
+ self,
947
+ save_directory: Union[str, os.PathLike],
948
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
949
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
950
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
951
+ is_main_process: bool = True,
952
+ weight_name: str = None,
953
+ save_function: Callable = None,
954
+ safe_serialization: bool = False,
955
+ ):
956
+ state_dict = {}
957
+
958
+ def pack_weights(layers, prefix):
959
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
960
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
961
+ return layers_state_dict
962
+
963
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
964
+
965
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
966
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
967
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
968
+
969
+ self.write_lora_layers(
970
+ state_dict=state_dict,
971
+ save_directory=save_directory,
972
+ is_main_process=is_main_process,
973
+ weight_name=weight_name,
974
+ save_function=save_function,
975
+ safe_serialization=safe_serialization,
976
+ )
977
+
978
+ def decode_latents(self, latents):
979
+ video_length = latents.shape[2]
980
+ latents = 1 / self.vae.config.scaling_factor * latents
981
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
982
+ # video = self.vae.decode(latents).sample
983
+ video = []
984
+ for frame_idx in tqdm(range(latents.shape[0])):
985
+ video.append(self.vae.decode(
986
+ latents[frame_idx:frame_idx+1]).sample)
987
+ video = torch.cat(video)
988
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
989
+ video = (video / 2.0 + 0.5).clamp(0, 1)
990
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
991
+ video = video.cpu().float().numpy()
992
+ return video
993
+
994
+ def _remove_text_encoder_monkey_patch(self):
995
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
996
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)