SunderAli17 commited on
Commit
85dc639
·
verified ·
1 Parent(s): 13aeda5

Create pipeline_stable_diffusion_xl_chatglm_256_inpainting.py

Browse files
SAK/pipelines/pipeline_stable_diffusion_xl_chatglm_256_inpainting.py ADDED
@@ -0,0 +1,1770 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import PIL.Image
20
+ import torch
21
+ from transformers import (
22
+ CLIPImageProcessor,
23
+ CLIPTextModel,
24
+ CLIPTextModelWithProjection,
25
+ CLIPTokenizer,
26
+ CLIPVisionModelWithProjection,
27
+ )
28
+
29
+ from diffusers.callbacks import MultiPipelineCallbacks, PipelineCallback
30
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
31
+ from diffusers.loaders import (
32
+ FromSingleFileMixin,
33
+ IPAdapterMixin,
34
+ StableDiffusionXLLoraLoaderMixin,
35
+ TextualInversionLoaderMixin,
36
+ )
37
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel
38
+ from diffusers.models.attention_processor import (
39
+ AttnProcessor2_0,
40
+ LoRAAttnProcessor2_0,
41
+ LoRAXFormersAttnProcessor,
42
+ XFormersAttnProcessor,
43
+ )
44
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
45
+ from diffusers.schedulers import KarrasDiffusionSchedulers
46
+ from diffusers.utils import (
47
+ USE_PEFT_BACKEND,
48
+ deprecate,
49
+ is_invisible_watermark_available,
50
+ is_torch_xla_available,
51
+ logging,
52
+ replace_example_docstring,
53
+ scale_lora_layers,
54
+ unscale_lora_layers,
55
+ )
56
+ from diffusers.utils.torch_utils import randn_tensor
57
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
58
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
59
+
60
+
61
+ if is_invisible_watermark_available():
62
+ from .watermark import StableDiffusionXLWatermarker
63
+
64
+ if is_torch_xla_available():
65
+ import torch_xla.core.xla_model as xm
66
+
67
+ XLA_AVAILABLE = True
68
+ else:
69
+ XLA_AVAILABLE = False
70
+
71
+
72
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
73
+
74
+
75
+ EXAMPLE_DOC_STRING = """
76
+ Examples:
77
+ ```py
78
+ >>> import torch
79
+ >>> from diffusers import StableDiffusionXLInpaintPipeline
80
+ >>> from diffusers.utils import load_image
81
+ >>> pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
82
+ ... "stabilityai/stable-diffusion-xl-base-1.0",
83
+ ... torch_dtype=torch.float16,
84
+ ... variant="fp16",
85
+ ... use_safetensors=True,
86
+ ... )
87
+ >>> pipe.to("cuda")
88
+ >>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
89
+ >>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
90
+ >>> init_image = load_image(img_url).convert("RGB")
91
+ >>> mask_image = load_image(mask_url).convert("RGB")
92
+ >>> prompt = "A majestic tiger sitting on a bench"
93
+ >>> image = pipe(
94
+ ... prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80
95
+ ... ).images[0]
96
+ ```
97
+ """
98
+
99
+
100
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
101
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
102
+ """
103
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
104
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
105
+ """
106
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
107
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
108
+ # rescale the results from guidance (fixes overexposure)
109
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
110
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
111
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
112
+ return noise_cfg
113
+
114
+
115
+ def mask_pil_to_torch(mask, height, width):
116
+ # preprocess mask
117
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
118
+ mask = [mask]
119
+
120
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
121
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
122
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
123
+ mask = mask.astype(np.float32) / 255.0
124
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
125
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
126
+
127
+ mask = torch.from_numpy(mask)
128
+ return mask
129
+
130
+
131
+ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False):
132
+ """
133
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
134
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
135
+ ``image`` and ``1`` for the ``mask``.
136
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
137
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
138
+ Args:
139
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
140
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
141
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
142
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
143
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
144
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
145
+ Raises:
146
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
147
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
148
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
149
+ (ot the other way around).
150
+ Returns:
151
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
152
+ dimensions: ``batch x channels x height x width``.
153
+ """
154
+
155
+ # checkpoint. TOD(Yiyi) - need to clean this up later
156
+ deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead"
157
+ deprecate(
158
+ "prepare_mask_and_masked_image",
159
+ "0.30.0",
160
+ deprecation_message,
161
+ )
162
+ if image is None:
163
+ raise ValueError("`image` input cannot be undefined.")
164
+
165
+ if mask is None:
166
+ raise ValueError("`mask_image` input cannot be undefined.")
167
+
168
+ if isinstance(image, torch.Tensor):
169
+ if not isinstance(mask, torch.Tensor):
170
+ mask = mask_pil_to_torch(mask, height, width)
171
+
172
+ if image.ndim == 3:
173
+ image = image.unsqueeze(0)
174
+
175
+ # Batch and add channel dim for single mask
176
+ if mask.ndim == 2:
177
+ mask = mask.unsqueeze(0).unsqueeze(0)
178
+
179
+ # Batch single mask or add channel dim
180
+ if mask.ndim == 3:
181
+ # Single batched mask, no channel dim or single mask not batched but channel dim
182
+ if mask.shape[0] == 1:
183
+ mask = mask.unsqueeze(0)
184
+
185
+ # Batched masks no channel dim
186
+ else:
187
+ mask = mask.unsqueeze(1)
188
+
189
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
190
+ # assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
191
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
192
+
193
+ # Check image is in [-1, 1]
194
+ # if image.min() < -1 or image.max() > 1:
195
+ # raise ValueError("Image should be in [-1, 1] range")
196
+
197
+ # Check mask is in [0, 1]
198
+ if mask.min() < 0 or mask.max() > 1:
199
+ raise ValueError("Mask should be in [0, 1] range")
200
+
201
+ # Binarize mask
202
+ mask[mask < 0.5] = 0
203
+ mask[mask >= 0.5] = 1
204
+
205
+ # Image as float32
206
+ image = image.to(dtype=torch.float32)
207
+ elif isinstance(mask, torch.Tensor):
208
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
209
+ else:
210
+ # preprocess image
211
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
212
+ image = [image]
213
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
214
+ # resize all images w.r.t passed height an width
215
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
216
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
217
+ image = np.concatenate(image, axis=0)
218
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
219
+ image = np.concatenate([i[None, :] for i in image], axis=0)
220
+
221
+ image = image.transpose(0, 3, 1, 2)
222
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
223
+
224
+ mask = mask_pil_to_torch(mask, height, width)
225
+ mask[mask < 0.5] = 0
226
+ mask[mask >= 0.5] = 1
227
+
228
+ if image.shape[1] == 4:
229
+ # images are in latent space and thus can't
230
+ # be masked set masked_image to None
231
+ # we assume that the checkpoint is not an inpainting
232
+ # checkpoint. TOD(Yiyi) - need to clean this up later
233
+ masked_image = None
234
+ else:
235
+ masked_image = image * (mask < 0.5)
236
+
237
+ # n.b. ensure backwards compatibility as old function does not return image
238
+ if return_image:
239
+ return mask, masked_image, image
240
+
241
+ return mask, masked_image
242
+
243
+
244
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
245
+ def retrieve_latents(
246
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
247
+ ):
248
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
249
+ return encoder_output.latent_dist.sample(generator)
250
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
251
+ return encoder_output.latent_dist.mode()
252
+ elif hasattr(encoder_output, "latents"):
253
+ return encoder_output.latents
254
+ else:
255
+ raise AttributeError("Could not access latents of provided encoder_output")
256
+
257
+
258
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
259
+ def retrieve_timesteps(
260
+ scheduler,
261
+ num_inference_steps: Optional[int] = None,
262
+ device: Optional[Union[str, torch.device]] = None,
263
+ timesteps: Optional[List[int]] = None,
264
+ sigmas: Optional[List[float]] = None,
265
+ **kwargs,
266
+ ):
267
+ """
268
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
269
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
270
+ Args:
271
+ scheduler (`SchedulerMixin`):
272
+ The scheduler to get timesteps from.
273
+ num_inference_steps (`int`):
274
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
275
+ must be `None`.
276
+ device (`str` or `torch.device`, *optional*):
277
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
278
+ timesteps (`List[int]`, *optional*):
279
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
280
+ `num_inference_steps` and `sigmas` must be `None`.
281
+ sigmas (`List[float]`, *optional*):
282
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
283
+ `num_inference_steps` and `timesteps` must be `None`.
284
+ Returns:
285
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
286
+ second element is the number of inference steps.
287
+ """
288
+ if timesteps is not None and sigmas is not None:
289
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
290
+ if timesteps is not None:
291
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
292
+ if not accepts_timesteps:
293
+ raise ValueError(
294
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
295
+ f" timestep schedules. Please check whether you are using the correct scheduler."
296
+ )
297
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
298
+ timesteps = scheduler.timesteps
299
+ num_inference_steps = len(timesteps)
300
+ elif sigmas is not None:
301
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
302
+ if not accept_sigmas:
303
+ raise ValueError(
304
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
305
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
306
+ )
307
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
308
+ timesteps = scheduler.timesteps
309
+ num_inference_steps = len(timesteps)
310
+ else:
311
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
312
+ timesteps = scheduler.timesteps
313
+ return timesteps, num_inference_steps
314
+
315
+
316
+ class StableDiffusionXLInpaintPipeline(
317
+ DiffusionPipeline,
318
+ StableDiffusionMixin,
319
+ TextualInversionLoaderMixin,
320
+ StableDiffusionXLLoraLoaderMixin,
321
+ FromSingleFileMixin,
322
+ IPAdapterMixin,
323
+ ):
324
+ r"""
325
+ Pipeline for text-to-image generation using Stable Diffusion XL.
326
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
327
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
328
+ The pipeline also inherits the following loading methods:
329
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
330
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
331
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
332
+ - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
333
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
334
+ Args:
335
+ vae ([`AutoencoderKL`]):
336
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
337
+ text_encoder ([`CLIPTextModel`]):
338
+ Frozen text-encoder. Stable Diffusion XL uses the text portion of
339
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
340
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
341
+ text_encoder_2 ([` CLIPTextModelWithProjection`]):
342
+ Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
343
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
344
+ specifically the
345
+ [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
346
+ variant.
347
+ tokenizer (`CLIPTokenizer`):
348
+ Tokenizer of class
349
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
350
+ tokenizer_2 (`CLIPTokenizer`):
351
+ Second Tokenizer of class
352
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
353
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
354
+ scheduler ([`SchedulerMixin`]):
355
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
356
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
357
+ requires_aesthetics_score (`bool`, *optional*, defaults to `"False"`):
358
+ Whether the `unet` requires a aesthetic_score condition to be passed during inference. Also see the config
359
+ of `stabilityai/stable-diffusion-xl-refiner-1-0`.
360
+ force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
361
+ Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
362
+ `stabilityai/stable-diffusion-xl-base-1-0`.
363
+ add_watermarker (`bool`, *optional*):
364
+ Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
365
+ watermark output images. If not defined, it will default to True if the package is installed, otherwise no
366
+ watermarker will be used.
367
+ """
368
+
369
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->unet->vae"
370
+
371
+ _optional_components = [
372
+ "tokenizer",
373
+ "tokenizer_2",
374
+ "text_encoder",
375
+ "text_encoder_2",
376
+ "image_encoder",
377
+ "feature_extractor",
378
+ ]
379
+ _callback_tensor_inputs = [
380
+ "latents",
381
+ "prompt_embeds",
382
+ "negative_prompt_embeds",
383
+ "add_text_embeds",
384
+ "add_time_ids",
385
+ "negative_pooled_prompt_embeds",
386
+ "add_neg_time_ids",
387
+ "mask",
388
+ "masked_image_latents",
389
+ ]
390
+
391
+ def __init__(
392
+ self,
393
+ vae: AutoencoderKL,
394
+ text_encoder: CLIPTextModel,
395
+ tokenizer: CLIPTokenizer,
396
+ unet: UNet2DConditionModel,
397
+ scheduler: KarrasDiffusionSchedulers,
398
+ tokenizer_2: CLIPTokenizer = None,
399
+ text_encoder_2: CLIPTextModelWithProjection = None,
400
+ image_encoder: CLIPVisionModelWithProjection = None,
401
+ feature_extractor: CLIPImageProcessor = None,
402
+ requires_aesthetics_score: bool = False,
403
+ force_zeros_for_empty_prompt: bool = True,
404
+ add_watermarker: Optional[bool] = None,
405
+ ):
406
+ super().__init__()
407
+
408
+ self.register_modules(
409
+ vae=vae,
410
+ text_encoder=text_encoder,
411
+ text_encoder_2=text_encoder_2,
412
+ tokenizer=tokenizer,
413
+ tokenizer_2=tokenizer_2,
414
+ unet=unet,
415
+ image_encoder=image_encoder,
416
+ feature_extractor=feature_extractor,
417
+ scheduler=scheduler,
418
+ )
419
+ self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
420
+ self.register_to_config(requires_aesthetics_score=requires_aesthetics_score)
421
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
422
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
423
+ self.mask_processor = VaeImageProcessor(
424
+ vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
425
+ )
426
+
427
+ add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
428
+
429
+ if add_watermarker:
430
+ self.watermark = StableDiffusionXLWatermarker()
431
+ else:
432
+ self.watermark = None
433
+
434
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
435
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
436
+ dtype = next(self.image_encoder.parameters()).dtype
437
+
438
+ if not isinstance(image, torch.Tensor):
439
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
440
+
441
+ image = image.to(device=device, dtype=dtype)
442
+ if output_hidden_states:
443
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
444
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
445
+ uncond_image_enc_hidden_states = self.image_encoder(
446
+ torch.zeros_like(image), output_hidden_states=True
447
+ ).hidden_states[-2]
448
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
449
+ num_images_per_prompt, dim=0
450
+ )
451
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
452
+ else:
453
+ image_embeds = self.image_encoder(image).image_embeds
454
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
455
+ uncond_image_embeds = torch.zeros_like(image_embeds)
456
+
457
+ return image_embeds, uncond_image_embeds
458
+
459
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
460
+ def prepare_ip_adapter_image_embeds(
461
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
462
+ ):
463
+ if ip_adapter_image_embeds is None:
464
+ if not isinstance(ip_adapter_image, list):
465
+ ip_adapter_image = [ip_adapter_image]
466
+
467
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
468
+ raise ValueError(
469
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
470
+ )
471
+
472
+ image_embeds = []
473
+ for single_ip_adapter_image, image_proj_layer in zip(
474
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
475
+ ):
476
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
477
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
478
+ single_ip_adapter_image, device, 1, output_hidden_state
479
+ )
480
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
481
+ single_negative_image_embeds = torch.stack(
482
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
483
+ )
484
+
485
+ if do_classifier_free_guidance:
486
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
487
+ single_image_embeds = single_image_embeds.to(device)
488
+
489
+ image_embeds.append(single_image_embeds)
490
+ else:
491
+ repeat_dims = [1]
492
+ image_embeds = []
493
+ for single_image_embeds in ip_adapter_image_embeds:
494
+ if do_classifier_free_guidance:
495
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
496
+ single_image_embeds = single_image_embeds.repeat(
497
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
498
+ )
499
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
500
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
501
+ )
502
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
503
+ else:
504
+ single_image_embeds = single_image_embeds.repeat(
505
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
506
+ )
507
+ image_embeds.append(single_image_embeds)
508
+
509
+ return image_embeds
510
+
511
+ def encode_prompt(
512
+ self,
513
+ prompt,
514
+ device: Optional[torch.device] = None,
515
+ num_images_per_prompt: int = 1,
516
+ do_classifier_free_guidance: bool = True,
517
+ negative_prompt=None,
518
+ prompt_embeds: Optional[torch.FloatTensor] = None,
519
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
520
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
521
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
522
+ lora_scale: Optional[float] = None,
523
+ ):
524
+ r"""
525
+ Encodes the prompt into text encoder hidden states.
526
+ Args:
527
+ prompt (`str` or `List[str]`, *optional*):
528
+ prompt to be encoded
529
+ device: (`torch.device`):
530
+ torch device
531
+ num_images_per_prompt (`int`):
532
+ number of images that should be generated per prompt
533
+ do_classifier_free_guidance (`bool`):
534
+ whether to use classifier free guidance or not
535
+ negative_prompt (`str` or `List[str]`, *optional*):
536
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
537
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
538
+ less than `1`).
539
+ prompt_embeds (`torch.FloatTensor`, *optional*):
540
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
541
+ provided, text embeddings will be generated from `prompt` input argument.
542
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
543
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
544
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
545
+ argument.
546
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
547
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
548
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
549
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
550
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
551
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
552
+ input argument.
553
+ lora_scale (`float`, *optional*):
554
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
555
+ """
556
+ # from IPython import embed; embed(); exit()
557
+ device = device or self._execution_device
558
+
559
+ # set lora scale so that monkey patched LoRA
560
+ # function of text encoder can correctly access it
561
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
562
+ self._lora_scale = lora_scale
563
+
564
+ if prompt is not None and isinstance(prompt, str):
565
+ batch_size = 1
566
+ elif prompt is not None and isinstance(prompt, list):
567
+ batch_size = len(prompt)
568
+ else:
569
+ batch_size = prompt_embeds.shape[0]
570
+
571
+ # Define tokenizers and text encoders
572
+ tokenizers = [self.tokenizer]
573
+ text_encoders = [self.text_encoder]
574
+
575
+ if prompt_embeds is None:
576
+ # textual inversion: procecss multi-vector tokens if necessary
577
+ prompt_embeds_list = []
578
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
579
+ if isinstance(self, TextualInversionLoaderMixin):
580
+ prompt = self.maybe_convert_prompt(prompt, tokenizer)
581
+
582
+ text_inputs = tokenizer(
583
+ prompt,
584
+ padding="max_length",
585
+ max_length=256,
586
+ truncation=True,
587
+ return_tensors="pt",
588
+ ).to('cuda')
589
+ output = text_encoder(
590
+ input_ids=text_inputs['input_ids'] ,
591
+ attention_mask=text_inputs['attention_mask'],
592
+ position_ids=text_inputs['position_ids'],
593
+ output_hidden_states=True)
594
+ prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
595
+ text_proj = output.hidden_states[-1][-1, :, :].clone()
596
+ bs_embed, seq_len, _ = prompt_embeds.shape
597
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
598
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
599
+ prompt_embeds_list.append(prompt_embeds)
600
+
601
+ prompt_embeds = prompt_embeds_list[0]
602
+
603
+ # get unconditional embeddings for classifier free guidance
604
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
605
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
606
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
607
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
608
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
609
+ # negative_prompt = negative_prompt or ""
610
+ uncond_tokens: List[str]
611
+ if negative_prompt is None:
612
+ uncond_tokens = [""] * batch_size
613
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
614
+ raise TypeError(
615
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
616
+ f" {type(prompt)}."
617
+ )
618
+ elif isinstance(negative_prompt, str):
619
+ uncond_tokens = [negative_prompt]
620
+ elif batch_size != len(negative_prompt):
621
+ raise ValueError(
622
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
623
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
624
+ " the batch size of `prompt`."
625
+ )
626
+ else:
627
+ uncond_tokens = negative_prompt
628
+
629
+ negative_prompt_embeds_list = []
630
+ for tokenizer, text_encoder in zip(tokenizers, text_encoders):
631
+ # textual inversion: procecss multi-vector tokens if necessary
632
+ if isinstance(self, TextualInversionLoaderMixin):
633
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer)
634
+
635
+ max_length = prompt_embeds.shape[1]
636
+ uncond_input = tokenizer(
637
+ uncond_tokens,
638
+ padding="max_length",
639
+ max_length=max_length,
640
+ truncation=True,
641
+ return_tensors="pt",
642
+ ).to('cuda')
643
+ output = text_encoder(
644
+ input_ids=uncond_input['input_ids'] ,
645
+ attention_mask=uncond_input['attention_mask'],
646
+ position_ids=uncond_input['position_ids'],
647
+ output_hidden_states=True)
648
+ negative_prompt_embeds = output.hidden_states[-2].permute(1, 0, 2).clone()
649
+ negative_text_proj = output.hidden_states[-1][-1, :, :].clone()
650
+
651
+ if do_classifier_free_guidance:
652
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
653
+ seq_len = negative_prompt_embeds.shape[1]
654
+
655
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
656
+
657
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
658
+ negative_prompt_embeds = negative_prompt_embeds.view(
659
+ batch_size * num_images_per_prompt, seq_len, -1
660
+ )
661
+
662
+ # For classifier free guidance, we need to do two forward passes.
663
+ # Here we concatenate the unconditional and text embeddings into a single batch
664
+ # to avoid doing two forward passes
665
+
666
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
667
+
668
+ negative_prompt_embeds = negative_prompt_embeds_list[0]
669
+
670
+ bs_embed = text_proj.shape[0]
671
+ text_proj = text_proj.repeat(1, num_images_per_prompt).view(
672
+ bs_embed * num_images_per_prompt, -1
673
+ )
674
+ negative_text_proj = negative_text_proj.repeat(1, num_images_per_prompt).view(
675
+ bs_embed * num_images_per_prompt, -1
676
+ )
677
+
678
+ return prompt_embeds, negative_prompt_embeds, text_proj, negative_text_proj
679
+
680
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
681
+ def prepare_extra_step_kwargs(self, generator, eta):
682
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
683
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
684
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
685
+ # and should be between [0, 1]
686
+
687
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
688
+ extra_step_kwargs = {}
689
+ if accepts_eta:
690
+ extra_step_kwargs["eta"] = eta
691
+
692
+ # check if the scheduler accepts generator
693
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
694
+ if accepts_generator:
695
+ extra_step_kwargs["generator"] = generator
696
+ return extra_step_kwargs
697
+
698
+ def check_inputs(
699
+ self,
700
+ prompt,
701
+ prompt_2,
702
+ image,
703
+ mask_image,
704
+ height,
705
+ width,
706
+ strength,
707
+ callback_steps,
708
+ output_type,
709
+ negative_prompt=None,
710
+ negative_prompt_2=None,
711
+ prompt_embeds=None,
712
+ negative_prompt_embeds=None,
713
+ ip_adapter_image=None,
714
+ ip_adapter_image_embeds=None,
715
+ callback_on_step_end_tensor_inputs=None,
716
+ padding_mask_crop=None,
717
+ ):
718
+ if strength < 0 or strength > 1:
719
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
720
+
721
+ if height % 8 != 0 or width % 8 != 0:
722
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
723
+
724
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
725
+ raise ValueError(
726
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
727
+ f" {type(callback_steps)}."
728
+ )
729
+
730
+ if callback_on_step_end_tensor_inputs is not None and not all(
731
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
732
+ ):
733
+ raise ValueError(
734
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
735
+ )
736
+
737
+ if prompt is not None and prompt_embeds is not None:
738
+ raise ValueError(
739
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
740
+ " only forward one of the two."
741
+ )
742
+ elif prompt_2 is not None and prompt_embeds is not None:
743
+ raise ValueError(
744
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
745
+ " only forward one of the two."
746
+ )
747
+ elif prompt is None and prompt_embeds is None:
748
+ raise ValueError(
749
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
750
+ )
751
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
752
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
753
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
754
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
755
+
756
+ if negative_prompt is not None and negative_prompt_embeds is not None:
757
+ raise ValueError(
758
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
759
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
760
+ )
761
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
762
+ raise ValueError(
763
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
764
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
765
+ )
766
+
767
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
768
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
769
+ raise ValueError(
770
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
771
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
772
+ f" {negative_prompt_embeds.shape}."
773
+ )
774
+ if padding_mask_crop is not None:
775
+ if not isinstance(image, PIL.Image.Image):
776
+ raise ValueError(
777
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
778
+ )
779
+ if not isinstance(mask_image, PIL.Image.Image):
780
+ raise ValueError(
781
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
782
+ f" {type(mask_image)}."
783
+ )
784
+ if output_type != "pil":
785
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
786
+
787
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
788
+ raise ValueError(
789
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
790
+ )
791
+
792
+ if ip_adapter_image_embeds is not None:
793
+ if not isinstance(ip_adapter_image_embeds, list):
794
+ raise ValueError(
795
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
796
+ )
797
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
798
+ raise ValueError(
799
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
800
+ )
801
+
802
+ def prepare_latents(
803
+ self,
804
+ batch_size,
805
+ num_channels_latents,
806
+ height,
807
+ width,
808
+ dtype,
809
+ device,
810
+ generator,
811
+ latents=None,
812
+ image=None,
813
+ timestep=None,
814
+ is_strength_max=True,
815
+ add_noise=True,
816
+ return_noise=False,
817
+ return_image_latents=False,
818
+ ):
819
+ shape = (
820
+ batch_size,
821
+ num_channels_latents,
822
+ int(height) // self.vae_scale_factor,
823
+ int(width) // self.vae_scale_factor,
824
+ )
825
+ if isinstance(generator, list) and len(generator) != batch_size:
826
+ raise ValueError(
827
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
828
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
829
+ )
830
+
831
+ if (image is None or timestep is None) and not is_strength_max:
832
+ raise ValueError(
833
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
834
+ "However, either the image or the noise timestep has not been provided."
835
+ )
836
+
837
+ if image.shape[1] == 4:
838
+ image_latents = image.to(device=device, dtype=dtype)
839
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
840
+ elif return_image_latents or (latents is None and not is_strength_max):
841
+ image = image.to(device=device, dtype=dtype)
842
+ image_latents = self._encode_vae_image(image=image, generator=generator)
843
+ image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1)
844
+
845
+ if latents is None and add_noise:
846
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
847
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
848
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
849
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
850
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
851
+ elif add_noise:
852
+ noise = latents.to(device)
853
+ latents = noise * self.scheduler.init_noise_sigma
854
+ else:
855
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
856
+ latents = image_latents.to(device)
857
+
858
+ outputs = (latents,)
859
+
860
+ if return_noise:
861
+ outputs += (noise,)
862
+
863
+ if return_image_latents:
864
+ outputs += (image_latents,)
865
+
866
+ return outputs
867
+
868
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
869
+ dtype = image.dtype
870
+ if self.vae.config.force_upcast:
871
+ image = image.float()
872
+ self.vae.to(dtype=torch.float32)
873
+
874
+ if isinstance(generator, list):
875
+ image_latents = [
876
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
877
+ for i in range(image.shape[0])
878
+ ]
879
+ image_latents = torch.cat(image_latents, dim=0)
880
+ else:
881
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
882
+
883
+ if self.vae.config.force_upcast:
884
+ self.vae.to(dtype)
885
+
886
+ image_latents = image_latents.to(dtype)
887
+ image_latents = self.vae.config.scaling_factor * image_latents
888
+
889
+ return image_latents
890
+
891
+ def prepare_mask_latents(
892
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
893
+ ):
894
+ # resize the mask to latents shape as we concatenate the mask to the latents
895
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
896
+ # and half precision
897
+ mask = torch.nn.functional.interpolate(
898
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
899
+ )
900
+ mask = mask.to(device=device, dtype=dtype)
901
+
902
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
903
+ if mask.shape[0] < batch_size:
904
+ if not batch_size % mask.shape[0] == 0:
905
+ raise ValueError(
906
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
907
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
908
+ " of masks that you pass is divisible by the total requested batch size."
909
+ )
910
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
911
+
912
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
913
+
914
+ if masked_image is not None and masked_image.shape[1] == 4:
915
+ masked_image_latents = masked_image
916
+ else:
917
+ masked_image_latents = None
918
+
919
+ if masked_image is not None:
920
+ if masked_image_latents is None:
921
+ masked_image = masked_image.to(device=device, dtype=dtype)
922
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
923
+
924
+ if masked_image_latents.shape[0] < batch_size:
925
+ if not batch_size % masked_image_latents.shape[0] == 0:
926
+ raise ValueError(
927
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
928
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
929
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
930
+ )
931
+ masked_image_latents = masked_image_latents.repeat(
932
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
933
+ )
934
+
935
+ masked_image_latents = (
936
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
937
+ )
938
+
939
+ # aligning device to prevent device errors when concating it with the latent model input
940
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
941
+
942
+ return mask, masked_image_latents
943
+
944
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps
945
+ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None):
946
+ # get the original timestep using init_timestep
947
+ if denoising_start is None:
948
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
949
+ t_start = max(num_inference_steps - init_timestep, 0)
950
+ else:
951
+ t_start = 0
952
+
953
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
954
+
955
+ # Strength is irrelevant if we directly request a timestep to start at;
956
+ # that is, strength is determined by the denoising_start instead.
957
+ if denoising_start is not None:
958
+ discrete_timestep_cutoff = int(
959
+ round(
960
+ self.scheduler.config.num_train_timesteps
961
+ - (denoising_start * self.scheduler.config.num_train_timesteps)
962
+ )
963
+ )
964
+
965
+ num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
966
+ if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
967
+ # if the scheduler is a 2nd order scheduler we might have to do +1
968
+ # because `num_inference_steps` might be even given that every timestep
969
+ # (except the highest one) is duplicated. If `num_inference_steps` is even it would
970
+ # mean that we cut the timesteps in the middle of the denoising step
971
+ # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
972
+ # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
973
+ num_inference_steps = num_inference_steps + 1
974
+
975
+ # because t_n+1 >= t_n, we slice the timesteps starting from the end
976
+ timesteps = timesteps[-num_inference_steps:]
977
+ return timesteps, num_inference_steps
978
+
979
+ return timesteps, num_inference_steps - t_start
980
+
981
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline._get_add_time_ids
982
+ def _get_add_time_ids(
983
+ self,
984
+ original_size,
985
+ crops_coords_top_left,
986
+ target_size,
987
+ aesthetic_score,
988
+ negative_aesthetic_score,
989
+ negative_original_size,
990
+ negative_crops_coords_top_left,
991
+ negative_target_size,
992
+ dtype,
993
+ text_encoder_projection_dim=None,
994
+ ):
995
+ if self.config.requires_aesthetics_score:
996
+ add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
997
+ add_neg_time_ids = list(
998
+ negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,)
999
+ )
1000
+ else:
1001
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
1002
+ add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size)
1003
+
1004
+ passed_add_embed_dim = (
1005
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + 4096
1006
+ )
1007
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
1008
+
1009
+ if (
1010
+ expected_add_embed_dim > passed_add_embed_dim
1011
+ and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim
1012
+ ):
1013
+ raise ValueError(
1014
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model."
1015
+ )
1016
+ elif (
1017
+ expected_add_embed_dim < passed_add_embed_dim
1018
+ and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim
1019
+ ):
1020
+ raise ValueError(
1021
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model."
1022
+ )
1023
+ elif expected_add_embed_dim != passed_add_embed_dim:
1024
+ raise ValueError(
1025
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
1026
+ )
1027
+
1028
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
1029
+ add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
1030
+
1031
+ return add_time_ids, add_neg_time_ids
1032
+
1033
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae
1034
+ def upcast_vae(self):
1035
+ dtype = self.vae.dtype
1036
+ self.vae.to(dtype=torch.float32)
1037
+ use_torch_2_0_or_xformers = isinstance(
1038
+ self.vae.decoder.mid_block.attentions[0].processor,
1039
+ (
1040
+ AttnProcessor2_0,
1041
+ XFormersAttnProcessor,
1042
+ LoRAXFormersAttnProcessor,
1043
+ LoRAAttnProcessor2_0,
1044
+ ),
1045
+ )
1046
+ # if xformers or torch_2_0 is used attention block does not need
1047
+ # to be in float32 which can save lots of memory
1048
+ if use_torch_2_0_or_xformers:
1049
+ self.vae.post_quant_conv.to(dtype)
1050
+ self.vae.decoder.conv_in.to(dtype)
1051
+ self.vae.decoder.mid_block.to(dtype)
1052
+
1053
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
1054
+ def get_guidance_scale_embedding(
1055
+ self, w: torch.Tensor, embedding_dim: int = 512, dtype: torch.dtype = torch.float32
1056
+ ) -> torch.Tensor:
1057
+ """
1058
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
1059
+ Args:
1060
+ w (`torch.Tensor`):
1061
+ Generate embedding vectors with a specified guidance scale to subsequently enrich timestep embeddings.
1062
+ embedding_dim (`int`, *optional*, defaults to 512):
1063
+ Dimension of the embeddings to generate.
1064
+ dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
1065
+ Data type of the generated embeddings.
1066
+ Returns:
1067
+ `torch.Tensor`: Embedding vectors with shape `(len(w), embedding_dim)`.
1068
+ """
1069
+ assert len(w.shape) == 1
1070
+ w = w * 1000.0
1071
+
1072
+ half_dim = embedding_dim // 2
1073
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
1074
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
1075
+ emb = w.to(dtype)[:, None] * emb[None, :]
1076
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
1077
+ if embedding_dim % 2 == 1: # zero pad
1078
+ emb = torch.nn.functional.pad(emb, (0, 1))
1079
+ assert emb.shape == (w.shape[0], embedding_dim)
1080
+ return emb
1081
+
1082
+ @property
1083
+ def guidance_scale(self):
1084
+ return self._guidance_scale
1085
+
1086
+ @property
1087
+ def guidance_rescale(self):
1088
+ return self._guidance_rescale
1089
+
1090
+ @property
1091
+ def clip_skip(self):
1092
+ return self._clip_skip
1093
+
1094
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1095
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1096
+ # corresponds to doing no classifier free guidance.
1097
+ @property
1098
+ def do_classifier_free_guidance(self):
1099
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
1100
+
1101
+ @property
1102
+ def cross_attention_kwargs(self):
1103
+ return self._cross_attention_kwargs
1104
+
1105
+ @property
1106
+ def denoising_end(self):
1107
+ return self._denoising_end
1108
+
1109
+ @property
1110
+ def denoising_start(self):
1111
+ return self._denoising_start
1112
+
1113
+ @property
1114
+ def num_timesteps(self):
1115
+ return self._num_timesteps
1116
+
1117
+ @property
1118
+ def interrupt(self):
1119
+ return self._interrupt
1120
+
1121
+ @torch.no_grad()
1122
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
1123
+ def __call__(
1124
+ self,
1125
+ prompt: Union[str, List[str]] = None,
1126
+ prompt_2: Optional[Union[str, List[str]]] = None,
1127
+ image: PipelineImageInput = None,
1128
+ mask_image: PipelineImageInput = None,
1129
+ masked_image_latents: torch.Tensor = None,
1130
+ height: Optional[int] = None,
1131
+ width: Optional[int] = None,
1132
+ padding_mask_crop: Optional[int] = None,
1133
+ strength: float = 0.9999,
1134
+ num_inference_steps: int = 50,
1135
+ timesteps: List[int] = None,
1136
+ sigmas: List[float] = None,
1137
+ denoising_start: Optional[float] = None,
1138
+ denoising_end: Optional[float] = None,
1139
+ guidance_scale: float = 7.5,
1140
+ negative_prompt: Optional[Union[str, List[str]]] = None,
1141
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
1142
+ num_images_per_prompt: Optional[int] = 1,
1143
+ eta: float = 0.0,
1144
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
1145
+ latents: Optional[torch.Tensor] = None,
1146
+ prompt_embeds: Optional[torch.Tensor] = None,
1147
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
1148
+ pooled_prompt_embeds: Optional[torch.Tensor] = None,
1149
+ negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
1150
+ ip_adapter_image: Optional[PipelineImageInput] = None,
1151
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
1152
+ output_type: Optional[str] = "pil",
1153
+ return_dict: bool = True,
1154
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1155
+ guidance_rescale: float = 0.0,
1156
+ original_size: Tuple[int, int] = None,
1157
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
1158
+ target_size: Tuple[int, int] = None,
1159
+ negative_original_size: Optional[Tuple[int, int]] = None,
1160
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
1161
+ negative_target_size: Optional[Tuple[int, int]] = None,
1162
+ aesthetic_score: float = 6.0,
1163
+ negative_aesthetic_score: float = 2.5,
1164
+ clip_skip: Optional[int] = None,
1165
+ callback_on_step_end: Optional[
1166
+ Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
1167
+ ] = None,
1168
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1169
+ **kwargs,
1170
+ ):
1171
+ r"""
1172
+ Function invoked when calling the pipeline for generation.
1173
+ Args:
1174
+ prompt (`str` or `List[str]`, *optional*):
1175
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1176
+ instead.
1177
+ prompt_2 (`str` or `List[str]`, *optional*):
1178
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
1179
+ used in both text-encoders
1180
+ image (`PIL.Image.Image`):
1181
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
1182
+ be masked out with `mask_image` and repainted according to `prompt`.
1183
+ mask_image (`PIL.Image.Image`):
1184
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1185
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
1186
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
1187
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
1188
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1189
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
1190
+ Anything below 512 pixels won't work well for
1191
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1192
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1193
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1194
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
1195
+ Anything below 512 pixels won't work well for
1196
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
1197
+ and checkpoints that are not specifically fine-tuned on low resolutions.
1198
+ padding_mask_crop (`int`, *optional*, defaults to `None`):
1199
+ The size of margin in the crop to be applied to the image and masking. If `None`, no crop is applied to
1200
+ image and mask_image. If `padding_mask_crop` is not `None`, it will first find a rectangular region
1201
+ with the same aspect ration of the image and contains all masked area, and then expand that area based
1202
+ on `padding_mask_crop`. The image and mask_image will then be cropped based on the expanded area before
1203
+ resizing to the original image size for inpainting. This is useful when the masked area is small while
1204
+ the image is large and contain information irrelevant for inpainting, such as background.
1205
+ strength (`float`, *optional*, defaults to 0.9999):
1206
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
1207
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
1208
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
1209
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
1210
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
1211
+ portion of the reference `image`. Note that in the case of `denoising_start` being declared as an
1212
+ integer, the value of `strength` will be ignored.
1213
+ num_inference_steps (`int`, *optional*, defaults to 50):
1214
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1215
+ expense of slower inference.
1216
+ timesteps (`List[int]`, *optional*):
1217
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
1218
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
1219
+ passed will be used. Must be in descending order.
1220
+ sigmas (`List[float]`, *optional*):
1221
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
1222
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
1223
+ will be used.
1224
+ denoising_start (`float`, *optional*):
1225
+ When specified, indicates the fraction (between 0.0 and 1.0) of the total denoising process to be
1226
+ bypassed before it is initiated. Consequently, the initial part of the denoising process is skipped and
1227
+ it is assumed that the passed `image` is a partly denoised image. Note that when this is specified,
1228
+ strength will be ignored. The `denoising_start` parameter is particularly beneficial when this pipeline
1229
+ is integrated into a "Mixture of Denoisers" multi-pipeline setup, as detailed in [**Refining the Image
1230
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
1231
+ denoising_end (`float`, *optional*):
1232
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
1233
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
1234
+ still retain a substantial amount of noise (ca. final 20% of timesteps still needed) and should be
1235
+ denoised by a successor pipeline that has `denoising_start` set to 0.8 so that it only denoises the
1236
+ final 20% of the scheduler. The denoising_end parameter should ideally be utilized when this pipeline
1237
+ forms a part of a "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
1238
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output).
1239
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1240
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1241
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1242
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1243
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1244
+ usually at the expense of lower image quality.
1245
+ negative_prompt (`str` or `List[str]`, *optional*):
1246
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1247
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1248
+ less than `1`).
1249
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
1250
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
1251
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
1252
+ prompt_embeds (`torch.Tensor`, *optional*):
1253
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1254
+ provided, text embeddings will be generated from `prompt` input argument.
1255
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
1256
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1257
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1258
+ argument.
1259
+ pooled_prompt_embeds (`torch.Tensor`, *optional*):
1260
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
1261
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
1262
+ negative_pooled_prompt_embeds (`torch.Tensor`, *optional*):
1263
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1264
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
1265
+ input argument.
1266
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
1267
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
1268
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
1269
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
1270
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
1271
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
1272
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1273
+ The number of images to generate per prompt.
1274
+ eta (`float`, *optional*, defaults to 0.0):
1275
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1276
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1277
+ generator (`torch.Generator`, *optional*):
1278
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1279
+ to make generation deterministic.
1280
+ latents (`torch.Tensor`, *optional*):
1281
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1282
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1283
+ tensor will ge generated by sampling using the supplied random `generator`.
1284
+ output_type (`str`, *optional*, defaults to `"pil"`):
1285
+ The output format of the generate image. Choose between
1286
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1287
+ return_dict (`bool`, *optional*, defaults to `True`):
1288
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1289
+ plain tuple.
1290
+ cross_attention_kwargs (`dict`, *optional*):
1291
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1292
+ `self.processor` in
1293
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
1294
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1295
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
1296
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
1297
+ explained in section 2.2 of
1298
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1299
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1300
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
1301
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
1302
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
1303
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1304
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1305
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
1306
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
1307
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1308
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1309
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
1310
+ micro-conditioning as explained in section 2.2 of
1311
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1312
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1313
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
1314
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
1315
+ micro-conditioning as explained in section 2.2 of
1316
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1317
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1318
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
1319
+ To negatively condition the generation process based on a target image resolution. It should be as same
1320
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
1321
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
1322
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
1323
+ aesthetic_score (`float`, *optional*, defaults to 6.0):
1324
+ Used to simulate an aesthetic score of the generated image by influencing the positive text condition.
1325
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
1326
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
1327
+ negative_aesthetic_score (`float`, *optional*, defaults to 2.5):
1328
+ Part of SDXL's micro-conditioning as explained in section 2.2 of
1329
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Can be used to
1330
+ simulate an aesthetic score of the generated image by influencing the negative text condition.
1331
+ clip_skip (`int`, *optional*):
1332
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1333
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1334
+ callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*):
1335
+ A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of
1336
+ each denoising step during the inference. with the following arguments: `callback_on_step_end(self:
1337
+ DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a
1338
+ list of all tensors as specified by `callback_on_step_end_tensor_inputs`.
1339
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
1340
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
1341
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
1342
+ `._callback_tensor_inputs` attribute of your pipeline class.
1343
+ Examples:
1344
+ Returns:
1345
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`:
1346
+ [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
1347
+ `tuple. `tuple. When returning a tuple, the first element is a list with the generated images.
1348
+ """
1349
+
1350
+ callback = kwargs.pop("callback", None)
1351
+ callback_steps = kwargs.pop("callback_steps", None)
1352
+
1353
+ if callback is not None:
1354
+ deprecate(
1355
+ "callback",
1356
+ "1.0.0",
1357
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1358
+ )
1359
+ if callback_steps is not None:
1360
+ deprecate(
1361
+ "callback_steps",
1362
+ "1.0.0",
1363
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
1364
+ )
1365
+
1366
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
1367
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
1368
+
1369
+ # 0. Default height and width to unet
1370
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
1371
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
1372
+
1373
+ # 1. Check inputs
1374
+ self.check_inputs(
1375
+ prompt,
1376
+ prompt_2,
1377
+ image,
1378
+ mask_image,
1379
+ height,
1380
+ width,
1381
+ strength,
1382
+ callback_steps,
1383
+ output_type,
1384
+ negative_prompt,
1385
+ negative_prompt_2,
1386
+ prompt_embeds,
1387
+ negative_prompt_embeds,
1388
+ ip_adapter_image,
1389
+ ip_adapter_image_embeds,
1390
+ callback_on_step_end_tensor_inputs,
1391
+ padding_mask_crop,
1392
+ )
1393
+
1394
+ self._guidance_scale = guidance_scale
1395
+ self._guidance_rescale = guidance_rescale
1396
+ self._clip_skip = clip_skip
1397
+ self._cross_attention_kwargs = cross_attention_kwargs
1398
+ self._denoising_end = denoising_end
1399
+ self._denoising_start = denoising_start
1400
+ self._interrupt = False
1401
+
1402
+ # 2. Define call parameters
1403
+ if prompt is not None and isinstance(prompt, str):
1404
+ batch_size = 1
1405
+ elif prompt is not None and isinstance(prompt, list):
1406
+ batch_size = len(prompt)
1407
+ else:
1408
+ batch_size = prompt_embeds.shape[0]
1409
+
1410
+ device = self._execution_device
1411
+
1412
+ # 3. Encode input prompt
1413
+ text_encoder_lora_scale = (
1414
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1415
+ )
1416
+
1417
+ (
1418
+ prompt_embeds,
1419
+ negative_prompt_embeds,
1420
+ pooled_prompt_embeds,
1421
+ negative_pooled_prompt_embeds,
1422
+ ) = self.encode_prompt(
1423
+ prompt=prompt,
1424
+ device=device,
1425
+ num_images_per_prompt=num_images_per_prompt,
1426
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1427
+ negative_prompt=negative_prompt,
1428
+ prompt_embeds=prompt_embeds,
1429
+ negative_prompt_embeds=negative_prompt_embeds,
1430
+ pooled_prompt_embeds=pooled_prompt_embeds,
1431
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1432
+ lora_scale=text_encoder_lora_scale,
1433
+ )
1434
+
1435
+ # 4. set timesteps
1436
+ def denoising_value_valid(dnv):
1437
+ return isinstance(dnv, float) and 0 < dnv < 1
1438
+
1439
+ timesteps, num_inference_steps = retrieve_timesteps(
1440
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
1441
+ )
1442
+ timesteps, num_inference_steps = self.get_timesteps(
1443
+ num_inference_steps,
1444
+ strength,
1445
+ device,
1446
+ denoising_start=self.denoising_start if denoising_value_valid(self.denoising_start) else None,
1447
+ )
1448
+ # check that number of inference steps is not < 1 - as this doesn't make sense
1449
+ if num_inference_steps < 1:
1450
+ raise ValueError(
1451
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
1452
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
1453
+ )
1454
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
1455
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1456
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
1457
+ is_strength_max = strength == 1.0
1458
+
1459
+ # 5. Preprocess mask and image
1460
+ if padding_mask_crop is not None:
1461
+ crops_coords = self.mask_processor.get_crop_region(mask_image, width, height, pad=padding_mask_crop)
1462
+ resize_mode = "fill"
1463
+ else:
1464
+ crops_coords = None
1465
+ resize_mode = "default"
1466
+
1467
+ original_image = image
1468
+ init_image = self.image_processor.preprocess(
1469
+ image, height=height, width=width, crops_coords=crops_coords, resize_mode=resize_mode
1470
+ )
1471
+ init_image = init_image.to(dtype=torch.float32)
1472
+
1473
+ mask = self.mask_processor.preprocess(
1474
+ mask_image, height=height, width=width, resize_mode=resize_mode, crops_coords=crops_coords
1475
+ )
1476
+
1477
+ if masked_image_latents is not None:
1478
+ masked_image = masked_image_latents
1479
+ elif init_image.shape[1] == 4:
1480
+ # if images are in latent space, we can't mask it
1481
+ masked_image = None
1482
+ else:
1483
+ masked_image = init_image * (mask < 0.5)
1484
+
1485
+ # 6. Prepare latent variables
1486
+ num_channels_latents = self.vae.config.latent_channels
1487
+ num_channels_unet = self.unet.config.in_channels
1488
+ return_image_latents = num_channels_unet == 4
1489
+
1490
+ add_noise = True if self.denoising_start is None else False
1491
+ latents_outputs = self.prepare_latents(
1492
+ batch_size * num_images_per_prompt,
1493
+ num_channels_latents,
1494
+ height,
1495
+ width,
1496
+ prompt_embeds.dtype,
1497
+ device,
1498
+ generator,
1499
+ latents,
1500
+ image=init_image,
1501
+ timestep=latent_timestep,
1502
+ is_strength_max=is_strength_max,
1503
+ add_noise=add_noise,
1504
+ return_noise=True,
1505
+ return_image_latents=return_image_latents,
1506
+ )
1507
+
1508
+ if return_image_latents:
1509
+ latents, noise, image_latents = latents_outputs
1510
+ else:
1511
+ latents, noise = latents_outputs
1512
+
1513
+ # 7. Prepare mask latent variables
1514
+ mask, masked_image_latents = self.prepare_mask_latents(
1515
+ mask,
1516
+ masked_image,
1517
+ batch_size * num_images_per_prompt,
1518
+ height,
1519
+ width,
1520
+ prompt_embeds.dtype,
1521
+ device,
1522
+ generator,
1523
+ self.do_classifier_free_guidance,
1524
+ )
1525
+
1526
+ # 8. Check that sizes of mask, masked image and latents match
1527
+ if num_channels_unet == 9:
1528
+ # default case for runwayml/stable-diffusion-inpainting
1529
+ num_channels_mask = mask.shape[1]
1530
+ num_channels_masked_image = masked_image_latents.shape[1]
1531
+ if num_channels_latents + num_channels_mask + num_channels_masked_image != self.unet.config.in_channels:
1532
+ raise ValueError(
1533
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
1534
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
1535
+ f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
1536
+ f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
1537
+ " `pipeline.unet` or your `mask_image` or `image` input."
1538
+ )
1539
+ elif num_channels_unet != 4:
1540
+ raise ValueError(
1541
+ f"The unet {self.unet.__class__} should have either 4 or 9 input channels, not {self.unet.config.in_channels}."
1542
+ )
1543
+ # 8.1 Prepare extra step kwargs.
1544
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1545
+
1546
+ # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1547
+ height, width = latents.shape[-2:]
1548
+ height = height * self.vae_scale_factor
1549
+ width = width * self.vae_scale_factor
1550
+
1551
+ original_size = original_size or (height, width)
1552
+ target_size = target_size or (height, width)
1553
+
1554
+ # 10. Prepare added time ids & embeddings
1555
+ if negative_original_size is None:
1556
+ negative_original_size = original_size
1557
+ if negative_target_size is None:
1558
+ negative_target_size = target_size
1559
+
1560
+ add_text_embeds = pooled_prompt_embeds
1561
+ if self.text_encoder_2 is None:
1562
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1563
+ else:
1564
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1565
+
1566
+ add_time_ids, add_neg_time_ids = self._get_add_time_ids(
1567
+ original_size,
1568
+ crops_coords_top_left,
1569
+ target_size,
1570
+ aesthetic_score,
1571
+ negative_aesthetic_score,
1572
+ negative_original_size,
1573
+ negative_crops_coords_top_left,
1574
+ negative_target_size,
1575
+ dtype=prompt_embeds.dtype,
1576
+ text_encoder_projection_dim=text_encoder_projection_dim,
1577
+ )
1578
+ add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1579
+
1580
+ if self.do_classifier_free_guidance:
1581
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1582
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1583
+ add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1)
1584
+ add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0)
1585
+
1586
+ prompt_embeds = prompt_embeds.to(device)
1587
+ add_text_embeds = add_text_embeds.to(device)
1588
+ add_time_ids = add_time_ids.to(device)
1589
+
1590
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1591
+ image_embeds = self.prepare_ip_adapter_image_embeds(
1592
+ ip_adapter_image,
1593
+ ip_adapter_image_embeds,
1594
+ device,
1595
+ batch_size * num_images_per_prompt,
1596
+ self.do_classifier_free_guidance,
1597
+ )
1598
+
1599
+
1600
+ # 11. Denoising loop
1601
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1602
+
1603
+ if (
1604
+ self.denoising_end is not None
1605
+ and self.denoising_start is not None
1606
+ and denoising_value_valid(self.denoising_end)
1607
+ and denoising_value_valid(self.denoising_start)
1608
+ and self.denoising_start >= self.denoising_end
1609
+ ):
1610
+ raise ValueError(
1611
+ f"`denoising_start`: {self.denoising_start} cannot be larger than or equal to `denoising_end`: "
1612
+ + f" {self.denoising_end} when using type float."
1613
+ )
1614
+ elif self.denoising_end is not None and denoising_value_valid(self.denoising_end):
1615
+ discrete_timestep_cutoff = int(
1616
+ round(
1617
+ self.scheduler.config.num_train_timesteps
1618
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1619
+ )
1620
+ )
1621
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1622
+ timesteps = timesteps[:num_inference_steps]
1623
+
1624
+ # 11.1 Optionally get Guidance Scale Embedding
1625
+ timestep_cond = None
1626
+ if self.unet.config.time_cond_proj_dim is not None:
1627
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1628
+ timestep_cond = self.get_guidance_scale_embedding(
1629
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1630
+ ).to(device=device, dtype=latents.dtype)
1631
+
1632
+ self._num_timesteps = len(timesteps)
1633
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1634
+ for i, t in enumerate(timesteps):
1635
+ if self.interrupt:
1636
+ continue
1637
+ # expand the latents if we are doing classifier free guidance
1638
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1639
+
1640
+ # concat latents, mask, masked_image_latents in the channel dimension
1641
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1642
+
1643
+ if num_channels_unet == 9:
1644
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1645
+
1646
+ # predict the noise residual
1647
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1648
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
1649
+ added_cond_kwargs["image_embeds"] = image_embeds
1650
+ noise_pred = self.unet(
1651
+ latent_model_input,
1652
+ t,
1653
+ encoder_hidden_states=prompt_embeds,
1654
+ timestep_cond=timestep_cond,
1655
+ cross_attention_kwargs=self.cross_attention_kwargs,
1656
+ added_cond_kwargs=added_cond_kwargs,
1657
+ return_dict=False,
1658
+ )[0]
1659
+
1660
+ # perform guidance
1661
+ if self.do_classifier_free_guidance:
1662
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1663
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1664
+
1665
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1666
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1667
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1668
+
1669
+ # compute the previous noisy sample x_t -> x_t-1
1670
+ latents_dtype = latents.dtype
1671
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1672
+ if latents.dtype != latents_dtype:
1673
+ if torch.backends.mps.is_available():
1674
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1675
+ latents = latents.to(latents_dtype)
1676
+
1677
+ if num_channels_unet == 4:
1678
+ init_latents_proper = image_latents
1679
+ if self.do_classifier_free_guidance:
1680
+ init_mask, _ = mask.chunk(2)
1681
+ else:
1682
+ init_mask = mask
1683
+
1684
+ if i < len(timesteps) - 1:
1685
+ noise_timestep = timesteps[i + 1]
1686
+ init_latents_proper = self.scheduler.add_noise(
1687
+ init_latents_proper, noise, torch.tensor([noise_timestep])
1688
+ )
1689
+
1690
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1691
+
1692
+ if callback_on_step_end is not None:
1693
+ callback_kwargs = {}
1694
+ for k in callback_on_step_end_tensor_inputs:
1695
+ callback_kwargs[k] = locals()[k]
1696
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1697
+
1698
+ latents = callback_outputs.pop("latents", latents)
1699
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1700
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1701
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1702
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1703
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1704
+ )
1705
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1706
+ add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
1707
+ mask = callback_outputs.pop("mask", mask)
1708
+ masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
1709
+
1710
+ # call the callback, if provided
1711
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1712
+ progress_bar.update()
1713
+ if callback is not None and i % callback_steps == 0:
1714
+ step_idx = i // getattr(self.scheduler, "order", 1)
1715
+ callback(step_idx, t, latents)
1716
+
1717
+ if XLA_AVAILABLE:
1718
+ xm.mark_step()
1719
+
1720
+ if not output_type == "latent":
1721
+ # make sure the VAE is in float32 mode, as it overflows in float16
1722
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1723
+
1724
+ if needs_upcasting:
1725
+ self.upcast_vae()
1726
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1727
+ elif latents.dtype != self.vae.dtype:
1728
+ if torch.backends.mps.is_available():
1729
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1730
+ self.vae = self.vae.to(latents.dtype)
1731
+
1732
+ # unscale/denormalize the latents
1733
+ # denormalize with the mean and std if available and not None
1734
+ has_latents_mean = hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None
1735
+ has_latents_std = hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None
1736
+ if has_latents_mean and has_latents_std:
1737
+ latents_mean = (
1738
+ torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1739
+ )
1740
+ latents_std = (
1741
+ torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1).to(latents.device, latents.dtype)
1742
+ )
1743
+ latents = latents * latents_std / self.vae.config.scaling_factor + latents_mean
1744
+ else:
1745
+ latents = latents / self.vae.config.scaling_factor
1746
+
1747
+ image = self.vae.decode(latents, return_dict=False)[0]
1748
+
1749
+ # cast back to fp16 if needed
1750
+ if needs_upcasting:
1751
+ self.vae.to(dtype=torch.float16)
1752
+ else:
1753
+ return StableDiffusionXLPipelineOutput(images=latents)
1754
+
1755
+ # apply watermark if available
1756
+ if self.watermark is not None:
1757
+ image = self.watermark.apply_watermark(image)
1758
+
1759
+ image = self.image_processor.postprocess(image, output_type=output_type)
1760
+
1761
+ if padding_mask_crop is not None:
1762
+ image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
1763
+
1764
+ # Offload all models
1765
+ self.maybe_free_model_hooks()
1766
+
1767
+ if not return_dict:
1768
+ return (image,)
1769
+
1770
+ return StableDiffusionXLPipelineOutput(images=image)