AlanB commited on
Commit
900bfd5
1 Parent(s): 46fd08c

Updates from latest commit, looks like good fixes

Browse files

https://github.com/huggingface/diffusers/commit/618260409f5c0ac6b6cbf79ed21ef51ba57db1c7

Files changed (1) hide show
  1. pipeline.py +609 -21
pipeline.py CHANGED
@@ -16,6 +16,7 @@
16
 
17
  import ast
18
  import gc
 
19
  import math
20
  import warnings
21
  from collections.abc import Iterable
@@ -23,16 +24,29 @@ from typing import Any, Callable, Dict, List, Optional, Union
23
 
24
  import torch
25
  import torch.nn.functional as F
 
26
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
27
 
 
 
 
28
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
29
  from diffusers.models.attention import Attention, GatedSelfAttentionDense
30
  from diffusers.models.attention_processor import AttnProcessor2_0
31
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
 
32
  from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
33
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
34
  from diffusers.schedulers import KarrasDiffusionSchedulers
35
- from diffusers.utils import logging, replace_example_docstring, USE_PEFT_BACKEND
 
 
 
 
 
 
 
 
36
 
37
 
38
  EXAMPLE_DOC_STRING = """
@@ -44,6 +58,7 @@ EXAMPLE_DOC_STRING = """
44
  >>> pipe = DiffusionPipeline.from_pretrained(
45
  ... "longlian/lmd_plus",
46
  ... custom_pipeline="llm_grounded_diffusion",
 
47
  ... variant="fp16", torch_dtype=torch.float16
48
  ... )
49
  >>> pipe.enable_model_cpu_offload()
@@ -96,7 +111,12 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
96
 
97
  # All keys in Stable Diffusion models: [('down', 0, 0, 0), ('down', 0, 1, 0), ('down', 1, 0, 0), ('down', 1, 1, 0), ('down', 2, 0, 0), ('down', 2, 1, 0), ('mid', 0, 0, 0), ('up', 1, 0, 0), ('up', 1, 1, 0), ('up', 1, 2, 0), ('up', 2, 0, 0), ('up', 2, 1, 0), ('up', 2, 2, 0), ('up', 3, 0, 0), ('up', 3, 1, 0), ('up', 3, 2, 0)]
98
  # Note that the first up block is `UpBlock2D` rather than `CrossAttnUpBlock2D` and does not have attention. The last index is always 0 in our case since we have one `BasicTransformerBlock` in each `Transformer2DModel`.
99
- DEFAULT_GUIDANCE_ATTN_KEYS = [("mid", 0, 0, 0), ("up", 1, 0, 0), ("up", 1, 1, 0), ("up", 1, 2, 0)]
 
 
 
 
 
100
 
101
 
102
  def convert_attn_keys(key):
@@ -126,7 +146,15 @@ def scale_proportion(obj_box, H, W):
126
 
127
  # Adapted from the parent class `AttnProcessor2_0`
128
  class AttnProcessorWithHook(AttnProcessor2_0):
129
- def __init__(self, attn_processor_key, hidden_size, cross_attention_dim, hook=None, fast_attn=True, enabled=True):
 
 
 
 
 
 
 
 
130
  super().__init__()
131
  self.attn_processor_key = attn_processor_key
132
  self.hidden_size = hidden_size
@@ -187,7 +215,13 @@ class AttnProcessorWithHook(AttnProcessor2_0):
187
 
188
  if self.hook is not None and self.enabled:
189
  # Call the hook with query, key, value, and attention maps
190
- self.hook(self.attn_processor_key, query_batch_dim, key_batch_dim, value_batch_dim, attention_probs)
 
 
 
 
 
 
191
 
192
  if self.fast_attn:
193
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
@@ -203,7 +237,12 @@ class AttnProcessorWithHook(AttnProcessor2_0):
203
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
204
  # TODO: add support for attn.scale when we move to Torch 2.1
205
  hidden_states = F.scaled_dot_product_attention(
206
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
 
 
 
 
 
207
  )
208
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
209
  hidden_states = hidden_states.to(query.dtype)
@@ -227,7 +266,9 @@ class AttnProcessorWithHook(AttnProcessor2_0):
227
  return hidden_states
228
 
229
 
230
- class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
 
 
231
  r"""
232
  Pipeline for layout-grounded text-to-image generation using LLM-grounded Diffusion (LMD+): https://arxiv.org/pdf/2305.13655.pdf.
233
 
@@ -258,6 +299,11 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
258
  Whether a safety checker is needed for this pipeline.
259
  """
260
 
 
 
 
 
 
261
  objects_text = "Objects: "
262
  bg_prompt_text = "Background prompt: "
263
  bg_prompt_text_no_trailing_space = bg_prompt_text.rstrip()
@@ -276,18 +322,88 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
276
  image_encoder: CLIPVisionModelWithProjection = None,
277
  requires_safety_checker: bool = True,
278
  ):
279
- super().__init__(
280
- vae,
281
- text_encoder,
282
- tokenizer,
283
- unet,
284
- scheduler,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
  safety_checker=safety_checker,
286
  feature_extractor=feature_extractor,
287
  image_encoder=image_encoder,
288
- requires_safety_checker=requires_safety_checker,
289
  )
 
 
 
290
 
 
291
  self.register_attn_hooks(unet)
292
  self._saved_attn = None
293
 
@@ -474,7 +590,14 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
474
 
475
  return token_map
476
 
477
- def get_phrase_indices(self, prompt, phrases, token_map=None, add_suffix_if_not_found=False, verbose=False):
 
 
 
 
 
 
 
478
  for obj in phrases:
479
  # Suffix the prompt with object name for attention guidance if object is not in the prompt, using "|" to separate the prompt and the suffix
480
  if obj not in prompt:
@@ -495,7 +618,14 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
495
  phrase_token_map_str = " ".join(phrase_token_map)
496
 
497
  if verbose:
498
- logger.info("Full str:", token_map_str, "Substr:", phrase_token_map_str, "Phrase:", phrases)
 
 
 
 
 
 
 
499
 
500
  # Count the number of token before substr
501
  # The substring comes with a trailing space that needs to be removed by minus one in the index.
@@ -562,7 +692,15 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
562
 
563
  return loss
564
 
565
- def compute_ca_loss(self, saved_attn, bboxes, phrase_indices, guidance_attn_keys, verbose=False, **kwargs):
 
 
 
 
 
 
 
 
566
  """
567
  The `saved_attn` is supposed to be passed to `save_attn_to_dict` in `cross_attention_kwargs` prior to computing ths loss.
568
  `AttnProcessor` will put attention maps into the `save_attn_to_dict`.
@@ -615,6 +753,7 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
615
  latents: Optional[torch.FloatTensor] = None,
616
  prompt_embeds: Optional[torch.FloatTensor] = None,
617
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
 
618
  output_type: Optional[str] = "pil",
619
  return_dict: bool = True,
620
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -672,6 +811,7 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
672
  negative_prompt_embeds (`torch.FloatTensor`, *optional*):
673
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
674
  not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
 
675
  output_type (`str`, *optional*, defaults to `"pil"`):
676
  The output format of the generated image. Choose between `PIL.Image` or `np.array`.
677
  return_dict (`bool`, *optional*, defaults to `True`):
@@ -734,9 +874,10 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
734
  phrase_indices = []
735
  prompt_parsed = []
736
  for prompt_item in prompt:
737
- phrase_indices_parsed_item, prompt_parsed_item = self.get_phrase_indices(
738
- prompt_item, add_suffix_if_not_found=True
739
- )
 
740
  phrase_indices.append(phrase_indices_parsed_item)
741
  prompt_parsed.append(prompt_parsed_item)
742
  prompt = prompt_parsed
@@ -769,6 +910,11 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
769
  if do_classifier_free_guidance:
770
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
771
 
 
 
 
 
 
772
  # 4. Prepare timesteps
773
  self.scheduler.set_timesteps(num_inference_steps, device=device)
774
  timesteps = self.scheduler.timesteps
@@ -811,7 +957,10 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
811
  if n_objs:
812
  cond_boxes[:n_objs] = torch.tensor(boxes)
813
  text_embeddings = torch.zeros(
814
- max_objs, self.unet.config.cross_attention_dim, device=device, dtype=self.text_encoder.dtype
 
 
 
815
  )
816
  if n_objs:
817
  text_embeddings[:n_objs] = _text_embeddings
@@ -843,6 +992,9 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
843
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
844
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
845
 
 
 
 
846
  loss_attn = torch.tensor(10000.0)
847
 
848
  # 7. Denoising loop
@@ -879,6 +1031,7 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
879
  t,
880
  encoder_hidden_states=prompt_embeds,
881
  cross_attention_kwargs=cross_attention_kwargs,
 
882
  ).sample
883
 
884
  # perform guidance
@@ -1023,3 +1176,438 @@ class LLMGroundedDiffusionPipeline(StableDiffusionPipeline):
1023
  self.enable_attn_hook(enabled=False)
1024
 
1025
  return latents, loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  import ast
18
  import gc
19
+ import inspect
20
  import math
21
  import warnings
22
  from collections.abc import Iterable
 
24
 
25
  import torch
26
  import torch.nn.functional as F
27
+ from packaging import version
28
  from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
29
 
30
+ from diffusers.configuration_utils import FrozenDict
31
+ from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
32
+ from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
33
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
34
  from diffusers.models.attention import Attention, GatedSelfAttentionDense
35
  from diffusers.models.attention_processor import AttnProcessor2_0
36
+ from diffusers.models.lora import adjust_lora_scale_text_encoder
37
+ from diffusers.pipelines import DiffusionPipeline
38
  from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
39
  from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
40
  from diffusers.schedulers import KarrasDiffusionSchedulers
41
+ from diffusers.utils import (
42
+ USE_PEFT_BACKEND,
43
+ deprecate,
44
+ logging,
45
+ replace_example_docstring,
46
+ scale_lora_layers,
47
+ unscale_lora_layers,
48
+ )
49
+ from diffusers.utils.torch_utils import randn_tensor
50
 
51
 
52
  EXAMPLE_DOC_STRING = """
 
58
  >>> pipe = DiffusionPipeline.from_pretrained(
59
  ... "longlian/lmd_plus",
60
  ... custom_pipeline="llm_grounded_diffusion",
61
+ ... custom_revision="main",
62
  ... variant="fp16", torch_dtype=torch.float16
63
  ... )
64
  >>> pipe.enable_model_cpu_offload()
 
111
 
112
  # All keys in Stable Diffusion models: [('down', 0, 0, 0), ('down', 0, 1, 0), ('down', 1, 0, 0), ('down', 1, 1, 0), ('down', 2, 0, 0), ('down', 2, 1, 0), ('mid', 0, 0, 0), ('up', 1, 0, 0), ('up', 1, 1, 0), ('up', 1, 2, 0), ('up', 2, 0, 0), ('up', 2, 1, 0), ('up', 2, 2, 0), ('up', 3, 0, 0), ('up', 3, 1, 0), ('up', 3, 2, 0)]
113
  # Note that the first up block is `UpBlock2D` rather than `CrossAttnUpBlock2D` and does not have attention. The last index is always 0 in our case since we have one `BasicTransformerBlock` in each `Transformer2DModel`.
114
+ DEFAULT_GUIDANCE_ATTN_KEYS = [
115
+ ("mid", 0, 0, 0),
116
+ ("up", 1, 0, 0),
117
+ ("up", 1, 1, 0),
118
+ ("up", 1, 2, 0),
119
+ ]
120
 
121
 
122
  def convert_attn_keys(key):
 
146
 
147
  # Adapted from the parent class `AttnProcessor2_0`
148
  class AttnProcessorWithHook(AttnProcessor2_0):
149
+ def __init__(
150
+ self,
151
+ attn_processor_key,
152
+ hidden_size,
153
+ cross_attention_dim,
154
+ hook=None,
155
+ fast_attn=True,
156
+ enabled=True,
157
+ ):
158
  super().__init__()
159
  self.attn_processor_key = attn_processor_key
160
  self.hidden_size = hidden_size
 
215
 
216
  if self.hook is not None and self.enabled:
217
  # Call the hook with query, key, value, and attention maps
218
+ self.hook(
219
+ self.attn_processor_key,
220
+ query_batch_dim,
221
+ key_batch_dim,
222
+ value_batch_dim,
223
+ attention_probs,
224
+ )
225
 
226
  if self.fast_attn:
227
  query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
 
237
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
238
  # TODO: add support for attn.scale when we move to Torch 2.1
239
  hidden_states = F.scaled_dot_product_attention(
240
+ query,
241
+ key,
242
+ value,
243
+ attn_mask=attention_mask,
244
+ dropout_p=0.0,
245
+ is_causal=False,
246
  )
247
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
248
  hidden_states = hidden_states.to(query.dtype)
 
266
  return hidden_states
267
 
268
 
269
+ class LLMGroundedDiffusionPipeline(
270
+ DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, IPAdapterMixin, FromSingleFileMixin
271
+ ):
272
  r"""
273
  Pipeline for layout-grounded text-to-image generation using LLM-grounded Diffusion (LMD+): https://arxiv.org/pdf/2305.13655.pdf.
274
 
 
299
  Whether a safety checker is needed for this pipeline.
300
  """
301
 
302
+ model_cpu_offload_seq = "text_encoder->unet->vae"
303
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
304
+ _exclude_from_cpu_offload = ["safety_checker"]
305
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
306
+
307
  objects_text = "Objects: "
308
  bg_prompt_text = "Background prompt: "
309
  bg_prompt_text_no_trailing_space = bg_prompt_text.rstrip()
 
322
  image_encoder: CLIPVisionModelWithProjection = None,
323
  requires_safety_checker: bool = True,
324
  ):
325
+ # This is copied from StableDiffusionPipeline, with hook initizations for LMD+.
326
+ super().__init__()
327
+
328
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
329
+ deprecation_message = (
330
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
331
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
332
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
333
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
334
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
335
+ " file"
336
+ )
337
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
338
+ new_config = dict(scheduler.config)
339
+ new_config["steps_offset"] = 1
340
+ scheduler._internal_dict = FrozenDict(new_config)
341
+
342
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
343
+ deprecation_message = (
344
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
345
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
346
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
347
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
348
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
349
+ )
350
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
351
+ new_config = dict(scheduler.config)
352
+ new_config["clip_sample"] = False
353
+ scheduler._internal_dict = FrozenDict(new_config)
354
+
355
+ if safety_checker is None and requires_safety_checker:
356
+ logger.warning(
357
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
358
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
359
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
360
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
361
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
362
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
363
+ )
364
+
365
+ if safety_checker is not None and feature_extractor is None:
366
+ raise ValueError(
367
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
368
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
369
+ )
370
+
371
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
372
+ version.parse(unet.config._diffusers_version).base_version
373
+ ) < version.parse("0.9.0.dev0")
374
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
375
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
376
+ deprecation_message = (
377
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
378
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
379
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
380
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
381
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
382
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
383
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
384
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
385
+ " the `unet/config.json` file"
386
+ )
387
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
388
+ new_config = dict(unet.config)
389
+ new_config["sample_size"] = 64
390
+ unet._internal_dict = FrozenDict(new_config)
391
+
392
+ self.register_modules(
393
+ vae=vae,
394
+ text_encoder=text_encoder,
395
+ tokenizer=tokenizer,
396
+ unet=unet,
397
+ scheduler=scheduler,
398
  safety_checker=safety_checker,
399
  feature_extractor=feature_extractor,
400
  image_encoder=image_encoder,
 
401
  )
402
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
403
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
404
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
405
 
406
+ # Initialize the attention hooks for LLM-grounded Diffusion
407
  self.register_attn_hooks(unet)
408
  self._saved_attn = None
409
 
 
590
 
591
  return token_map
592
 
593
+ def get_phrase_indices(
594
+ self,
595
+ prompt,
596
+ phrases,
597
+ token_map=None,
598
+ add_suffix_if_not_found=False,
599
+ verbose=False,
600
+ ):
601
  for obj in phrases:
602
  # Suffix the prompt with object name for attention guidance if object is not in the prompt, using "|" to separate the prompt and the suffix
603
  if obj not in prompt:
 
618
  phrase_token_map_str = " ".join(phrase_token_map)
619
 
620
  if verbose:
621
+ logger.info(
622
+ "Full str:",
623
+ token_map_str,
624
+ "Substr:",
625
+ phrase_token_map_str,
626
+ "Phrase:",
627
+ phrases,
628
+ )
629
 
630
  # Count the number of token before substr
631
  # The substring comes with a trailing space that needs to be removed by minus one in the index.
 
692
 
693
  return loss
694
 
695
+ def compute_ca_loss(
696
+ self,
697
+ saved_attn,
698
+ bboxes,
699
+ phrase_indices,
700
+ guidance_attn_keys,
701
+ verbose=False,
702
+ **kwargs,
703
+ ):
704
  """
705
  The `saved_attn` is supposed to be passed to `save_attn_to_dict` in `cross_attention_kwargs` prior to computing ths loss.
706
  `AttnProcessor` will put attention maps into the `save_attn_to_dict`.
 
753
  latents: Optional[torch.FloatTensor] = None,
754
  prompt_embeds: Optional[torch.FloatTensor] = None,
755
  negative_prompt_embeds: Optional[torch.FloatTensor] = None,
756
+ ip_adapter_image: Optional[PipelineImageInput] = None,
757
  output_type: Optional[str] = "pil",
758
  return_dict: bool = True,
759
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
 
811
  negative_prompt_embeds (`torch.FloatTensor`, *optional*):
812
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
813
  not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
814
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
815
  output_type (`str`, *optional*, defaults to `"pil"`):
816
  The output format of the generated image. Choose between `PIL.Image` or `np.array`.
817
  return_dict (`bool`, *optional*, defaults to `True`):
 
874
  phrase_indices = []
875
  prompt_parsed = []
876
  for prompt_item in prompt:
877
+ (
878
+ phrase_indices_parsed_item,
879
+ prompt_parsed_item,
880
+ ) = self.get_phrase_indices(prompt_item, add_suffix_if_not_found=True)
881
  phrase_indices.append(phrase_indices_parsed_item)
882
  prompt_parsed.append(prompt_parsed_item)
883
  prompt = prompt_parsed
 
910
  if do_classifier_free_guidance:
911
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
912
 
913
+ if ip_adapter_image is not None:
914
+ image_embeds, negative_image_embeds = self.encode_image(ip_adapter_image, device, num_images_per_prompt)
915
+ if self.do_classifier_free_guidance:
916
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
917
+
918
  # 4. Prepare timesteps
919
  self.scheduler.set_timesteps(num_inference_steps, device=device)
920
  timesteps = self.scheduler.timesteps
 
957
  if n_objs:
958
  cond_boxes[:n_objs] = torch.tensor(boxes)
959
  text_embeddings = torch.zeros(
960
+ max_objs,
961
+ self.unet.config.cross_attention_dim,
962
+ device=device,
963
+ dtype=self.text_encoder.dtype,
964
  )
965
  if n_objs:
966
  text_embeddings[:n_objs] = _text_embeddings
 
992
  # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
993
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
994
 
995
+ # 6.1 Add image embeds for IP-Adapter
996
+ added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
997
+
998
  loss_attn = torch.tensor(10000.0)
999
 
1000
  # 7. Denoising loop
 
1031
  t,
1032
  encoder_hidden_states=prompt_embeds,
1033
  cross_attention_kwargs=cross_attention_kwargs,
1034
+ added_cond_kwargs=added_cond_kwargs,
1035
  ).sample
1036
 
1037
  # perform guidance
 
1176
  self.enable_attn_hook(enabled=False)
1177
 
1178
  return latents, loss
1179
+
1180
+ # Below are methods copied from StableDiffusionPipeline
1181
+ # The design choice of not inheriting from StableDiffusionPipeline is discussed here: https://github.com/huggingface/diffusers/pull/5993#issuecomment-1834258517
1182
+
1183
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
1184
+ def enable_vae_slicing(self):
1185
+ r"""
1186
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1187
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1188
+ """
1189
+ self.vae.enable_slicing()
1190
+
1191
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
1192
+ def disable_vae_slicing(self):
1193
+ r"""
1194
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
1195
+ computing decoding in one step.
1196
+ """
1197
+ self.vae.disable_slicing()
1198
+
1199
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
1200
+ def enable_vae_tiling(self):
1201
+ r"""
1202
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1203
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1204
+ processing larger images.
1205
+ """
1206
+ self.vae.enable_tiling()
1207
+
1208
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
1209
+ def disable_vae_tiling(self):
1210
+ r"""
1211
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
1212
+ computing decoding in one step.
1213
+ """
1214
+ self.vae.disable_tiling()
1215
+
1216
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
1217
+ def _encode_prompt(
1218
+ self,
1219
+ prompt,
1220
+ device,
1221
+ num_images_per_prompt,
1222
+ do_classifier_free_guidance,
1223
+ negative_prompt=None,
1224
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1225
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1226
+ lora_scale: Optional[float] = None,
1227
+ **kwargs,
1228
+ ):
1229
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
1230
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
1231
+
1232
+ prompt_embeds_tuple = self.encode_prompt(
1233
+ prompt=prompt,
1234
+ device=device,
1235
+ num_images_per_prompt=num_images_per_prompt,
1236
+ do_classifier_free_guidance=do_classifier_free_guidance,
1237
+ negative_prompt=negative_prompt,
1238
+ prompt_embeds=prompt_embeds,
1239
+ negative_prompt_embeds=negative_prompt_embeds,
1240
+ lora_scale=lora_scale,
1241
+ **kwargs,
1242
+ )
1243
+
1244
+ # concatenate for backwards comp
1245
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
1246
+
1247
+ return prompt_embeds
1248
+
1249
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
1250
+ def encode_prompt(
1251
+ self,
1252
+ prompt,
1253
+ device,
1254
+ num_images_per_prompt,
1255
+ do_classifier_free_guidance,
1256
+ negative_prompt=None,
1257
+ prompt_embeds: Optional[torch.FloatTensor] = None,
1258
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
1259
+ lora_scale: Optional[float] = None,
1260
+ clip_skip: Optional[int] = None,
1261
+ ):
1262
+ r"""
1263
+ Encodes the prompt into text encoder hidden states.
1264
+
1265
+ Args:
1266
+ prompt (`str` or `List[str]`, *optional*):
1267
+ prompt to be encoded
1268
+ device: (`torch.device`):
1269
+ torch device
1270
+ num_images_per_prompt (`int`):
1271
+ number of images that should be generated per prompt
1272
+ do_classifier_free_guidance (`bool`):
1273
+ whether to use classifier free guidance or not
1274
+ negative_prompt (`str` or `List[str]`, *optional*):
1275
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
1276
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
1277
+ less than `1`).
1278
+ prompt_embeds (`torch.FloatTensor`, *optional*):
1279
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
1280
+ provided, text embeddings will be generated from `prompt` input argument.
1281
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
1282
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
1283
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
1284
+ argument.
1285
+ lora_scale (`float`, *optional*):
1286
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
1287
+ clip_skip (`int`, *optional*):
1288
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
1289
+ the output of the pre-final layer will be used for computing the prompt embeddings.
1290
+ """
1291
+ # set lora scale so that monkey patched LoRA
1292
+ # function of text encoder can correctly access it
1293
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
1294
+ self._lora_scale = lora_scale
1295
+
1296
+ # dynamically adjust the LoRA scale
1297
+ if not USE_PEFT_BACKEND:
1298
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
1299
+ else:
1300
+ scale_lora_layers(self.text_encoder, lora_scale)
1301
+
1302
+ if prompt is not None and isinstance(prompt, str):
1303
+ batch_size = 1
1304
+ elif prompt is not None and isinstance(prompt, list):
1305
+ batch_size = len(prompt)
1306
+ else:
1307
+ batch_size = prompt_embeds.shape[0]
1308
+
1309
+ if prompt_embeds is None:
1310
+ # textual inversion: procecss multi-vector tokens if necessary
1311
+ if isinstance(self, TextualInversionLoaderMixin):
1312
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
1313
+
1314
+ text_inputs = self.tokenizer(
1315
+ prompt,
1316
+ padding="max_length",
1317
+ max_length=self.tokenizer.model_max_length,
1318
+ truncation=True,
1319
+ return_tensors="pt",
1320
+ )
1321
+ text_input_ids = text_inputs.input_ids
1322
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
1323
+
1324
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
1325
+ text_input_ids, untruncated_ids
1326
+ ):
1327
+ removed_text = self.tokenizer.batch_decode(
1328
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
1329
+ )
1330
+ logger.warning(
1331
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
1332
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
1333
+ )
1334
+
1335
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
1336
+ attention_mask = text_inputs.attention_mask.to(device)
1337
+ else:
1338
+ attention_mask = None
1339
+
1340
+ if clip_skip is None:
1341
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
1342
+ prompt_embeds = prompt_embeds[0]
1343
+ else:
1344
+ prompt_embeds = self.text_encoder(
1345
+ text_input_ids.to(device),
1346
+ attention_mask=attention_mask,
1347
+ output_hidden_states=True,
1348
+ )
1349
+ # Access the `hidden_states` first, that contains a tuple of
1350
+ # all the hidden states from the encoder layers. Then index into
1351
+ # the tuple to access the hidden states from the desired layer.
1352
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
1353
+ # We also need to apply the final LayerNorm here to not mess with the
1354
+ # representations. The `last_hidden_states` that we typically use for
1355
+ # obtaining the final prompt representations passes through the LayerNorm
1356
+ # layer.
1357
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
1358
+
1359
+ if self.text_encoder is not None:
1360
+ prompt_embeds_dtype = self.text_encoder.dtype
1361
+ elif self.unet is not None:
1362
+ prompt_embeds_dtype = self.unet.dtype
1363
+ else:
1364
+ prompt_embeds_dtype = prompt_embeds.dtype
1365
+
1366
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
1367
+
1368
+ bs_embed, seq_len, _ = prompt_embeds.shape
1369
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
1370
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
1371
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
1372
+
1373
+ # get unconditional embeddings for classifier free guidance
1374
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
1375
+ uncond_tokens: List[str]
1376
+ if negative_prompt is None:
1377
+ uncond_tokens = [""] * batch_size
1378
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
1379
+ raise TypeError(
1380
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
1381
+ f" {type(prompt)}."
1382
+ )
1383
+ elif isinstance(negative_prompt, str):
1384
+ uncond_tokens = [negative_prompt]
1385
+ elif batch_size != len(negative_prompt):
1386
+ raise ValueError(
1387
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
1388
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
1389
+ " the batch size of `prompt`."
1390
+ )
1391
+ else:
1392
+ uncond_tokens = negative_prompt
1393
+
1394
+ # textual inversion: procecss multi-vector tokens if necessary
1395
+ if isinstance(self, TextualInversionLoaderMixin):
1396
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
1397
+
1398
+ max_length = prompt_embeds.shape[1]
1399
+ uncond_input = self.tokenizer(
1400
+ uncond_tokens,
1401
+ padding="max_length",
1402
+ max_length=max_length,
1403
+ truncation=True,
1404
+ return_tensors="pt",
1405
+ )
1406
+
1407
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
1408
+ attention_mask = uncond_input.attention_mask.to(device)
1409
+ else:
1410
+ attention_mask = None
1411
+
1412
+ negative_prompt_embeds = self.text_encoder(
1413
+ uncond_input.input_ids.to(device),
1414
+ attention_mask=attention_mask,
1415
+ )
1416
+ negative_prompt_embeds = negative_prompt_embeds[0]
1417
+
1418
+ if do_classifier_free_guidance:
1419
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
1420
+ seq_len = negative_prompt_embeds.shape[1]
1421
+
1422
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
1423
+
1424
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
1425
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
1426
+
1427
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
1428
+ # Retrieve the original scale by scaling back the LoRA layers
1429
+ unscale_lora_layers(self.text_encoder, lora_scale)
1430
+
1431
+ return prompt_embeds, negative_prompt_embeds
1432
+
1433
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
1434
+ def encode_image(self, image, device, num_images_per_prompt):
1435
+ dtype = next(self.image_encoder.parameters()).dtype
1436
+
1437
+ if not isinstance(image, torch.Tensor):
1438
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
1439
+
1440
+ image = image.to(device=device, dtype=dtype)
1441
+ image_embeds = self.image_encoder(image).image_embeds
1442
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
1443
+
1444
+ uncond_image_embeds = torch.zeros_like(image_embeds)
1445
+ return image_embeds, uncond_image_embeds
1446
+
1447
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
1448
+ def run_safety_checker(self, image, device, dtype):
1449
+ if self.safety_checker is None:
1450
+ has_nsfw_concept = None
1451
+ else:
1452
+ if torch.is_tensor(image):
1453
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
1454
+ else:
1455
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
1456
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
1457
+ image, has_nsfw_concept = self.safety_checker(
1458
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
1459
+ )
1460
+ return image, has_nsfw_concept
1461
+
1462
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
1463
+ def decode_latents(self, latents):
1464
+ deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
1465
+ deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
1466
+
1467
+ latents = 1 / self.vae.config.scaling_factor * latents
1468
+ image = self.vae.decode(latents, return_dict=False)[0]
1469
+ image = (image / 2 + 0.5).clamp(0, 1)
1470
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
1471
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
1472
+ return image
1473
+
1474
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
1475
+ def prepare_extra_step_kwargs(self, generator, eta):
1476
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
1477
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
1478
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
1479
+ # and should be between [0, 1]
1480
+
1481
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
1482
+ extra_step_kwargs = {}
1483
+ if accepts_eta:
1484
+ extra_step_kwargs["eta"] = eta
1485
+
1486
+ # check if the scheduler accepts generator
1487
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
1488
+ if accepts_generator:
1489
+ extra_step_kwargs["generator"] = generator
1490
+ return extra_step_kwargs
1491
+
1492
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
1493
+ def prepare_latents(
1494
+ self,
1495
+ batch_size,
1496
+ num_channels_latents,
1497
+ height,
1498
+ width,
1499
+ dtype,
1500
+ device,
1501
+ generator,
1502
+ latents=None,
1503
+ ):
1504
+ shape = (
1505
+ batch_size,
1506
+ num_channels_latents,
1507
+ height // self.vae_scale_factor,
1508
+ width // self.vae_scale_factor,
1509
+ )
1510
+ if isinstance(generator, list) and len(generator) != batch_size:
1511
+ raise ValueError(
1512
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
1513
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
1514
+ )
1515
+
1516
+ if latents is None:
1517
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
1518
+ else:
1519
+ latents = latents.to(device)
1520
+
1521
+ # scale the initial noise by the standard deviation required by the scheduler
1522
+ latents = latents * self.scheduler.init_noise_sigma
1523
+ return latents
1524
+
1525
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_freeu
1526
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
1527
+ r"""Enables the FreeU mechanism as in https://arxiv.org/abs/2309.11497.
1528
+
1529
+ The suffixes after the scaling factors represent the stages where they are being applied.
1530
+
1531
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of the values
1532
+ that are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
1533
+
1534
+ Args:
1535
+ s1 (`float`):
1536
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
1537
+ mitigate "oversmoothing effect" in the enhanced denoising process.
1538
+ s2 (`float`):
1539
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
1540
+ mitigate "oversmoothing effect" in the enhanced denoising process.
1541
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
1542
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
1543
+ """
1544
+ if not hasattr(self, "unet"):
1545
+ raise ValueError("The pipeline must have `unet` for using FreeU.")
1546
+ self.unet.enable_freeu(s1=s1, s2=s2, b1=b1, b2=b2)
1547
+
1548
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
1549
+ def disable_freeu(self):
1550
+ """Disables the FreeU mechanism if enabled."""
1551
+ self.unet.disable_freeu()
1552
+
1553
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
1554
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
1555
+ """
1556
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
1557
+
1558
+ Args:
1559
+ timesteps (`torch.Tensor`):
1560
+ generate embedding vectors at these timesteps
1561
+ embedding_dim (`int`, *optional*, defaults to 512):
1562
+ dimension of the embeddings to generate
1563
+ dtype:
1564
+ data type of the generated embeddings
1565
+
1566
+ Returns:
1567
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
1568
+ """
1569
+ assert len(w.shape) == 1
1570
+ w = w * 1000.0
1571
+
1572
+ half_dim = embedding_dim // 2
1573
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
1574
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
1575
+ emb = w.to(dtype)[:, None] * emb[None, :]
1576
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
1577
+ if embedding_dim % 2 == 1: # zero pad
1578
+ emb = torch.nn.functional.pad(emb, (0, 1))
1579
+ assert emb.shape == (w.shape[0], embedding_dim)
1580
+ return emb
1581
+
1582
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_scale
1583
+ @property
1584
+ def guidance_scale(self):
1585
+ return self._guidance_scale
1586
+
1587
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.guidance_rescale
1588
+ @property
1589
+ def guidance_rescale(self):
1590
+ return self._guidance_rescale
1591
+
1592
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.clip_skip
1593
+ @property
1594
+ def clip_skip(self):
1595
+ return self._clip_skip
1596
+
1597
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1598
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1599
+ # corresponds to doing no classifier free guidance.
1600
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.do_classifier_free_guidance
1601
+ @property
1602
+ def do_classifier_free_guidance(self):
1603
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
1604
+
1605
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.cross_attention_kwargs
1606
+ @property
1607
+ def cross_attention_kwargs(self):
1608
+ return self._cross_attention_kwargs
1609
+
1610
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.num_timesteps
1611
+ @property
1612
+ def num_timesteps(self):
1613
+ return self._num_timesteps