kadirnar commited on
Commit
8bf6bdc
1 Parent(s): bc13541

add controlnet inpaint model

Browse files
app.py CHANGED
@@ -11,6 +11,7 @@ from diffusion_webui.helpers import (
11
  stable_diffusion_controlnet_seg_app,
12
  stable_diffusion_img2img_app,
13
  stable_diffusion_inpaint_app,
 
14
  stable_diffusion_text2img_app,
15
  )
16
 
@@ -56,7 +57,11 @@ with app:
56
  with gr.Tab("Scribble"):
57
  stable_diffusion_controlnet_scribble_app()
58
 
 
 
 
 
59
  with gr.Tab("Keras Diffusion"):
60
  keras_diffusion_app = keras_stable_diffusion_app()
61
 
62
- app.launch(debug=True)
 
11
  stable_diffusion_controlnet_seg_app,
12
  stable_diffusion_img2img_app,
13
  stable_diffusion_inpaint_app,
14
+ stable_diffusion_inpiant_controlnet_canny_app,
15
  stable_diffusion_text2img_app,
16
  )
17
 
 
57
  with gr.Tab("Scribble"):
58
  stable_diffusion_controlnet_scribble_app()
59
 
60
+ with gr.Tab("ControlNet Inpaint"):
61
+ with gr.Tab("Inpaint Canny"):
62
+ stable_diffusion_inpiant_controlnet_canny_app()
63
+
64
  with gr.Tab("Keras Diffusion"):
65
  keras_diffusion_app = keras_stable_diffusion_app()
66
 
67
+ app.launch(debug=True, enable_queue=True)
diffusion_webui/controlnet_inpaint/__init__.py ADDED
File without changes
diffusion_webui/controlnet_inpaint/controlnet_inpaint_app.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ from diffusers import UniPCMultistepScheduler
5
+ from PIL import Image
6
+
7
+ from diffusion_webui.controlnet.controlnet_canny import controlnet_canny
8
+ from diffusion_webui.controlnet_inpaint.pipeline_stable_diffusion_controlnet_inpaint import (
9
+ StableDiffusionControlNetInpaintPipeline,
10
+ )
11
+
12
+ stable_inpaint_model_list = [
13
+ "stabilityai/stable-diffusion-2-inpainting",
14
+ "runwayml/stable-diffusion-inpainting",
15
+ ]
16
+
17
+ controlnet_model_list = [
18
+ "lllyasviel/sd-controlnet-canny",
19
+ ]
20
+
21
+ prompt_list = [
22
+ "a red panda sitting on a bench",
23
+ ]
24
+
25
+ negative_prompt_list = [
26
+ "bad, ugly",
27
+ ]
28
+
29
+
30
+ def load_img(image_path: str):
31
+ image = Image.open(image_path)
32
+ image = np.array(image)
33
+ image = Image.fromarray(image)
34
+
35
+ return image
36
+
37
+
38
+ def stable_diffusion_inpiant_controlnet_canny(
39
+ normal_image_path: str,
40
+ stable_model_path: str,
41
+ controlnet_model_path: str,
42
+ prompt: str,
43
+ negative_prompt: str,
44
+ controlnet_conditioning_scale: str,
45
+ guidance_scale: int,
46
+ num_inference_steps: int,
47
+ ):
48
+ pil_image = Image.open(normal_image_path)
49
+ normal_image = pil_image["image"].convert("RGB").resize((512, 512))
50
+ mask_image = pil_image["mask"].convert("RGB").resize((512, 512))
51
+
52
+ # normal_image = load_img(normal_image_path)
53
+ # mask_image = load_img(mask_image_path)
54
+
55
+ controlnet, control_image = controlnet_canny(
56
+ image_path=normal_image_path,
57
+ controlnet_model_path=controlnet_model_path,
58
+ )
59
+
60
+ pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
61
+ pretrained_model_name_or_path=stable_model_path,
62
+ controlnet=controlnet,
63
+ torch_dtype=torch.float16,
64
+ )
65
+ pipe.to("cuda")
66
+ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
67
+ pipe.enable_xformers_memory_efficient_attention()
68
+
69
+ generator = torch.manual_seed(0)
70
+
71
+ output = pipe(
72
+ prompt=prompt,
73
+ negative_prompt=negative_prompt,
74
+ num_inference_steps=num_inference_steps,
75
+ generator=generator,
76
+ image=normal_image,
77
+ control_image=control_image,
78
+ controlnet_conditioning_scale=controlnet_conditioning_scale,
79
+ guidance_scale=guidance_scale,
80
+ mask_image=mask_image,
81
+ ).images
82
+
83
+ return output[0]
84
+
85
+
86
+ def stable_diffusion_inpiant_controlnet_canny_app():
87
+ with gr.Blocks():
88
+ with gr.Row():
89
+ with gr.Column():
90
+ inpaint_image_file = gr.Image(
91
+ source="upload",
92
+ tool="sketch",
93
+ elem_id="image_upload",
94
+ type="filepath",
95
+ label="Upload",
96
+ )
97
+
98
+ inpaint_model_id = gr.Dropdown(
99
+ choices=stable_inpaint_model_list,
100
+ value=stable_inpaint_model_list[0],
101
+ label="Inpaint Model Id",
102
+ )
103
+
104
+ inpaint_controlnet_model_id = gr.Dropdown(
105
+ choices=controlnet_model_list,
106
+ value=controlnet_model_list[0],
107
+ label="ControlNet Model Id",
108
+ )
109
+
110
+ inpaint_prompt = gr.Textbox(
111
+ lines=1, value=prompt_list[0], label="Prompt"
112
+ )
113
+
114
+ inpaint_negative_prompt = gr.Textbox(
115
+ lines=1,
116
+ value=negative_prompt_list[0],
117
+ label="Negative Prompt",
118
+ )
119
+
120
+ with gr.Accordion("Advanced Options", open=False):
121
+ controlnet_conditioning_scale = gr.Slider(
122
+ minimum=0.1,
123
+ maximum=1,
124
+ step=0.1,
125
+ value=0.5,
126
+ label="ControlNet Conditioning Scale",
127
+ )
128
+
129
+ inpaint_guidance_scale = gr.Slider(
130
+ minimum=0.1,
131
+ maximum=15,
132
+ step=0.1,
133
+ value=7.5,
134
+ label="Guidance Scale",
135
+ )
136
+
137
+ inpaint_num_inference_step = gr.Slider(
138
+ minimum=1,
139
+ maximum=100,
140
+ step=1,
141
+ value=50,
142
+ label="Num Inference Step",
143
+ )
144
+
145
+ inpaint_predict = gr.Button(value="Generator")
146
+
147
+ with gr.Column():
148
+ output_image = gr.Image(label="Outputs")
149
+
150
+ inpaint_predict.click(
151
+ fn=stable_diffusion_inpiant_controlnet_canny,
152
+ inputs=[
153
+ inpaint_image_file,
154
+ inpaint_model_id,
155
+ inpaint_controlnet_model_id,
156
+ inpaint_prompt,
157
+ inpaint_negative_prompt,
158
+ controlnet_conditioning_scale,
159
+ inpaint_guidance_scale,
160
+ inpaint_num_inference_step,
161
+ ],
162
+ outputs=output_image,
163
+ )
diffusion_webui/controlnet_inpaint/pipeline_stable_diffusion_controlnet_inpaint.py ADDED
@@ -0,0 +1,607 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import numpy as np
16
+ import PIL.Image
17
+ import torch
18
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import *
19
+
20
+ EXAMPLE_DOC_STRING = """
21
+ Examples:
22
+ ```py
23
+ >>> # !pip install opencv-python transformers accelerate
24
+ >>> from diffusers import StableDiffusionControlNetInpaintPipeline, ControlNetModel, UniPCMultistepScheduler
25
+ >>> from diffusers.utils import load_image
26
+ >>> import numpy as np
27
+ >>> import torch
28
+
29
+ >>> import cv2
30
+ >>> from PIL import Image
31
+ >>> # download an image
32
+ >>> image = load_image(
33
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
34
+ ... )
35
+ >>> image = np.array(image)
36
+ >>> mask_image = load_image(
37
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
38
+ ... )
39
+ >>> mask_image = np.array(mask_image)
40
+ >>> # get canny image
41
+ >>> canny_image = cv2.Canny(image, 100, 200)
42
+ >>> canny_image = canny_image[:, :, None]
43
+ >>> canny_image = np.concatenate([canny_image, canny_image, canny_image], axis=2)
44
+ >>> canny_image = Image.fromarray(canny_image)
45
+
46
+ >>> # load control net and stable diffusion v1-5
47
+ >>> controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
48
+ >>> pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
49
+ ... "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
50
+ ... )
51
+
52
+ >>> # speed up diffusion process with faster scheduler and memory optimization
53
+ >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
54
+ >>> # remove following line if xformers is not installed
55
+ >>> pipe.enable_xformers_memory_efficient_attention()
56
+
57
+ >>> pipe.enable_model_cpu_offload()
58
+
59
+ >>> # generate image
60
+ >>> generator = torch.manual_seed(0)
61
+ >>> image = pipe(
62
+ ... "futuristic-looking doggo",
63
+ ... num_inference_steps=20,
64
+ ... generator=generator,
65
+ ... image=image,
66
+ ... control_image=canny_image,
67
+ ... mask_image=mask_image
68
+ ... ).images[0]
69
+ ```
70
+ """
71
+
72
+
73
+ def prepare_mask_and_masked_image(image, mask):
74
+ """
75
+ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
76
+ converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
77
+ ``image`` and ``1`` for the ``mask``.
78
+ The ``image`` will be converted to ``torch.float32`` and normalized to be in ``[-1, 1]``. The ``mask`` will be
79
+ binarized (``mask > 0.5``) and cast to ``torch.float32`` too.
80
+ Args:
81
+ image (Union[np.array, PIL.Image, torch.Tensor]): The image to inpaint.
82
+ It can be a ``PIL.Image``, or a ``height x width x 3`` ``np.array`` or a ``channels x height x width``
83
+ ``torch.Tensor`` or a ``batch x channels x height x width`` ``torch.Tensor``.
84
+ mask (_type_): The mask to apply to the image, i.e. regions to inpaint.
85
+ It can be a ``PIL.Image``, or a ``height x width`` ``np.array`` or a ``1 x height x width``
86
+ ``torch.Tensor`` or a ``batch x 1 x height x width`` ``torch.Tensor``.
87
+ Raises:
88
+ ValueError: ``torch.Tensor`` images should be in the ``[-1, 1]`` range. ValueError: ``torch.Tensor`` mask
89
+ should be in the ``[0, 1]`` range. ValueError: ``mask`` and ``image`` should have the same spatial dimensions.
90
+ TypeError: ``mask`` is a ``torch.Tensor`` but ``image`` is not
91
+ (ot the other way around).
92
+ Returns:
93
+ tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
94
+ dimensions: ``batch x channels x height x width``.
95
+ """
96
+ if isinstance(image, torch.Tensor):
97
+ if not isinstance(mask, torch.Tensor):
98
+ raise TypeError(
99
+ f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not"
100
+ )
101
+
102
+ # Batch single image
103
+ if image.ndim == 3:
104
+ assert (
105
+ image.shape[0] == 3
106
+ ), "Image outside a batch should be of shape (3, H, W)"
107
+ image = image.unsqueeze(0)
108
+
109
+ # Batch and add channel dim for single mask
110
+ if mask.ndim == 2:
111
+ mask = mask.unsqueeze(0).unsqueeze(0)
112
+
113
+ # Batch single mask or add channel dim
114
+ if mask.ndim == 3:
115
+ # Single batched mask, no channel dim or single mask not batched but channel dim
116
+ if mask.shape[0] == 1:
117
+ mask = mask.unsqueeze(0)
118
+
119
+ # Batched masks no channel dim
120
+ else:
121
+ mask = mask.unsqueeze(1)
122
+
123
+ assert (
124
+ image.ndim == 4 and mask.ndim == 4
125
+ ), "Image and Mask must have 4 dimensions"
126
+ assert (
127
+ image.shape[-2:] == mask.shape[-2:]
128
+ ), "Image and Mask must have the same spatial dimensions"
129
+ assert (
130
+ image.shape[0] == mask.shape[0]
131
+ ), "Image and Mask must have the same batch size"
132
+
133
+ # Check image is in [-1, 1]
134
+ if image.min() < -1 or image.max() > 1:
135
+ raise ValueError("Image should be in [-1, 1] range")
136
+
137
+ # Check mask is in [0, 1]
138
+ if mask.min() < 0 or mask.max() > 1:
139
+ raise ValueError("Mask should be in [0, 1] range")
140
+
141
+ # Binarize mask
142
+ mask[mask < 0.5] = 0
143
+ mask[mask >= 0.5] = 1
144
+
145
+ # Image as float32
146
+ image = image.to(dtype=torch.float32)
147
+ elif isinstance(mask, torch.Tensor):
148
+ raise TypeError(
149
+ f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not"
150
+ )
151
+ else:
152
+ # preprocess image
153
+ if isinstance(image, (PIL.Image.Image, np.ndarray)):
154
+ image = [image]
155
+
156
+ if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
157
+ image = [np.array(i.convert("RGB"))[None, :] for i in image]
158
+ image = np.concatenate(image, axis=0)
159
+ elif isinstance(image, list) and isinstance(image[0], np.ndarray):
160
+ image = np.concatenate([i[None, :] for i in image], axis=0)
161
+
162
+ image = image.transpose(0, 3, 1, 2)
163
+ image = torch.from_numpy(image).to(dtype=torch.float32) / 127.5 - 1.0
164
+
165
+ # preprocess mask
166
+ if isinstance(mask, (PIL.Image.Image, np.ndarray)):
167
+ mask = [mask]
168
+
169
+ if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
170
+ mask = np.concatenate(
171
+ [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
172
+ )
173
+ mask = mask.astype(np.float32) / 255.0
174
+ elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
175
+ mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
176
+
177
+ mask[mask < 0.5] = 0
178
+ mask[mask >= 0.5] = 1
179
+ mask = torch.from_numpy(mask)
180
+
181
+ masked_image = image * (mask < 0.5)
182
+
183
+ return mask, masked_image
184
+
185
+
186
+ class StableDiffusionControlNetInpaintPipeline(
187
+ StableDiffusionControlNetPipeline
188
+ ):
189
+ r"""
190
+ Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
191
+
192
+ This model inherits from [`StableDiffusionControlNetPipeline`]. Check the superclass documentation for the generic methods the
193
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
194
+
195
+ Args:
196
+ vae ([`AutoencoderKL`]):
197
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
198
+ text_encoder ([`CLIPTextModel`]):
199
+ Frozen text-encoder. Stable Diffusion uses the text portion of
200
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
201
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
202
+ tokenizer (`CLIPTokenizer`):
203
+ Tokenizer of class
204
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
205
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
206
+ controlnet ([`ControlNetModel`]):
207
+ Provides additional conditioning to the unet during the denoising process
208
+ scheduler ([`SchedulerMixin`]):
209
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
210
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
211
+ safety_checker ([`StableDiffusionSafetyChecker`]):
212
+ Classification module that estimates whether generated images could be considered offensive or harmful.
213
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
214
+ feature_extractor ([`CLIPFeatureExtractor`]):
215
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
216
+ """
217
+
218
+ def prepare_mask_latents(
219
+ self,
220
+ mask,
221
+ masked_image,
222
+ batch_size,
223
+ height,
224
+ width,
225
+ dtype,
226
+ device,
227
+ generator,
228
+ do_classifier_free_guidance,
229
+ ):
230
+ # resize the mask to latents shape as we concatenate the mask to the latents
231
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
232
+ # and half precision
233
+ mask = torch.nn.functional.interpolate(
234
+ mask,
235
+ size=(
236
+ height // self.vae_scale_factor,
237
+ width // self.vae_scale_factor,
238
+ ),
239
+ )
240
+ mask = mask.to(device=device, dtype=dtype)
241
+
242
+ masked_image = masked_image.to(device=device, dtype=dtype)
243
+
244
+ # encode the mask image into latents space so we can concatenate it to the latents
245
+ if isinstance(generator, list):
246
+ masked_image_latents = [
247
+ self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(
248
+ generator=generator[i]
249
+ )
250
+ for i in range(batch_size)
251
+ ]
252
+ masked_image_latents = torch.cat(masked_image_latents, dim=0)
253
+ else:
254
+ masked_image_latents = self.vae.encode(
255
+ masked_image
256
+ ).latent_dist.sample(generator=generator)
257
+ masked_image_latents = (
258
+ self.vae.config.scaling_factor * masked_image_latents
259
+ )
260
+
261
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
262
+ if mask.shape[0] < batch_size:
263
+ if not batch_size % mask.shape[0] == 0:
264
+ raise ValueError(
265
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
266
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
267
+ " of masks that you pass is divisible by the total requested batch size."
268
+ )
269
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
270
+ if masked_image_latents.shape[0] < batch_size:
271
+ if not batch_size % masked_image_latents.shape[0] == 0:
272
+ raise ValueError(
273
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
274
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
275
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
276
+ )
277
+ masked_image_latents = masked_image_latents.repeat(
278
+ batch_size // masked_image_latents.shape[0], 1, 1, 1
279
+ )
280
+
281
+ mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
282
+ masked_image_latents = (
283
+ torch.cat([masked_image_latents] * 2)
284
+ if do_classifier_free_guidance
285
+ else masked_image_latents
286
+ )
287
+
288
+ # aligning device to prevent device errors when concating it with the latent model input
289
+ masked_image_latents = masked_image_latents.to(
290
+ device=device, dtype=dtype
291
+ )
292
+ return mask, masked_image_latents
293
+
294
+ @torch.no_grad()
295
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
296
+ def __call__(
297
+ self,
298
+ prompt: Union[str, List[str]] = None,
299
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
300
+ control_image: Union[
301
+ torch.FloatTensor,
302
+ PIL.Image.Image,
303
+ List[torch.FloatTensor],
304
+ List[PIL.Image.Image],
305
+ ] = None,
306
+ mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
307
+ height: Optional[int] = None,
308
+ width: Optional[int] = None,
309
+ num_inference_steps: int = 50,
310
+ guidance_scale: float = 7.5,
311
+ negative_prompt: Optional[Union[str, List[str]]] = None,
312
+ num_images_per_prompt: Optional[int] = 1,
313
+ eta: float = 0.0,
314
+ generator: Optional[
315
+ Union[torch.Generator, List[torch.Generator]]
316
+ ] = None,
317
+ latents: Optional[torch.FloatTensor] = None,
318
+ prompt_embeds: Optional[torch.FloatTensor] = None,
319
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
320
+ output_type: Optional[str] = "pil",
321
+ return_dict: bool = True,
322
+ callback: Optional[
323
+ Callable[[int, int, torch.FloatTensor], None]
324
+ ] = None,
325
+ callback_steps: int = 1,
326
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
327
+ controlnet_conditioning_scale: float = 1.0,
328
+ ):
329
+ r"""
330
+ Function invoked when calling the pipeline for generation.
331
+ Args:
332
+ prompt (`str` or `List[str]`, *optional*):
333
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
334
+ instead.
335
+ image (`PIL.Image.Image`):
336
+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
337
+ be masked out with `mask_image` and repainted according to `prompt`.
338
+ control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]`):
339
+ The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
340
+ the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. PIL.Image.Image` can
341
+ also be accepted as an image. The control image is automatically resized to fit the output image.
342
+ mask_image (`PIL.Image.Image`):
343
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
344
+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
345
+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
346
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
347
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
348
+ The height in pixels of the generated image.
349
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
350
+ The width in pixels of the generated image.
351
+ num_inference_steps (`int`, *optional*, defaults to 50):
352
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
353
+ expense of slower inference.
354
+ guidance_scale (`float`, *optional*, defaults to 7.5):
355
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
356
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
357
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
358
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
359
+ usually at the expense of lower image quality.
360
+ negative_prompt (`str` or `List[str]`, *optional*):
361
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
362
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
363
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
364
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
365
+ The number of images to generate per prompt.
366
+ eta (`float`, *optional*, defaults to 0.0):
367
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
368
+ [`schedulers.DDIMScheduler`], will be ignored for others.
369
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
370
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
371
+ to make generation deterministic.
372
+ latents (`torch.FloatTensor`, *optional*):
373
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
374
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
375
+ tensor will ge generated by sampling using the supplied random `generator`.
376
+ prompt_embeds (`torch.FloatTensor`, *optional*):
377
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
378
+ provided, text embeddings will be generated from `prompt` input argument.
379
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
380
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
381
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
382
+ argument.
383
+ output_type (`str`, *optional*, defaults to `"pil"`):
384
+ The output format of the generate image. Choose between
385
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
386
+ return_dict (`bool`, *optional*, defaults to `True`):
387
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
388
+ plain tuple.
389
+ callback (`Callable`, *optional*):
390
+ A function that will be called every `callback_steps` steps during inference. The function will be
391
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
392
+ callback_steps (`int`, *optional*, defaults to 1):
393
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
394
+ called at every step.
395
+ cross_attention_kwargs (`dict`, *optional*):
396
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
397
+ `self.processor` in
398
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
399
+ controlnet_conditioning_scale (`float`, *optional*, defaults to 1.0):
400
+ The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
401
+ to the residual in the original unet.
402
+ Examples:
403
+ Returns:
404
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
405
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
406
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
407
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
408
+ (nsfw) content, according to the `safety_checker`.
409
+ """
410
+ # 0. Default height and width to unet
411
+ height, width = self._default_height_width(height, width, control_image)
412
+
413
+ # 1. Check inputs. Raise error if not correct
414
+ self.check_inputs(
415
+ prompt,
416
+ control_image,
417
+ height,
418
+ width,
419
+ callback_steps,
420
+ negative_prompt,
421
+ prompt_embeds,
422
+ negative_prompt_embeds,
423
+ )
424
+
425
+ # 2. Define call parameters
426
+ if prompt is not None and isinstance(prompt, str):
427
+ batch_size = 1
428
+ elif prompt is not None and isinstance(prompt, list):
429
+ batch_size = len(prompt)
430
+ else:
431
+ batch_size = prompt_embeds.shape[0]
432
+
433
+ device = self._execution_device
434
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
435
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
436
+ # corresponds to doing no classifier free guidance.
437
+ do_classifier_free_guidance = guidance_scale > 1.0
438
+
439
+ # 3. Encode input prompt
440
+ prompt_embeds = self._encode_prompt(
441
+ prompt,
442
+ device,
443
+ num_images_per_prompt,
444
+ do_classifier_free_guidance,
445
+ negative_prompt,
446
+ prompt_embeds=prompt_embeds,
447
+ negative_prompt_embeds=negative_prompt_embeds,
448
+ )
449
+
450
+ # 4. Prepare image
451
+ control_image = self.prepare_image(
452
+ control_image,
453
+ width,
454
+ height,
455
+ batch_size * num_images_per_prompt,
456
+ num_images_per_prompt,
457
+ device,
458
+ self.controlnet.dtype,
459
+ )
460
+
461
+ if do_classifier_free_guidance:
462
+ control_image = torch.cat([control_image] * 2)
463
+
464
+ # 5. Prepare timesteps
465
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
466
+ timesteps = self.scheduler.timesteps
467
+
468
+ # 6. Prepare latent variables
469
+ num_channels_latents = self.controlnet.in_channels
470
+ latents = self.prepare_latents(
471
+ batch_size * num_images_per_prompt,
472
+ num_channels_latents,
473
+ height,
474
+ width,
475
+ prompt_embeds.dtype,
476
+ device,
477
+ generator,
478
+ latents,
479
+ )
480
+
481
+ # EXTRA: prepare mask latents
482
+ mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
483
+ mask, masked_image_latents = self.prepare_mask_latents(
484
+ mask,
485
+ masked_image,
486
+ batch_size * num_images_per_prompt,
487
+ height,
488
+ width,
489
+ prompt_embeds.dtype,
490
+ device,
491
+ generator,
492
+ do_classifier_free_guidance,
493
+ )
494
+
495
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
496
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
497
+
498
+ # 8. Denoising loop
499
+ num_warmup_steps = (
500
+ len(timesteps) - num_inference_steps * self.scheduler.order
501
+ )
502
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
503
+ for i, t in enumerate(timesteps):
504
+ # expand the latents if we are doing classifier free guidance
505
+ latent_model_input = (
506
+ torch.cat([latents] * 2)
507
+ if do_classifier_free_guidance
508
+ else latents
509
+ )
510
+ latent_model_input = self.scheduler.scale_model_input(
511
+ latent_model_input, t
512
+ )
513
+
514
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
515
+ latent_model_input,
516
+ t,
517
+ encoder_hidden_states=prompt_embeds,
518
+ controlnet_cond=control_image,
519
+ return_dict=False,
520
+ )
521
+
522
+ down_block_res_samples = [
523
+ down_block_res_sample * controlnet_conditioning_scale
524
+ for down_block_res_sample in down_block_res_samples
525
+ ]
526
+ mid_block_res_sample *= controlnet_conditioning_scale
527
+
528
+ # predict the noise residual
529
+ latent_model_input = torch.cat(
530
+ [latent_model_input, mask, masked_image_latents], dim=1
531
+ )
532
+ noise_pred = self.unet(
533
+ latent_model_input,
534
+ t,
535
+ encoder_hidden_states=prompt_embeds,
536
+ cross_attention_kwargs=cross_attention_kwargs,
537
+ down_block_additional_residuals=down_block_res_samples,
538
+ mid_block_additional_residual=mid_block_res_sample,
539
+ ).sample
540
+
541
+ # perform guidance
542
+ if do_classifier_free_guidance:
543
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
544
+ noise_pred = noise_pred_uncond + guidance_scale * (
545
+ noise_pred_text - noise_pred_uncond
546
+ )
547
+
548
+ # compute the previous noisy sample x_t -> x_t-1
549
+ latents = self.scheduler.step(
550
+ noise_pred, t, latents, **extra_step_kwargs
551
+ ).prev_sample
552
+
553
+ # call the callback, if provided
554
+ if i == len(timesteps) - 1 or (
555
+ (i + 1) > num_warmup_steps
556
+ and (i + 1) % self.scheduler.order == 0
557
+ ):
558
+ progress_bar.update()
559
+ if callback is not None and i % callback_steps == 0:
560
+ callback(i, t, latents)
561
+
562
+ # If we do sequential model offloading, let's offload unet and controlnet
563
+ # manually for max memory savings
564
+ if (
565
+ hasattr(self, "final_offload_hook")
566
+ and self.final_offload_hook is not None
567
+ ):
568
+ self.unet.to("cpu")
569
+ self.controlnet.to("cpu")
570
+ torch.cuda.empty_cache()
571
+
572
+ if output_type == "latent":
573
+ image = latents
574
+ has_nsfw_concept = None
575
+ elif output_type == "pil":
576
+ # 8. Post-processing
577
+ image = self.decode_latents(latents)
578
+
579
+ # 9. Run safety checker
580
+ image, has_nsfw_concept = self.run_safety_checker(
581
+ image, device, prompt_embeds.dtype
582
+ )
583
+
584
+ # 10. Convert to PIL
585
+ image = self.numpy_to_pil(image)
586
+ else:
587
+ # 8. Post-processing
588
+ image = self.decode_latents(latents)
589
+
590
+ # 9. Run safety checker
591
+ image, has_nsfw_concept = self.run_safety_checker(
592
+ image, device, prompt_embeds.dtype
593
+ )
594
+
595
+ # Offload last model to CPU
596
+ if (
597
+ hasattr(self, "final_offload_hook")
598
+ and self.final_offload_hook is not None
599
+ ):
600
+ self.final_offload_hook.offload()
601
+
602
+ if not return_dict:
603
+ return (image, has_nsfw_concept)
604
+
605
+ return StableDiffusionPipelineOutput(
606
+ images=image, nsfw_content_detected=has_nsfw_concept
607
+ )
diffusion_webui/helpers.py CHANGED
@@ -1,33 +1,48 @@
1
  from diffusion_webui.controlnet.controlnet_canny import (
 
2
  stable_diffusion_controlnet_canny_app,
3
  )
4
  from diffusion_webui.controlnet.controlnet_depth import (
 
5
  stable_diffusion_controlnet_depth_app,
6
  )
7
  from diffusion_webui.controlnet.controlnet_hed import (
 
8
  stable_diffusion_controlnet_hed_app,
9
  )
10
  from diffusion_webui.controlnet.controlnet_mlsd import (
 
11
  stable_diffusion_controlnet_mlsd_app,
12
  )
13
  from diffusion_webui.controlnet.controlnet_pose import (
 
14
  stable_diffusion_controlnet_pose_app,
15
  )
16
  from diffusion_webui.controlnet.controlnet_scribble import (
 
17
  stable_diffusion_controlnet_scribble_app,
18
  )
19
  from diffusion_webui.controlnet.controlnet_seg import (
 
20
  stable_diffusion_controlnet_seg_app,
21
  )
 
 
 
 
22
  from diffusion_webui.stable_diffusion.img2img_app import (
 
23
  stable_diffusion_img2img_app,
24
  )
25
  from diffusion_webui.stable_diffusion.inpaint_app import (
 
26
  stable_diffusion_inpaint_app,
27
  )
28
  from diffusion_webui.stable_diffusion.keras_txt2img import (
 
29
  keras_stable_diffusion_app,
30
  )
31
  from diffusion_webui.stable_diffusion.text2img_app import (
 
32
  stable_diffusion_text2img_app,
33
  )
 
1
  from diffusion_webui.controlnet.controlnet_canny import (
2
+ stable_diffusion_controlnet_canny,
3
  stable_diffusion_controlnet_canny_app,
4
  )
5
  from diffusion_webui.controlnet.controlnet_depth import (
6
+ stable_diffusion_controlnet_depth,
7
  stable_diffusion_controlnet_depth_app,
8
  )
9
  from diffusion_webui.controlnet.controlnet_hed import (
10
+ stable_diffusion_controlnet_hed,
11
  stable_diffusion_controlnet_hed_app,
12
  )
13
  from diffusion_webui.controlnet.controlnet_mlsd import (
14
+ stable_diffusion_controlnet_mlsd,
15
  stable_diffusion_controlnet_mlsd_app,
16
  )
17
  from diffusion_webui.controlnet.controlnet_pose import (
18
+ stable_diffusion_controlnet_pose,
19
  stable_diffusion_controlnet_pose_app,
20
  )
21
  from diffusion_webui.controlnet.controlnet_scribble import (
22
+ stable_diffusion_controlnet_scribble,
23
  stable_diffusion_controlnet_scribble_app,
24
  )
25
  from diffusion_webui.controlnet.controlnet_seg import (
26
+ stable_diffusion_controlnet_seg,
27
  stable_diffusion_controlnet_seg_app,
28
  )
29
+ from diffusion_webui.controlnet_inpaint.controlnet_inpaint_app import (
30
+ stable_diffusion_inpiant_controlnet_canny,
31
+ stable_diffusion_inpiant_controlnet_canny_app,
32
+ )
33
  from diffusion_webui.stable_diffusion.img2img_app import (
34
+ stable_diffusion_img2img,
35
  stable_diffusion_img2img_app,
36
  )
37
  from diffusion_webui.stable_diffusion.inpaint_app import (
38
+ stable_diffusion_inpaint,
39
  stable_diffusion_inpaint_app,
40
  )
41
  from diffusion_webui.stable_diffusion.keras_txt2img import (
42
+ keras_stable_diffusion,
43
  keras_stable_diffusion_app,
44
  )
45
  from diffusion_webui.stable_diffusion.text2img_app import (
46
+ stable_diffusion_text2img,
47
  stable_diffusion_text2img_app,
48
  )