skytnt commited on
Commit
b3a152e
1 Parent(s): 4223034

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +294 -325
pipeline.py CHANGED
@@ -6,13 +6,10 @@ import numpy as np
6
  import torch
7
 
8
  import PIL
9
- from diffusers.configuration_utils import FrozenDict
10
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
11
- from diffusers.pipeline_utils import DiffusionPipeline
12
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
13
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
14
- from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
15
- from diffusers.utils import deprecate, is_accelerate_available, logging
16
  from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
17
 
18
 
@@ -124,7 +121,7 @@ def parse_prompt_attention(text):
124
  return res
125
 
126
 
127
- def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
128
  r"""
129
  Tokenize a list of prompts and return its tokens with weights of each token.
130
 
@@ -185,7 +182,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
185
 
186
 
187
  def get_unweighted_text_embeddings(
188
- pipe: DiffusionPipeline,
189
  text_input: torch.Tensor,
190
  chunk_length: int,
191
  no_boseos_middle: Optional[bool] = True,
@@ -225,10 +222,10 @@ def get_unweighted_text_embeddings(
225
 
226
 
227
  def get_weighted_text_embeddings(
228
- pipe: DiffusionPipeline,
229
  prompt: Union[str, List[str]],
230
  uncond_prompt: Optional[Union[str, List[str]]] = None,
231
- max_embeddings_multiples: Optional[int] = 1,
232
  no_boseos_middle: Optional[bool] = False,
233
  skip_parsing: Optional[bool] = False,
234
  skip_weighting: Optional[bool] = False,
@@ -242,14 +239,14 @@ def get_weighted_text_embeddings(
242
  Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
243
 
244
  Args:
245
- pipe (`DiffusionPipeline`):
246
  Pipe to provide access to the tokenizer and the text encoder.
247
  prompt (`str` or `List[str]`):
248
  The prompt or prompts to guide the image generation.
249
  uncond_prompt (`str` or `List[str]`):
250
  The unconditional prompt or prompts for guide the image generation. If unconditional prompt
251
  is provided, the embeddings of prompt and uncond_prompt are concatenated.
252
- max_embeddings_multiples (`int`, *optional*, defaults to `1`):
253
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
254
  no_boseos_middle (`bool`, *optional*, defaults to `False`):
255
  If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
@@ -358,18 +355,18 @@ def get_weighted_text_embeddings(
358
  def preprocess_image(image):
359
  w, h = image.size
360
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
361
- image = image.resize((w, h), resample=PIL.Image.LANCZOS)
362
  image = np.array(image).astype(np.float32) / 255.0
363
  image = image[None].transpose(0, 3, 1, 2)
364
  image = torch.from_numpy(image)
365
  return 2.0 * image - 1.0
366
 
367
 
368
- def preprocess_mask(mask):
369
  mask = mask.convert("L")
370
  w, h = mask.size
371
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
372
- mask = mask.resize((w // 8, h // 8), resample=PIL.Image.NEAREST)
373
  mask = np.array(mask).astype(np.float32) / 255.0
374
  mask = np.tile(mask, (4, 1, 1))
375
  mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -378,7 +375,7 @@ def preprocess_mask(mask):
378
  return mask
379
 
380
 
381
- class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
382
  r"""
383
  Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
384
  weighting in prompt.
@@ -398,7 +395,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
398
  [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
399
  unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
400
  scheduler ([`SchedulerMixin`]):
401
- A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
402
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
403
  safety_checker ([`StableDiffusionSafetyChecker`]):
404
  Classification module that estimates whether generated images could be considered offensive or harmful.
@@ -413,50 +410,12 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
413
  text_encoder: CLIPTextModel,
414
  tokenizer: CLIPTokenizer,
415
  unet: UNet2DConditionModel,
416
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
417
  safety_checker: StableDiffusionSafetyChecker,
418
  feature_extractor: CLIPFeatureExtractor,
 
419
  ):
420
- super().__init__()
421
-
422
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
423
- deprecation_message = (
424
- f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
425
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
426
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
427
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
428
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
429
- " file"
430
- )
431
- deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
432
- new_config = dict(scheduler.config)
433
- new_config["steps_offset"] = 1
434
- scheduler._internal_dict = FrozenDict(new_config)
435
-
436
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
437
- deprecation_message = (
438
- f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
439
- " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
440
- " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
441
- " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
442
- " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
443
- )
444
- deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
445
- new_config = dict(scheduler.config)
446
- new_config["clip_sample"] = False
447
- scheduler._internal_dict = FrozenDict(new_config)
448
-
449
- if safety_checker is None:
450
- logger.warn(
451
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
452
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
453
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
454
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
455
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
456
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
457
- )
458
-
459
- self.register_modules(
460
  vae=vae,
461
  text_encoder=text_encoder,
462
  tokenizer=tokenizer,
@@ -464,76 +423,178 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
464
  scheduler=scheduler,
465
  safety_checker=safety_checker,
466
  feature_extractor=feature_extractor,
 
467
  )
468
 
469
- def enable_xformers_memory_efficient_attention(self):
 
 
 
 
 
 
 
 
470
  r"""
471
- Enable memory efficient attention as implemented in xformers.
472
 
473
- When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
474
- time. Speed up at training time is not guaranteed.
475
-
476
- Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
477
- is used.
 
 
 
 
 
 
 
 
 
478
  """
479
- self.unet.set_use_memory_efficient_attention_xformers(True)
480
 
481
- def disable_xformers_memory_efficient_attention(self):
482
- r"""
483
- Disable memory efficient attention as implemented in xformers.
484
- """
485
- self.unet.set_use_memory_efficient_attention_xformers(False)
 
 
 
 
 
486
 
487
- def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
488
- r"""
489
- Enable sliced attention computation.
 
 
 
 
 
 
490
 
491
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
492
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
 
 
 
493
 
494
- Args:
495
- slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
496
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
497
- a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
498
- `attention_head_dim` must be a multiple of `slice_size`.
499
- """
500
- if slice_size == "auto":
501
- # half the attention head size is usually a good trade-off between
502
- # speed and memory
503
- slice_size = self.unet.config.attention_head_dim // 2
504
- self.unet.set_attention_slice(slice_size)
505
 
506
- def disable_attention_slicing(self):
507
- r"""
508
- Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
509
- back to computing attention in one step.
510
- """
511
- # set slice_size = `None` to disable `attention slicing`
512
- self.enable_attention_slicing(None)
513
 
514
- def enable_sequential_cpu_offload(self):
515
- r"""
516
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
517
- text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
518
- `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
519
- """
520
- if is_accelerate_available():
521
- from accelerate import cpu_offload
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
  else:
523
- raise ImportError("Please install accelerate via `pip install accelerate`")
 
524
 
525
- device = self.device
 
 
 
 
 
 
526
 
527
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
528
- if cpu_offloaded_model is not None:
529
- cpu_offload(cpu_offloaded_model, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
 
531
  @torch.no_grad()
532
  def __call__(
533
  self,
534
  prompt: Union[str, List[str]],
535
  negative_prompt: Optional[Union[str, List[str]]] = None,
536
- init_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
537
  mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
538
  height: int = 512,
539
  width: int = 512,
@@ -561,11 +622,11 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
561
  negative_prompt (`str` or `List[str]`, *optional*):
562
  The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
563
  if `guidance_scale` is less than `1`).
564
- init_image (`torch.FloatTensor` or `PIL.Image.Image`):
565
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
566
  process.
567
  mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
568
- `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
569
  replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
570
  PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
571
  contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
@@ -583,11 +644,11 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
583
  1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
584
  usually at the expense of lower image quality.
585
  strength (`float`, *optional*, defaults to 0.8):
586
- Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
587
- `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
588
  number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
589
  noise will be maximum and the denoising process will run for the full number of iterations specified in
590
- `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
591
  num_images_per_prompt (`int`, *optional*, defaults to 1):
592
  The number of images to generate per prompt.
593
  eta (`float`, *optional*, defaults to 0.0):
@@ -626,222 +687,115 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
626
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
627
  (nsfw) content, according to the `safety_checker`.
628
  """
 
 
 
629
 
630
- if isinstance(prompt, str):
631
- batch_size = 1
632
- prompt = [prompt]
633
- elif isinstance(prompt, list):
634
- batch_size = len(prompt)
635
- else:
636
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
637
 
638
- if strength < 0 or strength > 1:
639
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
640
-
641
- if height % 8 != 0 or width % 8 != 0:
642
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
643
-
644
- if (callback_steps is None) or (
645
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
646
- ):
647
- raise ValueError(
648
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
649
- f" {type(callback_steps)}."
650
- )
651
-
652
- # get prompt text embeddings
653
 
 
 
 
654
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
655
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
656
  # corresponds to doing no classifier free guidance.
657
  do_classifier_free_guidance = guidance_scale > 1.0
658
- # get unconditional embeddings for classifier free guidance
659
- if negative_prompt is None:
660
- negative_prompt = [""] * batch_size
661
- elif isinstance(negative_prompt, str):
662
- negative_prompt = [negative_prompt] * batch_size
663
- if batch_size != len(negative_prompt):
664
- raise ValueError(
665
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
666
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
667
- " the batch size of `prompt`."
668
- )
669
 
670
- text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
671
- pipe=self,
672
- prompt=prompt,
673
- uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
674
- max_embeddings_multiples=max_embeddings_multiples,
675
- **kwargs,
 
 
676
  )
677
- bs_embed, seq_len, _ = text_embeddings.shape
678
- text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
679
- text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
680
-
681
- if do_classifier_free_guidance:
682
- bs_embed, seq_len, _ = uncond_embeddings.shape
683
- uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
684
- uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
685
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
686
-
687
- # set timesteps
688
- self.scheduler.set_timesteps(num_inference_steps)
689
-
690
- latents_dtype = text_embeddings.dtype
691
- init_latents_orig = None
692
- mask = None
693
- noise = None
694
-
695
- if init_image is None:
696
- # get the initial random noise unless the user supplied it
697
-
698
- # Unlike in other pipelines, latents need to be generated in the target device
699
- # for 1-to-1 results reproducibility with the CompVis implementation.
700
- # However this currently doesn't work in `mps`.
701
- latents_shape = (
702
- batch_size * num_images_per_prompt,
703
- self.unet.in_channels,
704
- height // 8,
705
- width // 8,
706
- )
707
-
708
- if latents is None:
709
- if self.device.type == "mps":
710
- # randn does not exist on mps
711
- latents = torch.randn(
712
- latents_shape,
713
- generator=generator,
714
- device="cpu",
715
- dtype=latents_dtype,
716
- ).to(self.device)
717
- else:
718
- latents = torch.randn(
719
- latents_shape,
720
- generator=generator,
721
- device=self.device,
722
- dtype=latents_dtype,
723
- )
724
- else:
725
- if latents.shape != latents_shape:
726
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
727
- latents = latents.to(self.device)
728
-
729
- timesteps = self.scheduler.timesteps.to(self.device)
730
-
731
- # scale the initial noise by the standard deviation required by the scheduler
732
- latents = latents * self.scheduler.init_noise_sigma
733
  else:
734
- if isinstance(init_image, PIL.Image.Image):
735
- init_image = preprocess_image(init_image)
736
- # encode the init image into latents and scale the latents
737
- init_image = init_image.to(device=self.device, dtype=latents_dtype)
738
- init_latent_dist = self.vae.encode(init_image).latent_dist
739
- init_latents = init_latent_dist.sample(generator=generator)
740
- init_latents = 0.18215 * init_latents
741
- init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
742
- init_latents_orig = init_latents
743
-
744
- # preprocess mask
745
- if mask_image is not None:
746
- if isinstance(mask_image, PIL.Image.Image):
747
- mask_image = preprocess_mask(mask_image)
748
- mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
749
- mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
750
-
751
- # check sizes
752
- if not mask.shape == init_latents.shape:
753
- raise ValueError("The mask and init_image should be the same size!")
754
-
755
- # get the original timestep using init_timestep
756
- offset = self.scheduler.config.get("steps_offset", 0)
757
- init_timestep = int(num_inference_steps * strength) + offset
758
- init_timestep = min(init_timestep, num_inference_steps)
759
-
760
- timesteps = self.scheduler.timesteps[-init_timestep]
761
- timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
762
-
763
- # add noise to latents using the timesteps
764
- if self.device.type == "mps":
765
- # randn does not exist on mps
766
- noise = torch.randn(
767
- init_latents.shape,
768
- generator=generator,
769
- device="cpu",
770
- dtype=latents_dtype,
771
- ).to(self.device)
772
- else:
773
- noise = torch.randn(
774
- init_latents.shape,
775
- generator=generator,
776
- device=self.device,
777
- dtype=latents_dtype,
778
- )
779
- latents = self.scheduler.add_noise(init_latents, noise, timesteps)
780
-
781
- t_start = max(num_inference_steps - init_timestep + offset, 0)
782
- timesteps = self.scheduler.timesteps[t_start:].to(self.device)
783
-
784
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
785
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
786
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
787
- # and should be between [0, 1]
788
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
789
- extra_step_kwargs = {}
790
- if accepts_eta:
791
- extra_step_kwargs["eta"] = eta
792
-
793
- for i, t in enumerate(self.progress_bar(timesteps)):
794
- # expand the latents if we are doing classifier free guidance
795
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
796
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
797
-
798
- # predict the noise residual
799
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
800
-
801
- # perform guidance
802
- if do_classifier_free_guidance:
803
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
804
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
805
-
806
- # compute the previous noisy sample x_t -> x_t-1
807
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
808
-
809
- if mask is not None:
810
- # masking
811
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
812
- latents = (init_latents_proper * mask) + (latents * (1 - mask))
813
-
814
- # call the callback, if provided
815
- if i % callback_steps == 0:
816
- if callback is not None:
817
- callback(i, t, latents)
818
- if is_cancelled_callback is not None and is_cancelled_callback():
819
- return None
820
-
821
- latents = 1 / 0.18215 * latents
822
- image = self.vae.decode(latents).sample
823
-
824
- image = (image / 2 + 0.5).clamp(0, 1)
825
-
826
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
827
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
828
-
829
- if self.safety_checker is not None:
830
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
831
- self.device
832
- )
833
- image, has_nsfw_concept = self.safety_checker(
834
- images=image,
835
- clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype),
836
- )
837
- else:
838
- has_nsfw_concept = None
839
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
840
  if output_type == "pil":
841
  image = self.numpy_to_pil(image)
842
 
843
  if not return_dict:
844
- return (image, has_nsfw_concept)
845
 
846
  return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
847
 
@@ -861,6 +815,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
861
  output_type: Optional[str] = "pil",
862
  return_dict: bool = True,
863
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
 
864
  callback_steps: Optional[int] = 1,
865
  **kwargs,
866
  ):
@@ -908,6 +863,9 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
908
  callback (`Callable`, *optional*):
909
  A function that will be called every `callback_steps` steps during inference. The function will be
910
  called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
 
 
 
911
  callback_steps (`int`, *optional*, defaults to 1):
912
  The frequency at which the `callback` function will be called. If not specified, the callback will be
913
  called at every step.
@@ -933,13 +891,14 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
933
  output_type=output_type,
934
  return_dict=return_dict,
935
  callback=callback,
 
936
  callback_steps=callback_steps,
937
  **kwargs,
938
  )
939
 
940
  def img2img(
941
  self,
942
- init_image: Union[torch.FloatTensor, PIL.Image.Image],
943
  prompt: Union[str, List[str]],
944
  negative_prompt: Optional[Union[str, List[str]]] = None,
945
  strength: float = 0.8,
@@ -952,13 +911,14 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
952
  output_type: Optional[str] = "pil",
953
  return_dict: bool = True,
954
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
 
955
  callback_steps: Optional[int] = 1,
956
  **kwargs,
957
  ):
958
  r"""
959
  Function for image-to-image generation.
960
  Args:
961
- init_image (`torch.FloatTensor` or `PIL.Image.Image`):
962
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
963
  process.
964
  prompt (`str` or `List[str]`):
@@ -967,11 +927,11 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
967
  The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
968
  if `guidance_scale` is less than `1`).
969
  strength (`float`, *optional*, defaults to 0.8):
970
- Conceptually, indicates how much to transform the reference `init_image`. Must be between 0 and 1.
971
- `init_image` will be used as a starting point, adding more noise to it the larger the `strength`. The
972
  number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
973
  noise will be maximum and the denoising process will run for the full number of iterations specified in
974
- `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
975
  num_inference_steps (`int`, *optional*, defaults to 50):
976
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
977
  expense of slower inference. This parameter will be modulated by `strength`.
@@ -1000,6 +960,9 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1000
  callback (`Callable`, *optional*):
1001
  A function that will be called every `callback_steps` steps during inference. The function will be
1002
  called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
 
 
 
1003
  callback_steps (`int`, *optional*, defaults to 1):
1004
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1005
  called at every step.
@@ -1013,7 +976,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1013
  return self.__call__(
1014
  prompt=prompt,
1015
  negative_prompt=negative_prompt,
1016
- init_image=init_image,
1017
  num_inference_steps=num_inference_steps,
1018
  guidance_scale=guidance_scale,
1019
  strength=strength,
@@ -1024,13 +987,14 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1024
  output_type=output_type,
1025
  return_dict=return_dict,
1026
  callback=callback,
 
1027
  callback_steps=callback_steps,
1028
  **kwargs,
1029
  )
1030
 
1031
  def inpaint(
1032
  self,
1033
- init_image: Union[torch.FloatTensor, PIL.Image.Image],
1034
  mask_image: Union[torch.FloatTensor, PIL.Image.Image],
1035
  prompt: Union[str, List[str]],
1036
  negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -1044,17 +1008,18 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1044
  output_type: Optional[str] = "pil",
1045
  return_dict: bool = True,
1046
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
 
1047
  callback_steps: Optional[int] = 1,
1048
  **kwargs,
1049
  ):
1050
  r"""
1051
  Function for inpaint.
1052
  Args:
1053
- init_image (`torch.FloatTensor` or `PIL.Image.Image`):
1054
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
1055
  process. This is the image whose masked region will be inpainted.
1056
  mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1057
- `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
1058
  replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1059
  PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1060
  contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
@@ -1066,7 +1031,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1066
  strength (`float`, *optional*, defaults to 0.8):
1067
  Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1068
  is 1, the denoising process will be run on the masked area for the full number of iterations specified
1069
- in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
1070
  noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1071
  num_inference_steps (`int`, *optional*, defaults to 50):
1072
  The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
@@ -1096,6 +1061,9 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1096
  callback (`Callable`, *optional*):
1097
  A function that will be called every `callback_steps` steps during inference. The function will be
1098
  called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
 
 
 
1099
  callback_steps (`int`, *optional*, defaults to 1):
1100
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1101
  called at every step.
@@ -1109,7 +1077,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1109
  return self.__call__(
1110
  prompt=prompt,
1111
  negative_prompt=negative_prompt,
1112
- init_image=init_image,
1113
  mask_image=mask_image,
1114
  num_inference_steps=num_inference_steps,
1115
  guidance_scale=guidance_scale,
@@ -1121,6 +1089,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1121
  output_type=output_type,
1122
  return_dict=return_dict,
1123
  callback=callback,
 
1124
  callback_steps=callback_steps,
1125
  **kwargs,
1126
  )
 
6
  import torch
7
 
8
  import PIL
9
+ from diffusers import SchedulerMixin, StableDiffusionPipeline
10
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
11
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
12
+ from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
 
 
 
13
  from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
14
 
15
 
 
121
  return res
122
 
123
 
124
+ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
125
  r"""
126
  Tokenize a list of prompts and return its tokens with weights of each token.
127
 
 
182
 
183
 
184
  def get_unweighted_text_embeddings(
185
+ pipe: StableDiffusionPipeline,
186
  text_input: torch.Tensor,
187
  chunk_length: int,
188
  no_boseos_middle: Optional[bool] = True,
 
222
 
223
 
224
  def get_weighted_text_embeddings(
225
+ pipe: StableDiffusionPipeline,
226
  prompt: Union[str, List[str]],
227
  uncond_prompt: Optional[Union[str, List[str]]] = None,
228
+ max_embeddings_multiples: Optional[int] = 3,
229
  no_boseos_middle: Optional[bool] = False,
230
  skip_parsing: Optional[bool] = False,
231
  skip_weighting: Optional[bool] = False,
 
239
  Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
240
 
241
  Args:
242
+ pipe (`StableDiffusionPipeline`):
243
  Pipe to provide access to the tokenizer and the text encoder.
244
  prompt (`str` or `List[str]`):
245
  The prompt or prompts to guide the image generation.
246
  uncond_prompt (`str` or `List[str]`):
247
  The unconditional prompt or prompts for guide the image generation. If unconditional prompt
248
  is provided, the embeddings of prompt and uncond_prompt are concatenated.
249
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
250
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
251
  no_boseos_middle (`bool`, *optional*, defaults to `False`):
252
  If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
 
355
  def preprocess_image(image):
356
  w, h = image.size
357
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
358
+ image = image.resize((w, h), resample=PIL_INTERPOLATION["lanczos"])
359
  image = np.array(image).astype(np.float32) / 255.0
360
  image = image[None].transpose(0, 3, 1, 2)
361
  image = torch.from_numpy(image)
362
  return 2.0 * image - 1.0
363
 
364
 
365
+ def preprocess_mask(mask, scale_factor=8):
366
  mask = mask.convert("L")
367
  w, h = mask.size
368
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
369
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
370
  mask = np.array(mask).astype(np.float32) / 255.0
371
  mask = np.tile(mask, (4, 1, 1))
372
  mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
 
375
  return mask
376
 
377
 
378
+ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
379
  r"""
380
  Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
381
  weighting in prompt.
 
395
  [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
396
  unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
397
  scheduler ([`SchedulerMixin`]):
398
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
399
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
400
  safety_checker ([`StableDiffusionSafetyChecker`]):
401
  Classification module that estimates whether generated images could be considered offensive or harmful.
 
410
  text_encoder: CLIPTextModel,
411
  tokenizer: CLIPTokenizer,
412
  unet: UNet2DConditionModel,
413
+ scheduler: SchedulerMixin,
414
  safety_checker: StableDiffusionSafetyChecker,
415
  feature_extractor: CLIPFeatureExtractor,
416
+ requires_safety_checker: bool = True,
417
  ):
418
+ super().__init__(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
  vae=vae,
420
  text_encoder=text_encoder,
421
  tokenizer=tokenizer,
 
423
  scheduler=scheduler,
424
  safety_checker=safety_checker,
425
  feature_extractor=feature_extractor,
426
+ requires_safety_checker=requires_safety_checker,
427
  )
428
 
429
+ def _encode_prompt(
430
+ self,
431
+ prompt,
432
+ device,
433
+ num_images_per_prompt,
434
+ do_classifier_free_guidance,
435
+ negative_prompt,
436
+ max_embeddings_multiples,
437
+ ):
438
  r"""
439
+ Encodes the prompt into text encoder hidden states.
440
 
441
+ Args:
442
+ prompt (`str` or `list(int)`):
443
+ prompt to be encoded
444
+ device: (`torch.device`):
445
+ torch device
446
+ num_images_per_prompt (`int`):
447
+ number of images that should be generated per prompt
448
+ do_classifier_free_guidance (`bool`):
449
+ whether to use classifier free guidance or not
450
+ negative_prompt (`str` or `List[str]`):
451
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
452
+ if `guidance_scale` is less than `1`).
453
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
454
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
455
  """
456
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
457
 
458
+ if negative_prompt is None:
459
+ negative_prompt = [""] * batch_size
460
+ elif isinstance(negative_prompt, str):
461
+ negative_prompt = [negative_prompt] * batch_size
462
+ if batch_size != len(negative_prompt):
463
+ raise ValueError(
464
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
465
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
466
+ " the batch size of `prompt`."
467
+ )
468
 
469
+ text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
470
+ pipe=self,
471
+ prompt=prompt,
472
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
473
+ max_embeddings_multiples=max_embeddings_multiples,
474
+ )
475
+ bs_embed, seq_len, _ = text_embeddings.shape
476
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
477
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
478
 
479
+ if do_classifier_free_guidance:
480
+ bs_embed, seq_len, _ = uncond_embeddings.shape
481
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
482
+ uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
483
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
484
 
485
+ return text_embeddings
 
 
 
 
 
 
 
 
 
 
486
 
487
+ def check_inputs(self, prompt, height, width, strength, callback_steps):
488
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
489
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
 
 
 
 
490
 
491
+ if strength < 0 or strength > 1:
492
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
493
+
494
+ if height % 8 != 0 or width % 8 != 0:
495
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
496
+
497
+ if (callback_steps is None) or (
498
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
499
+ ):
500
+ raise ValueError(
501
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
502
+ f" {type(callback_steps)}."
503
+ )
504
+
505
+ def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
506
+ if is_text2img:
507
+ return self.scheduler.timesteps.to(device), num_inference_steps
508
+ else:
509
+ # get the original timestep using init_timestep
510
+ offset = self.scheduler.config.get("steps_offset", 0)
511
+ init_timestep = int(num_inference_steps * strength) + offset
512
+ init_timestep = min(init_timestep, num_inference_steps)
513
+
514
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
515
+ timesteps = self.scheduler.timesteps[t_start:].to(device)
516
+ return timesteps, num_inference_steps - t_start
517
+
518
+ def run_safety_checker(self, image, device, dtype):
519
+ if self.safety_checker is not None:
520
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
521
+ image, has_nsfw_concept = self.safety_checker(
522
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
523
+ )
524
  else:
525
+ has_nsfw_concept = None
526
+ return image, has_nsfw_concept
527
 
528
+ def decode_latents(self, latents):
529
+ latents = 1 / 0.18215 * latents
530
+ image = self.vae.decode(latents).sample
531
+ image = (image / 2 + 0.5).clamp(0, 1)
532
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
533
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
534
+ return image
535
 
536
+ def prepare_extra_step_kwargs(self, generator, eta):
537
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
538
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
539
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
540
+ # and should be between [0, 1]
541
+
542
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
543
+ extra_step_kwargs = {}
544
+ if accepts_eta:
545
+ extra_step_kwargs["eta"] = eta
546
+
547
+ # check if the scheduler accepts generator
548
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
549
+ if accepts_generator:
550
+ extra_step_kwargs["generator"] = generator
551
+ return extra_step_kwargs
552
+
553
+ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
554
+ if image is None:
555
+ shape = (
556
+ batch_size,
557
+ self.unet.in_channels,
558
+ height // self.vae_scale_factor,
559
+ width // self.vae_scale_factor,
560
+ )
561
+
562
+ if latents is None:
563
+ if device.type == "mps":
564
+ # randn does not work reproducibly on mps
565
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
566
+ else:
567
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
568
+ else:
569
+ if latents.shape != shape:
570
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
571
+ latents = latents.to(device)
572
+
573
+ # scale the initial noise by the standard deviation required by the scheduler
574
+ latents = latents * self.scheduler.init_noise_sigma
575
+ return latents, None, None
576
+ else:
577
+ init_latent_dist = self.vae.encode(image).latent_dist
578
+ init_latents = init_latent_dist.sample(generator=generator)
579
+ init_latents = 0.18215 * init_latents
580
+ init_latents = torch.cat([init_latents] * batch_size, dim=0)
581
+ init_latents_orig = init_latents
582
+ shape = init_latents.shape
583
+
584
+ # add noise to latents using the timesteps
585
+ if device.type == "mps":
586
+ noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
587
+ else:
588
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
589
+ latents = self.scheduler.add_noise(init_latents, noise, timestep)
590
+ return latents, init_latents_orig, noise
591
 
592
  @torch.no_grad()
593
  def __call__(
594
  self,
595
  prompt: Union[str, List[str]],
596
  negative_prompt: Optional[Union[str, List[str]]] = None,
597
+ image: Union[torch.FloatTensor, PIL.Image.Image] = None,
598
  mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None,
599
  height: int = 512,
600
  width: int = 512,
 
622
  negative_prompt (`str` or `List[str]`, *optional*):
623
  The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
624
  if `guidance_scale` is less than `1`).
625
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
626
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
627
  process.
628
  mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
629
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
630
  replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
631
  PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
632
  contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
 
644
  1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
645
  usually at the expense of lower image quality.
646
  strength (`float`, *optional*, defaults to 0.8):
647
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
648
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
649
  number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
650
  noise will be maximum and the denoising process will run for the full number of iterations specified in
651
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
652
  num_images_per_prompt (`int`, *optional*, defaults to 1):
653
  The number of images to generate per prompt.
654
  eta (`float`, *optional*, defaults to 0.0):
 
687
  list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
688
  (nsfw) content, according to the `safety_checker`.
689
  """
690
+ message = "Please use `image` instead of `init_image`."
691
+ init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
692
+ image = init_image or image
693
 
694
+ # 0. Default height and width to unet
695
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
696
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
 
 
 
 
697
 
698
+ # 1. Check inputs. Raise error if not correct
699
+ self.check_inputs(prompt, height, width, strength, callback_steps)
 
 
 
 
 
 
 
 
 
 
 
 
 
700
 
701
+ # 2. Define call parameters
702
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
703
+ device = self._execution_device
704
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
705
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
706
  # corresponds to doing no classifier free guidance.
707
  do_classifier_free_guidance = guidance_scale > 1.0
 
 
 
 
 
 
 
 
 
 
 
708
 
709
+ # 3. Encode input prompt
710
+ text_embeddings = self._encode_prompt(
711
+ prompt,
712
+ device,
713
+ num_images_per_prompt,
714
+ do_classifier_free_guidance,
715
+ negative_prompt,
716
+ max_embeddings_multiples,
717
  )
718
+ dtype = text_embeddings.dtype
719
+
720
+ # 4. Preprocess image and mask
721
+ if isinstance(image, PIL.Image.Image):
722
+ image = preprocess_image(image)
723
+ if image is not None:
724
+ image = image.to(device=self.device, dtype=dtype)
725
+ if isinstance(mask_image, PIL.Image.Image):
726
+ mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
727
+ if mask_image is not None:
728
+ mask = mask_image.to(device=self.device, dtype=dtype)
729
+ mask = torch.cat([mask] * batch_size * num_images_per_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
730
  else:
731
+ mask = None
732
+
733
+ # 5. set timesteps
734
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
735
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
736
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
737
+
738
+ # 6. Prepare latent variables
739
+ latents, init_latents_orig, noise = self.prepare_latents(
740
+ image,
741
+ latent_timestep,
742
+ batch_size * num_images_per_prompt,
743
+ height,
744
+ width,
745
+ dtype,
746
+ device,
747
+ generator,
748
+ latents,
749
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750
 
751
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
752
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
753
+
754
+ # 8. Denoising loop
755
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
756
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
757
+ for i, t in enumerate(timesteps):
758
+ # expand the latents if we are doing classifier free guidance
759
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
760
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
761
+
762
+ # predict the noise residual
763
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
764
+
765
+ # perform guidance
766
+ if do_classifier_free_guidance:
767
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
768
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
769
+
770
+ # compute the previous noisy sample x_t -> x_t-1
771
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
772
+
773
+ if mask is not None:
774
+ # masking
775
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
776
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
777
+
778
+ # call the callback, if provided
779
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
780
+ progress_bar.update()
781
+ if i % callback_steps == 0:
782
+ if callback is not None:
783
+ callback(i, t, latents)
784
+ if is_cancelled_callback is not None and is_cancelled_callback():
785
+ return None
786
+
787
+ # 9. Post-processing
788
+ image = self.decode_latents(latents)
789
+
790
+ # 10. Run safety checker
791
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
792
+
793
+ # 11. Convert to PIL
794
  if output_type == "pil":
795
  image = self.numpy_to_pil(image)
796
 
797
  if not return_dict:
798
+ return image, has_nsfw_concept
799
 
800
  return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
801
 
 
815
  output_type: Optional[str] = "pil",
816
  return_dict: bool = True,
817
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
818
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
819
  callback_steps: Optional[int] = 1,
820
  **kwargs,
821
  ):
 
863
  callback (`Callable`, *optional*):
864
  A function that will be called every `callback_steps` steps during inference. The function will be
865
  called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
866
+ is_cancelled_callback (`Callable`, *optional*):
867
+ A function that will be called every `callback_steps` steps during inference. If the function returns
868
+ `True`, the inference will be cancelled.
869
  callback_steps (`int`, *optional*, defaults to 1):
870
  The frequency at which the `callback` function will be called. If not specified, the callback will be
871
  called at every step.
 
891
  output_type=output_type,
892
  return_dict=return_dict,
893
  callback=callback,
894
+ is_cancelled_callback=is_cancelled_callback,
895
  callback_steps=callback_steps,
896
  **kwargs,
897
  )
898
 
899
  def img2img(
900
  self,
901
+ image: Union[torch.FloatTensor, PIL.Image.Image],
902
  prompt: Union[str, List[str]],
903
  negative_prompt: Optional[Union[str, List[str]]] = None,
904
  strength: float = 0.8,
 
911
  output_type: Optional[str] = "pil",
912
  return_dict: bool = True,
913
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
914
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
915
  callback_steps: Optional[int] = 1,
916
  **kwargs,
917
  ):
918
  r"""
919
  Function for image-to-image generation.
920
  Args:
921
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
922
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
923
  process.
924
  prompt (`str` or `List[str]`):
 
927
  The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
928
  if `guidance_scale` is less than `1`).
929
  strength (`float`, *optional*, defaults to 0.8):
930
+ Conceptually, indicates how much to transform the reference `image`. Must be between 0 and 1.
931
+ `image` will be used as a starting point, adding more noise to it the larger the `strength`. The
932
  number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
933
  noise will be maximum and the denoising process will run for the full number of iterations specified in
934
+ `num_inference_steps`. A value of 1, therefore, essentially ignores `image`.
935
  num_inference_steps (`int`, *optional*, defaults to 50):
936
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
937
  expense of slower inference. This parameter will be modulated by `strength`.
 
960
  callback (`Callable`, *optional*):
961
  A function that will be called every `callback_steps` steps during inference. The function will be
962
  called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
963
+ is_cancelled_callback (`Callable`, *optional*):
964
+ A function that will be called every `callback_steps` steps during inference. If the function returns
965
+ `True`, the inference will be cancelled.
966
  callback_steps (`int`, *optional*, defaults to 1):
967
  The frequency at which the `callback` function will be called. If not specified, the callback will be
968
  called at every step.
 
976
  return self.__call__(
977
  prompt=prompt,
978
  negative_prompt=negative_prompt,
979
+ image=image,
980
  num_inference_steps=num_inference_steps,
981
  guidance_scale=guidance_scale,
982
  strength=strength,
 
987
  output_type=output_type,
988
  return_dict=return_dict,
989
  callback=callback,
990
+ is_cancelled_callback=is_cancelled_callback,
991
  callback_steps=callback_steps,
992
  **kwargs,
993
  )
994
 
995
  def inpaint(
996
  self,
997
+ image: Union[torch.FloatTensor, PIL.Image.Image],
998
  mask_image: Union[torch.FloatTensor, PIL.Image.Image],
999
  prompt: Union[str, List[str]],
1000
  negative_prompt: Optional[Union[str, List[str]]] = None,
 
1008
  output_type: Optional[str] = "pil",
1009
  return_dict: bool = True,
1010
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1011
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1012
  callback_steps: Optional[int] = 1,
1013
  **kwargs,
1014
  ):
1015
  r"""
1016
  Function for inpaint.
1017
  Args:
1018
+ image (`torch.FloatTensor` or `PIL.Image.Image`):
1019
  `Image`, or tensor representing an image batch, that will be used as the starting point for the
1020
  process. This is the image whose masked region will be inpainted.
1021
  mask_image (`torch.FloatTensor` or `PIL.Image.Image`):
1022
+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
1023
  replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
1024
  PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
1025
  contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
 
1031
  strength (`float`, *optional*, defaults to 0.8):
1032
  Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
1033
  is 1, the denoising process will be run on the masked area for the full number of iterations specified
1034
+ in `num_inference_steps`. `image` will be used as a reference for the masked area, adding more
1035
  noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
1036
  num_inference_steps (`int`, *optional*, defaults to 50):
1037
  The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
 
1061
  callback (`Callable`, *optional*):
1062
  A function that will be called every `callback_steps` steps during inference. The function will be
1063
  called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1064
+ is_cancelled_callback (`Callable`, *optional*):
1065
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1066
+ `True`, the inference will be cancelled.
1067
  callback_steps (`int`, *optional*, defaults to 1):
1068
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1069
  called at every step.
 
1077
  return self.__call__(
1078
  prompt=prompt,
1079
  negative_prompt=negative_prompt,
1080
+ image=image,
1081
  mask_image=mask_image,
1082
  num_inference_steps=num_inference_steps,
1083
  guidance_scale=guidance_scale,
 
1089
  output_type=output_type,
1090
  return_dict=return_dict,
1091
  callback=callback,
1092
+ is_cancelled_callback=is_cancelled_callback,
1093
  callback_steps=callback_steps,
1094
  **kwargs,
1095
  )