Vijish commited on
Commit
03246eb
1 Parent(s): fc6571d

Upload 3 files

Browse files
pipeline_controlnet_img2img.py ADDED
@@ -0,0 +1,1114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import inspect
17
+ import os
18
+ import warnings
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import PIL.Image
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
26
+
27
+ from ...image_processor import VaeImageProcessor
28
+ from ...loaders import TextualInversionLoaderMixin
29
+ from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
30
+ from ...schedulers import KarrasDiffusionSchedulers
31
+ from ...utils import (
32
+ PIL_INTERPOLATION,
33
+ deprecate,
34
+ is_accelerate_available,
35
+ is_accelerate_version,
36
+ is_compiled_module,
37
+ logging,
38
+ randn_tensor,
39
+ replace_example_docstring,
40
+ )
41
+ from ..pipeline_utils import DiffusionPipeline
42
+ from ..stable_diffusion import StableDiffusionPipelineOutput
43
+ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
44
+ from .multicontrolnet import MultiControlNetModel
45
+
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+
50
+ EXAMPLE_DOC_STRING = """
51
+ Examples:
52
+ ```py
53
+ >>> # !pip install opencv-python transformers accelerate
54
+ >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler
55
+ >>> from diffusers.utils import load_image
56
+ >>> import numpy as np
57
+ >>> import torch
58
+
59
+ >>> import cv2
60
+ >>> from PIL import Image
61
+
62
+ >>> # download an image
63
+ >>> image = load_image(
64
+ ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
65
+ ... )
66
+ >>> np_image = np.array(image)
67
+
68
+ >>> # get canny image
69
+ >>> np_image = cv2.Canny(np_image, 100, 200)
70
+ >>> np_image = np_image[:, :, None]
71
+ >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2)
72
+ >>> canny_image = Image.fromarray(np_image)
73
+
74
+ >>> # load control net and stable diffusion v1-5
75
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
76
+ >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
77
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
78
+ ... )
79
+
80
+ >>> # speed up diffusion process with faster scheduler and memory optimization
81
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
82
+ >>> pipe.enable_model_cpu_offload()
83
+
84
+ >>> # generate image
85
+ >>> generator = torch.manual_seed(0)
86
+ >>> image = pipe(
87
+ ... "futuristic-looking woman",
88
+ ... num_inference_steps=20,
89
+ ... generator=generator,
90
+ ... image=image,
91
+ ... control_image=canny_image,
92
+ ... ).images[0]
93
+ ```
94
+ """
95
+
96
+
97
+ def prepare_image(image):
98
+ if isinstance(image, torch.Tensor):
99
+ # Batch single image
100
+ if image.ndim == 3:
101
+ image = image.unsqueeze(0)
102
+
103
+ image = image.to(dtype=torch.float32)
104
+ else:
105
+ # preprocess image
106
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
107
+ image = [image]
108
+
109
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
110
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
111
+ image = np.concatenate(image, axis=0)
112
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
113
+ image = np.concatenate([i[None, :] for i in image], axis=0)
114
+
115
+ image = image.transpose(0, 3, 1, 2)
116
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
117
+
118
+ return image
119
+
120
+
121
+ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
122
+ r"""
123
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
124
+
125
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
126
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
127
+
128
+ In addition the pipeline inherits the following loading methods:
129
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
130
+
131
+ Args:
132
+ vae ([`AutoencoderKL`]):
133
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
134
+ text_encoder ([`CLIPTextModel`]):
135
+ Frozen text-encoder. Stable Diffusion uses the text portion of
136
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
137
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
138
+ tokenizer (`CLIPTokenizer`):
139
+ Tokenizer of class
140
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
141
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
142
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
143
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
144
+ as a list, the outputs from each ControlNet are added together to create one combined additional
145
+ conditioning.
146
+ scheduler ([`SchedulerMixin`]):
147
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
148
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
149
+ safety_checker ([`StableDiffusionSafetyChecker`]):
150
+ Classification module that estimates whether generated images could be considered offensive or harmful.
151
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
152
+ feature_extractor ([`CLIPImageProcessor`]):
153
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
154
+ """
155
+ _optional_components = ["safety_checker", "feature_extractor"]
156
+
157
+ def __init__(
158
+ self,
159
+ vae: AutoencoderKL,
160
+ text_encoder: CLIPTextModel,
161
+ tokenizer: CLIPTokenizer,
162
+ unet: UNet2DConditionModel,
163
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
164
+ scheduler: KarrasDiffusionSchedulers,
165
+ safety_checker: StableDiffusionSafetyChecker,
166
+ feature_extractor: CLIPImageProcessor,
167
+ requires_safety_checker: bool = True,
168
+ ):
169
+ super().__init__()
170
+
171
+ if safety_checker is None and requires_safety_checker:
172
+ logger.warning(
173
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
174
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
175
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
176
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
177
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
178
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
179
+ )
180
+
181
+ if safety_checker is not None and feature_extractor is None:
182
+ raise ValueError(
183
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
184
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
185
+ )
186
+
187
+ if isinstance(controlnet, (list, tuple)):
188
+ controlnet = MultiControlNetModel(controlnet)
189
+
190
+ self.register_modules(
191
+ vae=vae,
192
+ text_encoder=text_encoder,
193
+ tokenizer=tokenizer,
194
+ unet=unet,
195
+ controlnet=controlnet,
196
+ scheduler=scheduler,
197
+ safety_checker=safety_checker,
198
+ feature_extractor=feature_extractor,
199
+ )
200
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
201
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
202
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
203
+
204
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
205
+ def enable_vae_slicing(self):
206
+ r"""
207
+ Enable sliced VAE decoding.
208
+
209
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
210
+ steps. This is useful to save some memory and allow larger batch sizes.
211
+ """
212
+ self.vae.enable_slicing()
213
+
214
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
215
+ def disable_vae_slicing(self):
216
+ r"""
217
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
218
+ computing decoding in one step.
219
+ """
220
+ self.vae.disable_slicing()
221
+
222
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
223
+ def enable_vae_tiling(self):
224
+ r"""
225
+ Enable tiled VAE decoding.
226
+
227
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
228
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
229
+ """
230
+ self.vae.enable_tiling()
231
+
232
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
233
+ def disable_vae_tiling(self):
234
+ r"""
235
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
236
+ computing decoding in one step.
237
+ """
238
+ self.vae.disable_tiling()
239
+
240
+ def enable_sequential_cpu_offload(self, gpu_id=0):
241
+ r"""
242
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
243
+ text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
244
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
245
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
246
+ `enable_model_cpu_offload`, but performance is lower.
247
+ """
248
+ if is_accelerate_available():
249
+ from accelerate import cpu_offload
250
+ else:
251
+ raise ImportError("Please install accelerate via `pip install accelerate`")
252
+
253
+ device = torch.device(f"cuda:{gpu_id}")
254
+
255
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
256
+ cpu_offload(cpu_offloaded_model, device)
257
+
258
+ if self.safety_checker is not None:
259
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
260
+
261
+ def enable_model_cpu_offload(self, gpu_id=0):
262
+ r"""
263
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
264
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
265
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
266
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
267
+ """
268
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
269
+ from accelerate import cpu_offload_with_hook
270
+ else:
271
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
272
+
273
+ device = torch.device(f"cuda:{gpu_id}")
274
+
275
+ hook = None
276
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
277
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
278
+
279
+ if self.safety_checker is not None:
280
+ # the safety checker can offload the vae again
281
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
282
+
283
+ # control net hook has be manually offloaded as it alternates with unet
284
+ cpu_offload_with_hook(self.controlnet, device)
285
+
286
+ # We'll offload the last model manually.
287
+ self.final_offload_hook = hook
288
+
289
+ @property
290
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
291
+ def _execution_device(self):
292
+ r"""
293
+ Returns the device on which the pipeline's models will be executed. After calling
294
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
295
+ hooks.
296
+ """
297
+ if not hasattr(self.unet, "_hf_hook"):
298
+ return self.device
299
+ for module in self.unet.modules():
300
+ if (
301
+ hasattr(module, "_hf_hook")
302
+ and hasattr(module._hf_hook, "execution_device")
303
+ and module._hf_hook.execution_device is not None
304
+ ):
305
+ return torch.device(module._hf_hook.execution_device)
306
+ return self.device
307
+
308
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
309
+ def _encode_prompt(
310
+ self,
311
+ prompt,
312
+ device,
313
+ num_images_per_prompt,
314
+ do_classifier_free_guidance,
315
+ negative_prompt=None,
316
+ prompt_embeds: Optional[torch.FloatTensor] = None,
317
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
318
+ ):
319
+ r"""
320
+ Encodes the prompt into text encoder hidden states.
321
+
322
+ Args:
323
+ prompt (`str` or `List[str]`, *optional*):
324
+ prompt to be encoded
325
+ device: (`torch.device`):
326
+ torch device
327
+ num_images_per_prompt (`int`):
328
+ number of images that should be generated per prompt
329
+ do_classifier_free_guidance (`bool`):
330
+ whether to use classifier free guidance or not
331
+ negative_prompt (`str` or `List[str]`, *optional*):
332
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
333
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
334
+ less than `1`).
335
+ prompt_embeds (`torch.FloatTensor`, *optional*):
336
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
337
+ provided, text embeddings will be generated from `prompt` input argument.
338
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
339
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
340
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
341
+ argument.
342
+ """
343
+ if prompt is not None and isinstance(prompt, str):
344
+ batch_size = 1
345
+ elif prompt is not None and isinstance(prompt, list):
346
+ batch_size = len(prompt)
347
+ else:
348
+ batch_size = prompt_embeds.shape[0]
349
+
350
+ if prompt_embeds is None:
351
+ # textual inversion: procecss multi-vector tokens if necessary
352
+ if isinstance(self, TextualInversionLoaderMixin):
353
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
354
+
355
+ text_inputs = self.tokenizer(
356
+ prompt,
357
+ padding="max_length",
358
+ max_length=self.tokenizer.model_max_length,
359
+ truncation=True,
360
+ return_tensors="pt",
361
+ )
362
+ text_input_ids = text_inputs.input_ids
363
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
364
+
365
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
366
+ text_input_ids, untruncated_ids
367
+ ):
368
+ removed_text = self.tokenizer.batch_decode(
369
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
370
+ )
371
+ logger.warning(
372
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
373
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
374
+ )
375
+
376
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
377
+ attention_mask = text_inputs.attention_mask.to(device)
378
+ else:
379
+ attention_mask = None
380
+
381
+ prompt_embeds = self.text_encoder(
382
+ text_input_ids.to(device),
383
+ attention_mask=attention_mask,
384
+ )
385
+ prompt_embeds = prompt_embeds[0]
386
+
387
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
388
+
389
+ bs_embed, seq_len, _ = prompt_embeds.shape
390
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
391
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
392
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
393
+
394
+ # get unconditional embeddings for classifier free guidance
395
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
396
+ uncond_tokens: List[str]
397
+ if negative_prompt is None:
398
+ uncond_tokens = [""] * batch_size
399
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
400
+ raise TypeError(
401
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
402
+ f" {type(prompt)}."
403
+ )
404
+ elif isinstance(negative_prompt, str):
405
+ uncond_tokens = [negative_prompt]
406
+ elif batch_size != len(negative_prompt):
407
+ raise ValueError(
408
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
409
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
410
+ " the batch size of `prompt`."
411
+ )
412
+ else:
413
+ uncond_tokens = negative_prompt
414
+
415
+ # textual inversion: procecss multi-vector tokens if necessary
416
+ if isinstance(self, TextualInversionLoaderMixin):
417
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
418
+
419
+ max_length = prompt_embeds.shape[1]
420
+ uncond_input = self.tokenizer(
421
+ uncond_tokens,
422
+ padding="max_length",
423
+ max_length=max_length,
424
+ truncation=True,
425
+ return_tensors="pt",
426
+ )
427
+
428
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
429
+ attention_mask = uncond_input.attention_mask.to(device)
430
+ else:
431
+ attention_mask = None
432
+
433
+ negative_prompt_embeds = self.text_encoder(
434
+ uncond_input.input_ids.to(device),
435
+ attention_mask=attention_mask,
436
+ )
437
+ negative_prompt_embeds = negative_prompt_embeds[0]
438
+
439
+ if do_classifier_free_guidance:
440
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
441
+ seq_len = negative_prompt_embeds.shape[1]
442
+
443
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
444
+
445
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
446
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
447
+
448
+ # For classifier free guidance, we need to do two forward passes.
449
+ # Here we concatenate the unconditional and text embeddings into a single batch
450
+ # to avoid doing two forward passes
451
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
452
+
453
+ return prompt_embeds
454
+
455
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
456
+ def run_safety_checker(self, image, device, dtype):
457
+ if self.safety_checker is None:
458
+ has_nsfw_concept = None
459
+ else:
460
+ if torch.is_tensor(image):
461
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
462
+ else:
463
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
464
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
465
+ image, has_nsfw_concept = self.safety_checker(
466
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
467
+ )
468
+ return image, has_nsfw_concept
469
+
470
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
471
+ def decode_latents(self, latents):
472
+ warnings.warn(
473
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
474
+ " use VaeImageProcessor instead",
475
+ FutureWarning,
476
+ )
477
+ latents = 1 / self.vae.config.scaling_factor * latents
478
+ image = self.vae.decode(latents, return_dict=False)[0]
479
+ image = (image / 2 + 0.5).clamp(0, 1)
480
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
481
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
482
+ return image
483
+
484
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
485
+ def prepare_extra_step_kwargs(self, generator, eta):
486
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
487
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
488
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
489
+ # and should be between [0, 1]
490
+
491
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
492
+ extra_step_kwargs = {}
493
+ if accepts_eta:
494
+ extra_step_kwargs["eta"] = eta
495
+
496
+ # check if the scheduler accepts generator
497
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
498
+ if accepts_generator:
499
+ extra_step_kwargs["generator"] = generator
500
+ return extra_step_kwargs
501
+
502
+ def check_inputs(
503
+ self,
504
+ prompt,
505
+ image,
506
+ height,
507
+ width,
508
+ callback_steps,
509
+ negative_prompt=None,
510
+ prompt_embeds=None,
511
+ negative_prompt_embeds=None,
512
+ controlnet_conditioning_scale=1.0,
513
+ ):
514
+ if height % 8 != 0 or width % 8 != 0:
515
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
516
+
517
+ if (callback_steps is None) or (
518
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
519
+ ):
520
+ raise ValueError(
521
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
522
+ f" {type(callback_steps)}."
523
+ )
524
+
525
+ if prompt is not None and prompt_embeds is not None:
526
+ raise ValueError(
527
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
528
+ " only forward one of the two."
529
+ )
530
+ elif prompt is None and prompt_embeds is None:
531
+ raise ValueError(
532
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
533
+ )
534
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
535
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
536
+
537
+ if negative_prompt is not None and negative_prompt_embeds is not None:
538
+ raise ValueError(
539
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
540
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
541
+ )
542
+
543
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
544
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
545
+ raise ValueError(
546
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
547
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
548
+ f" {negative_prompt_embeds.shape}."
549
+ )
550
+
551
+ # `prompt` needs more sophisticated handling when there are multiple
552
+ # conditionings.
553
+ if isinstance(self.controlnet, MultiControlNetModel):
554
+ if isinstance(prompt, list):
555
+ logger.warning(
556
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
557
+ " prompts. The conditionings will be fixed across the prompts."
558
+ )
559
+
560
+ # Check `image`
561
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
562
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
563
+ )
564
+ if (
565
+ isinstance(self.controlnet, ControlNetModel)
566
+ or is_compiled
567
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
568
+ ):
569
+ self.check_image(image, prompt, prompt_embeds)
570
+ elif (
571
+ isinstance(self.controlnet, MultiControlNetModel)
572
+ or is_compiled
573
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
574
+ ):
575
+ if not isinstance(image, list):
576
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
577
+
578
+ # When `image` is a nested list:
579
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
580
+ elif any(isinstance(i, list) for i in image):
581
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
582
+ elif len(image) != len(self.controlnet.nets):
583
+ raise ValueError(
584
+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
585
+ )
586
+
587
+ for image_ in image:
588
+ self.check_image(image_, prompt, prompt_embeds)
589
+ else:
590
+ assert False
591
+
592
+ # Check `controlnet_conditioning_scale`
593
+ if (
594
+ isinstance(self.controlnet, ControlNetModel)
595
+ or is_compiled
596
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
597
+ ):
598
+ if not isinstance(controlnet_conditioning_scale, float):
599
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
600
+ elif (
601
+ isinstance(self.controlnet, MultiControlNetModel)
602
+ or is_compiled
603
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
604
+ ):
605
+ if isinstance(controlnet_conditioning_scale, list):
606
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
607
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
608
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
609
+ self.controlnet.nets
610
+ ):
611
+ raise ValueError(
612
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
613
+ " the same length as the number of controlnets"
614
+ )
615
+ else:
616
+ assert False
617
+
618
+ def check_image(self, image, prompt, prompt_embeds):
619
+ image_is_pil = isinstance(image, PIL.Image.Image)
620
+ image_is_tensor = isinstance(image, torch.Tensor)
621
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
622
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
623
+
624
+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
625
+ raise TypeError(
626
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
627
+ )
628
+
629
+ if image_is_pil:
630
+ image_batch_size = 1
631
+ elif image_is_tensor:
632
+ image_batch_size = image.shape[0]
633
+ elif image_is_pil_list:
634
+ image_batch_size = len(image)
635
+ elif image_is_tensor_list:
636
+ image_batch_size = len(image)
637
+
638
+ if prompt is not None and isinstance(prompt, str):
639
+ prompt_batch_size = 1
640
+ elif prompt is not None and isinstance(prompt, list):
641
+ prompt_batch_size = len(prompt)
642
+ elif prompt_embeds is not None:
643
+ prompt_batch_size = prompt_embeds.shape[0]
644
+
645
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
646
+ raise ValueError(
647
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
648
+ )
649
+
650
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
651
+ def prepare_control_image(
652
+ self,
653
+ image,
654
+ width,
655
+ height,
656
+ batch_size,
657
+ num_images_per_prompt,
658
+ device,
659
+ dtype,
660
+ do_classifier_free_guidance=False,
661
+ guess_mode=False,
662
+ ):
663
+ if not isinstance(image, torch.Tensor):
664
+ if isinstance(image, PIL.Image.Image):
665
+ image = [image]
666
+
667
+ if isinstance(image[0], PIL.Image.Image):
668
+ images = []
669
+
670
+ for image_ in image:
671
+ image_ = image_.convert("RGB")
672
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
673
+ image_ = np.array(image_)
674
+ image_ = image_[None, :]
675
+ images.append(image_)
676
+
677
+ image = images
678
+
679
+ image = np.concatenate(image, axis=0)
680
+ image = np.array(image).astype(np.float32) / 255.0
681
+ image = image.transpose(0, 3, 1, 2)
682
+ image = torch.from_numpy(image)
683
+ elif isinstance(image[0], torch.Tensor):
684
+ image = torch.cat(image, dim=0)
685
+
686
+ image_batch_size = image.shape[0]
687
+
688
+ if image_batch_size == 1:
689
+ repeat_by = batch_size
690
+ else:
691
+ # image batch size is the same as prompt batch size
692
+ repeat_by = num_images_per_prompt
693
+
694
+ image = image.repeat_interleave(repeat_by, dim=0)
695
+
696
+ image = image.to(device=device, dtype=dtype)
697
+
698
+ if do_classifier_free_guidance and not guess_mode:
699
+ image = torch.cat([image] * 2)
700
+
701
+ return image
702
+
703
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
704
+ def get_timesteps(self, num_inference_steps, strength, device):
705
+ # get the original timestep using init_timestep
706
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
707
+
708
+ t_start = max(num_inference_steps - init_timestep, 0)
709
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
710
+
711
+ return timesteps, num_inference_steps - t_start
712
+
713
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
714
+ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
715
+ if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
716
+ raise ValueError(
717
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
718
+ )
719
+
720
+ image = image.to(device=device, dtype=dtype)
721
+
722
+ batch_size = batch_size * num_images_per_prompt
723
+ if isinstance(generator, list) and len(generator) != batch_size:
724
+ raise ValueError(
725
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
726
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
727
+ )
728
+
729
+ if isinstance(generator, list):
730
+ init_latents = [
731
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
732
+ ]
733
+ init_latents = torch.cat(init_latents, dim=0)
734
+ else:
735
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
736
+
737
+ init_latents = self.vae.config.scaling_factor * init_latents
738
+
739
+ if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0:
740
+ # expand init_latents for batch_size
741
+ deprecation_message = (
742
+ f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial"
743
+ " images (`image`). Initial images are now duplicating to match the number of text prompts. Note"
744
+ " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update"
745
+ " your script to pass as many initial images as text prompts to suppress this warning."
746
+ )
747
+ deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False)
748
+ additional_image_per_prompt = batch_size // init_latents.shape[0]
749
+ init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0)
750
+ elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0:
751
+ raise ValueError(
752
+ f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts."
753
+ )
754
+ else:
755
+ init_latents = torch.cat([init_latents], dim=0)
756
+
757
+ shape = init_latents.shape
758
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
759
+
760
+ # get latents
761
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
762
+ latents = init_latents
763
+
764
+ return latents
765
+
766
+ def _default_height_width(self, height, width, image):
767
+ # NOTE: It is possible that a list of images have different
768
+ # dimensions for each image, so just checking the first image
769
+ # is not _exactly_ correct, but it is simple.
770
+ while isinstance(image, list):
771
+ image = image[0]
772
+
773
+ if height is None:
774
+ if isinstance(image, PIL.Image.Image):
775
+ height = image.height
776
+ elif isinstance(image, torch.Tensor):
777
+ height = image.shape[2]
778
+
779
+ height = (height // 8) * 8 # round down to nearest multiple of 8
780
+
781
+ if width is None:
782
+ if isinstance(image, PIL.Image.Image):
783
+ width = image.width
784
+ elif isinstance(image, torch.Tensor):
785
+ width = image.shape[3]
786
+
787
+ width = (width // 8) * 8 # round down to nearest multiple of 8
788
+
789
+ return height, width
790
+
791
+ # override DiffusionPipeline
792
+ def save_pretrained(
793
+ self,
794
+ save_directory: Union[str, os.PathLike],
795
+ safe_serialization: bool = False,
796
+ variant: Optional[str] = None,
797
+ ):
798
+ if isinstance(self.controlnet, ControlNetModel):
799
+ super().save_pretrained(save_directory, safe_serialization, variant)
800
+ else:
801
+ raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")
802
+
803
+ @torch.no_grad()
804
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
805
+ def __call__(
806
+ self,
807
+ prompt: Union[str, List[str]] = None,
808
+ image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
809
+ control_image: Union[
810
+ torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
811
+ ] = None,
812
+ height: Optional[int] = None,
813
+ width: Optional[int] = None,
814
+ strength: float = 0.8,
815
+ num_inference_steps: int = 50,
816
+ guidance_scale: float = 7.5,
817
+ negative_prompt: Optional[Union[str, List[str]]] = None,
818
+ num_images_per_prompt: Optional[int] = 1,
819
+ eta: float = 0.0,
820
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
821
+ latents: Optional[torch.FloatTensor] = None,
822
+ prompt_embeds: Optional[torch.FloatTensor] = None,
823
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
824
+ output_type: Optional[str] = "pil",
825
+ return_dict: bool = True,
826
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
827
+ callback_steps: int = 1,
828
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
829
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
830
+ guess_mode: bool = False,
831
+ ):
832
+ r"""
833
+ Function invoked when calling the pipeline for generation.
834
+
835
+ Args:
836
+ prompt (`str` or `List[str]`, *optional*):
837
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
838
+ instead.
839
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
840
+ `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
841
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
842
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
843
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
844
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
845
+ specified in init, images must be passed as a list such that each element of the list can be correctly
846
+ batched for input to a single controlnet.
847
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
848
+ The height in pixels of the generated image.
849
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
850
+ The width in pixels of the generated image.
851
+ num_inference_steps (`int`, *optional*, defaults to 50):
852
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
853
+ expense of slower inference.
854
+ guidance_scale (`float`, *optional*, defaults to 7.5):
855
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
856
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
857
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
858
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
859
+ usually at the expense of lower image quality.
860
+ negative_prompt (`str` or `List[str]`, *optional*):
861
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
862
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
863
+ less than `1`).
864
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
865
+ The number of images to generate per prompt.
866
+ eta (`float`, *optional*, defaults to 0.0):
867
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
868
+ [`schedulers.DDIMScheduler`], will be ignored for others.
869
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
870
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
871
+ to make generation deterministic.
872
+ latents (`torch.FloatTensor`, *optional*):
873
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
874
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
875
+ tensor will ge generated by sampling using the supplied random `generator`.
876
+ prompt_embeds (`torch.FloatTensor`, *optional*):
877
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
878
+ provided, text embeddings will be generated from `prompt` input argument.
879
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
880
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
881
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
882
+ argument.
883
+ output_type (`str`, *optional*, defaults to `"pil"`):
884
+ The output format of the generate image. Choose between
885
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
886
+ return_dict (`bool`, *optional*, defaults to `True`):
887
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
888
+ plain tuple.
889
+ callback (`Callable`, *optional*):
890
+ A function that will be called every `callback_steps` steps during inference. The function will be
891
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
892
+ callback_steps (`int`, *optional*, defaults to 1):
893
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
894
+ called at every step.
895
+ cross_attention_kwargs (`dict`, *optional*):
896
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
897
+ `self.processor` in
898
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
899
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
900
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
901
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
902
+ corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
903
+ than for [`~StableDiffusionControlNetPipeline.__call__`].
904
+ guess_mode (`bool`, *optional*, defaults to `False`):
905
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
906
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
907
+
908
+ Examples:
909
+
910
+ Returns:
911
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
912
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
913
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
914
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
915
+ (nsfw) content, according to the `safety_checker`.
916
+ """
917
+ # 0. Default height and width to unet
918
+ height, width = self._default_height_width(height, width, image)
919
+
920
+ # 1. Check inputs. Raise error if not correct
921
+ self.check_inputs(
922
+ prompt,
923
+ control_image,
924
+ height,
925
+ width,
926
+ callback_steps,
927
+ negative_prompt,
928
+ prompt_embeds,
929
+ negative_prompt_embeds,
930
+ controlnet_conditioning_scale,
931
+ )
932
+
933
+ # 2. Define call parameters
934
+ if prompt is not None and isinstance(prompt, str):
935
+ batch_size = 1
936
+ elif prompt is not None and isinstance(prompt, list):
937
+ batch_size = len(prompt)
938
+ else:
939
+ batch_size = prompt_embeds.shape[0]
940
+
941
+ device = self._execution_device
942
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
943
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
944
+ # corresponds to doing no classifier free guidance.
945
+ do_classifier_free_guidance = guidance_scale > 1.0
946
+
947
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
948
+
949
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
950
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
951
+
952
+ global_pool_conditions = (
953
+ controlnet.config.global_pool_conditions
954
+ if isinstance(controlnet, ControlNetModel)
955
+ else controlnet.nets[0].config.global_pool_conditions
956
+ )
957
+ guess_mode = guess_mode or global_pool_conditions
958
+
959
+ # 3. Encode input prompt
960
+ prompt_embeds = self._encode_prompt(
961
+ prompt,
962
+ device,
963
+ num_images_per_prompt,
964
+ do_classifier_free_guidance,
965
+ negative_prompt,
966
+ prompt_embeds=prompt_embeds,
967
+ negative_prompt_embeds=negative_prompt_embeds,
968
+ )
969
+ # 4. Prepare image, and controlnet_conditioning_image
970
+ image = prepare_image(image)
971
+
972
+ # 5. Prepare image
973
+ if isinstance(controlnet, ControlNetModel):
974
+ control_image = self.prepare_control_image(
975
+ image=control_image,
976
+ width=width,
977
+ height=height,
978
+ batch_size=batch_size * num_images_per_prompt,
979
+ num_images_per_prompt=num_images_per_prompt,
980
+ device=device,
981
+ dtype=controlnet.dtype,
982
+ do_classifier_free_guidance=do_classifier_free_guidance,
983
+ guess_mode=guess_mode,
984
+ )
985
+ elif isinstance(controlnet, MultiControlNetModel):
986
+ control_images = []
987
+
988
+ for control_image_ in control_image:
989
+ control_image_ = self.prepare_control_image(
990
+ image=control_image_,
991
+ width=width,
992
+ height=height,
993
+ batch_size=batch_size * num_images_per_prompt,
994
+ num_images_per_prompt=num_images_per_prompt,
995
+ device=device,
996
+ dtype=controlnet.dtype,
997
+ do_classifier_free_guidance=do_classifier_free_guidance,
998
+ guess_mode=guess_mode,
999
+ )
1000
+
1001
+ control_images.append(control_image_)
1002
+
1003
+ control_image = control_images
1004
+ else:
1005
+ assert False
1006
+
1007
+ # 5. Prepare timesteps
1008
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1009
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
1010
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1011
+
1012
+ # 6. Prepare latent variables
1013
+ latents = self.prepare_latents(
1014
+ image,
1015
+ latent_timestep,
1016
+ batch_size,
1017
+ num_images_per_prompt,
1018
+ prompt_embeds.dtype,
1019
+ device,
1020
+ generator,
1021
+ )
1022
+
1023
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1024
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1025
+
1026
+ # 8. Denoising loop
1027
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1028
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1029
+ for i, t in enumerate(timesteps):
1030
+ # expand the latents if we are doing classifier free guidance
1031
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1032
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1033
+
1034
+ # controlnet(s) inference
1035
+ if guess_mode and do_classifier_free_guidance:
1036
+ # Infer ControlNet only for the conditional batch.
1037
+ control_model_input = latents
1038
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1039
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1040
+ else:
1041
+ control_model_input = latent_model_input
1042
+ controlnet_prompt_embeds = prompt_embeds
1043
+
1044
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1045
+ control_model_input,
1046
+ t,
1047
+ encoder_hidden_states=controlnet_prompt_embeds,
1048
+ controlnet_cond=control_image,
1049
+ conditioning_scale=controlnet_conditioning_scale,
1050
+ guess_mode=guess_mode,
1051
+ return_dict=False,
1052
+ )
1053
+
1054
+ if guess_mode and do_classifier_free_guidance:
1055
+ # Infered ControlNet only for the conditional batch.
1056
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1057
+ # add 0 to the unconditional batch to keep it unchanged.
1058
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1059
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1060
+
1061
+ # predict the noise residual
1062
+ noise_pred = self.unet(
1063
+ latent_model_input,
1064
+ t,
1065
+ encoder_hidden_states=prompt_embeds,
1066
+ cross_attention_kwargs=cross_attention_kwargs,
1067
+ down_block_additional_residuals=down_block_res_samples,
1068
+ mid_block_additional_residual=mid_block_res_sample,
1069
+ return_dict=False,
1070
+ )[0]
1071
+
1072
+ # perform guidance
1073
+ if do_classifier_free_guidance:
1074
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1075
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1076
+
1077
+ # compute the previous noisy sample x_t -> x_t-1
1078
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1079
+
1080
+ # call the callback, if provided
1081
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1082
+ progress_bar.update()
1083
+ if callback is not None and i % callback_steps == 0:
1084
+ callback(i, t, latents)
1085
+
1086
+ # If we do sequential model offloading, let's offload unet and controlnet
1087
+ # manually for max memory savings
1088
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1089
+ self.unet.to("cpu")
1090
+ self.controlnet.to("cpu")
1091
+ torch.cuda.empty_cache()
1092
+
1093
+ if not output_type == "latent":
1094
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1095
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1096
+ else:
1097
+ image = latents
1098
+ has_nsfw_concept = None
1099
+
1100
+ if has_nsfw_concept is None:
1101
+ do_denormalize = [True] * image.shape[0]
1102
+ else:
1103
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1104
+
1105
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1106
+
1107
+ # Offload last model to CPU
1108
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1109
+ self.final_offload_hook.offload()
1110
+
1111
+ if not return_dict:
1112
+ return (image, has_nsfw_concept)
1113
+
1114
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
pipeline_controlnet_inpaint.py ADDED
@@ -0,0 +1,1344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # This model implementation is heavily inspired by https://github.com/haofanwang/ControlNet-for-Diffusers/
16
+
17
+ import inspect
18
+ import os
19
+ import warnings
20
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
21
+
22
+ import numpy as np
23
+ import PIL.Image
24
+ import torch
25
+ import torch.nn.functional as F
26
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
27
+
28
+ from ...image_processor import VaeImageProcessor
29
+ from ...loaders import TextualInversionLoaderMixin
30
+ from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
31
+ from ...schedulers import KarrasDiffusionSchedulers
32
+ from ...utils import (
33
+ PIL_INTERPOLATION,
34
+ is_accelerate_available,
35
+ is_accelerate_version,
36
+ is_compiled_module,
37
+ logging,
38
+ randn_tensor,
39
+ replace_example_docstring,
40
+ )
41
+ from ..pipeline_utils import DiffusionPipeline
42
+ from ..stable_diffusion import StableDiffusionPipelineOutput
43
+ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
44
+ from .multicontrolnet import MultiControlNetModel
45
+
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+
50
+ EXAMPLE_DOC_STRING = """
51
+ Examples:
52
+ ```py
53
+ >>> # !pip install transformers accelerate
54
+ >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
55
+ >>> from diffusers.utils import load_image
56
+ >>> import numpy as np
57
+ >>> import torch
58
+
59
+ >>> init_image = load_image(
60
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
61
+ ... )
62
+ >>> init_image = init_image.resize((512, 512))
63
+
64
+ >>> generator = torch.Generator(device="cpu").manual_seed(1)
65
+
66
+ >>> mask_image = load_image(
67
+ ... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
68
+ ... )
69
+ >>> mask_image = mask_image.resize((512, 512))
70
+
71
+
72
+ >>> def make_inpaint_condition(image, image_mask):
73
+ ... image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
74
+ ... image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
75
+
76
+ ... assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
77
+ ... image[image_mask > 0.5] = -1.0 # set as masked pixel
78
+ ... image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
79
+ ... image = torch.from_numpy(image)
80
+ ... return image
81
+
82
+
83
+ >>> control_image = make_inpaint_condition(init_image, mask_image)
84
+
85
+ >>> controlnet = ControlNetModel.from_pretrained(
86
+ ... "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
87
+ ... )
88
+ >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
89
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
90
+ ... )
91
+
92
+ >>> pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
93
+ >>> pipe.enable_model_cpu_offload()
94
+
95
+ >>> # generate image
96
+ >>> image = pipe(
97
+ ... "a handsome man with ray-ban sunglasses",
98
+ ... num_inference_steps=20,
99
+ ... generator=generator,
100
+ ... eta=1.0,
101
+ ... image=init_image,
102
+ ... mask_image=mask_image,
103
+ ... control_image=control_image,
104
+ ... ).images[0]
105
+ ```
106
+ """
107
+
108
+
109
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
110
+ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
111
+ """
112
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
113
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
114
+ ``image`` and ``1`` for the ``mask``.
115
+
116
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
117
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
118
+
119
+ Args:
120
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
121
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
122
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
123
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
124
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
125
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
126
+
127
+
128
+ Raises:
129
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
130
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
131
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
132
+ (ot the other way around).
133
+
134
+ Returns:
135
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
136
+ dimensions: ``batch x channels x height x width``.
137
+ """
138
+
139
+ if image is None:
140
+ raise ValueError("`image` input cannot be undefined.")
141
+
142
+ if mask is None:
143
+ raise ValueError("`mask_image` input cannot be undefined.")
144
+
145
+ if isinstance(image, torch.Tensor):
146
+ if not isinstance(mask, torch.Tensor):
147
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
148
+
149
+ # Batch single image
150
+ if image.ndim == 3:
151
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
152
+ image = image.unsqueeze(0)
153
+
154
+ # Batch and add channel dim for single mask
155
+ if mask.ndim == 2:
156
+ mask = mask.unsqueeze(0).unsqueeze(0)
157
+
158
+ # Batch single mask or add channel dim
159
+ if mask.ndim == 3:
160
+ # Single batched mask, no channel dim or single mask not batched but channel dim
161
+ if mask.shape[0] == 1:
162
+ mask = mask.unsqueeze(0)
163
+
164
+ # Batched masks no channel dim
165
+ else:
166
+ mask = mask.unsqueeze(1)
167
+
168
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
169
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
170
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
171
+
172
+ # Check image is in [-1, 1]
173
+ if image.min() < -1 or image.max() > 1:
174
+ raise ValueError("Image should be in [-1, 1] range")
175
+
176
+ # Check mask is in [0, 1]
177
+ if mask.min() < 0 or mask.max() > 1:
178
+ raise ValueError("Mask should be in [0, 1] range")
179
+
180
+ # Binarize mask
181
+ mask[mask < 0.5] = 0
182
+ mask[mask >= 0.5] = 1
183
+
184
+ # Image as float32
185
+ image = image.to(dtype=torch.float32)
186
+ elif isinstance(mask, torch.Tensor):
187
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
188
+ else:
189
+ # preprocess image
190
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
191
+ image = [image]
192
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
193
+ # resize all images w.r.t passed height an width
194
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
195
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
196
+ image = np.concatenate(image, axis=0)
197
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
198
+ image = np.concatenate([i[None, :] for i in image], axis=0)
199
+
200
+ image = image.transpose(0, 3, 1, 2)
201
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
202
+
203
+ # preprocess mask
204
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
205
+ mask = [mask]
206
+
207
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
208
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
209
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
210
+ mask = mask.astype(np.float32) / 255.0
211
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
212
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
213
+
214
+ mask[mask < 0.5] = 0
215
+ mask[mask >= 0.5] = 1
216
+ mask = torch.from_numpy(mask)
217
+
218
+ masked_image = image * (mask < 0.5)
219
+
220
+ # n.b. ensure backwards compatibility as old function does not return image
221
+ if return_image:
222
+ return mask, masked_image, image
223
+
224
+ return mask, masked_image
225
+
226
+
227
+ class StableDiffusionControlNetInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
228
+ r"""
229
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
230
+
231
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
232
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
233
+
234
+ In addition the pipeline inherits the following loading methods:
235
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
236
+
237
+ <Tip>
238
+
239
+ This pipeline can be used both with checkpoints that have been specifically fine-tuned for inpainting, such as
240
+ [runwayml/stable-diffusion-inpainting](https://huggingface.co/runwayml/stable-diffusion-inpainting)
241
+ as well as default text-to-image stable diffusion checkpoints, such as
242
+ [runwayml/stable-diffusion-v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
243
+ Default text-to-image stable diffusion checkpoints might be preferable for controlnets that have been fine-tuned on
244
+ those, such as [lllyasviel/control_v11p_sd15_inpaint](https://huggingface.co/lllyasviel/control_v11p_sd15_inpaint).
245
+
246
+ </Tip>
247
+
248
+ Args:
249
+ vae ([`AutoencoderKL`]):
250
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
251
+ text_encoder ([`CLIPTextModel`]):
252
+ Frozen text-encoder. Stable Diffusion uses the text portion of
253
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
254
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
255
+ tokenizer (`CLIPTokenizer`):
256
+ Tokenizer of class
257
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
258
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
259
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
260
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
261
+ as a list, the outputs from each ControlNet are added together to create one combined additional
262
+ conditioning.
263
+ scheduler ([`SchedulerMixin`]):
264
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
265
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
266
+ safety_checker ([`StableDiffusionSafetyChecker`]):
267
+ Classification module that estimates whether generated images could be considered offensive or harmful.
268
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
269
+ feature_extractor ([`CLIPImageProcessor`]):
270
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
271
+ """
272
+ _optional_components = ["safety_checker", "feature_extractor"]
273
+
274
+ def __init__(
275
+ self,
276
+ vae: AutoencoderKL,
277
+ text_encoder: CLIPTextModel,
278
+ tokenizer: CLIPTokenizer,
279
+ unet: UNet2DConditionModel,
280
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
281
+ scheduler: KarrasDiffusionSchedulers,
282
+ safety_checker: StableDiffusionSafetyChecker,
283
+ feature_extractor: CLIPImageProcessor,
284
+ requires_safety_checker: bool = True,
285
+ ):
286
+ super().__init__()
287
+
288
+ if safety_checker is None and requires_safety_checker:
289
+ logger.warning(
290
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
291
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
292
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
293
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
294
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
295
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
296
+ )
297
+
298
+ if safety_checker is not None and feature_extractor is None:
299
+ raise ValueError(
300
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
301
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
302
+ )
303
+
304
+ if isinstance(controlnet, (list, tuple)):
305
+ controlnet = MultiControlNetModel(controlnet)
306
+
307
+ self.register_modules(
308
+ vae=vae,
309
+ text_encoder=text_encoder,
310
+ tokenizer=tokenizer,
311
+ unet=unet,
312
+ controlnet=controlnet,
313
+ scheduler=scheduler,
314
+ safety_checker=safety_checker,
315
+ feature_extractor=feature_extractor,
316
+ )
317
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
318
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
319
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
320
+
321
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
322
+ def enable_vae_slicing(self):
323
+ r"""
324
+ Enable sliced VAE decoding.
325
+
326
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
327
+ steps. This is useful to save some memory and allow larger batch sizes.
328
+ """
329
+ self.vae.enable_slicing()
330
+
331
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
332
+ def disable_vae_slicing(self):
333
+ r"""
334
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
335
+ computing decoding in one step.
336
+ """
337
+ self.vae.disable_slicing()
338
+
339
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
340
+ def enable_vae_tiling(self):
341
+ r"""
342
+ Enable tiled VAE decoding.
343
+
344
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
345
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
346
+ """
347
+ self.vae.enable_tiling()
348
+
349
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
350
+ def disable_vae_tiling(self):
351
+ r"""
352
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
353
+ computing decoding in one step.
354
+ """
355
+ self.vae.disable_tiling()
356
+
357
+ def enable_sequential_cpu_offload(self, gpu_id=0):
358
+ r"""
359
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
360
+ text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
361
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
362
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
363
+ `enable_model_cpu_offload`, but performance is lower.
364
+ """
365
+ if is_accelerate_available():
366
+ from accelerate import cpu_offload
367
+ else:
368
+ raise ImportError("Please install accelerate via `pip install accelerate`")
369
+
370
+ device = torch.device(f"cuda:{gpu_id}")
371
+
372
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
373
+ cpu_offload(cpu_offloaded_model, device)
374
+
375
+ if self.safety_checker is not None:
376
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
377
+
378
+ def enable_model_cpu_offload(self, gpu_id=0):
379
+ r"""
380
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
381
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
382
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
383
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
384
+ """
385
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
386
+ from accelerate import cpu_offload_with_hook
387
+ else:
388
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
389
+
390
+ device = torch.device(f"cuda:{gpu_id}")
391
+
392
+ hook = None
393
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
394
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
395
+
396
+ if self.safety_checker is not None:
397
+ # the safety checker can offload the vae again
398
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
399
+
400
+ # control net hook has be manually offloaded as it alternates with unet
401
+ cpu_offload_with_hook(self.controlnet, device)
402
+
403
+ # We'll offload the last model manually.
404
+ self.final_offload_hook = hook
405
+
406
+ @property
407
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
408
+ def _execution_device(self):
409
+ r"""
410
+ Returns the device on which the pipeline's models will be executed. After calling
411
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
412
+ hooks.
413
+ """
414
+ if not hasattr(self.unet, "_hf_hook"):
415
+ return self.device
416
+ for module in self.unet.modules():
417
+ if (
418
+ hasattr(module, "_hf_hook")
419
+ and hasattr(module._hf_hook, "execution_device")
420
+ and module._hf_hook.execution_device is not None
421
+ ):
422
+ return torch.device(module._hf_hook.execution_device)
423
+ return self.device
424
+
425
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
426
+ def _encode_prompt(
427
+ self,
428
+ prompt,
429
+ device,
430
+ num_images_per_prompt,
431
+ do_classifier_free_guidance,
432
+ negative_prompt=None,
433
+ prompt_embeds: Optional[torch.FloatTensor] = None,
434
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
435
+ ):
436
+ r"""
437
+ Encodes the prompt into text encoder hidden states.
438
+
439
+ Args:
440
+ prompt (`str` or `List[str]`, *optional*):
441
+ prompt to be encoded
442
+ device: (`torch.device`):
443
+ torch device
444
+ num_images_per_prompt (`int`):
445
+ number of images that should be generated per prompt
446
+ do_classifier_free_guidance (`bool`):
447
+ whether to use classifier free guidance or not
448
+ negative_prompt (`str` or `List[str]`, *optional*):
449
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
450
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
451
+ less than `1`).
452
+ prompt_embeds (`torch.FloatTensor`, *optional*):
453
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
454
+ provided, text embeddings will be generated from `prompt` input argument.
455
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
456
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
457
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
458
+ argument.
459
+ """
460
+ if prompt is not None and isinstance(prompt, str):
461
+ batch_size = 1
462
+ elif prompt is not None and isinstance(prompt, list):
463
+ batch_size = len(prompt)
464
+ else:
465
+ batch_size = prompt_embeds.shape[0]
466
+
467
+ if prompt_embeds is None:
468
+ # textual inversion: procecss multi-vector tokens if necessary
469
+ if isinstance(self, TextualInversionLoaderMixin):
470
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
471
+
472
+ text_inputs = self.tokenizer(
473
+ prompt,
474
+ padding="max_length",
475
+ max_length=self.tokenizer.model_max_length,
476
+ truncation=True,
477
+ return_tensors="pt",
478
+ )
479
+ text_input_ids = text_inputs.input_ids
480
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
481
+
482
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
483
+ text_input_ids, untruncated_ids
484
+ ):
485
+ removed_text = self.tokenizer.batch_decode(
486
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
487
+ )
488
+ logger.warning(
489
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
490
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
491
+ )
492
+
493
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
494
+ attention_mask = text_inputs.attention_mask.to(device)
495
+ else:
496
+ attention_mask = None
497
+
498
+ prompt_embeds = self.text_encoder(
499
+ text_input_ids.to(device),
500
+ attention_mask=attention_mask,
501
+ )
502
+ prompt_embeds = prompt_embeds[0]
503
+
504
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
505
+
506
+ bs_embed, seq_len, _ = prompt_embeds.shape
507
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
508
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
509
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
510
+
511
+ # get unconditional embeddings for classifier free guidance
512
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
513
+ uncond_tokens: List[str]
514
+ if negative_prompt is None:
515
+ uncond_tokens = [""] * batch_size
516
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
517
+ raise TypeError(
518
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
519
+ f" {type(prompt)}."
520
+ )
521
+ elif isinstance(negative_prompt, str):
522
+ uncond_tokens = [negative_prompt]
523
+ elif batch_size != len(negative_prompt):
524
+ raise ValueError(
525
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
526
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
527
+ " the batch size of `prompt`."
528
+ )
529
+ else:
530
+ uncond_tokens = negative_prompt
531
+
532
+ # textual inversion: procecss multi-vector tokens if necessary
533
+ if isinstance(self, TextualInversionLoaderMixin):
534
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
535
+
536
+ max_length = prompt_embeds.shape[1]
537
+ uncond_input = self.tokenizer(
538
+ uncond_tokens,
539
+ padding="max_length",
540
+ max_length=max_length,
541
+ truncation=True,
542
+ return_tensors="pt",
543
+ )
544
+
545
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
546
+ attention_mask = uncond_input.attention_mask.to(device)
547
+ else:
548
+ attention_mask = None
549
+
550
+ negative_prompt_embeds = self.text_encoder(
551
+ uncond_input.input_ids.to(device),
552
+ attention_mask=attention_mask,
553
+ )
554
+ negative_prompt_embeds = negative_prompt_embeds[0]
555
+
556
+ if do_classifier_free_guidance:
557
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
558
+ seq_len = negative_prompt_embeds.shape[1]
559
+
560
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
561
+
562
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
563
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
564
+
565
+ # For classifier free guidance, we need to do two forward passes.
566
+ # Here we concatenate the unconditional and text embeddings into a single batch
567
+ # to avoid doing two forward passes
568
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
569
+
570
+ return prompt_embeds
571
+
572
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
573
+ def run_safety_checker(self, image, device, dtype):
574
+ if self.safety_checker is None:
575
+ has_nsfw_concept = None
576
+ else:
577
+ if torch.is_tensor(image):
578
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
579
+ else:
580
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
581
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
582
+ image, has_nsfw_concept = self.safety_checker(
583
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
584
+ )
585
+ return image, has_nsfw_concept
586
+
587
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
588
+ def decode_latents(self, latents):
589
+ warnings.warn(
590
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
591
+ " use VaeImageProcessor instead",
592
+ FutureWarning,
593
+ )
594
+ latents = 1 / self.vae.config.scaling_factor * latents
595
+ image = self.vae.decode(latents, return_dict=False)[0]
596
+ image = (image / 2 + 0.5).clamp(0, 1)
597
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
598
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
599
+ return image
600
+
601
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
602
+ def prepare_extra_step_kwargs(self, generator, eta):
603
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
604
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
605
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
606
+ # and should be between [0, 1]
607
+
608
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
609
+ extra_step_kwargs = {}
610
+ if accepts_eta:
611
+ extra_step_kwargs["eta"] = eta
612
+
613
+ # check if the scheduler accepts generator
614
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
615
+ if accepts_generator:
616
+ extra_step_kwargs["generator"] = generator
617
+ return extra_step_kwargs
618
+
619
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
620
+ def get_timesteps(self, num_inference_steps, strength, device):
621
+ # get the original timestep using init_timestep
622
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
623
+
624
+ t_start = max(num_inference_steps - init_timestep, 0)
625
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
626
+
627
+ return timesteps, num_inference_steps - t_start
628
+
629
+ def check_inputs(
630
+ self,
631
+ prompt,
632
+ image,
633
+ height,
634
+ width,
635
+ callback_steps,
636
+ negative_prompt=None,
637
+ prompt_embeds=None,
638
+ negative_prompt_embeds=None,
639
+ controlnet_conditioning_scale=1.0,
640
+ ):
641
+ if height % 8 != 0 or width % 8 != 0:
642
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
643
+
644
+ if (callback_steps is None) or (
645
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
646
+ ):
647
+ raise ValueError(
648
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
649
+ f" {type(callback_steps)}."
650
+ )
651
+
652
+ if prompt is not None and prompt_embeds is not None:
653
+ raise ValueError(
654
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
655
+ " only forward one of the two."
656
+ )
657
+ elif prompt is None and prompt_embeds is None:
658
+ raise ValueError(
659
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
660
+ )
661
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
662
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
663
+
664
+ if negative_prompt is not None and negative_prompt_embeds is not None:
665
+ raise ValueError(
666
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
667
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
668
+ )
669
+
670
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
671
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
672
+ raise ValueError(
673
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
674
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
675
+ f" {negative_prompt_embeds.shape}."
676
+ )
677
+
678
+ # `prompt` needs more sophisticated handling when there are multiple
679
+ # conditionings.
680
+ if isinstance(self.controlnet, MultiControlNetModel):
681
+ if isinstance(prompt, list):
682
+ logger.warning(
683
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
684
+ " prompts. The conditionings will be fixed across the prompts."
685
+ )
686
+
687
+ # Check `image`
688
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
689
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
690
+ )
691
+ if (
692
+ isinstance(self.controlnet, ControlNetModel)
693
+ or is_compiled
694
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
695
+ ):
696
+ self.check_image(image, prompt, prompt_embeds)
697
+ elif (
698
+ isinstance(self.controlnet, MultiControlNetModel)
699
+ or is_compiled
700
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
701
+ ):
702
+ if not isinstance(image, list):
703
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
704
+
705
+ # When `image` is a nested list:
706
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
707
+ elif any(isinstance(i, list) for i in image):
708
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
709
+ elif len(image) != len(self.controlnet.nets):
710
+ raise ValueError(
711
+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
712
+ )
713
+
714
+ for image_ in image:
715
+ self.check_image(image_, prompt, prompt_embeds)
716
+ else:
717
+ assert False
718
+
719
+ # Check `controlnet_conditioning_scale`
720
+ if (
721
+ isinstance(self.controlnet, ControlNetModel)
722
+ or is_compiled
723
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
724
+ ):
725
+ if not isinstance(controlnet_conditioning_scale, float):
726
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
727
+ elif (
728
+ isinstance(self.controlnet, MultiControlNetModel)
729
+ or is_compiled
730
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
731
+ ):
732
+ if isinstance(controlnet_conditioning_scale, list):
733
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
734
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
735
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
736
+ self.controlnet.nets
737
+ ):
738
+ raise ValueError(
739
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
740
+ " the same length as the number of controlnets"
741
+ )
742
+ else:
743
+ assert False
744
+
745
+ def check_image(self, image, prompt, prompt_embeds):
746
+ image_is_pil = isinstance(image, PIL.Image.Image)
747
+ image_is_tensor = isinstance(image, torch.Tensor)
748
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
749
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
750
+
751
+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
752
+ raise TypeError(
753
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
754
+ )
755
+
756
+ if image_is_pil:
757
+ image_batch_size = 1
758
+ elif image_is_tensor:
759
+ image_batch_size = image.shape[0]
760
+ elif image_is_pil_list:
761
+ image_batch_size = len(image)
762
+ elif image_is_tensor_list:
763
+ image_batch_size = len(image)
764
+
765
+ if prompt is not None and isinstance(prompt, str):
766
+ prompt_batch_size = 1
767
+ elif prompt is not None and isinstance(prompt, list):
768
+ prompt_batch_size = len(prompt)
769
+ elif prompt_embeds is not None:
770
+ prompt_batch_size = prompt_embeds.shape[0]
771
+
772
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
773
+ raise ValueError(
774
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
775
+ )
776
+
777
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
778
+ def prepare_control_image(
779
+ self,
780
+ image,
781
+ width,
782
+ height,
783
+ batch_size,
784
+ num_images_per_prompt,
785
+ device,
786
+ dtype,
787
+ do_classifier_free_guidance=False,
788
+ guess_mode=False,
789
+ ):
790
+ if not isinstance(image, torch.Tensor):
791
+ if isinstance(image, PIL.Image.Image):
792
+ image = [image]
793
+
794
+ if isinstance(image[0], PIL.Image.Image):
795
+ images = []
796
+
797
+ for image_ in image:
798
+ image_ = image_.convert("RGB")
799
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
800
+ image_ = np.array(image_)
801
+ image_ = image_[None, :]
802
+ images.append(image_)
803
+
804
+ image = images
805
+
806
+ image = np.concatenate(image, axis=0)
807
+ image = np.array(image).astype(np.float32) / 255.0
808
+ image = image.transpose(0, 3, 1, 2)
809
+ image = torch.from_numpy(image)
810
+ elif isinstance(image[0], torch.Tensor):
811
+ image = torch.cat(image, dim=0)
812
+
813
+ image_batch_size = image.shape[0]
814
+
815
+ if image_batch_size == 1:
816
+ repeat_by = batch_size
817
+ else:
818
+ # image batch size is the same as prompt batch size
819
+ repeat_by = num_images_per_prompt
820
+
821
+ image = image.repeat_interleave(repeat_by, dim=0)
822
+
823
+ image = image.to(device=device, dtype=dtype)
824
+
825
+ if do_classifier_free_guidance and not guess_mode:
826
+ image = torch.cat([image] * 2)
827
+
828
+ return image
829
+
830
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents
831
+ def prepare_latents(
832
+ self,
833
+ batch_size,
834
+ num_channels_latents,
835
+ height,
836
+ width,
837
+ dtype,
838
+ device,
839
+ generator,
840
+ latents=None,
841
+ image=None,
842
+ timestep=None,
843
+ is_strength_max=True,
844
+ return_noise=False,
845
+ return_image_latents=False,
846
+ ):
847
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
848
+ if isinstance(generator, list) and len(generator) != batch_size:
849
+ raise ValueError(
850
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
851
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
852
+ )
853
+
854
+ if (image is None or timestep is None) and not is_strength_max:
855
+ raise ValueError(
856
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
857
+ "However, either the image or the noise timestep has not been provided."
858
+ )
859
+
860
+ if return_image_latents or (latents is None and not is_strength_max):
861
+ image = image.to(device=device, dtype=dtype)
862
+ image_latents = self._encode_vae_image(image=image, generator=generator)
863
+
864
+ if latents is None:
865
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
866
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
867
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
868
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
869
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
870
+ else:
871
+ noise = latents.to(device)
872
+ latents = noise * self.scheduler.init_noise_sigma
873
+
874
+ outputs = (latents,)
875
+
876
+ if return_noise:
877
+ outputs += (noise,)
878
+
879
+ if return_image_latents:
880
+ outputs += (image_latents,)
881
+
882
+ return outputs
883
+
884
+ def _default_height_width(self, height, width, image):
885
+ # NOTE: It is possible that a list of images have different
886
+ # dimensions for each image, so just checking the first image
887
+ # is not _exactly_ correct, but it is simple.
888
+ while isinstance(image, list):
889
+ image = image[0]
890
+
891
+ if height is None:
892
+ if isinstance(image, PIL.Image.Image):
893
+ height = image.height
894
+ elif isinstance(image, torch.Tensor):
895
+ height = image.shape[2]
896
+
897
+ height = (height // 8) * 8 # round down to nearest multiple of 8
898
+
899
+ if width is None:
900
+ if isinstance(image, PIL.Image.Image):
901
+ width = image.width
902
+ elif isinstance(image, torch.Tensor):
903
+ width = image.shape[3]
904
+
905
+ width = (width // 8) * 8 # round down to nearest multiple of 8
906
+
907
+ return height, width
908
+
909
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
910
+ def prepare_mask_latents(
911
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
912
+ ):
913
+ # resize the mask to latents shape as we concatenate the mask to the latents
914
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
915
+ # and half precision
916
+ mask = torch.nn.functional.interpolate(
917
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
918
+ )
919
+ mask = mask.to(device=device, dtype=dtype)
920
+
921
+ masked_image = masked_image.to(device=device, dtype=dtype)
922
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
923
+
924
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
925
+ if mask.shape[0] < batch_size:
926
+ if not batch_size % mask.shape[0] == 0:
927
+ raise ValueError(
928
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
929
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
930
+ " of masks that you pass is divisible by the total requested batch size."
931
+ )
932
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
933
+ if masked_image_latents.shape[0] < batch_size:
934
+ if not batch_size % masked_image_latents.shape[0] == 0:
935
+ raise ValueError(
936
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
937
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
938
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
939
+ )
940
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
941
+
942
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
943
+ masked_image_latents = (
944
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
945
+ )
946
+
947
+ # aligning device to prevent device errors when concating it with the latent model input
948
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
949
+ return mask, masked_image_latents
950
+
951
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline._encode_vae_image
952
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
953
+ if isinstance(generator, list):
954
+ image_latents = [
955
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
956
+ for i in range(image.shape[0])
957
+ ]
958
+ image_latents = torch.cat(image_latents, dim=0)
959
+ else:
960
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
961
+
962
+ image_latents = self.vae.config.scaling_factor * image_latents
963
+
964
+ return image_latents
965
+
966
+ # override DiffusionPipeline
967
+ def save_pretrained(
968
+ self,
969
+ save_directory: Union[str, os.PathLike],
970
+ safe_serialization: bool = False,
971
+ variant: Optional[str] = None,
972
+ ):
973
+ if isinstance(self.controlnet, ControlNetModel):
974
+ super().save_pretrained(save_directory, safe_serialization, variant)
975
+ else:
976
+ raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")
977
+
978
+ @torch.no_grad()
979
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
980
+ def __call__(
981
+ self,
982
+ prompt: Union[str, List[str]] = None,
983
+ image: Union[torch.Tensor, PIL.Image.Image] = None,
984
+ mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
985
+ control_image: Union[
986
+ torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
987
+ ] = None,
988
+ height: Optional[int] = None,
989
+ width: Optional[int] = None,
990
+ strength: float = 1.0,
991
+ num_inference_steps: int = 50,
992
+ guidance_scale: float = 7.5,
993
+ negative_prompt: Optional[Union[str, List[str]]] = None,
994
+ num_images_per_prompt: Optional[int] = 1,
995
+ eta: float = 0.0,
996
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
997
+ latents: Optional[torch.FloatTensor] = None,
998
+ prompt_embeds: Optional[torch.FloatTensor] = None,
999
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1000
+ output_type: Optional[str] = "pil",
1001
+ return_dict: bool = True,
1002
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1003
+ callback_steps: int = 1,
1004
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1005
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.5,
1006
+ guess_mode: bool = False,
1007
+ ):
1008
+ r"""
1009
+ Function invoked when calling the pipeline for generation.
1010
+
1011
+ Args:
1012
+ prompt (`str` or `List[str]`, *optional*):
1013
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
1014
+ instead.
1015
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
1016
+ `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
1017
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
1018
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
1019
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
1020
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
1021
+ specified in init, images must be passed as a list such that each element of the list can be correctly
1022
+ batched for input to a single controlnet.
1023
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1024
+ The height in pixels of the generated image.
1025
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1026
+ The width in pixels of the generated image.
1027
+ strength (`float`, *optional*, defaults to 1.):
1028
+ Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be
1029
+ between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the
1030
+ `strength`. The number of denoising steps depends on the amount of noise initially added. When
1031
+ `strength` is 1, added noise will be maximum and the denoising process will run for the full number of
1032
+ iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked
1033
+ portion of the reference `image`.
1034
+ num_inference_steps (`int`, *optional*, defaults to 50):
1035
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1036
+ expense of slower inference.
1037
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1038
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1039
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1040
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1041
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1042
+ usually at the expense of lower image quality.
1043
+ negative_prompt (`str` or `List[str]`, *optional*):
1044
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1045
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1046
+ less than `1`).
1047
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1048
+ The number of images to generate per prompt.
1049
+ eta (`float`, *optional*, defaults to 0.0):
1050
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1051
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1052
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1053
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1054
+ to make generation deterministic.
1055
+ latents (`torch.FloatTensor`, *optional*):
1056
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1057
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1058
+ tensor will ge generated by sampling using the supplied random `generator`.
1059
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1060
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1061
+ provided, text embeddings will be generated from `prompt` input argument.
1062
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1063
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1064
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1065
+ argument.
1066
+ output_type (`str`, *optional*, defaults to `"pil"`):
1067
+ The output format of the generate image. Choose between
1068
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1069
+ return_dict (`bool`, *optional*, defaults to `True`):
1070
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1071
+ plain tuple.
1072
+ callback (`Callable`, *optional*):
1073
+ A function that will be called every `callback_steps` steps during inference. The function will be
1074
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1075
+ callback_steps (`int`, *optional*, defaults to 1):
1076
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1077
+ called at every step.
1078
+ cross_attention_kwargs (`dict`, *optional*):
1079
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1080
+ `self.processor` in
1081
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1082
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 0.5):
1083
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
1084
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
1085
+ corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
1086
+ than for [`~StableDiffusionControlNetPipeline.__call__`].
1087
+ guess_mode (`bool`, *optional*, defaults to `False`):
1088
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
1089
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
1090
+
1091
+ Examples:
1092
+
1093
+ Returns:
1094
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1095
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1096
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1097
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1098
+ (nsfw) content, according to the `safety_checker`.
1099
+ """
1100
+ # 0. Default height and width to unet
1101
+ height, width = self._default_height_width(height, width, image)
1102
+
1103
+ # 1. Check inputs. Raise error if not correct
1104
+ self.check_inputs(
1105
+ prompt,
1106
+ control_image,
1107
+ height,
1108
+ width,
1109
+ callback_steps,
1110
+ negative_prompt,
1111
+ prompt_embeds,
1112
+ negative_prompt_embeds,
1113
+ controlnet_conditioning_scale,
1114
+ )
1115
+
1116
+ # 2. Define call parameters
1117
+ if prompt is not None and isinstance(prompt, str):
1118
+ batch_size = 1
1119
+ elif prompt is not None and isinstance(prompt, list):
1120
+ batch_size = len(prompt)
1121
+ else:
1122
+ batch_size = prompt_embeds.shape[0]
1123
+
1124
+ device = self._execution_device
1125
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1126
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1127
+ # corresponds to doing no classifier free guidance.
1128
+ do_classifier_free_guidance = guidance_scale > 1.0
1129
+
1130
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1131
+
1132
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1133
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1134
+
1135
+ global_pool_conditions = (
1136
+ controlnet.config.global_pool_conditions
1137
+ if isinstance(controlnet, ControlNetModel)
1138
+ else controlnet.nets[0].config.global_pool_conditions
1139
+ )
1140
+ guess_mode = guess_mode or global_pool_conditions
1141
+
1142
+ # 3. Encode input prompt
1143
+ prompt_embeds = self._encode_prompt(
1144
+ prompt,
1145
+ device,
1146
+ num_images_per_prompt,
1147
+ do_classifier_free_guidance,
1148
+ negative_prompt,
1149
+ prompt_embeds=prompt_embeds,
1150
+ negative_prompt_embeds=negative_prompt_embeds,
1151
+ )
1152
+
1153
+ # 4. Prepare image
1154
+ if isinstance(controlnet, ControlNetModel):
1155
+ control_image = self.prepare_control_image(
1156
+ image=control_image,
1157
+ width=width,
1158
+ height=height,
1159
+ batch_size=batch_size * num_images_per_prompt,
1160
+ num_images_per_prompt=num_images_per_prompt,
1161
+ device=device,
1162
+ dtype=controlnet.dtype,
1163
+ do_classifier_free_guidance=do_classifier_free_guidance,
1164
+ guess_mode=guess_mode,
1165
+ )
1166
+ elif isinstance(controlnet, MultiControlNetModel):
1167
+ control_images = []
1168
+
1169
+ for control_image_ in control_image:
1170
+ control_image_ = self.prepare_control_image(
1171
+ image=control_image_,
1172
+ width=width,
1173
+ height=height,
1174
+ batch_size=batch_size * num_images_per_prompt,
1175
+ num_images_per_prompt=num_images_per_prompt,
1176
+ device=device,
1177
+ dtype=controlnet.dtype,
1178
+ do_classifier_free_guidance=do_classifier_free_guidance,
1179
+ guess_mode=guess_mode,
1180
+ )
1181
+
1182
+ control_images.append(control_image_)
1183
+
1184
+ control_image = control_images
1185
+ else:
1186
+ assert False
1187
+
1188
+ # 4. Preprocess mask and image - resizes image and mask w.r.t height and width
1189
+ mask, masked_image, init_image = prepare_mask_and_masked_image(
1190
+ image, mask_image, height, width, return_image=True
1191
+ )
1192
+
1193
+ # 5. Prepare timesteps
1194
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1195
+ timesteps, num_inference_steps = self.get_timesteps(
1196
+ num_inference_steps=num_inference_steps, strength=strength, device=device
1197
+ )
1198
+ # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
1199
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1200
+ # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
1201
+ is_strength_max = strength == 1.0
1202
+
1203
+ # 6. Prepare latent variables
1204
+ num_channels_latents = self.vae.config.latent_channels
1205
+ num_channels_unet = self.unet.config.in_channels
1206
+ return_image_latents = num_channels_unet == 4
1207
+ latents_outputs = self.prepare_latents(
1208
+ batch_size * num_images_per_prompt,
1209
+ num_channels_latents,
1210
+ height,
1211
+ width,
1212
+ prompt_embeds.dtype,
1213
+ device,
1214
+ generator,
1215
+ latents,
1216
+ image=init_image,
1217
+ timestep=latent_timestep,
1218
+ is_strength_max=is_strength_max,
1219
+ return_noise=True,
1220
+ return_image_latents=return_image_latents,
1221
+ )
1222
+
1223
+ if return_image_latents:
1224
+ latents, noise, image_latents = latents_outputs
1225
+ else:
1226
+ latents, noise = latents_outputs
1227
+
1228
+ # 7. Prepare mask latent variables
1229
+ mask, masked_image_latents = self.prepare_mask_latents(
1230
+ mask,
1231
+ masked_image,
1232
+ batch_size * num_images_per_prompt,
1233
+ height,
1234
+ width,
1235
+ prompt_embeds.dtype,
1236
+ device,
1237
+ generator,
1238
+ do_classifier_free_guidance,
1239
+ )
1240
+
1241
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1242
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1243
+
1244
+ # 8. Denoising loop
1245
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1246
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1247
+ for i, t in enumerate(timesteps):
1248
+ # expand the latents if we are doing classifier free guidance
1249
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1250
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1251
+
1252
+ # controlnet(s) inference
1253
+ if guess_mode and do_classifier_free_guidance:
1254
+ # Infer ControlNet only for the conditional batch.
1255
+ control_model_input = latents
1256
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1257
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1258
+ else:
1259
+ control_model_input = latent_model_input
1260
+ controlnet_prompt_embeds = prompt_embeds
1261
+
1262
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1263
+ control_model_input,
1264
+ t,
1265
+ encoder_hidden_states=controlnet_prompt_embeds,
1266
+ controlnet_cond=control_image,
1267
+ conditioning_scale=controlnet_conditioning_scale,
1268
+ guess_mode=guess_mode,
1269
+ return_dict=False,
1270
+ )
1271
+
1272
+ if guess_mode and do_classifier_free_guidance:
1273
+ # Infered ControlNet only for the conditional batch.
1274
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1275
+ # add 0 to the unconditional batch to keep it unchanged.
1276
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1277
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1278
+
1279
+ # predict the noise residual
1280
+ if num_channels_unet == 9:
1281
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1282
+
1283
+ noise_pred = self.unet(
1284
+ latent_model_input,
1285
+ t,
1286
+ encoder_hidden_states=prompt_embeds,
1287
+ cross_attention_kwargs=cross_attention_kwargs,
1288
+ down_block_additional_residuals=down_block_res_samples,
1289
+ mid_block_additional_residual=mid_block_res_sample,
1290
+ return_dict=False,
1291
+ )[0]
1292
+
1293
+ # perform guidance
1294
+ if do_classifier_free_guidance:
1295
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1296
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1297
+
1298
+ # compute the previous noisy sample x_t -> x_t-1
1299
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1300
+
1301
+ if num_channels_unet == 4:
1302
+ init_latents_proper = image_latents[:1]
1303
+ init_mask = mask[:1]
1304
+
1305
+ if i < len(timesteps) - 1:
1306
+ init_latents_proper = self.scheduler.add_noise(init_latents_proper, noise, torch.tensor([t]))
1307
+
1308
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1309
+
1310
+ # call the callback, if provided
1311
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1312
+ progress_bar.update()
1313
+ if callback is not None and i % callback_steps == 0:
1314
+ callback(i, t, latents)
1315
+
1316
+ # If we do sequential model offloading, let's offload unet and controlnet
1317
+ # manually for max memory savings
1318
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1319
+ self.unet.to("cpu")
1320
+ self.controlnet.to("cpu")
1321
+ torch.cuda.empty_cache()
1322
+
1323
+ if not output_type == "latent":
1324
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1325
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1326
+ else:
1327
+ image = latents
1328
+ has_nsfw_concept = None
1329
+
1330
+ if has_nsfw_concept is None:
1331
+ do_denormalize = [True] * image.shape[0]
1332
+ else:
1333
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1334
+
1335
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1336
+
1337
+ # Offload last model to CPU
1338
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1339
+ self.final_offload_hook.offload()
1340
+
1341
+ if not return_dict:
1342
+ return (image, has_nsfw_concept)
1343
+
1344
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
stable_diffusion_controlnet_inpaint_img2img.py ADDED
@@ -0,0 +1,1323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ import inspect
17
+ import os
18
+ import warnings
19
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
20
+
21
+ import numpy as np
22
+ import PIL.Image
23
+ import torch
24
+ import torch.nn.functional as F
25
+ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
26
+
27
+ from ...image_processor import VaeImageProcessor
28
+ from ...loaders import TextualInversionLoaderMixin
29
+ from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
30
+ from ...schedulers import KarrasDiffusionSchedulers
31
+
32
+ from ...utils import (
33
+ PIL_INTERPOLATION,
34
+ is_accelerate_available,
35
+ is_accelerate_version,
36
+ is_compiled_module,
37
+ logging,
38
+ randn_tensor,
39
+ replace_example_docstring,
40
+ )
41
+ from ..pipeline_utils import DiffusionPipeline
42
+ from ..stable_diffusion import StableDiffusionPipelineOutput
43
+ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
44
+ from .multicontrolnet import MultiControlNetModel
45
+
46
+
47
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
48
+
49
+
50
+ EXAMPLE_DOC_STRING = """
51
+ Examples:
52
+ ```py
53
+ >>> # !pip install opencv-python transformers accelerate
54
+ >>> from diffusers import StableDiffusionControlNetImg2ImgPipeline, ControlNetModel, UniPCMultistepScheduler
55
+ >>> from diffusers.utils import load_image
56
+ >>> import numpy as np
57
+ >>> import torch
58
+
59
+ >>> import cv2
60
+ >>> from PIL import Image
61
+
62
+ >>> # download an image
63
+ >>> image = load_image(
64
+ ... "https://hf.co/datasets/huggingface/documentation-images/resolve/main/diffusers/input_image_vermeer.png"
65
+ ... )
66
+ >>> np_image = np.array(image)
67
+
68
+ >>> # get canny image
69
+ >>> np_image = cv2.Canny(np_image, 100, 200)
70
+ >>> np_image = np_image[:, :, None]
71
+ >>> np_image = np.concatenate([np_image, np_image, np_image], axis=2)
72
+ >>> canny_image = Image.fromarray(np_image)
73
+
74
+ >>> # load control net and stable diffusion v1-5
75
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
76
+ >>> pipe = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
77
+ ... "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16
78
+ ... )
79
+
80
+ >>> # speed up diffusion process with faster scheduler and memory optimization
81
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
82
+ >>> pipe.enable_model_cpu_offload()
83
+
84
+ >>> # generate image
85
+ >>> generator = torch.manual_seed(0)
86
+ >>> image = pipe(
87
+ ... "futuristic-looking woman",
88
+ ... num_inference_steps=20,
89
+ ... generator=generator,
90
+ ... image=image,
91
+ ... control_image=canny_image,
92
+ ... ).images[0]
93
+ ```
94
+ """
95
+
96
+
97
+ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
98
+ """
99
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
100
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
101
+ ``image`` and ``1`` for the ``mask``.
102
+
103
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
104
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
105
+
106
+ Args:
107
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
108
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
109
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
110
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
111
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
112
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
113
+
114
+
115
+ Raises:
116
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
117
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
118
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
119
+ (ot the other way around).
120
+
121
+ Returns:
122
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
123
+ dimensions: ``batch x channels x height x width``.
124
+ """
125
+
126
+ if image is None:
127
+ raise ValueError("`image` input cannot be undefined.")
128
+
129
+ if mask is None:
130
+ raise ValueError("`mask_image` input cannot be undefined.")
131
+
132
+ if isinstance(image, torch.Tensor):
133
+ if not isinstance(mask, torch.Tensor):
134
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
135
+
136
+ # Batch single image
137
+ if image.ndim == 3:
138
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
139
+ image = image.unsqueeze(0)
140
+
141
+ # Batch and add channel dim for single mask
142
+ if mask.ndim == 2:
143
+ mask = mask.unsqueeze(0).unsqueeze(0)
144
+
145
+ # Batch single mask or add channel dim
146
+ if mask.ndim == 3:
147
+ # Single batched mask, no channel dim or single mask not batched but channel dim
148
+ if mask.shape[0] == 1:
149
+ mask = mask.unsqueeze(0)
150
+
151
+ # Batched masks no channel dim
152
+ else:
153
+ mask = mask.unsqueeze(1)
154
+
155
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
156
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
157
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
158
+
159
+ # Check image is in [-1, 1]
160
+ if image.min() < -1 or image.max() > 1:
161
+ raise ValueError("Image should be in [-1, 1] range")
162
+
163
+ # Check mask is in [0, 1]
164
+ if mask.min() < 0 or mask.max() > 1:
165
+ raise ValueError("Mask should be in [0, 1] range")
166
+
167
+ # Binarize mask
168
+ mask[mask < 0.5] = 0
169
+ mask[mask >= 0.5] = 1
170
+
171
+ # Image as float32
172
+ image = image.to(dtype=torch.float32)
173
+ elif isinstance(mask, torch.Tensor):
174
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
175
+ else:
176
+ # preprocess image
177
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
178
+ image = [image]
179
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
180
+ # resize all images w.r.t passed height an width
181
+ image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
182
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
183
+ image = np.concatenate(image, axis=0)
184
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
185
+ image = np.concatenate([i[None, :] for i in image], axis=0)
186
+
187
+ image = image.transpose(0, 3, 1, 2)
188
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
189
+
190
+ # preprocess mask
191
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
192
+ mask = [mask]
193
+
194
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
195
+ mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
196
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
197
+ mask = mask.astype(np.float32) / 255.0
198
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
199
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
200
+
201
+ mask[mask < 0.5] = 0
202
+ mask[mask >= 0.5] = 1
203
+ mask = torch.from_numpy(mask)
204
+
205
+ masked_image = image * (mask < 0.5)
206
+
207
+ # n.b. ensure backwards compatibility as old function does not return image
208
+ if return_image:
209
+ return mask, masked_image, image
210
+
211
+ return mask, masked_image
212
+
213
+
214
+ class StableDiffusionControlNetImg2ImgPipeline(DiffusionPipeline, TextualInversionLoaderMixin):
215
+ r"""
216
+ Pipeline for text-to-image generation using Stable Diffusion with ControlNet guidance.
217
+
218
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
219
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
220
+
221
+ In addition the pipeline inherits the following loading methods:
222
+ - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
223
+
224
+ Args:
225
+ vae ([`AutoencoderKL`]):
226
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
227
+ text_encoder ([`CLIPTextModel`]):
228
+ Frozen text-encoder. Stable Diffusion uses the text portion of
229
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
230
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
231
+ tokenizer (`CLIPTokenizer`):
232
+ Tokenizer of class
233
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
234
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
235
+ controlnet ([`ControlNetModel`] or `List[ControlNetModel]`):
236
+ Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets
237
+ as a list, the outputs from each ControlNet are added together to create one combined additional
238
+ conditioning.
239
+ scheduler ([`SchedulerMixin`]):
240
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
241
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
242
+ safety_checker ([`StableDiffusionSafetyChecker`]):
243
+ Classification module that estimates whether generated images could be considered offensive or harmful.
244
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
245
+ feature_extractor ([`CLIPImageProcessor`]):
246
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
247
+ """
248
+ _optional_components = ["safety_checker", "feature_extractor"]
249
+
250
+ def __init__(
251
+ self,
252
+ vae: AutoencoderKL,
253
+ text_encoder: CLIPTextModel,
254
+ tokenizer: CLIPTokenizer,
255
+ unet: UNet2DConditionModel,
256
+ controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
257
+ scheduler: KarrasDiffusionSchedulers,
258
+ safety_checker: StableDiffusionSafetyChecker,
259
+ feature_extractor: CLIPImageProcessor,
260
+ requires_safety_checker: bool = True,
261
+ ):
262
+ super().__init__()
263
+
264
+ if safety_checker is None and requires_safety_checker:
265
+ logger.warning(
266
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
267
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
268
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
269
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
270
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
271
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
272
+ )
273
+
274
+ if safety_checker is not None and feature_extractor is None:
275
+ raise ValueError(
276
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
277
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
278
+ )
279
+
280
+ if isinstance(controlnet, (list, tuple)):
281
+ controlnet = MultiControlNetModel(controlnet)
282
+
283
+ self.register_modules(
284
+ vae=vae,
285
+ text_encoder=text_encoder,
286
+ tokenizer=tokenizer,
287
+ unet=unet,
288
+ controlnet=controlnet,
289
+ scheduler=scheduler,
290
+ safety_checker=safety_checker,
291
+ feature_extractor=feature_extractor,
292
+ )
293
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
294
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
295
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
296
+
297
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
298
+ def enable_vae_slicing(self):
299
+ r"""
300
+ Enable sliced VAE decoding.
301
+
302
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
303
+ steps. This is useful to save some memory and allow larger batch sizes.
304
+ """
305
+ self.vae.enable_slicing()
306
+
307
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
308
+ def disable_vae_slicing(self):
309
+ r"""
310
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
311
+ computing decoding in one step.
312
+ """
313
+ self.vae.disable_slicing()
314
+
315
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
316
+ def enable_vae_tiling(self):
317
+ r"""
318
+ Enable tiled VAE decoding.
319
+
320
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
321
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
322
+ """
323
+ self.vae.enable_tiling()
324
+
325
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
326
+ def disable_vae_tiling(self):
327
+ r"""
328
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
329
+ computing decoding in one step.
330
+ """
331
+ self.vae.disable_tiling()
332
+
333
+ def enable_sequential_cpu_offload(self, gpu_id=0):
334
+ r"""
335
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
336
+ text_encoder, vae, controlnet, and safety checker have their state dicts saved to CPU and then are moved to a
337
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
338
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
339
+ `enable_model_cpu_offload`, but performance is lower.
340
+ """
341
+ if is_accelerate_available():
342
+ from accelerate import cpu_offload
343
+ else:
344
+ raise ImportError("Please install accelerate via `pip install accelerate`")
345
+
346
+ device = torch.device(f"cuda:{gpu_id}")
347
+
348
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.controlnet]:
349
+ cpu_offload(cpu_offloaded_model, device)
350
+
351
+ if self.safety_checker is not None:
352
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
353
+
354
+ def enable_model_cpu_offload(self, gpu_id=0):
355
+ r"""
356
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
357
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
358
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
359
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
360
+ """
361
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
362
+ from accelerate import cpu_offload_with_hook
363
+ else:
364
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
365
+
366
+ device = torch.device(f"cuda:{gpu_id}")
367
+
368
+ hook = None
369
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
370
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
371
+
372
+ if self.safety_checker is not None:
373
+ # the safety checker can offload the vae again
374
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
375
+
376
+ # control net hook has be manually offloaded as it alternates with unet
377
+ cpu_offload_with_hook(self.controlnet, device)
378
+
379
+ # We'll offload the last model manually.
380
+ self.final_offload_hook = hook
381
+
382
+ @property
383
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
384
+ def _execution_device(self):
385
+ r"""
386
+ Returns the device on which the pipeline's models will be executed. After calling
387
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
388
+ hooks.
389
+ """
390
+ if not hasattr(self.unet, "_hf_hook"):
391
+ return self.device
392
+ for module in self.unet.modules():
393
+ if (
394
+ hasattr(module, "_hf_hook")
395
+ and hasattr(module._hf_hook, "execution_device")
396
+ and module._hf_hook.execution_device is not None
397
+ ):
398
+ return torch.device(module._hf_hook.execution_device)
399
+ return self.device
400
+
401
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
402
+ def _encode_prompt(
403
+ self,
404
+ prompt,
405
+ device,
406
+ num_images_per_prompt,
407
+ do_classifier_free_guidance,
408
+ negative_prompt=None,
409
+ prompt_embeds: Optional[torch.FloatTensor] = None,
410
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
411
+ ):
412
+ r"""
413
+ Encodes the prompt into text encoder hidden states.
414
+
415
+ Args:
416
+ prompt (`str` or `List[str]`, *optional*):
417
+ prompt to be encoded
418
+ device: (`torch.device`):
419
+ torch device
420
+ num_images_per_prompt (`int`):
421
+ number of images that should be generated per prompt
422
+ do_classifier_free_guidance (`bool`):
423
+ whether to use classifier free guidance or not
424
+ negative_prompt (`str` or `List[str]`, *optional*):
425
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
426
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
427
+ less than `1`).
428
+ prompt_embeds (`torch.FloatTensor`, *optional*):
429
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
430
+ provided, text embeddings will be generated from `prompt` input argument.
431
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
432
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
433
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
434
+ argument.
435
+ """
436
+ if prompt is not None and isinstance(prompt, str):
437
+ batch_size = 1
438
+ elif prompt is not None and isinstance(prompt, list):
439
+ batch_size = len(prompt)
440
+ else:
441
+ batch_size = prompt_embeds.shape[0]
442
+
443
+ if prompt_embeds is None:
444
+ # textual inversion: procecss multi-vector tokens if necessary
445
+ if isinstance(self, TextualInversionLoaderMixin):
446
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
447
+
448
+ text_inputs = self.tokenizer(
449
+ prompt,
450
+ padding="max_length",
451
+ max_length=self.tokenizer.model_max_length,
452
+ truncation=True,
453
+ return_tensors="pt",
454
+ )
455
+ text_input_ids = text_inputs.input_ids
456
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
457
+
458
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
459
+ text_input_ids, untruncated_ids
460
+ ):
461
+ removed_text = self.tokenizer.batch_decode(
462
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
463
+ )
464
+ logger.warning(
465
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
466
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
467
+ )
468
+
469
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
470
+ attention_mask = text_inputs.attention_mask.to(device)
471
+ else:
472
+ attention_mask = None
473
+
474
+ prompt_embeds = self.text_encoder(
475
+ text_input_ids.to(device),
476
+ attention_mask=attention_mask,
477
+ )
478
+ prompt_embeds = prompt_embeds[0]
479
+
480
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
481
+
482
+ bs_embed, seq_len, _ = prompt_embeds.shape
483
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
484
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
485
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
486
+
487
+ # get unconditional embeddings for classifier free guidance
488
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
489
+ uncond_tokens: List[str]
490
+ if negative_prompt is None:
491
+ uncond_tokens = [""] * batch_size
492
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
493
+ raise TypeError(
494
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
495
+ f" {type(prompt)}."
496
+ )
497
+ elif isinstance(negative_prompt, str):
498
+ uncond_tokens = [negative_prompt]
499
+ elif batch_size != len(negative_prompt):
500
+ raise ValueError(
501
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
502
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
503
+ " the batch size of `prompt`."
504
+ )
505
+ else:
506
+ uncond_tokens = negative_prompt
507
+
508
+ # textual inversion: procecss multi-vector tokens if necessary
509
+ if isinstance(self, TextualInversionLoaderMixin):
510
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
511
+
512
+ max_length = prompt_embeds.shape[1]
513
+ uncond_input = self.tokenizer(
514
+ uncond_tokens,
515
+ padding="max_length",
516
+ max_length=max_length,
517
+ truncation=True,
518
+ return_tensors="pt",
519
+ )
520
+
521
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
522
+ attention_mask = uncond_input.attention_mask.to(device)
523
+ else:
524
+ attention_mask = None
525
+
526
+ negative_prompt_embeds = self.text_encoder(
527
+ uncond_input.input_ids.to(device),
528
+ attention_mask=attention_mask,
529
+ )
530
+ negative_prompt_embeds = negative_prompt_embeds[0]
531
+
532
+ if do_classifier_free_guidance:
533
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
534
+ seq_len = negative_prompt_embeds.shape[1]
535
+
536
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
537
+
538
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
539
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
540
+
541
+ # For classifier free guidance, we need to do two forward passes.
542
+ # Here we concatenate the unconditional and text embeddings into a single batch
543
+ # to avoid doing two forward passes
544
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
545
+
546
+ return prompt_embeds
547
+
548
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
549
+ def run_safety_checker(self, image, device, dtype):
550
+ if self.safety_checker is None:
551
+ has_nsfw_concept = None
552
+ else:
553
+ if torch.is_tensor(image):
554
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
555
+ else:
556
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
557
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
558
+ image, has_nsfw_concept = self.safety_checker(
559
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
560
+ )
561
+ return image, has_nsfw_concept
562
+
563
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
564
+ def decode_latents(self, latents):
565
+ warnings.warn(
566
+ "The decode_latents method is deprecated and will be removed in a future version. Please"
567
+ " use VaeImageProcessor instead",
568
+ FutureWarning,
569
+ )
570
+ latents = 1 / self.vae.config.scaling_factor * latents
571
+ image = self.vae.decode(latents, return_dict=False)[0]
572
+ image = (image / 2 + 0.5).clamp(0, 1)
573
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
574
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
575
+ return image
576
+
577
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
578
+ def prepare_extra_step_kwargs(self, generator, eta):
579
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
580
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
581
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
582
+ # and should be between [0, 1]
583
+
584
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
585
+ extra_step_kwargs = {}
586
+ if accepts_eta:
587
+ extra_step_kwargs["eta"] = eta
588
+
589
+ # check if the scheduler accepts generator
590
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
591
+ if accepts_generator:
592
+ extra_step_kwargs["generator"] = generator
593
+ return extra_step_kwargs
594
+
595
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
596
+ def get_timesteps(self, num_inference_steps, strength, device):
597
+ # get the original timestep using init_timestep
598
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
599
+
600
+ t_start = max(num_inference_steps - init_timestep, 0)
601
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
602
+
603
+ return timesteps, num_inference_steps - t_start
604
+
605
+ def check_inputs(
606
+ self,
607
+ prompt,
608
+ image,
609
+ height,
610
+ width,
611
+ callback_steps,
612
+ negative_prompt=None,
613
+ prompt_embeds=None,
614
+ negative_prompt_embeds=None,
615
+ controlnet_conditioning_scale=1.0,
616
+ ):
617
+ if height % 8 != 0 or width % 8 != 0:
618
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
619
+
620
+ if (callback_steps is None) or (
621
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
622
+ ):
623
+ raise ValueError(
624
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
625
+ f" {type(callback_steps)}."
626
+ )
627
+
628
+ if prompt is not None and prompt_embeds is not None:
629
+ raise ValueError(
630
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
631
+ " only forward one of the two."
632
+ )
633
+ elif prompt is None and prompt_embeds is None:
634
+ raise ValueError(
635
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
636
+ )
637
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
638
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
639
+
640
+ if negative_prompt is not None and negative_prompt_embeds is not None:
641
+ raise ValueError(
642
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
643
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
644
+ )
645
+
646
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
647
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
648
+ raise ValueError(
649
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
650
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
651
+ f" {negative_prompt_embeds.shape}."
652
+ )
653
+
654
+ # `prompt` needs more sophisticated handling when there are multiple
655
+ # conditionings.
656
+ if isinstance(self.controlnet, MultiControlNetModel):
657
+ if isinstance(prompt, list):
658
+ logger.warning(
659
+ f"You have {len(self.controlnet.nets)} ControlNets and you have passed {len(prompt)}"
660
+ " prompts. The conditionings will be fixed across the prompts."
661
+ )
662
+
663
+ # Check `image`
664
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
665
+ self.controlnet, torch._dynamo.eval_frame.OptimizedModule
666
+ )
667
+ if (
668
+ isinstance(self.controlnet, ControlNetModel)
669
+ or is_compiled
670
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
671
+ ):
672
+ self.check_image(image, prompt, prompt_embeds)
673
+ elif (
674
+ isinstance(self.controlnet, MultiControlNetModel)
675
+ or is_compiled
676
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
677
+ ):
678
+ if not isinstance(image, list):
679
+ raise TypeError("For multiple controlnets: `image` must be type `list`")
680
+
681
+ # When `image` is a nested list:
682
+ # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]])
683
+ elif any(isinstance(i, list) for i in image):
684
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
685
+ elif len(image) != len(self.controlnet.nets):
686
+ raise ValueError(
687
+ "For multiple controlnets: `image` must have the same length as the number of controlnets."
688
+ )
689
+
690
+ for image_ in image:
691
+ self.check_image(image_, prompt, prompt_embeds)
692
+ else:
693
+ assert False
694
+
695
+ # Check `controlnet_conditioning_scale`
696
+ if (
697
+ isinstance(self.controlnet, ControlNetModel)
698
+ or is_compiled
699
+ and isinstance(self.controlnet._orig_mod, ControlNetModel)
700
+ ):
701
+ if not isinstance(controlnet_conditioning_scale, float):
702
+ raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
703
+ elif (
704
+ isinstance(self.controlnet, MultiControlNetModel)
705
+ or is_compiled
706
+ and isinstance(self.controlnet._orig_mod, MultiControlNetModel)
707
+ ):
708
+ if isinstance(controlnet_conditioning_scale, list):
709
+ if any(isinstance(i, list) for i in controlnet_conditioning_scale):
710
+ raise ValueError("A single batch of multiple conditionings are supported at the moment.")
711
+ elif isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
712
+ self.controlnet.nets
713
+ ):
714
+ raise ValueError(
715
+ "For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
716
+ " the same length as the number of controlnets"
717
+ )
718
+ else:
719
+ assert False
720
+
721
+ def check_image(self, image, prompt, prompt_embeds):
722
+ image_is_pil = isinstance(image, PIL.Image.Image)
723
+ image_is_tensor = isinstance(image, torch.Tensor)
724
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
725
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
726
+
727
+ if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
728
+ raise TypeError(
729
+ "image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
730
+ )
731
+
732
+ if image_is_pil:
733
+ image_batch_size = 1
734
+ elif image_is_tensor:
735
+ image_batch_size = image.shape[0]
736
+ elif image_is_pil_list:
737
+ image_batch_size = len(image)
738
+ elif image_is_tensor_list:
739
+ image_batch_size = len(image)
740
+
741
+ if prompt is not None and isinstance(prompt, str):
742
+ prompt_batch_size = 1
743
+ elif prompt is not None and isinstance(prompt, list):
744
+ prompt_batch_size = len(prompt)
745
+ elif prompt_embeds is not None:
746
+ prompt_batch_size = prompt_embeds.shape[0]
747
+
748
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
749
+ raise ValueError(
750
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
751
+ )
752
+
753
+ # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.prepare_image
754
+ def prepare_control_image(
755
+ self,
756
+ image,
757
+ width,
758
+ height,
759
+ batch_size,
760
+ num_images_per_prompt,
761
+ device,
762
+ dtype,
763
+ do_classifier_free_guidance=False,
764
+ guess_mode=False,
765
+ ):
766
+ if not isinstance(image, torch.Tensor):
767
+ if isinstance(image, PIL.Image.Image):
768
+ image = [image]
769
+
770
+ if isinstance(image[0], PIL.Image.Image):
771
+ images = []
772
+
773
+ for image_ in image:
774
+ image_ = image_.convert("RGB")
775
+ image_ = image_.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
776
+ image_ = np.array(image_)
777
+ image_ = image_[None, :]
778
+ images.append(image_)
779
+
780
+ image = images
781
+
782
+ image = np.concatenate(image, axis=0)
783
+ image = np.array(image).astype(np.float32) / 255.0
784
+ image = image.transpose(0, 3, 1, 2)
785
+ image = torch.from_numpy(image)
786
+ elif isinstance(image[0], torch.Tensor):
787
+ image = torch.cat(image, dim=0)
788
+
789
+ image_batch_size = image.shape[0]
790
+
791
+ if image_batch_size == 1:
792
+ repeat_by = batch_size
793
+ else:
794
+ # image batch size is the same as prompt batch size
795
+ repeat_by = num_images_per_prompt
796
+
797
+ image = image.repeat_interleave(repeat_by, dim=0)
798
+
799
+ image = image.to(device=device, dtype=dtype)
800
+
801
+ if do_classifier_free_guidance and not guess_mode:
802
+ image = torch.cat([image] * 2)
803
+
804
+ return image
805
+
806
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
807
+ def get_timesteps(self, num_inference_steps, strength, device):
808
+ # get the original timestep using init_timestep
809
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
810
+
811
+ t_start = max(num_inference_steps - init_timestep, 0)
812
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
813
+
814
+ return timesteps, num_inference_steps - t_start
815
+
816
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.prepare_latents
817
+ def prepare_latents(
818
+ self,
819
+ batch_size,
820
+ num_channels_latents,
821
+ height,
822
+ width,
823
+ dtype,
824
+ device,
825
+ generator,
826
+ latents=None,
827
+ image=None,
828
+ timestep=None,
829
+ is_strength_max=True,
830
+ return_noise=False,
831
+ return_image_latents=False,
832
+ ):
833
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
834
+ if isinstance(generator, list) and len(generator) != batch_size:
835
+ raise ValueError(
836
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
837
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
838
+ )
839
+
840
+ if (image is None or timestep is None) and not is_strength_max:
841
+ raise ValueError(
842
+ "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise."
843
+ "However, either the image or the noise timestep has not been provided."
844
+ )
845
+
846
+ if return_image_latents or (latents is None and not is_strength_max):
847
+ image = image.to(device=device, dtype=dtype)
848
+ image_latents = self._encode_vae_image(image=image, generator=generator)
849
+
850
+ if latents is None:
851
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
852
+ # if strength is 1. then initialise the latents to noise, else initial to image + noise
853
+ latents = noise if is_strength_max else self.scheduler.add_noise(image_latents, noise, timestep)
854
+ # if pure noise then scale the initial latents by the Scheduler's init sigma
855
+ latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
856
+ else:
857
+ noise = latents.to(device)
858
+ latents = noise * self.scheduler.init_noise_sigma
859
+
860
+ outputs = (latents,)
861
+
862
+ if return_noise:
863
+ outputs += (noise,)
864
+
865
+ if return_image_latents:
866
+ outputs += (image_latents,)
867
+
868
+ return outputs
869
+
870
+ def _default_height_width(self, height, width, image):
871
+ # NOTE: It is possible that a list of images have different
872
+ # dimensions for each image, so just checking the first image
873
+ # is not _exactly_ correct, but it is simple.
874
+ while isinstance(image, list):
875
+ image = image[0]
876
+
877
+ if height is None:
878
+ if isinstance(image, PIL.Image.Image):
879
+ height = image.height
880
+ elif isinstance(image, torch.Tensor):
881
+ height = image.shape[2]
882
+
883
+ height = (height // 8) * 8 # round down to nearest multiple of 8
884
+
885
+ if width is None:
886
+ if isinstance(image, PIL.Image.Image):
887
+ width = image.width
888
+ elif isinstance(image, torch.Tensor):
889
+ width = image.shape[3]
890
+
891
+ width = (width // 8) * 8 # round down to nearest multiple of 8
892
+
893
+ return height, width
894
+
895
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents
896
+ def prepare_mask_latents(
897
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
898
+ ):
899
+ # resize the mask to latents shape as we concatenate the mask to the latents
900
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
901
+ # and half precision
902
+ mask = torch.nn.functional.interpolate(
903
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
904
+ )
905
+ mask = mask.to(device=device, dtype=dtype)
906
+
907
+ masked_image = masked_image.to(device=device, dtype=dtype)
908
+ masked_image_latents = self._encode_vae_image(masked_image, generator=generator)
909
+
910
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
911
+ if mask.shape[0] < batch_size:
912
+ if not batch_size % mask.shape[0] == 0:
913
+ raise ValueError(
914
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
915
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
916
+ " of masks that you pass is divisible by the total requested batch size."
917
+ )
918
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
919
+ if masked_image_latents.shape[0] < batch_size:
920
+ if not batch_size % masked_image_latents.shape[0] == 0:
921
+ raise ValueError(
922
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
923
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
924
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
925
+ )
926
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
927
+
928
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
929
+ masked_image_latents = (
930
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
931
+ )
932
+
933
+ # aligning device to prevent device errors when concating it with the latent model input
934
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
935
+ return mask, masked_image_latents
936
+
937
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
938
+ if isinstance(generator, list):
939
+ image_latents = [
940
+ self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
941
+ for i in range(image.shape[0])
942
+ ]
943
+ image_latents = torch.cat(image_latents, dim=0)
944
+ else:
945
+ image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
946
+
947
+ image_latents = self.vae.config.scaling_factor * image_latents
948
+
949
+ return image_latents
950
+
951
+ # override DiffusionPipeline
952
+ def save_pretrained(
953
+ self,
954
+ save_directory: Union[str, os.PathLike],
955
+ safe_serialization: bool = False,
956
+ variant: Optional[str] = None,
957
+ ):
958
+ if isinstance(self.controlnet, ControlNetModel):
959
+ super().save_pretrained(save_directory, safe_serialization, variant)
960
+ else:
961
+ raise NotImplementedError("Currently, the `save_pretrained()` is not implemented for Multi-ControlNet.")
962
+
963
+ @torch.no_grad()
964
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
965
+ def __call__(
966
+ self,
967
+ prompt: Union[str, List[str]] = None,
968
+ image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
969
+ mask_image: Union[torch.Tensor, PIL.Image.Image] = None,
970
+ control_image: Union[
971
+ torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]
972
+ ] = None,
973
+ height: Optional[int] = None,
974
+ width: Optional[int] = None,
975
+ strength: float = 0.8,
976
+ num_inference_steps: int = 50,
977
+ guidance_scale: float = 7.5,
978
+ negative_prompt: Optional[Union[str, List[str]]] = None,
979
+ num_images_per_prompt: Optional[int] = 1,
980
+ eta: float = 0.0,
981
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
982
+ latents: Optional[torch.FloatTensor] = None,
983
+ prompt_embeds: Optional[torch.FloatTensor] = None,
984
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
985
+ output_type: Optional[str] = "pil",
986
+ return_dict: bool = True,
987
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
988
+ callback_steps: int = 1,
989
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
990
+ controlnet_conditioning_scale: Union[float, List[float]] = 0.8,
991
+ guess_mode: bool = False,
992
+ ):
993
+ r"""
994
+ Function invoked when calling the pipeline for generation.
995
+
996
+ Args:
997
+ prompt (`str` or `List[str]`, *optional*):
998
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
999
+ instead.
1000
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
1001
+ `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`):
1002
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
1003
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
1004
+ also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
1005
+ height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
1006
+ specified in init, images must be passed as a list such that each element of the list can be correctly
1007
+ batched for input to a single controlnet.
1008
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1009
+ The height in pixels of the generated image.
1010
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
1011
+ The width in pixels of the generated image.
1012
+ num_inference_steps (`int`, *optional*, defaults to 50):
1013
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
1014
+ expense of slower inference.
1015
+ guidance_scale (`float`, *optional*, defaults to 7.5):
1016
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
1017
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
1018
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1019
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
1020
+ usually at the expense of lower image quality.
1021
+ negative_prompt (`str` or `List[str]`, *optional*):
1022
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1023
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1024
+ less than `1`).
1025
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
1026
+ The number of images to generate per prompt.
1027
+ eta (`float`, *optional*, defaults to 0.0):
1028
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
1029
+ [`schedulers.DDIMScheduler`], will be ignored for others.
1030
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
1031
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
1032
+ to make generation deterministic.
1033
+ latents (`torch.FloatTensor`, *optional*):
1034
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
1035
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
1036
+ tensor will ge generated by sampling using the supplied random `generator`.
1037
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1038
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1039
+ provided, text embeddings will be generated from `prompt` input argument.
1040
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1041
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1042
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1043
+ argument.
1044
+ output_type (`str`, *optional*, defaults to `"pil"`):
1045
+ The output format of the generate image. Choose between
1046
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
1047
+ return_dict (`bool`, *optional*, defaults to `True`):
1048
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
1049
+ plain tuple.
1050
+ callback (`Callable`, *optional*):
1051
+ A function that will be called every `callback_steps` steps during inference. The function will be
1052
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1053
+ callback_steps (`int`, *optional*, defaults to 1):
1054
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
1055
+ called at every step.
1056
+ cross_attention_kwargs (`dict`, *optional*):
1057
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
1058
+ `self.processor` in
1059
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
1060
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
1061
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
1062
+ to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
1063
+ corresponding scale as a list. Note that by default, we use a smaller conditioning scale for inpainting
1064
+ than for [`~StableDiffusionControlNetPipeline.__call__`].
1065
+ guess_mode (`bool`, *optional*, defaults to `False`):
1066
+ In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
1067
+ you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
1068
+
1069
+ Examples:
1070
+
1071
+ Returns:
1072
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
1073
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
1074
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
1075
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
1076
+ (nsfw) content, according to the `safety_checker`.
1077
+ """
1078
+ # 0. Default height and width to unet
1079
+ height, width = self._default_height_width(height, width, image)
1080
+
1081
+ # 1. Check inputs. Raise error if not correct
1082
+ self.check_inputs(
1083
+ prompt,
1084
+ control_image,
1085
+ height,
1086
+ width,
1087
+ callback_steps,
1088
+ negative_prompt,
1089
+ prompt_embeds,
1090
+ negative_prompt_embeds,
1091
+ controlnet_conditioning_scale,
1092
+ )
1093
+
1094
+ # 2. Define call parameters
1095
+ if prompt is not None and isinstance(prompt, str):
1096
+ batch_size = 1
1097
+ elif prompt is not None and isinstance(prompt, list):
1098
+ batch_size = len(prompt)
1099
+ else:
1100
+ batch_size = prompt_embeds.shape[0]
1101
+
1102
+ device = self._execution_device
1103
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1104
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1105
+ # corresponds to doing no classifier free guidance.
1106
+ do_classifier_free_guidance = guidance_scale > 1.0
1107
+
1108
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1109
+
1110
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1111
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1112
+
1113
+ global_pool_conditions = (
1114
+ controlnet.config.global_pool_conditions
1115
+ if isinstance(controlnet, ControlNetModel)
1116
+ else controlnet.nets[0].config.global_pool_conditions
1117
+ )
1118
+ guess_mode = guess_mode or global_pool_conditions
1119
+
1120
+ # 3. Encode input prompt
1121
+ prompt_embeds = self._encode_prompt(
1122
+ prompt,
1123
+ device,
1124
+ num_images_per_prompt,
1125
+ do_classifier_free_guidance,
1126
+ negative_prompt,
1127
+ prompt_embeds=prompt_embeds,
1128
+ negative_prompt_embeds=negative_prompt_embeds,
1129
+ )
1130
+ # 4. Prepare image, and controlnet_conditioning_image
1131
+ #image = prepare_image(image)
1132
+
1133
+ # 5. Prepare image
1134
+ if isinstance(controlnet, ControlNetModel):
1135
+ control_image = self.prepare_control_image(
1136
+ image=control_image,
1137
+ width=width,
1138
+ height=height,
1139
+ batch_size=batch_size * num_images_per_prompt,
1140
+ num_images_per_prompt=num_images_per_prompt,
1141
+ device=device,
1142
+ dtype=controlnet.dtype,
1143
+ do_classifier_free_guidance=do_classifier_free_guidance,
1144
+ guess_mode=guess_mode,
1145
+ )
1146
+ elif isinstance(controlnet, MultiControlNetModel):
1147
+ control_images = []
1148
+
1149
+ for control_image_ in control_image:
1150
+ control_image_ = self.prepare_control_image(
1151
+ image=control_image_,
1152
+ width=width,
1153
+ height=height,
1154
+ batch_size=batch_size * num_images_per_prompt,
1155
+ num_images_per_prompt=num_images_per_prompt,
1156
+ device=device,
1157
+ dtype=controlnet.dtype,
1158
+ do_classifier_free_guidance=do_classifier_free_guidance,
1159
+ guess_mode=guess_mode,
1160
+ )
1161
+
1162
+ control_images.append(control_image_)
1163
+
1164
+ control_image = control_images
1165
+ else:
1166
+ assert False
1167
+
1168
+ # 4. Preprocess mask and image - resizes image and mask w.r.t height and width
1169
+ mask, masked_image, init_image = prepare_mask_and_masked_image(
1170
+ image, mask_image, height, width, return_image=True
1171
+ )
1172
+ # 5. Prepare timesteps
1173
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
1174
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
1175
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1176
+
1177
+ is_strength_max = strength == 1.0
1178
+
1179
+ # 6. Prepare latent variables
1180
+ num_channels_latents = self.vae.config.latent_channels
1181
+ num_channels_unet = self.unet.config.in_channels
1182
+ return_image_latents = num_channels_unet == 4
1183
+ latents_outputs = self.prepare_latents(
1184
+ batch_size * num_images_per_prompt,
1185
+ num_channels_latents,
1186
+ height,
1187
+ width,
1188
+ prompt_embeds.dtype,
1189
+ device,
1190
+ generator,
1191
+ latents,
1192
+ image=init_image,
1193
+ timestep=latent_timestep,
1194
+ is_strength_max=is_strength_max,
1195
+ return_noise=True,
1196
+ return_image_latents=return_image_latents,
1197
+ )
1198
+
1199
+ if return_image_latents:
1200
+ latents, noise, image_latents = latents_outputs
1201
+ else:
1202
+ latents, noise = latents_outputs
1203
+
1204
+ # 7. Prepare mask latent variables
1205
+ mask, masked_image_latents = self.prepare_mask_latents(
1206
+ mask,
1207
+ masked_image,
1208
+ batch_size * num_images_per_prompt,
1209
+ height,
1210
+ width,
1211
+ prompt_embeds.dtype,
1212
+ device,
1213
+ generator,
1214
+ do_classifier_free_guidance,
1215
+ )
1216
+
1217
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
1218
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
1219
+
1220
+ # 8. Denoising loop
1221
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1222
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1223
+ for i, t in enumerate(timesteps):
1224
+ # expand the latents if we are doing classifier free guidance
1225
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1226
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1227
+
1228
+ # controlnet(s) inference
1229
+ if guess_mode and do_classifier_free_guidance:
1230
+ # Infer ControlNet only for the conditional batch.
1231
+ control_model_input = latents
1232
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1233
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1234
+ else:
1235
+ control_model_input = latent_model_input
1236
+ controlnet_prompt_embeds = prompt_embeds
1237
+
1238
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1239
+ control_model_input,
1240
+ t,
1241
+ encoder_hidden_states=controlnet_prompt_embeds,
1242
+ controlnet_cond=control_image,
1243
+ conditioning_scale=controlnet_conditioning_scale,
1244
+ guess_mode=guess_mode,
1245
+ return_dict=False,
1246
+ )
1247
+
1248
+ if guess_mode and do_classifier_free_guidance:
1249
+ # Infered ControlNet only for the conditional batch.
1250
+ # To apply the output of ControlNet to both the unconditional and conditional batches,
1251
+ # add 0 to the unconditional batch to keep it unchanged.
1252
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
1253
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
1254
+
1255
+
1256
+ # predict the noise residual
1257
+ if num_channels_unet == 9:
1258
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1259
+
1260
+ # predict the noise residual
1261
+ noise_pred = self.unet(
1262
+ latent_model_input,
1263
+ t,
1264
+ encoder_hidden_states=prompt_embeds,
1265
+ cross_attention_kwargs=cross_attention_kwargs,
1266
+ down_block_additional_residuals=down_block_res_samples,
1267
+ mid_block_additional_residual=mid_block_res_sample,
1268
+ return_dict=False,
1269
+ )[0]
1270
+
1271
+ # perform guidance
1272
+ if do_classifier_free_guidance:
1273
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1274
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1275
+
1276
+ # compute the previous noisy sample x_t -> x_t-1
1277
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1278
+
1279
+
1280
+ if num_channels_unet == 4:
1281
+ init_latents_proper = image_latents[:1]
1282
+ init_mask = mask[:1]
1283
+
1284
+ if i < len(timesteps) - 1:
1285
+ init_latents_proper = self.scheduler.add_noise(init_latents_proper, noise, torch.tensor([t]))
1286
+
1287
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1288
+
1289
+ # call the callback, if provided
1290
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1291
+ progress_bar.update()
1292
+ if callback is not None and i % callback_steps == 0:
1293
+ callback(i, t, latents)
1294
+
1295
+ # If we do sequential model offloading, let's offload unet and controlnet
1296
+ # manually for max memory savings
1297
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1298
+ self.unet.to("cpu")
1299
+ self.controlnet.to("cpu")
1300
+ torch.cuda.empty_cache()
1301
+
1302
+ if not output_type == "latent":
1303
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1304
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
1305
+ else:
1306
+ image = latents
1307
+ has_nsfw_concept = None
1308
+
1309
+ if has_nsfw_concept is None:
1310
+ do_denormalize = [True] * image.shape[0]
1311
+ else:
1312
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
1313
+
1314
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1315
+
1316
+ # Offload last model to CPU
1317
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
1318
+ self.final_offload_hook.offload()
1319
+
1320
+ if not return_dict:
1321
+ return (image, has_nsfw_concept)
1322
+
1323
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)