kadirnar commited on
Commit
9f1d723
1 Parent(s): 7517ead

Update diffusion_webui/diffusion_models/controlnet/controlnet_inpaint/pipeline_stable_diffusion_controlnet_inpaint.py

Browse files
diffusion_webui/diffusion_models/controlnet/controlnet_inpaint/pipeline_stable_diffusion_controlnet_inpaint.py CHANGED
@@ -12,13 +12,11 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
15
-
16
- import numpy as np
17
- import PIL.Image
18
  import torch
19
- from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import *
 
20
 
21
- # https://github.com/mikonvergence/ControlNetInpaint
22
 
23
  EXAMPLE_DOC_STRING = """
24
  Examples:
@@ -98,15 +96,11 @@ def prepare_mask_and_masked_image(image, mask):
98
  """
99
  if isinstance(image, torch.Tensor):
100
  if not isinstance(mask, torch.Tensor):
101
- raise TypeError(
102
- f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not"
103
- )
104
 
105
  # Batch single image
106
  if image.ndim == 3:
107
- assert (
108
- image.shape[0] == 3
109
- ), "Image outside a batch should be of shape (3, H, W)"
110
  image = image.unsqueeze(0)
111
 
112
  # Batch and add channel dim for single mask
@@ -123,15 +117,9 @@ def prepare_mask_and_masked_image(image, mask):
123
  else:
124
  mask = mask.unsqueeze(1)
125
 
126
- assert (
127
- image.ndim == 4 and mask.ndim == 4
128
- ), "Image and Mask must have 4 dimensions"
129
- assert (
130
- image.shape[-2:] == mask.shape[-2:]
131
- ), "Image and Mask must have the same spatial dimensions"
132
- assert (
133
- image.shape[0] == mask.shape[0]
134
- ), "Image and Mask must have the same batch size"
135
 
136
  # Check image is in [-1, 1]
137
  if image.min() < -1 or image.max() > 1:
@@ -148,9 +136,7 @@ def prepare_mask_and_masked_image(image, mask):
148
  # Image as float32
149
  image = image.to(dtype=torch.float32)
150
  elif isinstance(mask, torch.Tensor):
151
- raise TypeError(
152
- f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not"
153
- )
154
  else:
155
  # preprocess image
156
  if isinstance(image, (PIL.Image.Image, np.ndarray)):
@@ -170,9 +156,7 @@ def prepare_mask_and_masked_image(image, mask):
170
  mask = [mask]
171
 
172
  if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
173
- mask = np.concatenate(
174
- [np.array(m.convert("L"))[None, None, :] for m in mask], axis=0
175
- )
176
  mask = mask.astype(np.float32) / 255.0
177
  elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
178
  mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
@@ -185,10 +169,7 @@ def prepare_mask_and_masked_image(image, mask):
185
 
186
  return mask, masked_image
187
 
188
-
189
- class StableDiffusionControlNetInpaintPipeline(
190
- StableDiffusionControlNetPipeline
191
- ):
192
  r"""
193
  Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
194
 
@@ -217,28 +198,15 @@ class StableDiffusionControlNetInpaintPipeline(
217
  feature_extractor ([`CLIPFeatureExtractor`]):
218
  Model that extracts features from generated images to be used as inputs for the `safety_checker`.
219
  """
220
-
221
  def prepare_mask_latents(
222
- self,
223
- mask,
224
- masked_image,
225
- batch_size,
226
- height,
227
- width,
228
- dtype,
229
- device,
230
- generator,
231
- do_classifier_free_guidance,
232
  ):
233
  # resize the mask to latents shape as we concatenate the mask to the latents
234
  # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
235
  # and half precision
236
  mask = torch.nn.functional.interpolate(
237
- mask,
238
- size=(
239
- height // self.vae_scale_factor,
240
- width // self.vae_scale_factor,
241
- ),
242
  )
243
  mask = mask.to(device=device, dtype=dtype)
244
 
@@ -247,19 +215,13 @@ class StableDiffusionControlNetInpaintPipeline(
247
  # encode the mask image into latents space so we can concatenate it to the latents
248
  if isinstance(generator, list):
249
  masked_image_latents = [
250
- self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(
251
- generator=generator[i]
252
- )
253
  for i in range(batch_size)
254
  ]
255
  masked_image_latents = torch.cat(masked_image_latents, dim=0)
256
  else:
257
- masked_image_latents = self.vae.encode(
258
- masked_image
259
- ).latent_dist.sample(generator=generator)
260
- masked_image_latents = (
261
- self.vae.config.scaling_factor * masked_image_latents
262
- )
263
 
264
  # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
265
  if mask.shape[0] < batch_size:
@@ -277,35 +239,24 @@ class StableDiffusionControlNetInpaintPipeline(
277
  f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
278
  " Make sure the number of images that you pass is divisible by the total requested batch size."
279
  )
280
- masked_image_latents = masked_image_latents.repeat(
281
- batch_size // masked_image_latents.shape[0], 1, 1, 1
282
- )
283
 
284
  mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
285
  masked_image_latents = (
286
- torch.cat([masked_image_latents] * 2)
287
- if do_classifier_free_guidance
288
- else masked_image_latents
289
  )
290
 
291
  # aligning device to prevent device errors when concating it with the latent model input
292
- masked_image_latents = masked_image_latents.to(
293
- device=device, dtype=dtype
294
- )
295
  return mask, masked_image_latents
296
-
297
  @torch.no_grad()
298
  @replace_example_docstring(EXAMPLE_DOC_STRING)
299
  def __call__(
300
  self,
301
- prompt: Union[str, List[str]] = None,
302
  image: Union[torch.FloatTensor, PIL.Image.Image] = None,
303
- control_image: Union[
304
- torch.FloatTensor,
305
- PIL.Image.Image,
306
- List[torch.FloatTensor],
307
- List[PIL.Image.Image],
308
- ] = None,
309
  mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
310
  height: Optional[int] = None,
311
  width: Optional[int] = None,
@@ -314,17 +265,13 @@ class StableDiffusionControlNetInpaintPipeline(
314
  negative_prompt: Optional[Union[str, List[str]]] = None,
315
  num_images_per_prompt: Optional[int] = 1,
316
  eta: float = 0.0,
317
- generator: Optional[
318
- Union[torch.Generator, List[torch.Generator]]
319
- ] = None,
320
  latents: Optional[torch.FloatTensor] = None,
321
  prompt_embeds: Optional[torch.FloatTensor] = None,
322
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
323
  output_type: Optional[str] = "pil",
324
  return_dict: bool = True,
325
- callback: Optional[
326
- Callable[[int, int, torch.FloatTensor], None]
327
- ] = None,
328
  callback_steps: int = 1,
329
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
330
  controlnet_conditioning_scale: float = 1.0,
@@ -346,7 +293,7 @@ class StableDiffusionControlNetInpaintPipeline(
346
  `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
347
  repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
348
  to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
349
- instead of 3, so the expected shape would be `(B, H, W, 1)`.
350
  height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
351
  The height in pixels of the generated image.
352
  width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -415,14 +362,7 @@ class StableDiffusionControlNetInpaintPipeline(
415
 
416
  # 1. Check inputs. Raise error if not correct
417
  self.check_inputs(
418
- prompt,
419
- control_image,
420
- height,
421
- width,
422
- callback_steps,
423
- negative_prompt,
424
- prompt_embeds,
425
- negative_prompt_embeds,
426
  )
427
 
428
  # 2. Define call parameters
@@ -452,15 +392,15 @@ class StableDiffusionControlNetInpaintPipeline(
452
 
453
  # 4. Prepare image
454
  control_image = self.prepare_image(
455
- control_image,
456
- width,
457
- height,
458
- batch_size * num_images_per_prompt,
459
- num_images_per_prompt,
460
- device,
461
- do_classifier_free_guidance,
462
- self.controlnet.dtype,
463
- )
464
  if do_classifier_free_guidance:
465
  control_image = torch.cat([control_image] * 2)
466
 
@@ -469,7 +409,7 @@ class StableDiffusionControlNetInpaintPipeline(
469
  timesteps = self.scheduler.timesteps
470
 
471
  # 6. Prepare latent variables
472
- num_channels_latents = self.controlnet.in_channels
473
  latents = self.prepare_latents(
474
  batch_size * num_images_per_prompt,
475
  num_channels_latents,
@@ -480,7 +420,7 @@ class StableDiffusionControlNetInpaintPipeline(
480
  generator,
481
  latents,
482
  )
483
-
484
  # EXTRA: prepare mask latents
485
  mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
486
  mask, masked_image_latents = self.prepare_mask_latents(
@@ -499,20 +439,12 @@ class StableDiffusionControlNetInpaintPipeline(
499
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
500
 
501
  # 8. Denoising loop
502
- num_warmup_steps = (
503
- len(timesteps) - num_inference_steps * self.scheduler.order
504
- )
505
  with self.progress_bar(total=num_inference_steps) as progress_bar:
506
  for i, t in enumerate(timesteps):
507
  # expand the latents if we are doing classifier free guidance
508
- latent_model_input = (
509
- torch.cat([latents] * 2)
510
- if do_classifier_free_guidance
511
- else latents
512
- )
513
- latent_model_input = self.scheduler.scale_model_input(
514
- latent_model_input, t
515
- )
516
 
517
  down_block_res_samples, mid_block_res_sample = self.controlnet(
518
  latent_model_input,
@@ -529,9 +461,7 @@ class StableDiffusionControlNetInpaintPipeline(
529
  mid_block_res_sample *= controlnet_conditioning_scale
530
 
531
  # predict the noise residual
532
- latent_model_input = torch.cat(
533
- [latent_model_input, mask, masked_image_latents], dim=1
534
- )
535
  noise_pred = self.unet(
536
  latent_model_input,
537
  t,
@@ -544,30 +474,20 @@ class StableDiffusionControlNetInpaintPipeline(
544
  # perform guidance
545
  if do_classifier_free_guidance:
546
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
547
- noise_pred = noise_pred_uncond + guidance_scale * (
548
- noise_pred_text - noise_pred_uncond
549
- )
550
 
551
  # compute the previous noisy sample x_t -> x_t-1
552
- latents = self.scheduler.step(
553
- noise_pred, t, latents, **extra_step_kwargs
554
- ).prev_sample
555
 
556
  # call the callback, if provided
557
- if i == len(timesteps) - 1 or (
558
- (i + 1) > num_warmup_steps
559
- and (i + 1) % self.scheduler.order == 0
560
- ):
561
  progress_bar.update()
562
  if callback is not None and i % callback_steps == 0:
563
  callback(i, t, latents)
564
 
565
  # If we do sequential model offloading, let's offload unet and controlnet
566
  # manually for max memory savings
567
- if (
568
- hasattr(self, "final_offload_hook")
569
- and self.final_offload_hook is not None
570
- ):
571
  self.unet.to("cpu")
572
  self.controlnet.to("cpu")
573
  torch.cuda.empty_cache()
@@ -580,9 +500,7 @@ class StableDiffusionControlNetInpaintPipeline(
580
  image = self.decode_latents(latents)
581
 
582
  # 9. Run safety checker
583
- image, has_nsfw_concept = self.run_safety_checker(
584
- image, device, prompt_embeds.dtype
585
- )
586
 
587
  # 10. Convert to PIL
588
  image = self.numpy_to_pil(image)
@@ -591,20 +509,13 @@ class StableDiffusionControlNetInpaintPipeline(
591
  image = self.decode_latents(latents)
592
 
593
  # 9. Run safety checker
594
- image, has_nsfw_concept = self.run_safety_checker(
595
- image, device, prompt_embeds.dtype
596
- )
597
 
598
  # Offload last model to CPU
599
- if (
600
- hasattr(self, "final_offload_hook")
601
- and self.final_offload_hook is not None
602
- ):
603
  self.final_offload_hook.offload()
604
 
605
  if not return_dict:
606
  return (image, has_nsfw_concept)
607
 
608
- return StableDiffusionPipelineOutput(
609
- images=image, nsfw_content_detected=has_nsfw_concept
610
- )
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
 
 
 
 
15
  import torch
16
+ import PIL.Image
17
+ import numpy as np
18
 
19
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import *
20
 
21
  EXAMPLE_DOC_STRING = """
22
  Examples:
 
96
  """
97
  if isinstance(image, torch.Tensor):
98
  if not isinstance(mask, torch.Tensor):
99
+ raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
 
 
100
 
101
  # Batch single image
102
  if image.ndim == 3:
103
+ assert image.shape[0] == 3, "Image outside a batch should be of shape (3, H, W)"
 
 
104
  image = image.unsqueeze(0)
105
 
106
  # Batch and add channel dim for single mask
 
117
  else:
118
  mask = mask.unsqueeze(1)
119
 
120
+ assert image.ndim == 4 and mask.ndim == 4, "Image and Mask must have 4 dimensions"
121
+ assert image.shape[-2:] == mask.shape[-2:], "Image and Mask must have the same spatial dimensions"
122
+ assert image.shape[0] == mask.shape[0], "Image and Mask must have the same batch size"
 
 
 
 
 
 
123
 
124
  # Check image is in [-1, 1]
125
  if image.min() < -1 or image.max() > 1:
 
136
  # Image as float32
137
  image = image.to(dtype=torch.float32)
138
  elif isinstance(mask, torch.Tensor):
139
+ raise TypeError(f"`mask` is a torch.Tensor but `image` (type: {type(image)} is not")
 
 
140
  else:
141
  # preprocess image
142
  if isinstance(image, (PIL.Image.Image, np.ndarray)):
 
156
  mask = [mask]
157
 
158
  if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
159
+ mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
 
 
160
  mask = mask.astype(np.float32) / 255.0
161
  elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
162
  mask = np.concatenate([m[None, None, :] for m in mask], axis=0)
 
169
 
170
  return mask, masked_image
171
 
172
+ class StableDiffusionControlNetInpaintPipeline(StableDiffusionControlNetPipeline):
 
 
 
173
  r"""
174
  Pipeline for text-guided image inpainting using Stable Diffusion with ControlNet guidance.
175
 
 
198
  feature_extractor ([`CLIPFeatureExtractor`]):
199
  Model that extracts features from generated images to be used as inputs for the `safety_checker`.
200
  """
201
+
202
  def prepare_mask_latents(
203
+ self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
 
 
 
 
 
 
 
 
 
204
  ):
205
  # resize the mask to latents shape as we concatenate the mask to the latents
206
  # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
207
  # and half precision
208
  mask = torch.nn.functional.interpolate(
209
+ mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
 
 
 
 
210
  )
211
  mask = mask.to(device=device, dtype=dtype)
212
 
 
215
  # encode the mask image into latents space so we can concatenate it to the latents
216
  if isinstance(generator, list):
217
  masked_image_latents = [
218
+ self.vae.encode(masked_image[i : i + 1]).latent_dist.sample(generator=generator[i])
 
 
219
  for i in range(batch_size)
220
  ]
221
  masked_image_latents = torch.cat(masked_image_latents, dim=0)
222
  else:
223
+ masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
224
+ masked_image_latents = self.vae.config.scaling_factor * masked_image_latents
 
 
 
 
225
 
226
  # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
227
  if mask.shape[0] < batch_size:
 
239
  f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
240
  " Make sure the number of images that you pass is divisible by the total requested batch size."
241
  )
242
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
 
 
243
 
244
  mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
245
  masked_image_latents = (
246
+ torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
 
 
247
  )
248
 
249
  # aligning device to prevent device errors when concating it with the latent model input
250
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
 
 
251
  return mask, masked_image_latents
252
+
253
  @torch.no_grad()
254
  @replace_example_docstring(EXAMPLE_DOC_STRING)
255
  def __call__(
256
  self,
257
+ prompt: Union[str, List[str]] = None,
258
  image: Union[torch.FloatTensor, PIL.Image.Image] = None,
259
+ control_image: Union[torch.FloatTensor, PIL.Image.Image, List[torch.FloatTensor], List[PIL.Image.Image]] = None,
 
 
 
 
 
260
  mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
261
  height: Optional[int] = None,
262
  width: Optional[int] = None,
 
265
  negative_prompt: Optional[Union[str, List[str]]] = None,
266
  num_images_per_prompt: Optional[int] = 1,
267
  eta: float = 0.0,
268
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
 
 
269
  latents: Optional[torch.FloatTensor] = None,
270
  prompt_embeds: Optional[torch.FloatTensor] = None,
271
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
272
  output_type: Optional[str] = "pil",
273
  return_dict: bool = True,
274
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
 
 
275
  callback_steps: int = 1,
276
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
277
  controlnet_conditioning_scale: float = 1.0,
 
293
  `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
294
  repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
295
  to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
296
+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
297
  height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
298
  The height in pixels of the generated image.
299
  width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
 
362
 
363
  # 1. Check inputs. Raise error if not correct
364
  self.check_inputs(
365
+ prompt, control_image, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
 
 
 
 
 
 
 
366
  )
367
 
368
  # 2. Define call parameters
 
392
 
393
  # 4. Prepare image
394
  control_image = self.prepare_image(
395
+ control_image,
396
+ width,
397
+ height,
398
+ batch_size * num_images_per_prompt,
399
+ num_images_per_prompt,
400
+ device,
401
+ self.controlnet.dtype,
402
+ )
403
+
404
  if do_classifier_free_guidance:
405
  control_image = torch.cat([control_image] * 2)
406
 
 
409
  timesteps = self.scheduler.timesteps
410
 
411
  # 6. Prepare latent variables
412
+ num_channels_latents = self.controlnet.config.in_channels
413
  latents = self.prepare_latents(
414
  batch_size * num_images_per_prompt,
415
  num_channels_latents,
 
420
  generator,
421
  latents,
422
  )
423
+
424
  # EXTRA: prepare mask latents
425
  mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
426
  mask, masked_image_latents = self.prepare_mask_latents(
 
439
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
440
 
441
  # 8. Denoising loop
442
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
 
 
443
  with self.progress_bar(total=num_inference_steps) as progress_bar:
444
  for i, t in enumerate(timesteps):
445
  # expand the latents if we are doing classifier free guidance
446
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
447
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
 
 
 
 
 
448
 
449
  down_block_res_samples, mid_block_res_sample = self.controlnet(
450
  latent_model_input,
 
461
  mid_block_res_sample *= controlnet_conditioning_scale
462
 
463
  # predict the noise residual
464
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
 
 
465
  noise_pred = self.unet(
466
  latent_model_input,
467
  t,
 
474
  # perform guidance
475
  if do_classifier_free_guidance:
476
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
477
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
 
 
478
 
479
  # compute the previous noisy sample x_t -> x_t-1
480
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
 
 
481
 
482
  # call the callback, if provided
483
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
 
 
 
484
  progress_bar.update()
485
  if callback is not None and i % callback_steps == 0:
486
  callback(i, t, latents)
487
 
488
  # If we do sequential model offloading, let's offload unet and controlnet
489
  # manually for max memory savings
490
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
 
 
 
491
  self.unet.to("cpu")
492
  self.controlnet.to("cpu")
493
  torch.cuda.empty_cache()
 
500
  image = self.decode_latents(latents)
501
 
502
  # 9. Run safety checker
503
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
 
 
504
 
505
  # 10. Convert to PIL
506
  image = self.numpy_to_pil(image)
 
509
  image = self.decode_latents(latents)
510
 
511
  # 9. Run safety checker
512
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
 
 
513
 
514
  # Offload last model to CPU
515
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
 
 
 
516
  self.final_offload_hook.offload()
517
 
518
  if not return_dict:
519
  return (image, has_nsfw_concept)
520
 
521
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)