jakebabbidge commited on
Commit
410bce5
1 Parent(s): f057048

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +1411 -0
pipeline.py ADDED
@@ -0,0 +1,1411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 Jake Babbidge, TencentARC and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import PIL
20
+ import torch
21
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
22
+
23
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
24
+
25
+ from diffusers.image_processor import VaeImageProcessor
26
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
27
+ from diffusers.models import AutoencoderKL, MultiAdapter, T2IAdapter, UNet2DConditionModel
28
+ from diffusers.models.attention_processor import (
29
+ AttnProcessor2_0,
30
+ LoRAAttnProcessor2_0,
31
+ LoRAXFormersAttnProcessor,
32
+ XFormersAttnProcessor,
33
+ )
34
+ from diffusers.schedulers import KarrasDiffusionSchedulers
35
+ from diffusers.utils import (
36
+ PIL_INTERPOLATION,
37
+ is_accelerate_available,
38
+ is_accelerate_version,
39
+ logging,
40
+ randn_tensor,
41
+ replace_example_docstring,
42
+ )
43
+ from diffusers import DiffusionPipeline
44
+
45
+
46
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
47
+
48
+ EXAMPLE_DOC_STRING = """
49
+ Examples:
50
+ ```py
51
+ >>> import torch
52
+ >>> from diffusers import DiffusionPipeline, T2IAdapter
53
+ >>> from diffusers.utils import load_image
54
+ >>> from PIL import Image
55
+
56
+ >>> adapter = T2IAdapter.from_pretrained(
57
+ ... "TencentARC/t2i-adapter-sketch-sdxl-1.0", torch_dtype=torch.float16, variant="fp16"
58
+ ... ).to("cuda")
59
+
60
+ >>> pipe = DiffusionPipeline.from_pretrained(
61
+ ... "stabilityai/stable-diffusion-xl-base-1.0",
62
+ ... torch_dtype=torch.float16,
63
+ ... variant="fp16",
64
+ ... use_safetensors=True,
65
+ ... custom_pipeline="jakebabbidge/sdxl-adapter-inpaint",
66
+ ... adapter=adapter
67
+ ... ).to("cuda")
68
+
69
+ >>> image = Image.open(image_path).convert("RGB")
70
+ >>> mask = Image.open(mask_path).convert("RGB")
71
+ >>> adapter_sketch = Image.open(adapter_sketch_path).convert("RGB")
72
+
73
+ >>> result_image = pipe(
74
+ ... image=image,
75
+ ... mask_image=mask,
76
+ ... adapter_image=adapter_sketch,
77
+ ... prompt="a photo of a dog in real world, high quality",
78
+ ... negative_prompt="extra digit, fewer digits, cropped, worst quality, low quality",
79
+ ... num_inference_steps=50
80
+ ... ).images[0]
81
+ ```
82
+ """
83
+
84
+
85
+ def _preprocess_adapter_image(image, height, width):
86
+ if isinstance(image, torch.Tensor):
87
+ return image
88
+ elif isinstance(image, PIL.Image.Image):
89
+ image = [image]
90
+
91
+ if isinstance(image[0], PIL.Image.Image):
92
+ image = [np.array(i.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])) for i in image]
93
+ image = [
94
+ i[None, ..., None] if i.ndim == 2 else i[None, ...] for i in image
95
+ ] # expand [h, w] or [h, w, c] to [b, h, w, c]
96
+ image = np.concatenate(image, axis=0)
97
+ image = np.array(image).astype(np.float32) / 255.0
98
+ image = image.transpose(0, 3, 1, 2)
99
+ image = torch.from_numpy(image)
100
+ elif isinstance(image[0], torch.Tensor):
101
+ if image[0].ndim == 3:
102
+ image = torch.stack(image, dim=0)
103
+ elif image[0].ndim == 4:
104
+ image = torch.cat(image, dim=0)
105
+ else:
106
+ raise ValueError(
107
+ f"Invalid image tensor! Expecting image tensor with 3 or 4 dimension, but recive: {image[0].ndim}"
108
+ )
109
+ return image
110
+
111
+
112
+ def mask_pil_to_torch(mask, height, width):
113
+ # preprocess mask
114
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
115
+ mask = [mask]
116
+
117
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
118
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
119
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
120
+ mask = mask.astype(np.float32) / 255.0
121
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
122
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
123
+
124
+ mask = torch.from_numpy(mask)
125
+ return mask
126
+
127
+
128
+ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
129
+ """
130
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
131
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
132
+ ``image`` and ``1`` for the ``mask``.
133
+
134
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
135
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
136
+
137
+ Args:
138
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
139
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
140
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
141
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
142
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
143
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
144
+
145
+
146
+ Raises:
147
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
148
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
149
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
150
+ (ot the other way around).
151
+
152
+ Returns:
153
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
154
+ dimensions: ``batch x channels x height x width``.
155
+ """
156
+
157
+ # checkpoint. TOD(Yiyi) - need to clean this up later
158
+ if image is None:
159
+ raise ValueError("`image` input cannot be undefined.")
160
+
161
+ if mask is None:
162
+ raise ValueError("`mask_image` input cannot be undefined.")
163
+
164
+ if isinstance(image, torch.Tensor):
165
+ if not isinstance(mask, torch.Tensor):
166
+ mask = mask_pil_to_torch(mask, height, width)
167
+
168
+ if image.ndim == 3:
169
+ image = image.unsqueeze(0)
170
+
171
+ # Batch and add channel dim for single mask
172
+ if mask.ndim == 2:
173
+ mask = mask.unsqueeze(0).unsqueeze(0)
174
+
175
+ # Batch single mask or add channel dim
176
+ if mask.ndim == 3:
177
+ # Single batched mask, no channel dim or single mask not batched but channel dim
178
+ if mask.shape[0] == 1:
179
+ mask = mask.unsqueeze(0)
180
+
181
+ # Batched masks no channel dim
182
+ else:
183
+ mask = mask.unsqueeze(1)
184
+
185
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
186
+ # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
187
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
188
+
189
+ # Check image is in [-1, 1]
190
+ # if image.min() < -1 or image.max() > 1:
191
+ # raise ValueError("Image should be in [-1, 1] range")
192
+
193
+ # Check mask is in [0, 1]
194
+ if mask.min() < 0 or mask.max() > 1:
195
+ raise ValueError("Mask should be in [0, 1] range")
196
+
197
+ # Binarize mask
198
+ mask[mask < 0.5] = 0
199
+ mask[mask >= 0.5] = 1
200
+
201
+ # Image as float32
202
+ image = image.to(dtype=torch.float32)
203
+ elif isinstance(mask, torch.Tensor):
204
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
205
+ else:
206
+ # preprocess image
207
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
208
+ image = [image]
209
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
210
+ # resize all images w.r.t passed height an width
211
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
212
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
213
+ image = np.concatenate(image, axis=0)
214
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
215
+ image = np.concatenate([i[None, :] for i in image], axis=0)
216
+
217
+ image = image.transpose(0, 3, 1, 2)
218
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
219
+
220
+ mask = mask_pil_to_torch(mask, height, width)
221
+ mask[mask < 0.5] = 0
222
+ mask[mask >= 0.5] = 1
223
+
224
+ if image.shape[1] == 4:
225
+ # images are in latent space and thus can't
226
+ # be masked set masked_image to None
227
+ # we assume that the checkpoint is not an inpainting
228
+ # checkpoint. TOD(Yiyi) - need to clean this up later
229
+ masked_image = None
230
+ else:
231
+ masked_image = image * (mask < 0.5)
232
+
233
+ # n.b. ensure backwards compatibility as old function does not return image
234
+ if return_image:
235
+ return mask, masked_image, image
236
+
237
+ return mask, masked_image
238
+
239
+
240
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
241
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
242
+ """
243
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
244
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
245
+ """
246
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
247
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
248
+ # rescale the results from guidance (fixes overexposure)
249
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
250
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
251
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
252
+ return noise_cfg
253
+
254
+
255
+ class StableDiffusionXLAdapterInpaintPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
256
+ r"""
257
+ Pipeline for text-to-image generation using Stable Diffusion augmented with T2I-Adapter
258
+ https://arxiv.org/abs/2302.08453
259
+
260
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
261
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
262
+
263
+ Args:
264
+ adapter ([`T2IAdapter`] or [`MultiAdapter`] or `List[T2IAdapter]`):
265
+ Provides additional conditioning to the unet during the denoising process. If you set multiple Adapter as a
266
+ list, the outputs from each Adapter are added together to create one combined additional conditioning.
267
+ adapter_weights (`List[float]`, *optional*, defaults to None):
268
+ List of floats representing the weight which will be multiply to each adapter's output before adding them
269
+ together.
270
+ vae ([`AutoencoderKL`]):
271
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
272
+ text_encoder ([`CLIPTextModel`]):
273
+ Frozen text-encoder. Stable Diffusion uses the text portion of
274
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
275
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
276
+ tokenizer (`CLIPTokenizer`):
277
+ Tokenizer of class
278
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
279
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
280
+ scheduler ([`SchedulerMixin`]):
281
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
282
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
283
+ safety_checker ([`StableDiffusionSafetyChecker`]):
284
+ Classification module that estimates whether generated images could be considered offensive or harmful.
285
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
286
+ feature_extractor ([`CLIPFeatureExtractor`]):
287
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
288
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
289
+ Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
290
+ of `stabilityai/stable-diffusion-xl-refiner-1-0`.
291
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
292
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
293
+ `stabilityai/stable-diffusion-xl-base-1-0`.
294
+ """
295
+
296
+ def __init__(
297
+ self,
298
+ vae: AutoencoderKL,
299
+ text_encoder: CLIPTextModel,
300
+ text_encoder_2: CLIPTextModelWithProjection,
301
+ tokenizer: CLIPTokenizer,
302
+ tokenizer_2: CLIPTokenizer,
303
+ unet: UNet2DConditionModel,
304
+ adapter: Union[T2IAdapter, MultiAdapter, List[T2IAdapter]],
305
+ scheduler: KarrasDiffusionSchedulers,
306
+ requires_aesthetics_score: bool = False,
307
+ force_zeros_for_empty_prompt: bool = True,
308
+ ):
309
+ super().__init__()
310
+
311
+ self.register_modules(
312
+ vae=vae,
313
+ text_encoder=text_encoder,
314
+ text_encoder_2=text_encoder_2,
315
+ tokenizer=tokenizer,
316
+ tokenizer_2=tokenizer_2,
317
+ unet=unet,
318
+ adapter=adapter,
319
+ scheduler=scheduler,
320
+ )
321
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
322
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
323
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
324
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
325
+ self.default_sample_size = self.unet.config.sample_size
326
+
327
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
328
+ def enable_vae_slicing(self):
329
+ r"""
330
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
331
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
332
+ """
333
+ self.vae.enable_slicing()
334
+
335
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
336
+ def disable_vae_slicing(self):
337
+ r"""
338
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
339
+ computing decoding in one step.
340
+ """
341
+ self.vae.disable_slicing()
342
+
343
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
344
+ def enable_vae_tiling(self):
345
+ r"""
346
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
347
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
348
+ processing larger images.
349
+ """
350
+ self.vae.enable_tiling()
351
+
352
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
353
+ def disable_vae_tiling(self):
354
+ r"""
355
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
356
+ computing decoding in one step.
357
+ """
358
+ self.vae.disable_tiling()
359
+
360
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.enable_model_cpu_offload
361
+ def enable_model_cpu_offload(self, gpu_id=0):
362
+ r"""
363
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
364
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
365
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
366
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
367
+ """
368
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
369
+ from accelerate import cpu_offload_with_hook
370
+ else:
371
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
372
+
373
+ device = torch.device(f"cuda:{gpu_id}")
374
+
375
+ if self.device.type != "cpu":
376
+ self.to("cpu", silence_dtype_warnings=True)
377
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
378
+
379
+ model_sequence = (
380
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
381
+ )
382
+ model_sequence.extend([self.unet, self.vae])
383
+
384
+ hook = None
385
+ for cpu_offloaded_model in model_sequence:
386
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
387
+
388
+ # We'll offload the last model manually.
389
+ self.final_offload_hook = hook
390
+
391
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt
392
+ def encode_prompt(
393
+ self,
394
+ prompt: str,
395
+ prompt_2: Optional[str] = None,
396
+ device: Optional[torch.device] = None,
397
+ num_images_per_prompt: int = 1,
398
+ do_classifier_free_guidance: bool = True,
399
+ negative_prompt: Optional[str] = None,
400
+ negative_prompt_2: Optional[str] = None,
401
+ prompt_embeds: Optional[torch.FloatTensor] = None,
402
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
403
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
404
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
405
+ lora_scale: Optional[float] = None,
406
+ ):
407
+ r"""
408
+ Encodes the prompt into text encoder hidden states.
409
+
410
+ Args:
411
+ prompt (`str` or `List[str]`, *optional*):
412
+ prompt to be encoded
413
+ prompt_2 (`str` or `List[str]`, *optional*):
414
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
415
+ used in both text-encoders
416
+ device: (`torch.device`):
417
+ torch device
418
+ num_images_per_prompt (`int`):
419
+ number of images that should be generated per prompt
420
+ do_classifier_free_guidance (`bool`):
421
+ whether to use classifier free guidance or not
422
+ negative_prompt (`str` or `List[str]`, *optional*):
423
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
424
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
425
+ less than `1`).
426
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
427
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
428
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
429
+ prompt_embeds (`torch.FloatTensor`, *optional*):
430
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
431
+ provided, text embeddings will be generated from `prompt` input argument.
432
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
433
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
434
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
435
+ argument.
436
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
437
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
438
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
439
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
440
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
441
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
442
+ input argument.
443
+ lora_scale (`float`, *optional*):
444
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
445
+ """
446
+ device = device or self._execution_device
447
+
448
+ # set lora scale so that monkey patched LoRA
449
+ # function of text encoder can correctly access it
450
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
451
+ self._lora_scale = lora_scale
452
+
453
+ if prompt is not None and isinstance(prompt, str):
454
+ batch_size = 1
455
+ elif prompt is not None and isinstance(prompt, list):
456
+ batch_size = len(prompt)
457
+ else:
458
+ batch_size = prompt_embeds.shape[0]
459
+
460
+ # Define tokenizers and text encoders
461
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
462
+ text_encoders = (
463
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
464
+ )
465
+
466
+ if prompt_embeds is None:
467
+ prompt_2 = prompt_2 or prompt
468
+ # textual inversion: procecss multi-vector tokens if necessary
469
+ prompt_embeds_list = []
470
+ prompts = [prompt, prompt_2]
471
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
472
+ if isinstance(self, TextualInversionLoaderMixin):
473
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
474
+
475
+ text_inputs = tokenizer(
476
+ prompt,
477
+ padding="max_length",
478
+ max_length=tokenizer.model_max_length,
479
+ truncation=True,
480
+ return_tensors="pt",
481
+ )
482
+
483
+ text_input_ids = text_inputs.input_ids
484
+ untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
485
+
486
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
487
+ text_input_ids, untruncated_ids
488
+ ):
489
+ removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
490
+ logger.warning(
491
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
492
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
493
+ )
494
+
495
+ prompt_embeds = text_encoder(
496
+ text_input_ids.to(device),
497
+ output_hidden_states=True,
498
+ )
499
+
500
+ # We are only ALWAYS interested in the pooled output of the final text encoder
501
+ pooled_prompt_embeds = prompt_embeds[0]
502
+ prompt_embeds = prompt_embeds.hidden_states[-2]
503
+
504
+ prompt_embeds_list.append(prompt_embeds)
505
+
506
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
507
+
508
+ # get unconditional embeddings for classifier free guidance
509
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
510
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
511
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
512
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
513
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
514
+ negative_prompt = negative_prompt or ""
515
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
516
+
517
+ uncond_tokens: List[str]
518
+ if prompt is not None and type(prompt) is not type(negative_prompt):
519
+ raise TypeError(
520
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
521
+ f" {type(prompt)}."
522
+ )
523
+ elif isinstance(negative_prompt, str):
524
+ uncond_tokens = [negative_prompt, negative_prompt_2]
525
+ elif batch_size != len(negative_prompt):
526
+ raise ValueError(
527
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
528
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
529
+ " the batch size of `prompt`."
530
+ )
531
+ else:
532
+ uncond_tokens = [negative_prompt, negative_prompt_2]
533
+
534
+ negative_prompt_embeds_list = []
535
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
536
+ if isinstance(self, TextualInversionLoaderMixin):
537
+ negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer)
538
+
539
+ max_length = prompt_embeds.shape[1]
540
+ uncond_input = tokenizer(
541
+ negative_prompt,
542
+ padding="max_length",
543
+ max_length=max_length,
544
+ truncation=True,
545
+ return_tensors="pt",
546
+ )
547
+
548
+ negative_prompt_embeds = text_encoder(
549
+ uncond_input.input_ids.to(device),
550
+ output_hidden_states=True,
551
+ )
552
+ # We are only ALWAYS interested in the pooled output of the final text encoder
553
+ negative_pooled_prompt_embeds = negative_prompt_embeds[0]
554
+ negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
555
+
556
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
557
+
558
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
559
+
560
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
561
+ bs_embed, seq_len, _ = prompt_embeds.shape
562
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
563
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
564
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
565
+
566
+ if do_classifier_free_guidance:
567
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
568
+ seq_len = negative_prompt_embeds.shape[1]
569
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
570
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
571
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
572
+
573
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
574
+ bs_embed * num_images_per_prompt, -1
575
+ )
576
+ if do_classifier_free_guidance:
577
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
578
+ bs_embed * num_images_per_prompt, -1
579
+ )
580
+
581
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
582
+
583
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
584
+ def prepare_extra_step_kwargs(self, generator, eta):
585
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
586
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
587
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
588
+ # and should be between [0, 1]
589
+
590
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
591
+ extra_step_kwargs = {}
592
+ if accepts_eta:
593
+ extra_step_kwargs["eta"] = eta
594
+
595
+ # check if the scheduler accepts generator
596
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
597
+ if accepts_generator:
598
+ extra_step_kwargs["generator"] = generator
599
+ return extra_step_kwargs
600
+
601
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.check_inputs
602
+ def check_inputs(
603
+ self,
604
+ prompt,
605
+ prompt_2,
606
+ strength,
607
+ num_inference_steps,
608
+ height,
609
+ width,
610
+ callback_steps,
611
+ negative_prompt=None,
612
+ negative_prompt_2=None,
613
+ prompt_embeds=None,
614
+ negative_prompt_embeds=None,
615
+ pooled_prompt_embeds=None,
616
+ negative_pooled_prompt_embeds=None,
617
+ ):
618
+ if strength < 0 or strength > 1:
619
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
620
+ if num_inference_steps is None:
621
+ raise ValueError("`num_inference_steps` cannot be None.")
622
+ elif not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
623
+ raise ValueError(
624
+ f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
625
+ f" {type(num_inference_steps)}."
626
+ )
627
+ if height % 8 != 0 or width % 8 != 0:
628
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
629
+
630
+ if (callback_steps is None) or (
631
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
632
+ ):
633
+ raise ValueError(
634
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
635
+ f" {type(callback_steps)}."
636
+ )
637
+
638
+ if prompt is not None and prompt_embeds is not None:
639
+ raise ValueError(
640
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
641
+ " only forward one of the two."
642
+ )
643
+ elif prompt_2 is not None and prompt_embeds is not None:
644
+ raise ValueError(
645
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
646
+ " only forward one of the two."
647
+ )
648
+ elif prompt is None and prompt_embeds is None:
649
+ raise ValueError(
650
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
651
+ )
652
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
653
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
654
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
655
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
656
+
657
+ if negative_prompt is not None and negative_prompt_embeds is not None:
658
+ raise ValueError(
659
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
660
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
661
+ )
662
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
663
+ raise ValueError(
664
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
665
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
666
+ )
667
+
668
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
669
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
670
+ raise ValueError(
671
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
672
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
673
+ f" {negative_prompt_embeds.shape}."
674
+ )
675
+
676
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
677
+ raise ValueError(
678
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
679
+ )
680
+
681
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
682
+ raise ValueError(
683
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
684
+ )
685
+
686
+ def prepare_latents(
687
+ self,
688
+ batch_size,
689
+ num_channels_latents,
690
+ height,
691
+ width,
692
+ dtype,
693
+ device,
694
+ generator,
695
+ latents=None,
696
+ image=None,
697
+ timestep=None,
698
+ is_strength_max=True,
699
+ add_noise=True,
700
+ return_noise=False,
701
+ return_image_latents=False,
702
+ ):
703
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
704
+ if isinstance(generator, list) and len(generator) != batch_size:
705
+ raise ValueError(
706
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
707
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
708
+ )
709
+
710
+ if (image is None or timestep is None) and not is_strength_max:
711
+ raise ValueError(
712
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
713
+ "However, either the image or the noise timestep has not been provided."
714
+ )
715
+
716
+ if image.shape[1] == 4:
717
+ image_latents = image.to(device=device, dtype=dtype)
718
+ elif return_image_latents or (latents is None and not is_strength_max):
719
+ image = image.to(device=device, dtype=dtype)
720
+ image_latents = self._encode_vae_image(image=image, generator=generator)
721
+
722
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
723
+
724
+ if latents is None and add_noise:
725
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
726
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
727
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
728
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
729
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
730
+ elif add_noise:
731
+ noise = latents.to(device)
732
+ latents = noise * self.scheduler.init_noise_sigma
733
+ else:
734
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
735
+ latents = image_latents.to(device)
736
+
737
+ outputs = (latents,)
738
+
739
+ if return_noise:
740
+ outputs += (noise,)
741
+
742
+ if return_image_latents:
743
+ outputs += (image_latents,)
744
+
745
+ return outputs
746
+
747
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
748
+ dtype = image.dtype
749
+ if self.vae.config.force_upcast:
750
+ image = image.float()
751
+ self.vae.to(dtype=torch.float32)
752
+
753
+ if isinstance(generator, list):
754
+ image_latents = [
755
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
756
+ for i in range(image.shape[0])
757
+ ]
758
+ image_latents = torch.cat(image_latents, dim=0)
759
+ else:
760
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
761
+
762
+ if self.vae.config.force_upcast:
763
+ self.vae.to(dtype)
764
+
765
+ image_latents = image_latents.to(dtype)
766
+ image_latents = self.vae.config.scaling_factor * image_latents
767
+
768
+ return image_latents
769
+
770
+ def prepare_mask_latents(
771
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
772
+ ):
773
+ # resize the mask to latents shape as we concatenate the mask to the latents
774
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
775
+ # and half precision
776
+ mask = torch.nn.functional.interpolate(
777
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
778
+ )
779
+ mask = mask.to(device=device, dtype=dtype)
780
+
781
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
782
+ if mask.shape[0] < batch_size:
783
+ if not batch_size % mask.shape[0] == 0:
784
+ raise ValueError(
785
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
786
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
787
+ " of masks that you pass is divisible by the total requested batch size."
788
+ )
789
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
790
+
791
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
792
+
793
+ masked_image_latents = None
794
+ if masked_image is not None:
795
+ masked_image = masked_image.to(device=device, dtype=dtype)
796
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
797
+ if masked_image_latents.shape[0] < batch_size:
798
+ if not batch_size % masked_image_latents.shape[0] == 0:
799
+ raise ValueError(
800
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
801
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
802
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
803
+ )
804
+ masked_image_latents = masked_image_latents.repeat(
805
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
806
+ )
807
+
808
+ masked_image_latents = (
809
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
810
+ )
811
+
812
+ # aligning device to prevent device errors when concating it with the latent model input
813
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
814
+
815
+ return mask, masked_image_latents
816
+
817
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
818
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
819
+ # get the original timestep using init_timestep
820
+ if denoising_start is None:
821
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
822
+ t_start = max(num_inference_steps - init_timestep, 0)
823
+ else:
824
+ t_start = 0
825
+
826
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
827
+
828
+ # Strength is irrelevant if we directly request a timestep to start at;
829
+ # that is, strength is determined by the denoising_start instead.
830
+ if denoising_start is not None:
831
+ discrete_timestep_cutoff = int(
832
+ round(
833
+ self.scheduler.config.num_train_timesteps
834
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
835
+ )
836
+ )
837
+ timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps))
838
+ return torch.tensor(timesteps), len(timesteps)
839
+
840
+ return timesteps, num_inference_steps - t_start
841
+
842
+ def _get_add_time_ids(
843
+ self,
844
+ original_size,
845
+ crops_coords_top_left,
846
+ target_size,
847
+ aesthetic_score,
848
+ negative_aesthetic_score,
849
+ dtype,
850
+ text_encoder_projection_dim=None,
851
+ ):
852
+ if self.config.requires_aesthetics_score:
853
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
854
+ add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
855
+ else:
856
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
857
+ add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
858
+
859
+ passed_add_embed_dim = (
860
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
861
+ )
862
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
863
+
864
+ if (
865
+ expected_add_embed_dim > passed_add_embed_dim
866
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
867
+ ):
868
+ raise ValueError(
869
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
870
+ )
871
+ elif (
872
+ expected_add_embed_dim < passed_add_embed_dim
873
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
874
+ ):
875
+ raise ValueError(
876
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
877
+ )
878
+ elif expected_add_embed_dim != passed_add_embed_dim:
879
+ raise ValueError(
880
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
881
+ )
882
+
883
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
884
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
885
+
886
+ return add_time_ids, add_neg_time_ids
887
+
888
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
889
+ def upcast_vae(self):
890
+ dtype = self.vae.dtype
891
+ self.vae.to(dtype=torch.float32)
892
+ use_torch_2_0_or_xformers = isinstance(
893
+ self.vae.decoder.mid_block.attentions[0].processor,
894
+ (
895
+ AttnProcessor2_0,
896
+ XFormersAttnProcessor,
897
+ LoRAXFormersAttnProcessor,
898
+ LoRAAttnProcessor2_0,
899
+ ),
900
+ )
901
+ # if xformers or torch_2_0 is used attention block does not need
902
+ # to be in float32 which can save lots of memory
903
+ if use_torch_2_0_or_xformers:
904
+ self.vae.post_quant_conv.to(dtype)
905
+ self.vae.decoder.conv_in.to(dtype)
906
+ self.vae.decoder.mid_block.to(dtype)
907
+
908
+ # Copied from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter.StableDiffusionAdapterPipeline._default_height_width
909
+ def _default_height_width(self, height, width, image):
910
+ # NOTE: It is possible that a list of images have different
911
+ # dimensions for each image, so just checking the first image
912
+ # is not _exactly_ correct, but it is simple.
913
+ while isinstance(image, list):
914
+ image = image[0]
915
+
916
+ if height is None:
917
+ if isinstance(image, PIL.Image.Image):
918
+ height = image.height
919
+ elif isinstance(image, torch.Tensor):
920
+ height = image.shape[-2]
921
+
922
+ # round down to nearest multiple of `self.adapter.total_downscale_factor`
923
+ height = (height // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
924
+
925
+ if width is None:
926
+ if isinstance(image, PIL.Image.Image):
927
+ width = image.width
928
+ elif isinstance(image, torch.Tensor):
929
+ width = image.shape[-1]
930
+
931
+ # round down to nearest multiple of `self.adapter.total_downscale_factor`
932
+ width = (width // self.adapter.total_downscale_factor) * self.adapter.total_downscale_factor
933
+
934
+ return height, width
935
+
936
+ @torch.no_grad()
937
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
938
+ def __call__(
939
+ self,
940
+ prompt: Union[str, List[str]] = None,
941
+ prompt_2: Optional[Union[str, List[str]]] = None,
942
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
943
+ mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
944
+ adapter_image: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
945
+ height: Optional[int] = None,
946
+ width: Optional[int] = None,
947
+ strength: float = 0.9999,
948
+ num_inference_steps: int = 50,
949
+ denoising_start: Optional[float] = None,
950
+ denoising_end: Optional[float] = None,
951
+ guidance_scale: float = 5.0,
952
+ negative_prompt: Optional[Union[str, List[str]]] = None,
953
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
954
+ num_images_per_prompt: Optional[int] = 1,
955
+ eta: float = 0.0,
956
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
957
+ latents: Optional[torch.FloatTensor] = None,
958
+ prompt_embeds: Optional[torch.FloatTensor] = None,
959
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
960
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
961
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
962
+ output_type: Optional[str] = "pil",
963
+ return_dict: bool = True,
964
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
965
+ callback_steps: int = 1,
966
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
967
+ guidance_rescale: float = 0.0,
968
+ original_size: Optional[Tuple[int, int]] = None,
969
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
970
+ target_size: Optional[Tuple[int, int]] = None,
971
+ adapter_conditioning_scale: Union[float, List[float]] = 1.0,
972
+ cond_tau: float = 1.0,
973
+ aesthetic_score: float = 6.0,
974
+ negative_aesthetic_score: float = 2.5,
975
+ ):
976
+ r"""
977
+ Function invoked when calling the pipeline for generation.
978
+
979
+ Args:
980
+ prompt (`str` or `List[str]`, *optional*):
981
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
982
+ instead.
983
+ prompt_2 (`str` or `List[str]`, *optional*):
984
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
985
+ used in both text-encoders
986
+ image (`PIL.Image.Image`):
987
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
988
+ be masked out with `mask_image` and repainted according to `prompt`.
989
+ mask_image (`PIL.Image.Image`):
990
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
991
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
992
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
993
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
994
+ adapter_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`):
995
+ The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the
996
+ type is specified as `Torch.FloatTensor`, it is passed to Adapter as is. PIL.Image.Image` can also be
997
+ accepted as an image. The control image is automatically resized to fit the output image.
998
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
999
+ The height in pixels of the generated image.
1000
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1001
+ The width in pixels of the generated image.
1002
+ strength (`float`, *optional*, defaults to 1.0):
1003
+ Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
1004
+ starting point and more noise is added the higher the `strength`. The number of denoising steps depends
1005
+ on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising
1006
+ process runs for the full number of iterations specified in `num_inference_steps`. A value of 1
1007
+ essentially ignores `image`.
1008
+ num_inference_steps (`int`, *optional*, defaults to 50):
1009
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1010
+ expense of slower inference.
1011
+ denoising_start (`float`, *optional*):
1012
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
1013
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
1014
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
1015
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
1016
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
1017
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
1018
+ denoising_end (`float`, *optional*):
1019
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
1020
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
1021
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
1022
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
1023
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1024
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
1025
+ guidance_scale (`float`, *optional*, defaults to 5.0):
1026
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1027
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1028
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1029
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1030
+ usually at the expense of lower image quality.
1031
+ negative_prompt (`str` or `List[str]`, *optional*):
1032
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1033
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1034
+ less than `1`).
1035
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1036
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1037
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1038
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1039
+ The number of images to generate per prompt.
1040
+ eta (`float`, *optional*, defaults to 0.0):
1041
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1042
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1043
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1044
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1045
+ to make generation deterministic.
1046
+ latents (`torch.FloatTensor`, *optional*):
1047
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1048
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1049
+ tensor will ge generated by sampling using the supplied random `generator`.
1050
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1051
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1052
+ provided, text embeddings will be generated from `prompt` input argument.
1053
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1054
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1055
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1056
+ argument.
1057
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1058
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1059
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1060
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
1061
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1062
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1063
+ input argument.
1064
+ output_type (`str`, *optional*, defaults to `"pil"`):
1065
+ The output format of the generate image. Choose between
1066
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1067
+ return_dict (`bool`, *optional*, defaults to `True`):
1068
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionAdapterPipelineOutput`]
1069
+ instead of a plain tuple.
1070
+ callback (`Callable`, *optional*):
1071
+ A function that will be called every `callback_steps` steps during inference. The function will be
1072
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1073
+ callback_steps (`int`, *optional*, defaults to 1):
1074
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1075
+ called at every step.
1076
+ cross_attention_kwargs (`dict`, *optional*):
1077
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1078
+ `self.processor` in
1079
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1080
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
1081
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
1082
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
1083
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
1084
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
1085
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1086
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1087
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
1088
+ explained in section 2.2 of
1089
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1090
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1091
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1092
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1093
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1094
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1095
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1096
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1097
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
1098
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1099
+ adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1100
+ The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the
1101
+ residual in the original unet. If multiple adapters are specified in init, you can set the
1102
+ corresponding scale as a list.
1103
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
1104
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
1105
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
1106
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1107
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
1108
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
1109
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
1110
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
1111
+ Examples:
1112
+
1113
+ Returns:
1114
+ [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`:
1115
+ [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a
1116
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
1117
+ """
1118
+ # 0. Default height and width to unet
1119
+
1120
+ height, width = self._default_height_width(height, width, adapter_image)
1121
+ device = self._execution_device
1122
+
1123
+ adapter_input = _preprocess_adapter_image(adapter_image, height, width).to(device)
1124
+
1125
+ original_size = original_size or (height, width)
1126
+ target_size = target_size or (height, width)
1127
+
1128
+ # 1. Check inputs. Raise error if not correct
1129
+ self.check_inputs(
1130
+ prompt,
1131
+ prompt_2,
1132
+ strength,
1133
+ num_inference_steps,
1134
+ height,
1135
+ width,
1136
+ callback_steps,
1137
+ negative_prompt,
1138
+ negative_prompt_2,
1139
+ prompt_embeds,
1140
+ negative_prompt_embeds,
1141
+ pooled_prompt_embeds,
1142
+ negative_pooled_prompt_embeds,
1143
+ )
1144
+
1145
+ # 2. Define call parameters
1146
+ if prompt is not None and isinstance(prompt, str):
1147
+ batch_size = 1
1148
+ elif prompt is not None and isinstance(prompt, list):
1149
+ batch_size = len(prompt)
1150
+ else:
1151
+ batch_size = prompt_embeds.shape[0]
1152
+
1153
+ device = self._execution_device
1154
+
1155
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1156
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1157
+ # corresponds to doing no classifier free guidance.
1158
+ do_classifier_free_guidance = guidance_scale > 1.0
1159
+
1160
+ # 3. Encode input prompt
1161
+ (
1162
+ prompt_embeds,
1163
+ negative_prompt_embeds,
1164
+ pooled_prompt_embeds,
1165
+ negative_pooled_prompt_embeds,
1166
+ ) = self.encode_prompt(
1167
+ prompt=prompt,
1168
+ prompt_2=prompt_2,
1169
+ device=device,
1170
+ num_images_per_prompt=num_images_per_prompt,
1171
+ do_classifier_free_guidance=do_classifier_free_guidance,
1172
+ negative_prompt=negative_prompt,
1173
+ negative_prompt_2=negative_prompt_2,
1174
+ prompt_embeds=prompt_embeds,
1175
+ negative_prompt_embeds=negative_prompt_embeds,
1176
+ pooled_prompt_embeds=pooled_prompt_embeds,
1177
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1178
+ )
1179
+
1180
+ # 4. set timesteps
1181
+ def denoising_value_valid(dnv):
1182
+ return isinstance(denoising_end, float) and 0 < dnv < 1
1183
+
1184
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1185
+ timesteps, num_inference_steps = self.get_timesteps(
1186
+ num_inference_steps, strength, device, denoising_start=denoising_start if denoising_value_valid else None
1187
+ )
1188
+ # check that number of inference steps is not < 1 - as this doesn't make sense
1189
+ if num_inference_steps < 1:
1190
+ raise ValueError(
1191
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
1192
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
1193
+ )
1194
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
1195
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1196
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
1197
+ is_strength_max = strength == 1.0
1198
+
1199
+ # 5. Preprocess mask and image - resizes image and mask w.r.t height and width
1200
+ mask, masked_image, init_image = prepare_mask_and_masked_image(
1201
+ image, mask_image, height, width, return_image=True
1202
+ )
1203
+
1204
+ # 6. Prepare latent variables
1205
+ num_channels_latents = self.vae.config.latent_channels
1206
+ num_channels_unet = self.unet.config.in_channels
1207
+ return_image_latents = num_channels_unet == 4
1208
+
1209
+ add_noise = True if denoising_start is None else False
1210
+ latents_outputs = self.prepare_latents(
1211
+ batch_size * num_images_per_prompt,
1212
+ num_channels_latents,
1213
+ height,
1214
+ width,
1215
+ prompt_embeds.dtype,
1216
+ device,
1217
+ generator,
1218
+ latents,
1219
+ image=init_image,
1220
+ timestep=latent_timestep,
1221
+ is_strength_max=is_strength_max,
1222
+ add_noise=add_noise,
1223
+ return_noise=True,
1224
+ return_image_latents=return_image_latents,
1225
+ )
1226
+
1227
+ if return_image_latents:
1228
+ latents, noise, image_latents = latents_outputs
1229
+ else:
1230
+ latents, noise = latents_outputs
1231
+
1232
+ # 7. Prepare mask latent variables
1233
+ mask, masked_image_latents = self.prepare_mask_latents(
1234
+ mask,
1235
+ masked_image,
1236
+ batch_size * num_images_per_prompt,
1237
+ height,
1238
+ width,
1239
+ prompt_embeds.dtype,
1240
+ device,
1241
+ generator,
1242
+ do_classifier_free_guidance,
1243
+ )
1244
+
1245
+ # 8. Check that sizes of mask, masked image and latents match
1246
+ if num_channels_unet == 9:
1247
+ # default case for runwayml/stable-diffusion-inpainting
1248
+ num_channels_mask = mask.shape[1]
1249
+ num_channels_masked_image = masked_image_latents.shape[1]
1250
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
1251
+ raise ValueError(
1252
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
1253
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1254
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1255
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1256
+ " `pipeline.unet` or your `mask_image` or `image` input."
1257
+ )
1258
+ elif num_channels_unet != 4:
1259
+ raise ValueError(
1260
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
1261
+ )
1262
+
1263
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1264
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1265
+
1266
+ # 10. Prepare added time ids & embeddings & adapter features
1267
+ adapter_input = adapter_input.type(latents.dtype)
1268
+ adapter_state = self.adapter(adapter_input)
1269
+ for k, v in enumerate(adapter_state):
1270
+ adapter_state[k] = v * adapter_conditioning_scale
1271
+ if num_images_per_prompt > 1:
1272
+ for k, v in enumerate(adapter_state):
1273
+ adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1)
1274
+ if do_classifier_free_guidance:
1275
+ for k, v in enumerate(adapter_state):
1276
+ adapter_state[k] = torch.cat([v] * 2, dim=0)
1277
+
1278
+ add_text_embeds = pooled_prompt_embeds
1279
+ if self.text_encoder_2 is None:
1280
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1281
+ else:
1282
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1283
+
1284
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
1285
+ original_size,
1286
+ crops_coords_top_left,
1287
+ target_size,
1288
+ aesthetic_score,
1289
+ negative_aesthetic_score,
1290
+ dtype=prompt_embeds.dtype,
1291
+ text_encoder_projection_dim=text_encoder_projection_dim,
1292
+ )
1293
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1294
+
1295
+ if do_classifier_free_guidance:
1296
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1297
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1298
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1299
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
1300
+
1301
+ prompt_embeds = prompt_embeds.to(device)
1302
+ add_text_embeds = add_text_embeds.to(device)
1303
+ add_time_ids = add_time_ids.to(device)
1304
+
1305
+ # 11. Denoising loop
1306
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1307
+
1308
+ # 11.1 Apply denoising_end
1309
+ if (
1310
+ denoising_end is not None
1311
+ and denoising_start is not None
1312
+ and denoising_value_valid(denoising_end)
1313
+ and denoising_value_valid(denoising_start)
1314
+ and denoising_start >= denoising_end
1315
+ ):
1316
+ raise ValueError(
1317
+ f"`denoising_start`: {denoising_start} cannot be larger than or equal to `denoising_end`: "
1318
+ + f" {denoising_end} when using type float."
1319
+ )
1320
+ elif denoising_end is not None and denoising_value_valid(denoising_end):
1321
+ discrete_timestep_cutoff = int(
1322
+ round(
1323
+ self.scheduler.config.num_train_timesteps
1324
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
1325
+ )
1326
+ )
1327
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1328
+ timesteps = timesteps[:num_inference_steps]
1329
+
1330
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1331
+ for i, t in enumerate(timesteps):
1332
+ # expand the latents if we are doing classifier free guidance
1333
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1334
+
1335
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1336
+
1337
+ if num_channels_unet == 9:
1338
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1339
+
1340
+ # predict the noise residual
1341
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1342
+
1343
+ if i < int(num_inference_steps * cond_tau):
1344
+ down_block_additional_residuals = [state.clone() for state in adapter_state]
1345
+ else:
1346
+ down_block_additional_residuals = None
1347
+
1348
+ noise_pred = self.unet(
1349
+ latent_model_input,
1350
+ t,
1351
+ encoder_hidden_states=prompt_embeds,
1352
+ cross_attention_kwargs=cross_attention_kwargs,
1353
+ added_cond_kwargs=added_cond_kwargs,
1354
+ return_dict=False,
1355
+ down_block_additional_residuals=down_block_additional_residuals,
1356
+ )[0]
1357
+
1358
+ # perform guidance
1359
+ if do_classifier_free_guidance:
1360
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1361
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1362
+
1363
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
1364
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1365
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
1366
+
1367
+ # compute the previous noisy sample x_t -> x_t-1
1368
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1369
+
1370
+ if num_channels_unet == 4:
1371
+ init_latents_proper = image_latents
1372
+ if do_classifier_free_guidance:
1373
+ init_mask, _ = mask.chunk(2)
1374
+ else:
1375
+ init_mask = mask
1376
+
1377
+ if i < len(timesteps) - 1:
1378
+ noise_timestep = timesteps[i + 1]
1379
+ init_latents_proper = self.scheduler.add_noise(
1380
+ init_latents_proper, noise, torch.tensor([noise_timestep])
1381
+ )
1382
+
1383
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1384
+
1385
+ # call the callback, if provided
1386
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1387
+ progress_bar.update()
1388
+ if callback is not None and i % callback_steps == 0:
1389
+ callback(i, t, latents)
1390
+
1391
+ # make sure the VAE is in float32 mode, as it overflows in float16
1392
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
1393
+ self.upcast_vae()
1394
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1395
+
1396
+ if not output_type == "latent":
1397
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1398
+ else:
1399
+ image = latents
1400
+ return StableDiffusionXLPipelineOutput(images=image)
1401
+
1402
+ image = self.image_processor.postprocess(image, output_type=output_type)
1403
+
1404
+ # Offload last model to CPU
1405
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1406
+ self.final_offload_hook.offload()
1407
+
1408
+ if not return_dict:
1409
+ return (image,)
1410
+
1411
+ return StableDiffusionXLPipelineOutput(images=image)