omer11a commited on
Commit
77e9eec
1 Parent(s): 72df8c8

Used diffusers==0.20.0

Browse files
pipeline_stable_diffusion_xl_opt.py CHANGED
@@ -13,24 +13,14 @@
13
  # limitations under the License.
14
 
15
  import inspect
 
16
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
 
18
  import torch
19
- from transformers import (
20
- CLIPImageProcessor,
21
- CLIPTextModel,
22
- CLIPTextModelWithProjection,
23
- CLIPTokenizer,
24
- CLIPVisionModelWithProjection,
25
- )
26
 
27
- from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
28
- from diffusers.loaders import (
29
- FromSingleFileMixin,
30
- IPAdapterMixin,
31
- StableDiffusionXLLoraLoaderMixin,
32
- TextualInversionLoaderMixin,
33
- )
34
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
35
  from diffusers.models.attention_processor import (
36
  AttnProcessor2_0,
@@ -38,33 +28,22 @@ from diffusers.models.attention_processor import (
38
  LoRAXFormersAttnProcessor,
39
  XFormersAttnProcessor,
40
  )
41
- from diffusers.models.lora import adjust_lora_scale_text_encoder
42
  from diffusers.schedulers import KarrasDiffusionSchedulers
43
  from diffusers.utils import (
44
- USE_PEFT_BACKEND,
45
- deprecate,
46
  is_invisible_watermark_available,
47
- is_torch_xla_available,
48
  logging,
 
49
  replace_example_docstring,
50
- scale_lora_layers,
51
- unscale_lora_layers,
52
  )
53
- from diffusers.utils.torch_utils import randn_tensor
54
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
55
- from diffusers.pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput
56
 
57
 
58
  if is_invisible_watermark_available():
59
  from .watermark import StableDiffusionXLWatermarker
60
 
61
- if is_torch_xla_available():
62
- import torch_xla.core.xla_model as xm
63
-
64
- XLA_AVAILABLE = True
65
- else:
66
- XLA_AVAILABLE = False
67
-
68
 
69
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
70
 
@@ -100,58 +79,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
100
  return noise_cfg
101
 
102
 
103
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
104
- def retrieve_timesteps(
105
- scheduler,
106
- num_inference_steps: Optional[int] = None,
107
- device: Optional[Union[str, torch.device]] = None,
108
- timesteps: Optional[List[int]] = None,
109
- **kwargs,
110
- ):
111
- """
112
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
113
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
114
-
115
- Args:
116
- scheduler (`SchedulerMixin`):
117
- The scheduler to get timesteps from.
118
- num_inference_steps (`int`):
119
- The number of diffusion steps used when generating samples with a pre-trained model. If used,
120
- `timesteps` must be `None`.
121
- device (`str` or `torch.device`, *optional*):
122
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
123
- timesteps (`List[int]`, *optional*):
124
- Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
125
- timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
126
- must be `None`.
127
-
128
- Returns:
129
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
130
- second element is the number of inference steps.
131
- """
132
- if timesteps is not None:
133
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
134
- if not accepts_timesteps:
135
- raise ValueError(
136
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
137
- f" timestep schedules. Please check whether you are using the correct scheduler."
138
- )
139
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
140
- timesteps = scheduler.timesteps
141
- num_inference_steps = len(timesteps)
142
- else:
143
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
144
- timesteps = scheduler.timesteps
145
- return timesteps, num_inference_steps
146
-
147
-
148
- class StableDiffusionXLPipeline(
149
- DiffusionPipeline,
150
- FromSingleFileMixin,
151
- StableDiffusionXLLoraLoaderMixin,
152
- TextualInversionLoaderMixin,
153
- IPAdapterMixin,
154
- ):
155
  r"""
156
  Pipeline for text-to-image generation using Stable Diffusion XL.
157
 
@@ -159,11 +87,11 @@ class StableDiffusionXLPipeline(
159
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
160
 
161
  In addition the pipeline inherits the following loading methods:
162
- - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`]
163
  - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
164
 
165
  as well as the following saving methods:
166
- - *LoRA*: [`loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`]
167
 
168
  Args:
169
  vae ([`AutoencoderKL`]):
@@ -188,34 +116,8 @@ class StableDiffusionXLPipeline(
188
  scheduler ([`SchedulerMixin`]):
189
  A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
190
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
191
- force_zeros_for_empty_prompt (`bool`, *optional*, defaults to `"True"`):
192
- Whether the negative prompt embeddings shall be forced to always be set to 0. Also see the config of
193
- `stabilityai/stable-diffusion-xl-base-1-0`.
194
- add_watermarker (`bool`, *optional*):
195
- Whether to use the [invisible_watermark library](https://github.com/ShieldMnt/invisible-watermark/) to
196
- watermark output images. If not defined, it will default to True if the package is installed, otherwise no
197
- watermarker will be used.
198
  """
199
 
200
- model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae"
201
- _optional_components = [
202
- "tokenizer",
203
- "tokenizer_2",
204
- "text_encoder",
205
- "text_encoder_2",
206
- "image_encoder",
207
- "feature_extractor",
208
- ]
209
- _callback_tensor_inputs = [
210
- "latents",
211
- "prompt_embeds",
212
- "negative_prompt_embeds",
213
- "add_text_embeds",
214
- "add_time_ids",
215
- "negative_pooled_prompt_embeds",
216
- "negative_add_time_ids",
217
- ]
218
-
219
  def __init__(
220
  self,
221
  vae: AutoencoderKL,
@@ -225,8 +127,6 @@ class StableDiffusionXLPipeline(
225
  tokenizer_2: CLIPTokenizer,
226
  unet: UNet2DConditionModel,
227
  scheduler: KarrasDiffusionSchedulers,
228
- image_encoder: CLIPVisionModelWithProjection = None,
229
- feature_extractor: CLIPImageProcessor = None,
230
  force_zeros_for_empty_prompt: bool = True,
231
  add_watermarker: Optional[bool] = None,
232
  ):
@@ -240,13 +140,10 @@ class StableDiffusionXLPipeline(
240
  tokenizer_2=tokenizer_2,
241
  unet=unet,
242
  scheduler=scheduler,
243
- image_encoder=image_encoder,
244
- feature_extractor=feature_extractor,
245
  )
246
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
247
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
248
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
249
-
250
  self.default_sample_size = self.unet.config.sample_size
251
 
252
  add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
@@ -289,6 +186,36 @@ class StableDiffusionXLPipeline(
289
  """
290
  self.vae.disable_tiling()
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  def encode_prompt(
293
  self,
294
  prompt: str,
@@ -303,7 +230,6 @@ class StableDiffusionXLPipeline(
303
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
304
  negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
305
  lora_scale: Optional[float] = None,
306
- clip_skip: Optional[int] = None,
307
  ):
308
  r"""
309
  Encodes the prompt into text encoder hidden states.
@@ -343,33 +269,17 @@ class StableDiffusionXLPipeline(
343
  input argument.
344
  lora_scale (`float`, *optional*):
345
  A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
346
- clip_skip (`int`, *optional*):
347
- Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
348
- the output of the pre-final layer will be used for computing the prompt embeddings.
349
  """
350
  device = device or self._execution_device
351
 
352
  # set lora scale so that monkey patched LoRA
353
  # function of text encoder can correctly access it
354
- if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
355
  self._lora_scale = lora_scale
356
 
357
- # dynamically adjust the LoRA scale
358
- if self.text_encoder is not None:
359
- if not USE_PEFT_BACKEND:
360
- adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
361
- else:
362
- scale_lora_layers(self.text_encoder, lora_scale)
363
-
364
- if self.text_encoder_2 is not None:
365
- if not USE_PEFT_BACKEND:
366
- adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
367
- else:
368
- scale_lora_layers(self.text_encoder_2, lora_scale)
369
-
370
- prompt = [prompt] if isinstance(prompt, str) else prompt
371
-
372
- if prompt is not None:
373
  batch_size = len(prompt)
374
  else:
375
  batch_size = prompt_embeds.shape[0]
@@ -382,8 +292,6 @@ class StableDiffusionXLPipeline(
382
 
383
  if prompt_embeds is None:
384
  prompt_2 = prompt_2 or prompt
385
- prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
386
-
387
  # textual inversion: procecss multi-vector tokens if necessary
388
  prompt_embeds_list = []
389
  prompts = [prompt, prompt_2]
@@ -411,15 +319,29 @@ class StableDiffusionXLPipeline(
411
  f" {tokenizer.model_max_length} tokens: {removed_text}"
412
  )
413
 
414
- prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
 
 
 
415
 
416
  # We are only ALWAYS interested in the pooled output of the final text encoder
417
  pooled_prompt_embeds = prompt_embeds[0]
418
- if clip_skip is None:
419
- prompt_embeds = prompt_embeds.hidden_states[-2]
420
- else:
421
- # "2" because SDXL always indexes from the penultimate layer.
422
- prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)]
 
 
 
 
 
 
 
 
 
 
 
423
 
424
  prompt_embeds_list.append(prompt_embeds)
425
 
@@ -434,18 +356,14 @@ class StableDiffusionXLPipeline(
434
  negative_prompt = negative_prompt or ""
435
  negative_prompt_2 = negative_prompt_2 or negative_prompt
436
 
437
- # normalize str to list
438
- negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
439
- negative_prompt_2 = (
440
- batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
441
- )
442
-
443
  uncond_tokens: List[str]
444
  if prompt is not None and type(prompt) is not type(negative_prompt):
445
  raise TypeError(
446
  f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
447
  f" {type(prompt)}."
448
  )
 
 
449
  elif batch_size != len(negative_prompt):
450
  raise ValueError(
451
  f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
@@ -481,11 +399,7 @@ class StableDiffusionXLPipeline(
481
 
482
  negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
483
 
484
- if self.text_encoder_2 is not None:
485
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
486
- else:
487
- prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
488
-
489
  bs_embed, seq_len, _ = prompt_embeds.shape
490
  # duplicate text embeddings for each generation per prompt, using mps friendly method
491
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -494,12 +408,7 @@ class StableDiffusionXLPipeline(
494
  if do_classifier_free_guidance:
495
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
496
  seq_len = negative_prompt_embeds.shape[1]
497
-
498
- if self.text_encoder_2 is not None:
499
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
500
- else:
501
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
502
-
503
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
504
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
505
 
@@ -511,32 +420,8 @@ class StableDiffusionXLPipeline(
511
  bs_embed * num_images_per_prompt, -1
512
  )
513
 
514
- if self.text_encoder is not None:
515
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
516
- # Retrieve the original scale by scaling back the LoRA layers
517
- unscale_lora_layers(self.text_encoder, lora_scale)
518
-
519
- if self.text_encoder_2 is not None:
520
- if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
521
- # Retrieve the original scale by scaling back the LoRA layers
522
- unscale_lora_layers(self.text_encoder_2, lora_scale)
523
-
524
  return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
525
 
526
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
527
- def encode_image(self, image, device, num_images_per_prompt):
528
- dtype = next(self.image_encoder.parameters()).dtype
529
-
530
- if not isinstance(image, torch.Tensor):
531
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
532
-
533
- image = image.to(device=device, dtype=dtype)
534
- image_embeds = self.image_encoder(image).image_embeds
535
- image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
536
-
537
- uncond_image_embeds = torch.zeros_like(image_embeds)
538
- return image_embeds, uncond_image_embeds
539
-
540
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
541
  def prepare_extra_step_kwargs(self, generator, eta):
542
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -568,24 +453,18 @@ class StableDiffusionXLPipeline(
568
  negative_prompt_embeds=None,
569
  pooled_prompt_embeds=None,
570
  negative_pooled_prompt_embeds=None,
571
- callback_on_step_end_tensor_inputs=None,
572
  ):
573
  if height % 8 != 0 or width % 8 != 0:
574
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
575
 
576
- if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
 
 
577
  raise ValueError(
578
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
579
  f" {type(callback_steps)}."
580
  )
581
 
582
- if callback_on_step_end_tensor_inputs is not None and not all(
583
- k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
584
- ):
585
- raise ValueError(
586
- f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
587
- )
588
-
589
  if prompt is not None and prompt_embeds is not None:
590
  raise ValueError(
591
  f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
@@ -652,13 +531,11 @@ class StableDiffusionXLPipeline(
652
  latents = latents * self.scheduler.init_noise_sigma
653
  return latents
654
 
655
- def _get_add_time_ids(
656
- self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
657
- ):
658
  add_time_ids = list(original_size + crops_coords_top_left + target_size)
659
 
660
  passed_add_embed_dim = (
661
- self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
662
  )
663
  expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
664
 
@@ -690,7 +567,7 @@ class StableDiffusionXLPipeline(
690
  self.vae.decoder.conv_in.to(dtype)
691
  self.vae.decoder.mid_block.to(dtype)
692
 
693
- def update_loss(self, latents, i, t, prompt_embeds, timestep_cond, add_text_embeds, add_time_ids):
694
  def forward_pass(latent_model_input):
695
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
696
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
@@ -698,8 +575,7 @@ class StableDiffusionXLPipeline(
698
  latent_model_input,
699
  t,
700
  encoder_hidden_states=prompt_embeds,
701
- timestep_cond=timestep_cond,
702
- cross_attention_kwargs=self.cross_attention_kwargs,
703
  added_cond_kwargs=added_cond_kwargs,
704
  return_dict=False,
705
  )
@@ -707,94 +583,6 @@ class StableDiffusionXLPipeline(
707
 
708
  return self.editor.update_loss(forward_pass, latents, i)
709
 
710
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
711
- def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
712
- r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
713
-
714
- The suffixes after the scaling factors represent the stages where they are being applied.
715
-
716
- Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
717
- that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
718
-
719
- Args:
720
- s1 (`float`):
721
- Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
722
- mitigate "oversmoothing effect" in the enhanced denoising process.
723
- s2 (`float`):
724
- Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
725
- mitigate "oversmoothing effect" in the enhanced denoising process.
726
- b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
727
- b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
728
- """
729
- if not hasattr(self, "unet"):
730
- raise ValueError("The pipeline must have `unet` for using FreeU.")
731
- self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
732
-
733
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_freeu
734
- def disable_freeu(self):
735
- """Disables the FreeU mechanism if enabled."""
736
- self.unet.disable_freeu()
737
-
738
- # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
739
- def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
740
- """
741
- See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
742
-
743
- Args:
744
- timesteps (`torch.Tensor`):
745
- generate embedding vectors at these timesteps
746
- embedding_dim (`int`, *optional*, defaults to 512):
747
- dimension of the embeddings to generate
748
- dtype:
749
- data type of the generated embeddings
750
-
751
- Returns:
752
- `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
753
- """
754
- assert len(w.shape) == 1
755
- w = w * 1000.0
756
-
757
- half_dim = embedding_dim // 2
758
- emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
759
- emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
760
- emb = w.to(dtype)[:, None] * emb[None, :]
761
- emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
762
- if embedding_dim % 2 == 1: # zero pad
763
- emb = torch.nn.functional.pad(emb, (0, 1))
764
- assert emb.shape == (w.shape[0], embedding_dim)
765
- return emb
766
-
767
- @property
768
- def guidance_scale(self):
769
- return self._guidance_scale
770
-
771
- @property
772
- def guidance_rescale(self):
773
- return self._guidance_rescale
774
-
775
- @property
776
- def clip_skip(self):
777
- return self._clip_skip
778
-
779
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
780
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
781
- # corresponds to doing no classifier free guidance.
782
- @property
783
- def do_classifier_free_guidance(self):
784
- return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
785
-
786
- @property
787
- def cross_attention_kwargs(self):
788
- return self._cross_attention_kwargs
789
-
790
- @property
791
- def denoising_end(self):
792
- return self._denoising_end
793
-
794
- @property
795
- def num_timesteps(self):
796
- return self._num_timesteps
797
-
798
  @torch.no_grad()
799
  @replace_example_docstring(EXAMPLE_DOC_STRING)
800
  def __call__(
@@ -804,7 +592,6 @@ class StableDiffusionXLPipeline(
804
  height: Optional[int] = None,
805
  width: Optional[int] = None,
806
  num_inference_steps: int = 50,
807
- timesteps: List[int] = None,
808
  denoising_end: Optional[float] = None,
809
  guidance_scale: float = 5.0,
810
  negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -817,21 +604,15 @@ class StableDiffusionXLPipeline(
817
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
818
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
819
  negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
820
- ip_adapter_image: Optional[PipelineImageInput] = None,
821
  output_type: Optional[str] = "pil",
822
  return_dict: bool = True,
 
 
823
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
824
  guidance_rescale: float = 0.0,
825
  original_size: Optional[Tuple[int, int]] = None,
826
  crops_coords_top_left: Tuple[int, int] = (0, 0),
827
  target_size: Optional[Tuple[int, int]] = None,
828
- negative_original_size: Optional[Tuple[int, int]] = None,
829
- negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
830
- negative_target_size: Optional[Tuple[int, int]] = None,
831
- clip_skip: Optional[int] = None,
832
- callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
833
- callback_on_step_end_tensor_inputs: List[str] = ["latents"],
834
- **kwargs,
835
  ):
836
  r"""
837
  Function invoked when calling the pipeline for generation.
@@ -844,22 +625,12 @@ class StableDiffusionXLPipeline(
844
  The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
845
  used in both text-encoders
846
  height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
847
- The height in pixels of the generated image. This is set to 1024 by default for the best results.
848
- Anything below 512 pixels won't work well for
849
- [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
850
- and checkpoints that are not specifically fine-tuned on low resolutions.
851
  width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
852
- The width in pixels of the generated image. This is set to 1024 by default for the best results.
853
- Anything below 512 pixels won't work well for
854
- [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
855
- and checkpoints that are not specifically fine-tuned on low resolutions.
856
  num_inference_steps (`int`, *optional*, defaults to 50):
857
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
858
  expense of slower inference.
859
- timesteps (`List[int]`, *optional*):
860
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
861
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
862
- passed will be used. Must be in descending order.
863
  denoising_end (`float`, *optional*):
864
  When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
865
  completed before it is intentionally prematurely terminated. As a result, the returned sample will
@@ -906,25 +677,30 @@ class StableDiffusionXLPipeline(
906
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
907
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
908
  input argument.
909
- ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
910
  output_type (`str`, *optional*, defaults to `"pil"`):
911
  The output format of the generate image. Choose between
912
  [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
913
  return_dict (`bool`, *optional*, defaults to `True`):
914
  Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
915
  of a plain tuple.
 
 
 
 
 
 
916
  cross_attention_kwargs (`dict`, *optional*):
917
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
918
  `self.processor` in
919
  [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
920
- guidance_rescale (`float`, *optional*, defaults to 0.0):
921
  Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
922
  Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
923
  [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
924
  Guidance rescale factor should fix overexposure when using zero terminal SNR.
925
  original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
926
  If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
927
- `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
928
  explained in section 2.2 of
929
  [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
930
  crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
@@ -934,32 +710,8 @@ class StableDiffusionXLPipeline(
934
  [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
935
  target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
936
  For most cases, `target_size` should be set to the desired height and width of the generated image. If
937
- not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
938
  section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
939
- negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
940
- To negatively condition the generation process based on a specific image resolution. Part of SDXL's
941
- micro-conditioning as explained in section 2.2 of
942
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
943
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
944
- negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
945
- To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
946
- micro-conditioning as explained in section 2.2 of
947
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
948
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
949
- negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
950
- To negatively condition the generation process based on a target image resolution. It should be as same
951
- as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
952
- [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
953
- information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
954
- callback_on_step_end (`Callable`, *optional*):
955
- A function that calls at the end of each denoising steps during the inference. The function is called
956
- with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
957
- callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
958
- `callback_on_step_end_tensor_inputs`.
959
- callback_on_step_end_tensor_inputs (`List`, *optional*):
960
- The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
961
- will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
962
- `._callback_tensor_inputs` attribute of your pipeline class.
963
 
964
  Examples:
965
 
@@ -968,23 +720,6 @@ class StableDiffusionXLPipeline(
968
  [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
969
  `tuple`. When returning a tuple, the first element is a list with the generated images.
970
  """
971
-
972
- callback = kwargs.pop("callback", None)
973
- callback_steps = kwargs.pop("callback_steps", None)
974
-
975
- if callback is not None:
976
- deprecate(
977
- "callback",
978
- "1.0.0",
979
- "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
980
- )
981
- if callback_steps is not None:
982
- deprecate(
983
- "callback_steps",
984
- "1.0.0",
985
- "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
986
- )
987
-
988
  # 0. Default height and width to unet
989
  height = height or self.default_sample_size * self.vae_scale_factor
990
  width = width or self.default_sample_size * self.vae_scale_factor
@@ -1005,15 +740,8 @@ class StableDiffusionXLPipeline(
1005
  negative_prompt_embeds,
1006
  pooled_prompt_embeds,
1007
  negative_pooled_prompt_embeds,
1008
- callback_on_step_end_tensor_inputs,
1009
  )
1010
 
1011
- self._guidance_scale = guidance_scale
1012
- self._guidance_rescale = guidance_rescale
1013
- self._clip_skip = clip_skip
1014
- self._cross_attention_kwargs = cross_attention_kwargs
1015
- self._denoising_end = denoising_end
1016
-
1017
  # 2. Define call parameters
1018
  if prompt is not None and isinstance(prompt, str):
1019
  batch_size = 1
@@ -1024,11 +752,15 @@ class StableDiffusionXLPipeline(
1024
 
1025
  device = self._execution_device
1026
 
 
 
 
 
 
1027
  # 3. Encode input prompt
1028
- lora_scale = (
1029
- self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
1030
  )
1031
-
1032
  (
1033
  prompt_embeds,
1034
  negative_prompt_embeds,
@@ -1039,19 +771,20 @@ class StableDiffusionXLPipeline(
1039
  prompt_2=prompt_2,
1040
  device=device,
1041
  num_images_per_prompt=num_images_per_prompt,
1042
- do_classifier_free_guidance=self.do_classifier_free_guidance,
1043
  negative_prompt=negative_prompt,
1044
  negative_prompt_2=negative_prompt_2,
1045
  prompt_embeds=prompt_embeds,
1046
  negative_prompt_embeds=negative_prompt_embeds,
1047
  pooled_prompt_embeds=pooled_prompt_embeds,
1048
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1049
- lora_scale=lora_scale,
1050
- clip_skip=self.clip_skip,
1051
  )
1052
 
1053
  # 4. Prepare timesteps
1054
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
 
 
1055
 
1056
  # 5. Prepare latent variables
1057
  num_channels_latents = self.unet.config.in_channels
@@ -1071,162 +804,165 @@ class StableDiffusionXLPipeline(
1071
 
1072
  # 7. Prepare added time ids & embeddings
1073
  add_text_embeds = pooled_prompt_embeds
1074
- if self.text_encoder_2 is None:
1075
- text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
1076
- else:
1077
- text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
1078
-
1079
  add_time_ids = self._get_add_time_ids(
1080
- original_size,
1081
- crops_coords_top_left,
1082
- target_size,
1083
- dtype=prompt_embeds.dtype,
1084
- text_encoder_projection_dim=text_encoder_projection_dim,
1085
  )
1086
- if negative_original_size is not None and negative_target_size is not None:
1087
- negative_add_time_ids = self._get_add_time_ids(
1088
- negative_original_size,
1089
- negative_crops_coords_top_left,
1090
- negative_target_size,
1091
- dtype=prompt_embeds.dtype,
1092
- text_encoder_projection_dim=text_encoder_projection_dim,
1093
- )
1094
- else:
1095
- negative_add_time_ids = add_time_ids
1096
 
1097
- if self.do_classifier_free_guidance:
1098
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
1099
  add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
1100
- add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
1101
 
1102
  prompt_embeds = prompt_embeds.to(device)
1103
  add_text_embeds = add_text_embeds.to(device)
1104
  add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
1105
 
1106
- if ip_adapter_image is not None:
1107
- image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
1108
- if self.do_classifier_free_guidance:
1109
- image_embeds = torch.cat([negative_image_embeds, image_embeds])
1110
- image_embeds = image_embeds.to(device)
1111
-
1112
  # 8. Denoising loop
1113
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1114
 
1115
- # 8.1 Apply denoising_end
1116
- if (
1117
- self.denoising_end is not None
1118
- and isinstance(self.denoising_end, float)
1119
- and self.denoising_end > 0
1120
- and self.denoising_end < 1
1121
- ):
1122
  discrete_timestep_cutoff = int(
1123
  round(
1124
  self.scheduler.config.num_train_timesteps
1125
- - (self.denoising_end * self.scheduler.config.num_train_timesteps)
1126
  )
1127
  )
1128
  num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
1129
  timesteps = timesteps[:num_inference_steps]
1130
 
1131
- # 9. Optionally get Guidance Scale Embedding
1132
- timestep_cond = None
1133
- if self.unet.config.time_cond_proj_dim is not None:
1134
- guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
1135
- timestep_cond = self.get_guidance_scale_embedding(
1136
- guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
1137
- ).to(device=device, dtype=latents.dtype)
1138
-
1139
- self._num_timesteps = len(timesteps)
1140
  latents = latents.half()
1141
  prompt_embeds = prompt_embeds.half()
1142
  with self.progress_bar(total=num_inference_steps) as progress_bar:
1143
  for i, t in enumerate(timesteps):
1144
- latents = self.update_loss(latents, i, t, prompt_embeds, timestep_cond, add_text_embeds, add_time_ids)
1145
 
1146
  # expand the latents if we are doing classifier free guidance
1147
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1148
 
1149
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1150
 
1151
  # predict the noise residual
1152
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1153
- if ip_adapter_image is not None:
1154
- added_cond_kwargs["image_embeds"] = image_embeds
1155
  noise_pred = self.unet(
1156
  latent_model_input,
1157
  t,
1158
  encoder_hidden_states=prompt_embeds,
1159
- timestep_cond=timestep_cond,
1160
- cross_attention_kwargs=self.cross_attention_kwargs,
1161
  added_cond_kwargs=added_cond_kwargs,
1162
  return_dict=False,
1163
  )[0]
1164
 
1165
  # perform guidance
1166
- if self.do_classifier_free_guidance:
1167
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1168
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1169
 
1170
- if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1171
  # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1172
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1173
 
1174
  # compute the previous noisy sample x_t -> x_t-1
1175
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1176
 
1177
- if callback_on_step_end is not None:
1178
- callback_kwargs = {}
1179
- for k in callback_on_step_end_tensor_inputs:
1180
- callback_kwargs[k] = locals()[k]
1181
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1182
-
1183
- latents = callback_outputs.pop("latents", latents)
1184
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1185
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1186
- add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1187
- negative_pooled_prompt_embeds = callback_outputs.pop(
1188
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1189
- )
1190
- add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1191
- negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1192
-
1193
  # call the callback, if provided
1194
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1195
  progress_bar.update()
1196
  if callback is not None and i % callback_steps == 0:
1197
- step_idx = i // getattr(self.scheduler, "order", 1)
1198
- callback(step_idx, t, latents)
1199
 
1200
- if XLA_AVAILABLE:
1201
- xm.mark_step()
 
 
1202
 
1203
  if not output_type == "latent":
1204
- # make sure the VAE is in float32 mode, as it overflows in float16
1205
- needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
1206
-
1207
- if needs_upcasting:
1208
- self.upcast_vae()
1209
- latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
1210
-
1211
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
1212
-
1213
- # cast back to fp16 if needed
1214
- if needs_upcasting:
1215
- self.vae.to(dtype=torch.float16)
1216
  else:
1217
  image = latents
 
1218
 
1219
- if not output_type == "latent":
1220
- # apply watermark if available
1221
- if self.watermark is not None:
1222
- image = self.watermark.apply_watermark(image)
1223
 
1224
- image = self.image_processor.postprocess(image, output_type=output_type)
1225
 
1226
- # Offload all models
1227
- self.maybe_free_model_hooks()
 
1228
 
1229
  if not return_dict:
1230
  return (image,)
1231
 
1232
  return StableDiffusionXLPipelineOutput(images=image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  # limitations under the License.
14
 
15
  import inspect
16
+ import os
17
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
18
 
19
  import torch
20
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
 
 
 
 
 
 
21
 
22
+ from diffusers.image_processor import VaeImageProcessor
23
+ from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
 
 
 
 
 
24
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
25
  from diffusers.models.attention_processor import (
26
  AttnProcessor2_0,
 
28
  LoRAXFormersAttnProcessor,
29
  XFormersAttnProcessor,
30
  )
 
31
  from diffusers.schedulers import KarrasDiffusionSchedulers
32
  from diffusers.utils import (
33
+ is_accelerate_available,
34
+ is_accelerate_version,
35
  is_invisible_watermark_available,
 
36
  logging,
37
+ randn_tensor,
38
  replace_example_docstring,
 
 
39
  )
40
+ from diffusers.pipeline_utils import DiffusionPipeline
41
+ from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput
 
42
 
43
 
44
  if is_invisible_watermark_available():
45
  from .watermark import StableDiffusionXLWatermarker
46
 
 
 
 
 
 
 
 
47
 
48
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
 
 
79
  return noise_cfg
80
 
81
 
82
+ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  r"""
84
  Pipeline for text-to-image generation using Stable Diffusion XL.
85
 
 
87
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
88
 
89
  In addition the pipeline inherits the following loading methods:
90
+ - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`]
91
  - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
92
 
93
  as well as the following saving methods:
94
+ - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`]
95
 
96
  Args:
97
  vae ([`AutoencoderKL`]):
 
116
  scheduler ([`SchedulerMixin`]):
117
  A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
118
  [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
 
 
 
 
 
 
 
119
  """
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  def __init__(
122
  self,
123
  vae: AutoencoderKL,
 
127
  tokenizer_2: CLIPTokenizer,
128
  unet: UNet2DConditionModel,
129
  scheduler: KarrasDiffusionSchedulers,
 
 
130
  force_zeros_for_empty_prompt: bool = True,
131
  add_watermarker: Optional[bool] = None,
132
  ):
 
140
  tokenizer_2=tokenizer_2,
141
  unet=unet,
142
  scheduler=scheduler,
 
 
143
  )
144
  self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
145
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
146
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
 
147
  self.default_sample_size = self.unet.config.sample_size
148
 
149
  add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available()
 
186
  """
187
  self.vae.disable_tiling()
188
 
189
+ def enable_model_cpu_offload(self, gpu_id=0):
190
+ r"""
191
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
192
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
193
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
194
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
195
+ """
196
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
197
+ from accelerate import cpu_offload_with_hook
198
+ else:
199
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
200
+
201
+ device = torch.device(f"cuda:{gpu_id}")
202
+
203
+ if self.device.type != "cpu":
204
+ self.to("cpu", silence_dtype_warnings=True)
205
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
206
+
207
+ model_sequence = (
208
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
209
+ )
210
+ model_sequence.extend([self.unet, self.vae])
211
+
212
+ hook = None
213
+ for cpu_offloaded_model in model_sequence:
214
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
215
+
216
+ # We'll offload the last model manually.
217
+ self.final_offload_hook = hook
218
+
219
  def encode_prompt(
220
  self,
221
  prompt: str,
 
230
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
231
  negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
232
  lora_scale: Optional[float] = None,
 
233
  ):
234
  r"""
235
  Encodes the prompt into text encoder hidden states.
 
269
  input argument.
270
  lora_scale (`float`, *optional*):
271
  A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
 
 
 
272
  """
273
  device = device or self._execution_device
274
 
275
  # set lora scale so that monkey patched LoRA
276
  # function of text encoder can correctly access it
277
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
278
  self._lora_scale = lora_scale
279
 
280
+ if prompt is not None and isinstance(prompt, str):
281
+ batch_size = 1
282
+ elif prompt is not None and isinstance(prompt, list):
 
 
 
 
 
 
 
 
 
 
 
 
 
283
  batch_size = len(prompt)
284
  else:
285
  batch_size = prompt_embeds.shape[0]
 
292
 
293
  if prompt_embeds is None:
294
  prompt_2 = prompt_2 or prompt
 
 
295
  # textual inversion: procecss multi-vector tokens if necessary
296
  prompt_embeds_list = []
297
  prompts = [prompt, prompt_2]
 
319
  f" {tokenizer.model_max_length} tokens: {removed_text}"
320
  )
321
 
322
+ prompt_embeds = text_encoder(
323
+ text_input_ids.to(device),
324
+ output_hidden_states=True,
325
+ )
326
 
327
  # We are only ALWAYS interested in the pooled output of the final text encoder
328
  pooled_prompt_embeds = prompt_embeds[0]
329
+ ### TODO: remove
330
+ null_text_inputs = tokenizer(
331
+ ['a realistic photo of an empty background'] * batch_size,
332
+ padding="max_length",
333
+ max_length=tokenizer.model_max_length,
334
+ truncation=True,
335
+ return_tensors="pt",
336
+ )
337
+ null_input_ids = null_text_inputs.input_ids
338
+ null_prompt_embeds = text_encoder(
339
+ null_input_ids.to(device),
340
+ output_hidden_states=True,
341
+ )
342
+ pooled_prompt_embeds = null_prompt_embeds[0]
343
+ ### TODO: remove
344
+ prompt_embeds = prompt_embeds.hidden_states[-2]
345
 
346
  prompt_embeds_list.append(prompt_embeds)
347
 
 
356
  negative_prompt = negative_prompt or ""
357
  negative_prompt_2 = negative_prompt_2 or negative_prompt
358
 
 
 
 
 
 
 
359
  uncond_tokens: List[str]
360
  if prompt is not None and type(prompt) is not type(negative_prompt):
361
  raise TypeError(
362
  f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
363
  f" {type(prompt)}."
364
  )
365
+ elif isinstance(negative_prompt, str):
366
+ uncond_tokens = [negative_prompt, negative_prompt_2]
367
  elif batch_size != len(negative_prompt):
368
  raise ValueError(
369
  f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
 
399
 
400
  negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
401
 
402
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
 
 
 
 
403
  bs_embed, seq_len, _ = prompt_embeds.shape
404
  # duplicate text embeddings for each generation per prompt, using mps friendly method
405
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
 
408
  if do_classifier_free_guidance:
409
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
410
  seq_len = negative_prompt_embeds.shape[1]
411
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
 
 
 
 
 
412
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
413
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
414
 
 
420
  bs_embed * num_images_per_prompt, -1
421
  )
422
 
 
 
 
 
 
 
 
 
 
 
423
  return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
426
  def prepare_extra_step_kwargs(self, generator, eta):
427
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
 
453
  negative_prompt_embeds=None,
454
  pooled_prompt_embeds=None,
455
  negative_pooled_prompt_embeds=None,
 
456
  ):
457
  if height % 8 != 0 or width % 8 != 0:
458
  raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
459
 
460
+ if (callback_steps is None) or (
461
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
462
+ ):
463
  raise ValueError(
464
  f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
465
  f" {type(callback_steps)}."
466
  )
467
 
 
 
 
 
 
 
 
468
  if prompt is not None and prompt_embeds is not None:
469
  raise ValueError(
470
  f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
 
531
  latents = latents * self.scheduler.init_noise_sigma
532
  return latents
533
 
534
+ def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype):
 
 
535
  add_time_ids = list(original_size + crops_coords_top_left + target_size)
536
 
537
  passed_add_embed_dim = (
538
+ self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim
539
  )
540
  expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
541
 
 
567
  self.vae.decoder.conv_in.to(dtype)
568
  self.vae.decoder.mid_block.to(dtype)
569
 
570
+ def update_loss(self, latents, i, t, prompt_embeds, cross_attention_kwargs, add_text_embeds, add_time_ids):
571
  def forward_pass(latent_model_input):
572
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
573
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
 
575
  latent_model_input,
576
  t,
577
  encoder_hidden_states=prompt_embeds,
578
+ cross_attention_kwargs=cross_attention_kwargs,
 
579
  added_cond_kwargs=added_cond_kwargs,
580
  return_dict=False,
581
  )
 
583
 
584
  return self.editor.update_loss(forward_pass, latents, i)
585
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
586
  @torch.no_grad()
587
  @replace_example_docstring(EXAMPLE_DOC_STRING)
588
  def __call__(
 
592
  height: Optional[int] = None,
593
  width: Optional[int] = None,
594
  num_inference_steps: int = 50,
 
595
  denoising_end: Optional[float] = None,
596
  guidance_scale: float = 5.0,
597
  negative_prompt: Optional[Union[str, List[str]]] = None,
 
604
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
605
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
606
  negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
 
607
  output_type: Optional[str] = "pil",
608
  return_dict: bool = True,
609
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
610
+ callback_steps: int = 1,
611
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
612
  guidance_rescale: float = 0.0,
613
  original_size: Optional[Tuple[int, int]] = None,
614
  crops_coords_top_left: Tuple[int, int] = (0, 0),
615
  target_size: Optional[Tuple[int, int]] = None,
 
 
 
 
 
 
 
616
  ):
617
  r"""
618
  Function invoked when calling the pipeline for generation.
 
625
  The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
626
  used in both text-encoders
627
  height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
628
+ The height in pixels of the generated image.
 
 
 
629
  width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
630
+ The width in pixels of the generated image.
 
 
 
631
  num_inference_steps (`int`, *optional*, defaults to 50):
632
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
633
  expense of slower inference.
 
 
 
 
634
  denoising_end (`float`, *optional*):
635
  When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
636
  completed before it is intentionally prematurely terminated. As a result, the returned sample will
 
677
  Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
678
  weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
679
  input argument.
 
680
  output_type (`str`, *optional*, defaults to `"pil"`):
681
  The output format of the generate image. Choose between
682
  [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
683
  return_dict (`bool`, *optional*, defaults to `True`):
684
  Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
685
  of a plain tuple.
686
+ callback (`Callable`, *optional*):
687
+ A function that will be called every `callback_steps` steps during inference. The function will be
688
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
689
+ callback_steps (`int`, *optional*, defaults to 1):
690
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
691
+ called at every step.
692
  cross_attention_kwargs (`dict`, *optional*):
693
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
694
  `self.processor` in
695
  [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
696
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
697
  Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
698
  Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
699
  [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
700
  Guidance rescale factor should fix overexposure when using zero terminal SNR.
701
  original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
702
  If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
703
+ `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
704
  explained in section 2.2 of
705
  [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
706
  crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
 
710
  [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
711
  target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
712
  For most cases, `target_size` should be set to the desired height and width of the generated image. If
713
+ not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in
714
  section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
715
 
716
  Examples:
717
 
 
720
  [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
721
  `tuple`. When returning a tuple, the first element is a list with the generated images.
722
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
723
  # 0. Default height and width to unet
724
  height = height or self.default_sample_size * self.vae_scale_factor
725
  width = width or self.default_sample_size * self.vae_scale_factor
 
740
  negative_prompt_embeds,
741
  pooled_prompt_embeds,
742
  negative_pooled_prompt_embeds,
 
743
  )
744
 
 
 
 
 
 
 
745
  # 2. Define call parameters
746
  if prompt is not None and isinstance(prompt, str):
747
  batch_size = 1
 
752
 
753
  device = self._execution_device
754
 
755
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
756
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
757
+ # corresponds to doing no classifier free guidance.
758
+ do_classifier_free_guidance = guidance_scale > 1.0
759
+
760
  # 3. Encode input prompt
761
+ text_encoder_lora_scale = (
762
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
763
  )
 
764
  (
765
  prompt_embeds,
766
  negative_prompt_embeds,
 
771
  prompt_2=prompt_2,
772
  device=device,
773
  num_images_per_prompt=num_images_per_prompt,
774
+ do_classifier_free_guidance=do_classifier_free_guidance,
775
  negative_prompt=negative_prompt,
776
  negative_prompt_2=negative_prompt_2,
777
  prompt_embeds=prompt_embeds,
778
  negative_prompt_embeds=negative_prompt_embeds,
779
  pooled_prompt_embeds=pooled_prompt_embeds,
780
  negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
781
+ lora_scale=text_encoder_lora_scale,
 
782
  )
783
 
784
  # 4. Prepare timesteps
785
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
786
+
787
+ timesteps = self.scheduler.timesteps
788
 
789
  # 5. Prepare latent variables
790
  num_channels_latents = self.unet.config.in_channels
 
804
 
805
  # 7. Prepare added time ids & embeddings
806
  add_text_embeds = pooled_prompt_embeds
 
 
 
 
 
807
  add_time_ids = self._get_add_time_ids(
808
+ original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype
 
 
 
 
809
  )
 
 
 
 
 
 
 
 
 
 
810
 
811
+ if do_classifier_free_guidance:
812
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
813
  add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
814
+ add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0)
815
 
816
  prompt_embeds = prompt_embeds.to(device)
817
  add_text_embeds = add_text_embeds.to(device)
818
  add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
819
 
 
 
 
 
 
 
820
  # 8. Denoising loop
821
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
822
 
823
+ # 7.1 Apply denoising_end
824
+ if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1:
 
 
 
 
 
825
  discrete_timestep_cutoff = int(
826
  round(
827
  self.scheduler.config.num_train_timesteps
828
+ - (denoising_end * self.scheduler.config.num_train_timesteps)
829
  )
830
  )
831
  num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
832
  timesteps = timesteps[:num_inference_steps]
833
 
 
 
 
 
 
 
 
 
 
834
  latents = latents.half()
835
  prompt_embeds = prompt_embeds.half()
836
  with self.progress_bar(total=num_inference_steps) as progress_bar:
837
  for i, t in enumerate(timesteps):
838
+ latents = self.update_loss(latents, i, t, prompt_embeds, cross_attention_kwargs, add_text_embeds, add_time_ids)
839
 
840
  # expand the latents if we are doing classifier free guidance
841
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
842
 
843
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
844
 
845
  # predict the noise residual
846
  added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
 
 
847
  noise_pred = self.unet(
848
  latent_model_input,
849
  t,
850
  encoder_hidden_states=prompt_embeds,
851
+ cross_attention_kwargs=cross_attention_kwargs,
 
852
  added_cond_kwargs=added_cond_kwargs,
853
  return_dict=False,
854
  )[0]
855
 
856
  # perform guidance
857
+ if do_classifier_free_guidance:
858
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
859
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
860
 
861
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
862
  # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
863
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
864
 
865
  # compute the previous noisy sample x_t -> x_t-1
866
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
867
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
868
  # call the callback, if provided
869
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
870
  progress_bar.update()
871
  if callback is not None and i % callback_steps == 0:
872
+ callback(i, t, latents)
 
873
 
874
+ # make sure the VAE is in float32 mode, as it overflows in float16
875
+ if self.vae.dtype == torch.float16 and self.vae.config.force_upcast:
876
+ self.upcast_vae()
877
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
878
 
879
  if not output_type == "latent":
 
 
 
 
 
 
 
880
  image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
 
 
 
 
881
  else:
882
  image = latents
883
+ return StableDiffusionXLPipelineOutput(images=image)
884
 
885
+ # apply watermark if available
886
+ if self.watermark is not None:
887
+ image = self.watermark.apply_watermark(image)
 
888
 
889
+ image = self.image_processor.postprocess(image, output_type=output_type)
890
 
891
+ # Offload last model to CPU
892
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
893
+ self.final_offload_hook.offload()
894
 
895
  if not return_dict:
896
  return (image,)
897
 
898
  return StableDiffusionXLPipelineOutput(images=image)
899
+
900
+ # Overrride to properly handle the loading and unloading of the additional text encoder.
901
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
902
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
903
+ # it here explicitly to be able to tell that it's coming from an SDXL
904
+ # pipeline.
905
+ state_dict, network_alphas = self.lora_state_dict(
906
+ pretrained_model_name_or_path_or_dict,
907
+ unet_config=self.unet.config,
908
+ **kwargs,
909
+ )
910
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
911
+
912
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
913
+ if len(text_encoder_state_dict) > 0:
914
+ self.load_lora_into_text_encoder(
915
+ text_encoder_state_dict,
916
+ network_alphas=network_alphas,
917
+ text_encoder=self.text_encoder,
918
+ prefix="text_encoder",
919
+ lora_scale=self.lora_scale,
920
+ )
921
+
922
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
923
+ if len(text_encoder_2_state_dict) > 0:
924
+ self.load_lora_into_text_encoder(
925
+ text_encoder_2_state_dict,
926
+ network_alphas=network_alphas,
927
+ text_encoder=self.text_encoder_2,
928
+ prefix="text_encoder_2",
929
+ lora_scale=self.lora_scale,
930
+ )
931
+
932
+ @classmethod
933
+ def save_lora_weights(
934
+ self,
935
+ save_directory: Union[str, os.PathLike],
936
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
937
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
938
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
939
+ is_main_process: bool = True,
940
+ weight_name: str = None,
941
+ save_function: Callable = None,
942
+ safe_serialization: bool = True,
943
+ ):
944
+ state_dict = {}
945
+
946
+ def pack_weights(layers, prefix):
947
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
948
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
949
+ return layers_state_dict
950
+
951
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
952
+
953
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
954
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
955
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
956
+
957
+ self.write_lora_layers(
958
+ state_dict=state_dict,
959
+ save_directory=save_directory,
960
+ is_main_process=is_main_process,
961
+ weight_name=weight_name,
962
+ save_function=save_function,
963
+ safe_serialization=safe_serialization,
964
+ )
965
+
966
+ def _remove_text_encoder_monkey_patch(self):
967
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
968
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
requirements.txt CHANGED
@@ -1,5 +1,5 @@
1
  accelerate==0.25.0
2
- diffusers==0.24.0
3
  einops==0.6.1
4
  lightning-utilities==0.9.0
5
  matplotlib==3.7.3
 
1
  accelerate==0.25.0
2
+ diffusers==0.20.0
3
  einops==0.6.1
4
  lightning-utilities==0.9.0
5
  matplotlib==3.7.3