tolgacangoz commited on
Commit
dc031b3
1 Parent(s): 98b1844

Upload matryoshka.py

Browse files
Files changed (1) hide show
  1. scheduler/matryoshka.py +116 -84
scheduler/matryoshka.py CHANGED
@@ -664,9 +664,7 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
664
  variance_noise = []
665
  for m_o in model_output:
666
  variance_noise.append(
667
- randn_tensor(
668
- m_o.shape, generator=generator, device=m_o.device, dtype=m_o.dtype
669
- )
670
  )
671
  else:
672
  variance_noise = randn_tensor(
@@ -1897,6 +1895,8 @@ class MatryoshkaCombinedTimestepTextEmbedding(nn.Module):
1897
  dim=1, keepdim=True
1898
  )
1899
  cond_emb = self.cond_emb(y)
 
 
1900
 
1901
  if not masked_cross_attention:
1902
  conditioning_mask = None
@@ -1905,11 +1905,8 @@ class MatryoshkaCombinedTimestepTextEmbedding(nn.Module):
1905
  if micro is not None:
1906
  temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype))
1907
  temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype))
1908
- if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False):
1909
- cond_emb_micro = cond_emb + temb_micro_conditioning
1910
- return cond_emb_micro, conditioning_mask, cond_emb
1911
- else:
1912
- return temb_micro_conditioning, conditioning_mask, None
1913
 
1914
  return cond_emb, conditioning_mask, cond_emb
1915
 
@@ -3035,11 +3032,6 @@ class MatryoshkaUNet2DConditionModel(
3035
  attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
3036
  attention_mask = attention_mask.unsqueeze(1)
3037
 
3038
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
3039
- if encoder_attention_mask is not None:
3040
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample[0][0].dtype)) * -10000.0
3041
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
3042
-
3043
  # 0. center input if necessary
3044
  if self.config.center_input_sample:
3045
  sample = 2 * sample - 1.0
@@ -3059,6 +3051,7 @@ class MatryoshkaUNet2DConditionModel(
3059
  added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention
3060
  added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
3061
  added_cond_kwargs["from_nested"] = from_nested
 
3062
 
3063
  if not from_nested:
3064
  encoder_hidden_states = self.process_encoder_hidden_states(
@@ -3073,6 +3066,11 @@ class MatryoshkaUNet2DConditionModel(
3073
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3074
  )
3075
 
 
 
 
 
 
3076
  if self.config.addition_embed_type == "image_hint":
3077
  aug_emb, hint = aug_emb
3078
  sample = torch.cat([sample, hint], dim=1)
@@ -3483,11 +3481,6 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
3483
  attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
3484
  attention_mask = attention_mask.unsqueeze(1)
3485
 
3486
- # convert encoder_attention_mask to a bias the same way we do for attention_mask
3487
- if encoder_attention_mask is not None:
3488
- encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
3489
- encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
3490
-
3491
  # 0. center input if necessary
3492
  if self.config.center_input_sample:
3493
  sample = 2 * sample - 1.0
@@ -3507,21 +3500,22 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
3507
  added_cond_kwargs = added_cond_kwargs or {}
3508
  added_cond_kwargs["masked_cross_attention"] = self.inner_unet.config.masked_cross_attention
3509
  added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
 
3510
 
3511
  if not self.config.nesting:
3512
  encoder_hidden_states = self.inner_unet.process_encoder_hidden_states(
3513
  encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3514
  )
3515
 
3516
- aug_emb_inner_unet, cond_mask_inner_unet, cond_emb = self.inner_unet.get_aug_embed(
3517
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3518
  )
3519
-
3520
- aug_emb, cond_mask, _ = self.get_aug_embed(
3521
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3522
  )
3523
  else:
3524
- aug_emb, cond_mask_inner_unet, _ = self.get_aug_embed(
3525
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3526
  )
3527
 
@@ -3529,19 +3523,25 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
3529
  added_cond_kwargs = added_cond_kwargs or {}
3530
  added_cond_kwargs["masked_cross_attention"] = self.inner_unet.inner_unet.config.masked_cross_attention
3531
  added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
 
3532
 
3533
  encoder_hidden_states = self.inner_unet.inner_unet.process_encoder_hidden_states(
3534
  encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3535
  )
3536
 
3537
- aug_emb_inner_unet, cond_mask_inner_unet, cond_emb = self.inner_unet.inner_unet.get_aug_embed(
3538
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3539
  )
3540
 
3541
- aug_emb, cond_mask, _ = self.get_aug_embed(
3542
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3543
  )
3544
 
 
 
 
 
 
3545
  if self.config.addition_embed_type == "image_hint":
3546
  aug_emb, hint = aug_emb
3547
  sample = torch.cat([sample, hint], dim=1)
@@ -3623,7 +3623,7 @@ class NestedUNet2DConditionModel(MatryoshkaUNet2DConditionModel):
3623
  timestep,
3624
  cond_emb=cond_emb,
3625
  encoder_hidden_states=encoder_hidden_states,
3626
- encoder_attention_mask=cond_mask_inner_unet,
3627
  from_nested=True,
3628
  )
3629
  x_low, x_inner = inner_unet_output.sample, inner_unet_output.sample_inner
@@ -3911,9 +3911,6 @@ class MatryoshkaPipeline(
3911
 
3912
  text_inputs = self.tokenizer(
3913
  prompt,
3914
- padding="max_length",
3915
- max_length=self.tokenizer.model_max_length,
3916
- truncation=True,
3917
  return_tensors="pt",
3918
  )
3919
  text_input_ids = text_inputs.input_ids
@@ -3931,26 +3928,9 @@ class MatryoshkaPipeline(
3931
  )
3932
 
3933
  if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
3934
- attention_mask = text_inputs.attention_mask.to(device)
3935
  else:
3936
- attention_mask = None
3937
-
3938
- if clip_skip is None:
3939
- prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
3940
- prompt_embeds = prompt_embeds[0]
3941
- else:
3942
- prompt_embeds = self.text_encoder(
3943
- text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
3944
- )
3945
- # Access the `hidden_states` first, that contains a tuple of
3946
- # all the hidden states from the encoder layers. Then index into
3947
- # the tuple to access the hidden states from the desired layer.
3948
- prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
3949
- # We also need to apply the final LayerNorm here to not mess with the
3950
- # representations. The `last_hidden_states` that we typically use for
3951
- # obtaining the final prompt representations passes through the LayerNorm
3952
- # layer.
3953
- prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
3954
 
3955
  if self.text_encoder is not None:
3956
  prompt_embeds_dtype = self.text_encoder.dtype
@@ -3959,13 +3939,6 @@ class MatryoshkaPipeline(
3959
  else:
3960
  prompt_embeds_dtype = prompt_embeds.dtype
3961
 
3962
- prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
3963
-
3964
- bs_embed, seq_len, _ = prompt_embeds.shape
3965
- # duplicate text embeddings for each generation per prompt, using mps friendly method
3966
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
3967
- prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
3968
-
3969
  # get unconditional embeddings for classifier free guidance
3970
  if do_classifier_free_guidance and negative_prompt_embeds is None:
3971
  uncond_tokens: List[str]
@@ -3991,41 +3964,78 @@ class MatryoshkaPipeline(
3991
  if isinstance(self, TextualInversionLoaderMixin):
3992
  uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
3993
 
3994
- max_length = prompt_embeds.shape[1]
3995
  uncond_input = self.tokenizer(
3996
  uncond_tokens,
3997
- padding="max_length",
3998
- max_length=max_length,
3999
- truncation=True,
4000
  return_tensors="pt",
4001
  )
 
4002
 
4003
  if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
4004
- attention_mask = uncond_input.attention_mask.to(device)
4005
  else:
4006
- attention_mask = None
4007
 
4008
- negative_prompt_embeds = self.text_encoder(
4009
- uncond_input.input_ids.to(device),
4010
- attention_mask=attention_mask,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4011
  )
4012
- negative_prompt_embeds = negative_prompt_embeds[0]
4013
-
4014
- if do_classifier_free_guidance:
4015
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
4016
- seq_len = negative_prompt_embeds.shape[1]
4017
 
4018
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
4019
-
4020
- negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
4021
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
4022
 
4023
  if self.text_encoder is not None:
4024
  if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
4025
  # Retrieve the original scale by scaling back the LoRA layers
4026
  unscale_lora_layers(self.text_encoder, lora_scale)
4027
 
4028
- return prompt_embeds, negative_prompt_embeds
 
 
4029
 
4030
  def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
4031
  dtype = next(self.image_encoder.parameters()).dtype
@@ -4282,10 +4292,6 @@ class MatryoshkaPipeline(
4282
  def interrupt(self):
4283
  return self._interrupt
4284
 
4285
- @property
4286
- def model_type(self):
4287
- return "nested_unet"
4288
-
4289
  @torch.no_grad()
4290
  @replace_example_docstring(EXAMPLE_DOC_STRING)
4291
  def __call__(
@@ -4462,7 +4468,12 @@ class MatryoshkaPipeline(
4462
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
4463
  )
4464
 
4465
- prompt_embeds, negative_prompt_embeds = self.encode_prompt(
 
 
 
 
 
4466
  prompt,
4467
  device,
4468
  num_images_per_prompt,
@@ -4478,7 +4489,12 @@ class MatryoshkaPipeline(
4478
  # Here we concatenate the unconditional and text embeddings into a single batch
4479
  # to avoid doing two forward passes
4480
  if self.do_classifier_free_guidance:
4481
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
 
 
 
 
 
4482
 
4483
  if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
4484
  image_embeds = self.prepare_ip_adapter_image_embeds(
@@ -4490,10 +4506,13 @@ class MatryoshkaPipeline(
4490
  )
4491
 
4492
  # 4. Prepare timesteps
4493
- timesteps, num_inference_steps = retrieve_timesteps(
4494
- self.scheduler, num_inference_steps, device, timesteps, sigmas
4495
- )
4496
- timesteps = timesteps[:-1]
 
 
 
4497
 
4498
  # 5. Prepare latent variables
4499
  num_channels_latents = self.unet.config.in_channels
@@ -4552,6 +4571,7 @@ class MatryoshkaPipeline(
4552
  timestep_cond=timestep_cond,
4553
  cross_attention_kwargs=self.cross_attention_kwargs,
4554
  added_cond_kwargs=added_cond_kwargs,
 
4555
  return_dict=False,
4556
  )[0]
4557
 
@@ -4568,7 +4588,19 @@ class MatryoshkaPipeline(
4568
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
4569
 
4570
  # compute the previous noisy sample x_t -> x_t-1
4571
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
 
 
 
 
 
 
 
 
 
 
 
 
4572
 
4573
  if callback_on_step_end is not None:
4574
  callback_kwargs = {}
 
664
  variance_noise = []
665
  for m_o in model_output:
666
  variance_noise.append(
667
+ randn_tensor(m_o.shape, generator=generator, device=m_o.device, dtype=m_o.dtype)
 
 
668
  )
669
  else:
670
  variance_noise = randn_tensor(
 
1895
  dim=1, keepdim=True
1896
  )
1897
  cond_emb = self.cond_emb(y)
1898
+ else:
1899
+ cond_emb = None
1900
 
1901
  if not masked_cross_attention:
1902
  conditioning_mask = None
 
1905
  if micro is not None:
1906
  temb = self.add_time_proj(torch.tensor([micro], device=emb.device, dtype=emb.dtype))
1907
  temb_micro_conditioning = self.add_timestep_embedder(temb.to(emb.dtype))
1908
+ # if self.cond_emb is not None and not added_cond_kwargs.get("from_nested", False):
1909
+ return temb_micro_conditioning, conditioning_mask, cond_emb
 
 
 
1910
 
1911
  return cond_emb, conditioning_mask, cond_emb
1912
 
 
3032
  attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
3033
  attention_mask = attention_mask.unsqueeze(1)
3034
 
 
 
 
 
 
3035
  # 0. center input if necessary
3036
  if self.config.center_input_sample:
3037
  sample = 2 * sample - 1.0
 
3051
  added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention
3052
  added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
3053
  added_cond_kwargs["from_nested"] = from_nested
3054
+ added_cond_kwargs["conditioning_mask"] = encoder_attention_mask
3055
 
3056
  if not from_nested:
3057
  encoder_hidden_states = self.process_encoder_hidden_states(
 
3066
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3067
  )
3068
 
3069
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
3070
+ if encoder_attention_mask is not None:
3071
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample[0][0].dtype)) * -10000.0
3072
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
3073
+
3074
  if self.config.addition_embed_type == "image_hint":
3075
  aug_emb, hint = aug_emb
3076
  sample = torch.cat([sample, hint], dim=1)
 
3481
  attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
3482
  attention_mask = attention_mask.unsqueeze(1)
3483
 
 
 
 
 
 
3484
  # 0. center input if necessary
3485
  if self.config.center_input_sample:
3486
  sample = 2 * sample - 1.0
 
3500
  added_cond_kwargs = added_cond_kwargs or {}
3501
  added_cond_kwargs["masked_cross_attention"] = self.inner_unet.config.masked_cross_attention
3502
  added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
3503
+ added_cond_kwargs["conditioning_mask"] = encoder_attention_mask
3504
 
3505
  if not self.config.nesting:
3506
  encoder_hidden_states = self.inner_unet.process_encoder_hidden_states(
3507
  encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3508
  )
3509
 
3510
+ aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.get_aug_embed(
3511
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3512
  )
3513
+ added_cond_kwargs["masked_cross_attention"] = self.config.masked_cross_attention
3514
+ aug_emb, __, _ = self.get_aug_embed(
3515
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3516
  )
3517
  else:
3518
+ aug_emb, cond_mask, _ = self.get_aug_embed(
3519
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3520
  )
3521
 
 
3523
  added_cond_kwargs = added_cond_kwargs or {}
3524
  added_cond_kwargs["masked_cross_attention"] = self.inner_unet.inner_unet.config.masked_cross_attention
3525
  added_cond_kwargs["micro_conditioning_scale"] = self.config.micro_conditioning_scale
3526
+ added_cond_kwargs["conditioning_mask"] = encoder_attention_mask
3527
 
3528
  encoder_hidden_states = self.inner_unet.inner_unet.process_encoder_hidden_states(
3529
  encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3530
  )
3531
 
3532
+ aug_emb_inner_unet, cond_mask, cond_emb = self.inner_unet.inner_unet.get_aug_embed(
3533
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3534
  )
3535
 
3536
+ aug_emb, __, _ = self.get_aug_embed(
3537
  emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
3538
  )
3539
 
3540
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
3541
+ if encoder_attention_mask is not None:
3542
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
3543
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
3544
+
3545
  if self.config.addition_embed_type == "image_hint":
3546
  aug_emb, hint = aug_emb
3547
  sample = torch.cat([sample, hint], dim=1)
 
3623
  timestep,
3624
  cond_emb=cond_emb,
3625
  encoder_hidden_states=encoder_hidden_states,
3626
+ encoder_attention_mask=cond_mask,
3627
  from_nested=True,
3628
  )
3629
  x_low, x_inner = inner_unet_output.sample, inner_unet_output.sample_inner
 
3911
 
3912
  text_inputs = self.tokenizer(
3913
  prompt,
 
 
 
3914
  return_tensors="pt",
3915
  )
3916
  text_input_ids = text_inputs.input_ids
 
3928
  )
3929
 
3930
  if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
3931
+ prompt_attention_mask = text_inputs.attention_mask.to(device)
3932
  else:
3933
+ prompt_attention_mask = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3934
 
3935
  if self.text_encoder is not None:
3936
  prompt_embeds_dtype = self.text_encoder.dtype
 
3939
  else:
3940
  prompt_embeds_dtype = prompt_embeds.dtype
3941
 
 
 
 
 
 
 
 
3942
  # get unconditional embeddings for classifier free guidance
3943
  if do_classifier_free_guidance and negative_prompt_embeds is None:
3944
  uncond_tokens: List[str]
 
3964
  if isinstance(self, TextualInversionLoaderMixin):
3965
  uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
3966
 
 
3967
  uncond_input = self.tokenizer(
3968
  uncond_tokens,
 
 
 
3969
  return_tensors="pt",
3970
  )
3971
+ uncond_input_ids = uncond_input.input_ids
3972
 
3973
  if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
3974
+ negative_prompt_attention_mask = uncond_input.attention_mask.to(device)
3975
  else:
3976
+ negative_prompt_attention_mask = None
3977
 
3978
+ if not do_classifier_free_guidance:
3979
+ if clip_skip is None:
3980
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
3981
+ prompt_embeds = prompt_embeds[0]
3982
+ else:
3983
+ prompt_embeds = self.text_encoder(
3984
+ text_input_ids.to(device), attention_mask=prompt_attention_mask, output_hidden_states=True
3985
+ )
3986
+ # Access the `hidden_states` first, that contains a tuple of
3987
+ # all the hidden states from the encoder layers. Then index into
3988
+ # the tuple to access the hidden states from the desired layer.
3989
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
3990
+ # We also need to apply the final LayerNorm here to not mess with the
3991
+ # representations. The `last_hidden_states` that we typically use for
3992
+ # obtaining the final prompt representations passes through the LayerNorm
3993
+ # layer.
3994
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
3995
+ else:
3996
+ max_len = max(len(text_input_ids[0]), len(uncond_input_ids[0]))
3997
+ if len(text_input_ids[0]) < max_len:
3998
+ text_input_ids = torch.cat(
3999
+ [text_input_ids, torch.zeros(batch_size, max_len - len(text_input_ids[0]), dtype=torch.long)],
4000
+ dim=1,
4001
+ )
4002
+ prompt_attention_mask = torch.cat(
4003
+ [
4004
+ prompt_attention_mask,
4005
+ torch.zeros(batch_size, max_len - len(prompt_attention_mask[0]), dtype=torch.long),
4006
+ ],
4007
+ dim=1,
4008
+ )
4009
+ elif len(uncond_input_ids[0]) < max_len:
4010
+ uncond_input_ids = torch.cat(
4011
+ [uncond_input_ids, torch.zeros(batch_size, max_len - len(uncond_input_ids[0]), dtype=torch.long)],
4012
+ dim=1,
4013
+ )
4014
+ negative_prompt_attention_mask = torch.cat(
4015
+ [
4016
+ negative_prompt_attention_mask,
4017
+ torch.zeros(batch_size, max_len - len(negative_prompt_attention_mask[0]), dtype=torch.long),
4018
+ ],
4019
+ dim=1,
4020
+ )
4021
+ cfg_input_ids = torch.cat([uncond_input_ids, text_input_ids], dim=0)
4022
+ cfg_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
4023
+ prompt_embeds = self.text_encoder(
4024
+ cfg_input_ids.to(device),
4025
+ attention_mask=cfg_attention_mask,
4026
  )
4027
+ prompt_embeds = prompt_embeds[0]
 
 
 
 
4028
 
4029
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
 
 
 
4030
 
4031
  if self.text_encoder is not None:
4032
  if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
4033
  # Retrieve the original scale by scaling back the LoRA layers
4034
  unscale_lora_layers(self.text_encoder, lora_scale)
4035
 
4036
+ if not do_classifier_free_guidance:
4037
+ return prompt_embeds, None, prompt_attention_mask, None
4038
+ return prompt_embeds[1], prompt_embeds[0], prompt_attention_mask, negative_prompt_attention_mask
4039
 
4040
  def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
4041
  dtype = next(self.image_encoder.parameters()).dtype
 
4292
  def interrupt(self):
4293
  return self._interrupt
4294
 
 
 
 
 
4295
  @torch.no_grad()
4296
  @replace_example_docstring(EXAMPLE_DOC_STRING)
4297
  def __call__(
 
4468
  self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
4469
  )
4470
 
4471
+ (
4472
+ prompt_embeds,
4473
+ negative_prompt_embeds,
4474
+ prompt_attention_mask,
4475
+ negative_prompt_attention_mask,
4476
+ ) = self.encode_prompt(
4477
  prompt,
4478
  device,
4479
  num_images_per_prompt,
 
4489
  # Here we concatenate the unconditional and text embeddings into a single batch
4490
  # to avoid doing two forward passes
4491
  if self.do_classifier_free_guidance:
4492
+ prompt_embeds = torch.cat([negative_prompt_embeds.unsqueeze(0), prompt_embeds.unsqueeze(0)])
4493
+ attention_masks = torch.cat([negative_prompt_attention_mask, prompt_attention_mask])
4494
+ else:
4495
+ attention_masks = prompt_attention_mask
4496
+
4497
+ prompt_embeds = prompt_embeds * attention_masks.unsqueeze(-1)
4498
 
4499
  if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
4500
  image_embeds = self.prepare_ip_adapter_image_embeds(
 
4506
  )
4507
 
4508
  # 4. Prepare timesteps
4509
+ if isinstance(self.scheduler, MatryoshkaDDIMScheduler):
4510
+ timesteps, num_inference_steps = retrieve_timesteps(
4511
+ self.scheduler, num_inference_steps, device, timesteps, sigmas
4512
+ )
4513
+ timesteps = timesteps[:-1] # is this correct???
4514
+ else:
4515
+ timesteps = self.scheduler.timesteps
4516
 
4517
  # 5. Prepare latent variables
4518
  num_channels_latents = self.unet.config.in_channels
 
4571
  timestep_cond=timestep_cond,
4572
  cross_attention_kwargs=self.cross_attention_kwargs,
4573
  added_cond_kwargs=added_cond_kwargs,
4574
+ encoder_attention_mask=attention_masks,
4575
  return_dict=False,
4576
  )[0]
4577
 
 
4588
  noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
4589
 
4590
  # compute the previous noisy sample x_t -> x_t-1
4591
+ if self.scheduler.scales is not None and not isinstance(self.scheduler, MatryoshkaDDIMScheduler):
4592
+ latents[0] = self.scheduler.step(
4593
+ noise_pred[0], t, latents[0], **extra_step_kwargs, return_dict=False
4594
+ )[0]
4595
+ latents[1] = self.scheduler.inner_scheduler.step(
4596
+ noise_pred[1], t, latents[1], **extra_step_kwargs, return_dict=False
4597
+ )[0]
4598
+ if len(latents) > 2:
4599
+ latents[2] = self.scheduler.inner_scheduler.inner_scheduler.step(
4600
+ noise_pred[2], t, latents[2], **extra_step_kwargs, return_dict=False
4601
+ )[0]
4602
+ else:
4603
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
4604
 
4605
  if callback_on_step_end is not None:
4606
  callback_kwargs = {}