Linoy Tsaban commited on
Commit
c71b83b
1 Parent(s): 8f22003

Update pipeline_semantic_stable_diffusion_img2img_solver.py

Browse files
pipeline_semantic_stable_diffusion_img2img_solver.py CHANGED
@@ -36,20 +36,21 @@ class AttentionStore():
36
 
37
  def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP=False):
38
  # attn.shape = batch_size * head_size, seq_len query, seq_len_key
39
- if attn.shape[1] <= self.max_size:
40
- bs = 1 + int(PnP) + editing_prompts
41
- skip = 2 if PnP else 1 # skip PnP & unconditional
42
- attn = torch.stack(attn.split(self.batch_size)).permute(1, 0, 2, 3)
43
- source_batch_size = int(attn.shape[1] // bs)
44
- self.forward(
45
- attn[:, skip * source_batch_size:],
46
- is_cross,
47
- place_in_unet)
 
48
 
49
  def forward(self, attn, is_cross: bool, place_in_unet: str):
50
  key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
51
-
52
- self.step_store[key].append(attn)
53
 
54
  def between_steps(self, store_step=True):
55
  if store_step:
@@ -95,13 +96,12 @@ class AttentionStore():
95
  out = out.sum(1) / out.shape[1]
96
  return out
97
 
98
- def __init__(self, average: bool, batch_size=1, max_resolution=16):
99
  self.step_store = self.get_empty_store()
100
  self.attention_store = []
101
  self.cur_step = 0
102
  self.average = average
103
  self.batch_size = batch_size
104
- self.max_size = max_resolution ** 2
105
 
106
 
107
  class CrossAttnProcessor:
@@ -433,10 +433,10 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
433
 
434
  # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
435
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents):
436
- #shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
437
 
438
- #if latents.shape != shape:
439
- # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
440
 
441
  latents = latents.to(device)
442
 
@@ -456,7 +456,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
456
  else:
457
  continue
458
 
459
- if "attn2" in name and place_in_unet != 'mid':
460
  attn_procs[name] = CrossAttnProcessor(
461
  attention_store=attention_store,
462
  place_in_unet=place_in_unet,
@@ -470,8 +470,16 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
470
  @torch.no_grad()
471
  def __call__(
472
  self,
473
- eta: Optional[float] = 1.0,
 
 
 
 
474
  negative_prompt: Optional[Union[str, List[str]]] = None,
 
 
 
 
475
  output_type: Optional[str] = "pil",
476
  return_dict: bool = True,
477
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -480,10 +488,12 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
480
  editing_prompt_embeddings: Optional[torch.Tensor] = None,
481
  reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
482
  edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
483
- edit_warmup_steps: Optional[Union[int, List[int]]] = 0,
484
  edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
485
  edit_threshold: Optional[Union[float, List[float]]] = 0.9,
486
  user_mask: Optional[torch.FloatTensor] = None,
 
 
487
  edit_weights: Optional[List[float]] = None,
488
  sem_guidance: Optional[List[torch.Tensor]] = None,
489
  verbose=True,
@@ -494,7 +504,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
494
  use_intersect_mask: bool = False,
495
  init_latents = None,
496
  zs = None,
497
-
498
  ):
499
  r"""
500
  Function invoked when calling the pipeline for generation.
@@ -589,7 +599,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
589
  second element is a list of `bool`s denoting whether the corresponding generated image likely represents
590
  "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
591
  """
592
- eta = 1.0
593
  num_images_per_prompt = 1
594
  # latents = self.init_latents
595
  latents = init_latents
@@ -604,7 +614,18 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
604
  if use_cross_attn_mask:
605
  self.smoothing = GaussianSmoothing(self.device)
606
 
607
- org_prompt = ""
 
 
 
 
 
 
 
 
 
 
 
608
 
609
  # 2. Define call parameters
610
  batch_size = self.batch_size
@@ -621,6 +642,35 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
621
  self.enabled_editing_prompts = 0
622
  enable_edit_guidance = False
623
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
624
  if enable_edit_guidance:
625
  # get safety text embeddings
626
  if editing_prompt_embeddings is None:
@@ -663,47 +713,54 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
663
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
664
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
665
  # corresponds to doing no classifier free guidance.
 
666
  # get unconditional embeddings for classifier free guidance
667
 
668
-
669
- uncond_tokens: List[str]
670
- if negative_prompt is None:
671
- uncond_tokens = [""]
672
- elif isinstance(negative_prompt, str):
673
- uncond_tokens = [negative_prompt]
674
- elif batch_size != len(negative_prompt):
675
- raise ValueError(
676
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
677
- f" has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
678
- " the batch size of `prompt`."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
679
  )
680
- else:
681
- uncond_tokens = negative_prompt
682
 
683
- max_length = self.tokenizer.model_max_length
684
- uncond_input = self.tokenizer(
685
- uncond_tokens,
686
- padding="max_length",
687
- max_length=max_length,
688
- truncation=True,
689
- return_tensors="pt",
690
- )
691
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
692
 
693
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
694
- seq_len = uncond_embeddings.shape[1]
695
- uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
696
- uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
697
-
698
- # For classifier free guidance, we need to do two forward passes.
699
- # Here we concatenate the unconditional and text embeddings into a single batch
700
- # to avoid doing two forward passes
701
- if enable_edit_guidance:
702
- text_embeddings = torch.cat([uncond_embeddings, edit_concepts])
703
- self.text_cross_attention_maps = \
704
- ([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt)
705
- else:
706
- text_embeddings = torch.cat([uncond_embeddings])
707
 
708
  # 4. Prepare timesteps
709
  #self.scheduler.set_timesteps(num_inference_steps, device=self.device)
@@ -721,8 +778,8 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
721
  latents = self.prepare_latents(
722
  batch_size * num_images_per_prompt,
723
  num_channels_latents,
724
- None,
725
- None,
726
  text_embeddings.dtype,
727
  self.device,
728
  latents,
@@ -731,16 +788,23 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
731
  # 6. Prepare extra step kwargs.
732
  extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
733
 
 
 
 
734
  self.uncond_estimates = None
 
735
  self.edit_estimates = None
736
  self.sem_guidance = None
737
  self.activation_mask = None
738
 
739
  for i, t in enumerate(self.progress_bar(timesteps, verbose=verbose)):
 
 
 
740
  # expand the latents if we are doing classifier free guidance
741
 
742
- if enable_edit_guidance:
743
- latent_model_input = torch.cat([latents] * (1 + self.enabled_editing_prompts))
744
  else:
745
  latent_model_input = latents
746
 
@@ -751,219 +815,256 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
751
  # predict the noise residual
752
  noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input).sample
753
 
 
 
754
 
755
- noise_pred_out = noise_pred.chunk(1 + self.enabled_editing_prompts) # [b,4, 64, 64]
756
- noise_pred_uncond = noise_pred_out[0]
757
- noise_pred_edit_concepts = noise_pred_out[1:]
758
 
759
- # default text guidance
760
- noise_guidance = torch.zeros_like(noise_pred_uncond)
761
 
762
- if self.uncond_estimates is None:
763
- self.uncond_estimates = torch.zeros((len(timesteps), *noise_pred_uncond.shape))
764
- self.uncond_estimates[i] = noise_pred_uncond.detach().cpu()
765
 
766
- if sem_guidance is not None and len(sem_guidance) > i:
767
- edit_guidance = sem_guidance[i].to(self.device)
768
- noise_guidance = noise_guidance + edit_guidance
769
 
770
- elif enable_edit_guidance:
771
- if self.activation_mask is None:
772
- self.activation_mask = torch.zeros(
773
- (len(timesteps), len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape)
774
- )
775
- if self.edit_estimates is None and enable_edit_guidance:
776
- self.edit_estimates = torch.zeros(
777
- (len(timesteps), len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape)
778
- )
779
 
780
- if self.sem_guidance is None:
781
- self.sem_guidance = torch.zeros((len(timesteps), *noise_pred_uncond.shape))
 
 
782
 
783
- concept_weights = torch.zeros(
784
- (len(noise_pred_edit_concepts), noise_guidance.shape[0]),
785
- device=self.device,
786
- dtype=noise_guidance.dtype,
787
- )
788
- noise_guidance_edit = torch.zeros(
789
- (len(noise_pred_edit_concepts), *noise_guidance.shape),
790
- device=self.device,
791
- dtype=noise_guidance.dtype,
792
- )
793
- warmup_inds = []
794
- # noise_guidance_edit = torch.zeros_like(noise_guidance)
795
- for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):
796
- self.edit_estimates[i, c] = noise_pred_edit_concept
797
- if isinstance(edit_warmup_steps, list):
798
- edit_warmup_steps_c = edit_warmup_steps[c]
799
- else:
800
- edit_warmup_steps_c = edit_warmup_steps
801
- if i >= edit_warmup_steps_c:
802
- warmup_inds.append(c)
803
- else:
804
- continue
805
-
806
- if isinstance(edit_guidance_scale, list):
807
- edit_guidance_scale_c = edit_guidance_scale[c]
808
- else:
809
- edit_guidance_scale_c = edit_guidance_scale
810
-
811
- if isinstance(edit_threshold, list):
812
- edit_threshold_c = edit_threshold[c]
813
- else:
814
- edit_threshold_c = edit_threshold
815
- if isinstance(reverse_editing_direction, list):
816
- reverse_editing_direction_c = reverse_editing_direction[c]
817
- else:
818
- reverse_editing_direction_c = reverse_editing_direction
819
- if edit_weights:
820
- edit_weight_c = edit_weights[c]
821
- else:
822
- edit_weight_c = 1.0
823
-
824
- if isinstance(edit_cooldown_steps, list):
825
- edit_cooldown_steps_c = edit_cooldown_steps[c]
826
- elif edit_cooldown_steps is None:
827
- edit_cooldown_steps_c = i + 1
828
- else:
829
- edit_cooldown_steps_c = edit_cooldown_steps
830
-
831
- if i >= edit_cooldown_steps_c:
832
- noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept)
833
- continue
834
-
835
- noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond
836
- # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3))
837
- tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3))
838
-
839
- tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts)
840
- if reverse_editing_direction_c:
841
- noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1
842
- concept_weights[c, :] = tmp_weights
843
-
844
- noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c
845
-
846
- if user_mask is not None:
847
- noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask
848
-
849
- if use_cross_attn_mask:
850
- out = self.attention_store.aggregate_attention(
851
- attention_maps=self.attention_store.step_store,
852
- prompts=self.text_cross_attention_maps,
853
- res=16,
854
- from_where=["up", "down"],
855
- is_cross=True,
856
- select=self.text_cross_attention_maps.index(editing_prompt[c]),
857
  )
858
- attn_map = out[:, :, :, 1:1 + num_edit_tokens[c]] # 0 -> startoftext
859
 
860
- # average over all tokens
861
- assert (attn_map.shape[3] == num_edit_tokens[c])
862
- attn_map = torch.sum(attn_map, dim=3)
863
 
864
- # gaussian_smoothing
865
- attn_map = F.pad(attn_map.unsqueeze(1), (1, 1, 1, 1), mode="reflect")
866
- attn_map = self.smoothing(attn_map).squeeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
867
 
868
- # create binary mask
869
- if attn_map.dtype == torch.float32:
870
- tmp = torch.quantile(attn_map.flatten(start_dim=1), edit_threshold_c, dim=1)
871
  else:
872
- tmp = torch.quantile(attn_map.flatten(start_dim=1).to(torch.float32), edit_threshold_c, dim=1).to(attn_map.dtype)
873
- attn_mask = torch.where(attn_map >= tmp.unsqueeze(1).unsqueeze(1).repeat(1,16,16), 1.0, 0.0)
874
-
875
- # resolution must match latent space dimension
876
- attn_mask = F.interpolate(
877
- attn_mask.unsqueeze(1),
878
- noise_guidance_edit_tmp.shape[-2:] # 64,64
879
- ).repeat(1, 4, 1, 1)
880
- self.activation_mask[i, c] = attn_mask.detach().cpu()
881
- if not use_intersect_mask:
882
- noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask
883
-
884
- if use_intersect_mask:
885
- noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
886
- noise_guidance_edit_tmp_quantile = torch.sum(noise_guidance_edit_tmp_quantile, dim=1,
887
- keepdim=True)
888
- noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1)
889
-
890
- # torch.quantile function expects float32
891
- if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
892
- tmp = torch.quantile(
893
- noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
894
- edit_threshold_c,
895
- dim=2,
896
- keepdim=False,
897
- )
898
  else:
899
- tmp = torch.quantile(
900
- noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
901
- edit_threshold_c,
902
- dim=2,
903
- keepdim=False,
904
- ).to(noise_guidance_edit_tmp_quantile.dtype)
905
-
906
- intersect_mask = torch.where(
907
- noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
908
- torch.ones_like(noise_guidance_edit_tmp),
909
- torch.zeros_like(noise_guidance_edit_tmp),
910
- ) * attn_mask
911
-
912
- self.activation_mask[i, c] = intersect_mask.detach().cpu()
913
-
914
- noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask
915
-
916
- elif not use_cross_attn_mask:
917
- # calculate quantile
918
- noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
919
- noise_guidance_edit_tmp_quantile = torch.sum(noise_guidance_edit_tmp_quantile, dim=1,
920
- keepdim=True)
921
- noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1)
922
-
923
- # torch.quantile function expects float32
924
- if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
925
- tmp = torch.quantile(
926
- noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
927
- edit_threshold_c,
928
- dim=2,
929
- keepdim=False,
930
- )
931
  else:
932
- tmp = torch.quantile(
933
- noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
934
- edit_threshold_c,
935
- dim=2,
936
- keepdim=False,
937
- ).to(noise_guidance_edit_tmp_quantile.dtype)
938
-
939
- self.activation_mask[i, c] = torch.where(
940
- noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
941
- torch.ones_like(noise_guidance_edit_tmp),
942
- torch.zeros_like(noise_guidance_edit_tmp),
943
- ).detach().cpu()
944
-
945
- noise_guidance_edit_tmp = torch.where(
946
- noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
947
- noise_guidance_edit_tmp,
948
- torch.zeros_like(noise_guidance_edit_tmp),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
949
  )
 
 
950
 
951
- noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp
 
 
 
 
 
 
 
952
 
953
- warmup_inds = torch.tensor(warmup_inds).to(self.device)
954
- concept_weights = torch.index_select(concept_weights, 0, warmup_inds)
955
- concept_weights = torch.where(
956
- concept_weights < 0, torch.zeros_like(concept_weights), concept_weights
957
- )
958
 
959
- concept_weights = torch.nan_to_num(concept_weights)
 
 
 
960
 
961
- noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit)
 
 
962
 
963
- noise_guidance = noise_guidance + noise_guidance_edit
964
- self.sem_guidance[i] = noise_guidance_edit.detach().cpu()
965
 
966
- noise_pred = noise_pred_uncond + noise_guidance
 
 
 
 
 
 
 
 
 
 
967
 
968
  # compute the previous noisy sample x_t -> x_t-1
969
  if use_ddpm:
@@ -971,7 +1072,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
971
  latents = self.scheduler.step(noise_pred, t, latents, variance_noise=zs[idx],
972
  **extra_step_kwargs).prev_sample
973
 
974
- else: # if not use_ddpm:
975
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
976
 
977
  # step callback
@@ -1031,7 +1132,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
1031
  source_prompt: str = "",
1032
  source_guidance_scale=3.5,
1033
  num_inversion_steps: int = 30,
1034
- skip: int = 15,
1035
  eta: float = 1.0,
1036
  generator: Optional[torch.Generator] = None,
1037
  verbose=True,
@@ -1048,7 +1149,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
1048
  # self.eta = eta
1049
  # assert (self.eta > 0)
1050
  skip = skip/100
1051
-
1052
  train_steps = self.scheduler.config.num_train_timesteps
1053
  timesteps = torch.from_numpy(
1054
  np.linspace(train_steps - skip * train_steps - 1, 1, num_inversion_steps).astype(np.int64)).to(self.device)
@@ -1057,8 +1158,11 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
1057
  self.num_inversion_steps = timesteps.shape[0]
1058
  self.scheduler.num_inference_steps = timesteps.shape[0]
1059
  self.scheduler.timesteps = timesteps
1060
-
 
 
1061
  self.unet.set_attn_processor(AttnProcessor())
 
1062
  # 1. get embeddings
1063
 
1064
  uncond_embedding = self.encode_text("")
@@ -1073,7 +1177,6 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
1073
  # autoencoder reconstruction
1074
  # image_rec = self.vae.decode(x0 / self.vae.config.scaling_factor, return_dict=False)[0]
1075
  # image_rec = self.image_processor.postprocess(image_rec, output_type="pil")
1076
-
1077
  # 3. find zs and xts
1078
  variance_noise_shape = (
1079
  self.num_inversion_steps,
@@ -1123,8 +1226,8 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
1123
  # self.zs = zs
1124
 
1125
 
1126
-
1127
  return zs, xts
 
1128
 
1129
  @torch.no_grad()
1130
  def encode_image(self, image_path, dtype=None):
 
36
 
37
  def __call__(self, attn, is_cross: bool, place_in_unet: str, editing_prompts, PnP=False):
38
  # attn.shape = batch_size * head_size, seq_len query, seq_len_key
39
+ bs = 2 + int(PnP) + editing_prompts
40
+ skip = 2 if PnP else 1 # skip PnP & unconditional
41
+
42
+ head_size = int(attn.shape[0] / self.batch_size)
43
+ attn = torch.stack(attn.split(self.batch_size)).permute(1, 0, 2, 3)
44
+ source_batch_size = int(attn.shape[1] // bs)
45
+ self.forward(
46
+ attn[:, skip * source_batch_size:],
47
+ is_cross,
48
+ place_in_unet)
49
 
50
  def forward(self, attn, is_cross: bool, place_in_unet: str):
51
  key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
52
+ if attn.shape[1] <= 32 ** 2: # avoid memory overhead
53
+ self.step_store[key].append(attn)
54
 
55
  def between_steps(self, store_step=True):
56
  if store_step:
 
96
  out = out.sum(1) / out.shape[1]
97
  return out
98
 
99
+ def __init__(self, average: bool, batch_size=1):
100
  self.step_store = self.get_empty_store()
101
  self.attention_store = []
102
  self.cur_step = 0
103
  self.average = average
104
  self.batch_size = batch_size
 
105
 
106
 
107
  class CrossAttnProcessor:
 
433
 
434
  # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
435
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents):
436
+ shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
437
 
438
+ if latents.shape != shape:
439
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
440
 
441
  latents = latents.to(device)
442
 
 
456
  else:
457
  continue
458
 
459
+ if "attn2" in name:
460
  attn_procs[name] = CrossAttnProcessor(
461
  attention_store=attention_store,
462
  place_in_unet=place_in_unet,
 
470
  @torch.no_grad()
471
  def __call__(
472
  self,
473
+ prompt: Union[str, List[str]] = "",
474
+ height: Optional[int] = None,
475
+ width: Optional[int] = None,
476
+ # num_inference_steps: int = 50,
477
+ guidance_scale: float = 7.5,
478
  negative_prompt: Optional[Union[str, List[str]]] = None,
479
+ # num_images_per_prompt: int = 1,
480
+ eta: float = 1.0,
481
+ # generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
482
+ # latents: Optional[torch.FloatTensor] = None,
483
  output_type: Optional[str] = "pil",
484
  return_dict: bool = True,
485
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
 
488
  editing_prompt_embeddings: Optional[torch.Tensor] = None,
489
  reverse_editing_direction: Optional[Union[bool, List[bool]]] = False,
490
  edit_guidance_scale: Optional[Union[float, List[float]]] = 5,
491
+ edit_warmup_steps: Optional[Union[int, List[int]]] = 10,
492
  edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
493
  edit_threshold: Optional[Union[float, List[float]]] = 0.9,
494
  user_mask: Optional[torch.FloatTensor] = None,
495
+ edit_momentum_scale: Optional[float] = 0.1,
496
+ edit_mom_beta: Optional[float] = 0.4,
497
  edit_weights: Optional[List[float]] = None,
498
  sem_guidance: Optional[List[torch.Tensor]] = None,
499
  verbose=True,
 
504
  use_intersect_mask: bool = False,
505
  init_latents = None,
506
  zs = None,
507
+
508
  ):
509
  r"""
510
  Function invoked when calling the pipeline for generation.
 
599
  second element is a list of `bool`s denoting whether the corresponding generated image likely represents
600
  "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
601
  """
602
+ # eta = self.eta
603
  num_images_per_prompt = 1
604
  # latents = self.init_latents
605
  latents = init_latents
 
614
  if use_cross_attn_mask:
615
  self.smoothing = GaussianSmoothing(self.device)
616
 
617
+ # 0. Default height and width to unet
618
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
619
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
620
+
621
+ # 1. Check inputs. Raise error if not correct
622
+ self.check_inputs(prompt, height, width, callback_steps)
623
+
624
+ org_prompt = prompt
625
+ if isinstance(prompt, list):
626
+ assert len(prompt) == self.batch_size
627
+ elif isinstance(prompt, str):
628
+ prompt = list(repeat(prompt, self.batch_size))
629
 
630
  # 2. Define call parameters
631
  batch_size = self.batch_size
 
642
  self.enabled_editing_prompts = 0
643
  enable_edit_guidance = False
644
 
645
+ # get prompt text embeddings
646
+ text_inputs = self.tokenizer(
647
+ prompt,
648
+ padding="max_length",
649
+ max_length=self.tokenizer.model_max_length,
650
+ truncation=True,
651
+ return_tensors="pt",
652
+ )
653
+ text_input_ids = text_inputs.input_ids
654
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
655
+
656
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
657
+ text_input_ids, untruncated_ids
658
+ ):
659
+ removed_text = self.tokenizer.batch_decode(
660
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
661
+ )
662
+ logger.warning(
663
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
664
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
665
+ )
666
+
667
+ text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
668
+
669
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
670
+ bs_embed, seq_len, _ = text_embeddings.shape
671
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
672
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
673
+
674
  if enable_edit_guidance:
675
  # get safety text embeddings
676
  if editing_prompt_embeddings is None:
 
713
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
714
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
715
  # corresponds to doing no classifier free guidance.
716
+ do_classifier_free_guidance = guidance_scale > 1.0
717
  # get unconditional embeddings for classifier free guidance
718
 
719
+ if do_classifier_free_guidance:
720
+ uncond_tokens: List[str]
721
+ if negative_prompt is None:
722
+ uncond_tokens = [""]
723
+ elif type(prompt) is not type(negative_prompt):
724
+ raise TypeError(
725
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
726
+ f" {type(prompt)}."
727
+ )
728
+ elif isinstance(negative_prompt, str):
729
+ uncond_tokens = [negative_prompt]
730
+ elif batch_size != len(negative_prompt):
731
+ raise ValueError(
732
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
733
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
734
+ " the batch size of `prompt`."
735
+ )
736
+ else:
737
+ uncond_tokens = negative_prompt
738
+
739
+ max_length = text_input_ids.shape[-1]
740
+ uncond_input = self.tokenizer(
741
+ uncond_tokens,
742
+ padding="max_length",
743
+ max_length=max_length,
744
+ truncation=True,
745
+ return_tensors="pt",
746
  )
747
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
 
748
 
749
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
750
+ seq_len = uncond_embeddings.shape[1]
751
+ uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
752
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
 
 
 
 
 
753
 
754
+ # For classifier free guidance, we need to do two forward passes.
755
+ # Here we concatenate the unconditional and text embeddings into a single batch
756
+ # to avoid doing two forward passes
757
+ self.text_cross_attention_maps = [org_prompt] if isinstance(org_prompt, str) else org_prompt
758
+ if enable_edit_guidance:
759
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
760
+ self.text_cross_attention_maps += \
761
+ ([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt)
762
+ else:
763
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
 
 
 
 
764
 
765
  # 4. Prepare timesteps
766
  #self.scheduler.set_timesteps(num_inference_steps, device=self.device)
 
778
  latents = self.prepare_latents(
779
  batch_size * num_images_per_prompt,
780
  num_channels_latents,
781
+ height,
782
+ width,
783
  text_embeddings.dtype,
784
  self.device,
785
  latents,
 
788
  # 6. Prepare extra step kwargs.
789
  extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
790
 
791
+ # Initialize edit_momentum to None
792
+ edit_momentum = None
793
+
794
  self.uncond_estimates = None
795
+ self.text_estimates = None
796
  self.edit_estimates = None
797
  self.sem_guidance = None
798
  self.activation_mask = None
799
 
800
  for i, t in enumerate(self.progress_bar(timesteps, verbose=verbose)):
801
+ idx = t_to_idx[int(t)]
802
+
803
+
804
  # expand the latents if we are doing classifier free guidance
805
 
806
+ if do_classifier_free_guidance:
807
+ latent_model_input = torch.cat([latents] * (2 + self.enabled_editing_prompts))
808
  else:
809
  latent_model_input = latents
810
 
 
815
  # predict the noise residual
816
  noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input).sample
817
 
818
+ # perform guidance
819
+ if do_classifier_free_guidance:
820
 
821
+ noise_pred_out = noise_pred.chunk(2 + self.enabled_editing_prompts) # [b,4, 64, 64]
822
+ noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
823
+ noise_pred_edit_concepts = noise_pred_out[2:]
824
 
825
+ # default text guidance
826
+ noise_guidance = guidance_scale * (noise_pred_text - noise_pred_uncond)
827
 
828
+ if self.uncond_estimates is None:
829
+ self.uncond_estimates = torch.zeros((len(timesteps), *noise_pred_uncond.shape))
830
+ self.uncond_estimates[i] = noise_pred_uncond.detach().cpu()
831
 
832
+ if self.text_estimates is None:
833
+ self.text_estimates = torch.zeros((len(timesteps), *noise_pred_text.shape))
834
+ self.text_estimates[i] = noise_pred_text.detach().cpu()
835
 
836
+ if edit_momentum is None:
837
+ edit_momentum = torch.zeros_like(noise_guidance)
 
 
 
 
 
 
 
838
 
839
+ if sem_guidance is not None and len(sem_guidance) > i:
840
+ edit_guidance = sem_guidance[i].to(self.device)
841
+ edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * edit_guidance
842
+ noise_guidance = noise_guidance + edit_guidance
843
 
844
+ elif enable_edit_guidance:
845
+ if self.activation_mask is None:
846
+ self.activation_mask = torch.zeros(
847
+ (len(timesteps), len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape)
848
+ )
849
+ if self.edit_estimates is None and enable_edit_guidance:
850
+ self.edit_estimates = torch.zeros(
851
+ (len(timesteps), len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
852
  )
 
853
 
854
+ if self.sem_guidance is None:
855
+ self.sem_guidance = torch.zeros((len(timesteps), *noise_pred_text.shape))
 
856
 
857
+ concept_weights = torch.zeros(
858
+ (len(noise_pred_edit_concepts), noise_guidance.shape[0]),
859
+ device=self.device,
860
+ dtype=noise_guidance.dtype,
861
+ )
862
+ noise_guidance_edit = torch.zeros(
863
+ (len(noise_pred_edit_concepts), *noise_guidance.shape),
864
+ device=self.device,
865
+ dtype=noise_guidance.dtype,
866
+ )
867
+ # noise_guidance_edit = torch.zeros_like(noise_guidance)
868
+ warmup_inds = []
869
+ for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):
870
+ self.edit_estimates[i, c] = noise_pred_edit_concept
871
+ if isinstance(edit_guidance_scale, list):
872
+ edit_guidance_scale_c = edit_guidance_scale[c]
873
+ else:
874
+ edit_guidance_scale_c = edit_guidance_scale
875
 
876
+ if isinstance(edit_threshold, list):
877
+ edit_threshold_c = edit_threshold[c]
 
878
  else:
879
+ edit_threshold_c = edit_threshold
880
+ if isinstance(reverse_editing_direction, list):
881
+ reverse_editing_direction_c = reverse_editing_direction[c]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
882
  else:
883
+ reverse_editing_direction_c = reverse_editing_direction
884
+ if edit_weights:
885
+ edit_weight_c = edit_weights[c]
886
+ else:
887
+ edit_weight_c = 1.0
888
+ if isinstance(edit_warmup_steps, list):
889
+ edit_warmup_steps_c = edit_warmup_steps[c]
890
+ else:
891
+ edit_warmup_steps_c = edit_warmup_steps
892
+
893
+ if isinstance(edit_cooldown_steps, list):
894
+ edit_cooldown_steps_c = edit_cooldown_steps[c]
895
+ elif edit_cooldown_steps is None:
896
+ edit_cooldown_steps_c = i + 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
897
  else:
898
+ edit_cooldown_steps_c = edit_cooldown_steps
899
+ if i >= edit_warmup_steps_c:
900
+ warmup_inds.append(c)
901
+ if i >= edit_cooldown_steps_c:
902
+ noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept)
903
+ continue
904
+
905
+ noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond
906
+ # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3))
907
+ tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3))
908
+
909
+ tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts)
910
+ if reverse_editing_direction_c:
911
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1
912
+ concept_weights[c, :] = tmp_weights
913
+
914
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c
915
+
916
+ if user_mask is not None:
917
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask
918
+
919
+ if use_cross_attn_mask:
920
+ out = self.attention_store.aggregate_attention(
921
+ attention_maps=self.attention_store.step_store,
922
+ prompts=self.text_cross_attention_maps,
923
+ res=16,
924
+ from_where=["up", "down"],
925
+ is_cross=True,
926
+ select=self.text_cross_attention_maps.index(editing_prompt[c]),
927
+ )
928
+ attn_map = out[:, :, :, 1:1 + num_edit_tokens[c]] # 0 -> startoftext
929
+
930
+ # average over all tokens
931
+ assert (attn_map.shape[3] == num_edit_tokens[c])
932
+ attn_map = torch.sum(attn_map, dim=3)
933
+
934
+ # gaussian_smoothing
935
+ attn_map = F.pad(attn_map.unsqueeze(1), (1, 1, 1, 1), mode="reflect")
936
+ attn_map = self.smoothing(attn_map).squeeze(1)
937
+
938
+ # create binary mask
939
+ if attn_map.dtype == torch.float32:
940
+ tmp = torch.quantile(attn_map.flatten(start_dim=1), edit_threshold_c, dim=1)
941
+ else:
942
+ tmp = torch.quantile(attn_map.flatten(start_dim=1).to(torch.float32), edit_threshold_c, dim=1).to(attn_map.dtype)
943
+ attn_mask = torch.where(attn_map >= tmp.unsqueeze(1).unsqueeze(1).repeat(1,16,16), 1.0, 0.0)
944
+
945
+ # resolution must match latent space dimension
946
+ attn_mask = F.interpolate(
947
+ attn_mask.unsqueeze(1),
948
+ noise_guidance_edit_tmp.shape[-2:] # 64,64
949
+ ).repeat(1, 4, 1, 1)
950
+ self.activation_mask[i, c] = attn_mask.detach().cpu()
951
+ if not use_intersect_mask:
952
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask
953
+
954
+ if use_intersect_mask:
955
+ noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
956
+ noise_guidance_edit_tmp_quantile = torch.sum(noise_guidance_edit_tmp_quantile, dim=1,
957
+ keepdim=True)
958
+ noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1)
959
+
960
+ # torch.quantile function expects float32
961
+ if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
962
+ tmp = torch.quantile(
963
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
964
+ edit_threshold_c,
965
+ dim=2,
966
+ keepdim=False,
967
+ )
968
+ else:
969
+ tmp = torch.quantile(
970
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
971
+ edit_threshold_c,
972
+ dim=2,
973
+ keepdim=False,
974
+ ).to(noise_guidance_edit_tmp_quantile.dtype)
975
+
976
+ intersect_mask = torch.where(
977
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
978
+ torch.ones_like(noise_guidance_edit_tmp),
979
+ torch.zeros_like(noise_guidance_edit_tmp),
980
+ ) * attn_mask
981
+
982
+ self.activation_mask[i, c] = intersect_mask.detach().cpu()
983
+
984
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask
985
+
986
+ elif not use_cross_attn_mask:
987
+ # calculate quantile
988
+ noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
989
+ noise_guidance_edit_tmp_quantile = torch.sum(noise_guidance_edit_tmp_quantile, dim=1,
990
+ keepdim=True)
991
+ noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1)
992
+
993
+ # torch.quantile function expects float32
994
+ if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
995
+ tmp = torch.quantile(
996
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
997
+ edit_threshold_c,
998
+ dim=2,
999
+ keepdim=False,
1000
+ )
1001
+ else:
1002
+ tmp = torch.quantile(
1003
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
1004
+ edit_threshold_c,
1005
+ dim=2,
1006
+ keepdim=False,
1007
+ ).to(noise_guidance_edit_tmp_quantile.dtype)
1008
+
1009
+ self.activation_mask[i, c] = torch.where(
1010
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
1011
+ torch.ones_like(noise_guidance_edit_tmp),
1012
+ torch.zeros_like(noise_guidance_edit_tmp),
1013
+ ).detach().cpu()
1014
+
1015
+ noise_guidance_edit_tmp = torch.where(
1016
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
1017
+ noise_guidance_edit_tmp,
1018
+ torch.zeros_like(noise_guidance_edit_tmp),
1019
+ )
1020
+
1021
+ noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp
1022
+
1023
+ warmup_inds = torch.tensor(warmup_inds).to(self.device)
1024
+ if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0:
1025
+ concept_weights = concept_weights.to("cpu") # Offload to cpu
1026
+ noise_guidance_edit = noise_guidance_edit.to("cpu")
1027
+
1028
+ concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds)
1029
+ concept_weights_tmp = torch.where(
1030
+ concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp
1031
  )
1032
+ concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
1033
+ # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)
1034
 
1035
+ noise_guidance_edit_tmp = torch.index_select(
1036
+ noise_guidance_edit.to(self.device), 0, warmup_inds
1037
+ )
1038
+ noise_guidance_edit_tmp = torch.einsum(
1039
+ "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp
1040
+ )
1041
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp
1042
+ noise_guidance = noise_guidance + noise_guidance_edit_tmp
1043
 
1044
+ self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu()
 
 
 
 
1045
 
1046
+ del noise_guidance_edit_tmp
1047
+ del concept_weights_tmp
1048
+ concept_weights = concept_weights.to(self.device)
1049
+ noise_guidance_edit = noise_guidance_edit.to(self.device)
1050
 
1051
+ concept_weights = torch.where(
1052
+ concept_weights < 0, torch.zeros_like(concept_weights), concept_weights
1053
+ )
1054
 
1055
+ concept_weights = torch.nan_to_num(concept_weights)
 
1056
 
1057
+ noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit)
1058
+
1059
+ noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum
1060
+
1061
+ edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit
1062
+
1063
+ if warmup_inds.shape[0] == len(noise_pred_edit_concepts):
1064
+ noise_guidance = noise_guidance + noise_guidance_edit
1065
+ self.sem_guidance[i] = noise_guidance_edit.detach().cpu()
1066
+
1067
+ noise_pred = noise_pred_uncond + noise_guidance
1068
 
1069
  # compute the previous noisy sample x_t -> x_t-1
1070
  if use_ddpm:
 
1072
  latents = self.scheduler.step(noise_pred, t, latents, variance_noise=zs[idx],
1073
  **extra_step_kwargs).prev_sample
1074
 
1075
+ else: #if not use_ddpm:
1076
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1077
 
1078
  # step callback
 
1132
  source_prompt: str = "",
1133
  source_guidance_scale=3.5,
1134
  num_inversion_steps: int = 30,
1135
+ skip: float = 0.15,
1136
  eta: float = 1.0,
1137
  generator: Optional[torch.Generator] = None,
1138
  verbose=True,
 
1149
  # self.eta = eta
1150
  # assert (self.eta > 0)
1151
  skip = skip/100
1152
+ print("YOOOOOOOOOOOOOOOOO ", skip, num_inversion_steps)
1153
  train_steps = self.scheduler.config.num_train_timesteps
1154
  timesteps = torch.from_numpy(
1155
  np.linspace(train_steps - skip * train_steps - 1, 1, num_inversion_steps).astype(np.int64)).to(self.device)
 
1158
  self.num_inversion_steps = timesteps.shape[0]
1159
  self.scheduler.num_inference_steps = timesteps.shape[0]
1160
  self.scheduler.timesteps = timesteps
1161
+
1162
+ # Reset attn processor, we do not want to store attn maps during inversion
1163
+ # self.unet.set_default_attn_processor()
1164
  self.unet.set_attn_processor(AttnProcessor())
1165
+
1166
  # 1. get embeddings
1167
 
1168
  uncond_embedding = self.encode_text("")
 
1177
  # autoencoder reconstruction
1178
  # image_rec = self.vae.decode(x0 / self.vae.config.scaling_factor, return_dict=False)[0]
1179
  # image_rec = self.image_processor.postprocess(image_rec, output_type="pil")
 
1180
  # 3. find zs and xts
1181
  variance_noise_shape = (
1182
  self.num_inversion_steps,
 
1226
  # self.zs = zs
1227
 
1228
 
 
1229
  return zs, xts
1230
+ # return zs, xts, image_rec
1231
 
1232
  @torch.no_grad()
1233
  def encode_image(self, image_path, dtype=None):