Linoy Tsaban commited on
Commit
54787fd
·
1 Parent(s): c02332f

Update pipeline_semantic_stable_diffusion_img2img_solver.py

Browse files
pipeline_semantic_stable_diffusion_img2img_solver.py CHANGED
@@ -1,33 +1,3 @@
1
- import inspect
2
- import warnings
3
- from itertools import repeat
4
- from typing import Callable, List, Optional, Union
5
-
6
- import torch
7
- from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
8
-
9
- from diffusers.image_processor import VaeImageProcessor
10
- from diffusers.models import AutoencoderKL, UNet2DConditionModel
11
- from diffusers.models.attention_processor import AttnProcessor, Attention
12
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
13
- from diffusers.schedulers import DDIMScheduler
14
- from scheduling_dpmsolver_multistep_inject import DPMSolverMultistepSchedulerInject
15
- # from diffusers.utils import logging, randn_tensor
16
- from diffusers.utils import logging
17
- from diffusers.utils.torch_utils import randn_tensor
18
- from diffusers.pipelines.pipeline_utils import DiffusionPipeline
19
- from diffusers.pipelines.semantic_stable_diffusion import SemanticStableDiffusionPipelineOutput
20
-
21
- import numpy as np
22
- from PIL import Image
23
- from tqdm import tqdm
24
- import torch.nn.functional as F
25
- import math
26
- from collections.abc import Iterable
27
-
28
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
-
30
-
31
  class AttentionStore():
32
  @staticmethod
33
  def get_empty_store():
@@ -48,6 +18,7 @@ class AttentionStore():
48
 
49
  def forward(self, attn, is_cross: bool, place_in_unet: str):
50
  key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
 
51
  self.step_store[key].append(attn)
52
 
53
  def between_steps(self, store_step=True):
@@ -432,10 +403,10 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
432
 
433
  # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
434
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents):
435
- # shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
436
 
437
- # if latents.shape != shape:
438
- # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
439
 
440
  latents = latents.to(device)
441
 
@@ -469,16 +440,8 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
469
  @torch.no_grad()
470
  def __call__(
471
  self,
472
- prompt: Union[str, List[str]] = "",
473
- height: Optional[int] = None,
474
- width: Optional[int] = None,
475
- # num_inference_steps: int = 50,
476
- guidance_scale: float = 7.5,
477
  negative_prompt: Optional[Union[str, List[str]]] = None,
478
- # num_images_per_prompt: int = 1,
479
- eta: float = 1.0,
480
- # generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
481
- # latents: Optional[torch.FloatTensor] = None,
482
  output_type: Optional[str] = "pil",
483
  return_dict: bool = True,
484
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
@@ -491,7 +454,6 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
491
  edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
492
  edit_threshold: Optional[Union[float, List[float]]] = 0.9,
493
  user_mask: Optional[torch.FloatTensor] = None,
494
-
495
  edit_weights: Optional[List[float]] = None,
496
  sem_guidance: Optional[List[torch.Tensor]] = None,
497
  verbose=True,
@@ -502,7 +464,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
502
  use_intersect_mask: bool = False,
503
  init_latents = None,
504
  zs = None,
505
-
506
  ):
507
  r"""
508
  Function invoked when calling the pipeline for generation.
@@ -597,7 +559,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
597
  second element is a list of `bool`s denoting whether the corresponding generated image likely represents
598
  "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
599
  """
600
- # eta = self.eta
601
  num_images_per_prompt = 1
602
  # latents = self.init_latents
603
  latents = init_latents
@@ -612,18 +574,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
612
  if use_cross_attn_mask:
613
  self.smoothing = GaussianSmoothing(self.device)
614
 
615
- # 0. Default height and width to unet
616
- height = height or self.unet.config.sample_size * self.vae_scale_factor
617
- width = width or self.unet.config.sample_size * self.vae_scale_factor
618
-
619
- # 1. Check inputs. Raise error if not correct
620
- self.check_inputs(prompt, height, width, callback_steps)
621
-
622
- org_prompt = prompt
623
- if isinstance(prompt, list):
624
- assert len(prompt) == self.batch_size
625
- elif isinstance(prompt, str):
626
- prompt = list(repeat(prompt, self.batch_size))
627
 
628
  # 2. Define call parameters
629
  batch_size = self.batch_size
@@ -640,35 +591,6 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
640
  self.enabled_editing_prompts = 0
641
  enable_edit_guidance = False
642
 
643
- # get prompt text embeddings
644
- text_inputs = self.tokenizer(
645
- prompt,
646
- padding="max_length",
647
- max_length=self.tokenizer.model_max_length,
648
- truncation=True,
649
- return_tensors="pt",
650
- )
651
- text_input_ids = text_inputs.input_ids
652
- untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
653
-
654
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
655
- text_input_ids, untruncated_ids
656
- ):
657
- removed_text = self.tokenizer.batch_decode(
658
- untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
659
- )
660
- logger.warning(
661
- "The following part of your input was truncated because CLIP can only handle sequences up to"
662
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
663
- )
664
-
665
- text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
666
-
667
- # duplicate text embeddings for each generation per prompt, using mps friendly method
668
- bs_embed, seq_len, _ = text_embeddings.shape
669
- text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
670
- text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
671
-
672
  if enable_edit_guidance:
673
  # get safety text embeddings
674
  if editing_prompt_embeddings is None:
@@ -711,54 +633,47 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
711
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
712
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
713
  # corresponds to doing no classifier free guidance.
714
- do_classifier_free_guidance = guidance_scale > 1.0
715
  # get unconditional embeddings for classifier free guidance
716
 
717
- if do_classifier_free_guidance:
718
- uncond_tokens: List[str]
719
- if negative_prompt is None:
720
- uncond_tokens = [""]
721
- elif type(prompt) is not type(negative_prompt):
722
- raise TypeError(
723
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
724
- f" {type(prompt)}."
725
- )
726
- elif isinstance(negative_prompt, str):
727
- uncond_tokens = [negative_prompt]
728
- elif batch_size != len(negative_prompt):
729
- raise ValueError(
730
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
731
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
732
- " the batch size of `prompt`."
733
- )
734
- else:
735
- uncond_tokens = negative_prompt
736
-
737
- max_length = text_input_ids.shape[-1]
738
- uncond_input = self.tokenizer(
739
- uncond_tokens,
740
- padding="max_length",
741
- max_length=max_length,
742
- truncation=True,
743
- return_tensors="pt",
744
  )
745
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
 
746
 
747
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
748
- seq_len = uncond_embeddings.shape[1]
749
- uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
750
- uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
 
 
 
 
 
751
 
752
- # For classifier free guidance, we need to do two forward passes.
753
- # Here we concatenate the unconditional and text embeddings into a single batch
754
- # to avoid doing two forward passes
755
- self.text_cross_attention_maps = [org_prompt] if isinstance(org_prompt, str) else org_prompt
756
- if enable_edit_guidance:
757
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
758
- self.text_cross_attention_maps += \
759
- ([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt)
760
- else:
761
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
 
 
 
 
762
 
763
  # 4. Prepare timesteps
764
  #self.scheduler.set_timesteps(num_inference_steps, device=self.device)
@@ -776,8 +691,8 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
776
  latents = self.prepare_latents(
777
  batch_size * num_images_per_prompt,
778
  num_channels_latents,
779
- height,
780
- width,
781
  text_embeddings.dtype,
782
  self.device,
783
  latents,
@@ -786,21 +701,16 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
786
  # 6. Prepare extra step kwargs.
787
  extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
788
 
789
-
790
  self.uncond_estimates = None
791
- self.text_estimates = None
792
  self.edit_estimates = None
793
  self.sem_guidance = None
794
  self.activation_mask = None
795
 
796
  for i, t in enumerate(self.progress_bar(timesteps, verbose=verbose)):
797
- idx = t_to_idx[int(t)]
798
-
799
-
800
  # expand the latents if we are doing classifier free guidance
801
 
802
- if do_classifier_free_guidance:
803
- latent_model_input = torch.cat([latents] * (2 + self.enabled_editing_prompts))
804
  else:
805
  latent_model_input = latents
806
 
@@ -811,254 +721,219 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
811
  # predict the noise residual
812
  noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input).sample
813
 
814
- # perform guidance
815
- if do_classifier_free_guidance:
816
 
817
- noise_pred_out = noise_pred.chunk(2 + self.enabled_editing_prompts) # [b,4, 64, 64]
818
- noise_pred_uncond, noise_pred_text = noise_pred_out[0], noise_pred_out[1]
819
- noise_pred_edit_concepts = noise_pred_out[2:]
820
 
821
- # default text guidance
822
- noise_guidance = guidance_scale * (noise_pred_text - noise_pred_uncond)
823
 
824
- if self.uncond_estimates is None:
825
- self.uncond_estimates = torch.zeros((len(timesteps), *noise_pred_uncond.shape))
826
- self.uncond_estimates[i] = noise_pred_uncond.detach().cpu()
827
 
828
- if self.text_estimates is None:
829
- self.text_estimates = torch.zeros((len(timesteps), *noise_pred_text.shape))
830
- self.text_estimates[i] = noise_pred_text.detach().cpu()
831
 
832
-
 
 
 
 
 
 
 
 
833
 
834
- if sem_guidance is not None and len(sem_guidance) > i:
835
- edit_guidance = sem_guidance[i].to(self.device)
836
- noise_guidance = noise_guidance + edit_guidance
837
 
838
- elif enable_edit_guidance:
839
- if self.activation_mask is None:
840
- self.activation_mask = torch.zeros(
841
- (len(timesteps), len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape)
842
- )
843
- if self.edit_estimates is None and enable_edit_guidance:
844
- self.edit_estimates = torch.zeros(
845
- (len(timesteps), len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
846
  )
 
847
 
848
- if self.sem_guidance is None:
849
- self.sem_guidance = torch.zeros((len(timesteps), *noise_pred_text.shape))
 
850
 
851
- concept_weights = torch.zeros(
852
- (len(noise_pred_edit_concepts), noise_guidance.shape[0]),
853
- device=self.device,
854
- dtype=noise_guidance.dtype,
855
- )
856
- noise_guidance_edit = torch.zeros(
857
- (len(noise_pred_edit_concepts), *noise_guidance.shape),
858
- device=self.device,
859
- dtype=noise_guidance.dtype,
860
- )
861
- # noise_guidance_edit = torch.zeros_like(noise_guidance)
862
- warmup_inds = []
863
- for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):
864
- self.edit_estimates[i, c] = noise_pred_edit_concept
865
- if isinstance(edit_guidance_scale, list):
866
- edit_guidance_scale_c = edit_guidance_scale[c]
867
- else:
868
- edit_guidance_scale_c = edit_guidance_scale
869
-
870
- if isinstance(edit_threshold, list):
871
- edit_threshold_c = edit_threshold[c]
872
- else:
873
- edit_threshold_c = edit_threshold
874
- if isinstance(reverse_editing_direction, list):
875
- reverse_editing_direction_c = reverse_editing_direction[c]
876
- else:
877
- reverse_editing_direction_c = reverse_editing_direction
878
- if edit_weights:
879
- edit_weight_c = edit_weights[c]
880
- else:
881
- edit_weight_c = 1.0
882
- if isinstance(edit_warmup_steps, list):
883
- edit_warmup_steps_c = edit_warmup_steps[c]
884
- else:
885
- edit_warmup_steps_c = edit_warmup_steps
886
 
887
- if isinstance(edit_cooldown_steps, list):
888
- edit_cooldown_steps_c = edit_cooldown_steps[c]
889
- elif edit_cooldown_steps is None:
890
- edit_cooldown_steps_c = i + 1
891
  else:
892
- edit_cooldown_steps_c = edit_cooldown_steps
893
- if i >= edit_warmup_steps_c:
894
- warmup_inds.append(c)
895
- if i >= edit_cooldown_steps_c:
896
- noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept)
897
- continue
898
-
899
- noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond
900
- # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3))
901
- tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3))
902
-
903
- tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts)
904
- if reverse_editing_direction_c:
905
- noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1
906
- concept_weights[c, :] = tmp_weights
907
-
908
- noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c
909
-
910
- if user_mask is not None:
911
- noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask
912
-
913
- if use_cross_attn_mask:
914
- out = self.attention_store.aggregate_attention(
915
- attention_maps=self.attention_store.step_store,
916
- prompts=self.text_cross_attention_maps,
917
- res=16,
918
- from_where=["up", "down"],
919
- is_cross=True,
920
- select=self.text_cross_attention_maps.index(editing_prompt[c]),
921
  )
922
- attn_map = out[:, :, :, 1:1 + num_edit_tokens[c]] # 0 -> startoftext
923
-
924
- # average over all tokens
925
- assert (attn_map.shape[3] == num_edit_tokens[c])
926
- attn_map = torch.sum(attn_map, dim=3)
927
-
928
- # gaussian_smoothing
929
- attn_map = F.pad(attn_map.unsqueeze(1), (1, 1, 1, 1), mode="reflect")
930
- attn_map = self.smoothing(attn_map).squeeze(1)
931
-
932
- # create binary mask
933
- if attn_map.dtype == torch.float32:
934
- tmp = torch.quantile(attn_map.flatten(start_dim=1), edit_threshold_c, dim=1)
935
- else:
936
- tmp = torch.quantile(attn_map.flatten(start_dim=1).to(torch.float32), edit_threshold_c, dim=1).to(attn_map.dtype)
937
- attn_mask = torch.where(attn_map >= tmp.unsqueeze(1).unsqueeze(1).repeat(1,16,16), 1.0, 0.0)
938
-
939
- # resolution must match latent space dimension
940
- attn_mask = F.interpolate(
941
- attn_mask.unsqueeze(1),
942
- noise_guidance_edit_tmp.shape[-2:] # 64,64
943
- ).repeat(1, 4, 1, 1)
944
- self.activation_mask[i, c] = attn_mask.detach().cpu()
945
- if not use_intersect_mask:
946
- noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask
947
-
948
- if use_intersect_mask:
949
- noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
950
- noise_guidance_edit_tmp_quantile = torch.sum(noise_guidance_edit_tmp_quantile, dim=1,
951
- keepdim=True)
952
- noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1)
953
-
954
- # torch.quantile function expects float32
955
- if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
956
- tmp = torch.quantile(
957
- noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
958
- edit_threshold_c,
959
- dim=2,
960
- keepdim=False,
961
- )
962
- else:
963
- tmp = torch.quantile(
964
- noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
965
- edit_threshold_c,
966
- dim=2,
967
- keepdim=False,
968
- ).to(noise_guidance_edit_tmp_quantile.dtype)
969
-
970
- intersect_mask = torch.where(
971
- noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
972
- torch.ones_like(noise_guidance_edit_tmp),
973
- torch.zeros_like(noise_guidance_edit_tmp),
974
- ) * attn_mask
975
-
976
- self.activation_mask[i, c] = intersect_mask.detach().cpu()
977
-
978
- noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask
979
-
980
- elif not use_cross_attn_mask:
981
- # calculate quantile
982
- noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
983
- noise_guidance_edit_tmp_quantile = torch.sum(noise_guidance_edit_tmp_quantile, dim=1,
984
- keepdim=True)
985
- noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1)
986
-
987
- # torch.quantile function expects float32
988
- if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
989
- tmp = torch.quantile(
990
- noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
991
- edit_threshold_c,
992
- dim=2,
993
- keepdim=False,
994
- )
995
- else:
996
- tmp = torch.quantile(
997
- noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
998
- edit_threshold_c,
999
- dim=2,
1000
- keepdim=False,
1001
- ).to(noise_guidance_edit_tmp_quantile.dtype)
1002
-
1003
- self.activation_mask[i, c] = torch.where(
1004
- noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
1005
- torch.ones_like(noise_guidance_edit_tmp),
1006
- torch.zeros_like(noise_guidance_edit_tmp),
1007
- ).detach().cpu()
1008
-
1009
- noise_guidance_edit_tmp = torch.where(
1010
- noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
1011
- noise_guidance_edit_tmp,
1012
- torch.zeros_like(noise_guidance_edit_tmp),
1013
  )
1014
-
1015
- noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp
1016
-
1017
- warmup_inds = torch.tensor(warmup_inds).to(self.device)
1018
- if len(noise_pred_edit_concepts) > warmup_inds.shape[0] > 0:
1019
- concept_weights = concept_weights.to("cpu") # Offload to cpu
1020
- noise_guidance_edit = noise_guidance_edit.to("cpu")
1021
-
1022
- concept_weights_tmp = torch.index_select(concept_weights.to(self.device), 0, warmup_inds)
1023
- concept_weights_tmp = torch.where(
1024
- concept_weights_tmp < 0, torch.zeros_like(concept_weights_tmp), concept_weights_tmp
1025
- )
1026
- concept_weights_tmp = concept_weights_tmp / concept_weights_tmp.sum(dim=0)
1027
- # concept_weights_tmp = torch.nan_to_num(concept_weights_tmp)
1028
-
1029
- noise_guidance_edit_tmp = torch.index_select(
1030
- noise_guidance_edit.to(self.device), 0, warmup_inds
1031
- )
1032
- noise_guidance_edit_tmp = torch.einsum(
1033
- "cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp
1034
  )
1035
- noise_guidance_edit_tmp = noise_guidance_edit_tmp
1036
- noise_guidance = noise_guidance + noise_guidance_edit_tmp
1037
 
1038
- self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu()
1039
 
1040
- del noise_guidance_edit_tmp
1041
- del concept_weights_tmp
1042
- concept_weights = concept_weights.to(self.device)
1043
- noise_guidance_edit = noise_guidance_edit.to(self.device)
1044
-
1045
- concept_weights = torch.where(
1046
- concept_weights < 0, torch.zeros_like(concept_weights), concept_weights
1047
- )
1048
-
1049
- concept_weights = torch.nan_to_num(concept_weights)
1050
-
1051
- noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit)
1052
 
1053
- noise_guidance_edit = noise_guidance_edit + edit_momentum_scale * edit_momentum
1054
 
1055
- edit_momentum = edit_mom_beta * edit_momentum + (1 - edit_mom_beta) * noise_guidance_edit
1056
 
1057
- if warmup_inds.shape[0] == len(noise_pred_edit_concepts):
1058
- noise_guidance = noise_guidance + noise_guidance_edit
1059
- self.sem_guidance[i] = noise_guidance_edit.detach().cpu()
1060
 
1061
- noise_pred = noise_pred_uncond + noise_guidance
1062
 
1063
  # compute the previous noisy sample x_t -> x_t-1
1064
  if use_ddpm:
@@ -1066,7 +941,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
1066
  latents = self.scheduler.step(noise_pred, t, latents, variance_noise=zs[idx],
1067
  **extra_step_kwargs).prev_sample
1068
 
1069
- else: #if not use_ddpm:
1070
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1071
 
1072
  # step callback
@@ -1126,7 +1001,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
1126
  source_prompt: str = "",
1127
  source_guidance_scale=3.5,
1128
  num_inversion_steps: int = 30,
1129
- skip: float = 0.15,
1130
  eta: float = 1.0,
1131
  generator: Optional[torch.Generator] = None,
1132
  verbose=True,
@@ -1143,7 +1018,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
1143
  # self.eta = eta
1144
  # assert (self.eta > 0)
1145
  skip = skip/100
1146
- print("YOOOOOOOOOOOOOOOOO ", skip, num_inversion_steps)
1147
  train_steps = self.scheduler.config.num_train_timesteps
1148
  timesteps = torch.from_numpy(
1149
  np.linspace(train_steps - skip * train_steps - 1, 1, num_inversion_steps).astype(np.int64)).to(self.device)
@@ -1152,10 +1027,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
1152
  self.num_inversion_steps = timesteps.shape[0]
1153
  self.scheduler.num_inference_steps = timesteps.shape[0]
1154
  self.scheduler.timesteps = timesteps
1155
-
1156
- # Reset attn processor, we do not want to store attn maps during inversion
1157
- # self.unet.set_default_attn_processor()
1158
- self.unet.set_attn_processor(AttnProcessor())
1159
 
1160
  # 1. get embeddings
1161
 
@@ -1171,6 +1043,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
1171
  # autoencoder reconstruction
1172
  # image_rec = self.vae.decode(x0 / self.vae.config.scaling_factor, return_dict=False)[0]
1173
  # image_rec = self.image_processor.postprocess(image_rec, output_type="pil")
 
1174
  # 3. find zs and xts
1175
  variance_noise_shape = (
1176
  self.num_inversion_steps,
@@ -1220,8 +1093,8 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
1220
  # self.zs = zs
1221
 
1222
 
 
1223
  return zs, xts
1224
- # return zs, xts, image_rec
1225
 
1226
  @torch.no_grad()
1227
  def encode_image(self, image_path, dtype=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  class AttentionStore():
2
  @staticmethod
3
  def get_empty_store():
 
18
 
19
  def forward(self, attn, is_cross: bool, place_in_unet: str):
20
  key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
21
+
22
  self.step_store[key].append(attn)
23
 
24
  def between_steps(self, store_step=True):
 
403
 
404
  # Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
405
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, latents):
406
+ #shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
407
 
408
+ #if latents.shape != shape:
409
+ # raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
410
 
411
  latents = latents.to(device)
412
 
 
440
  @torch.no_grad()
441
  def __call__(
442
  self,
443
+ eta: Optional[float] = 1.0,
 
 
 
 
444
  negative_prompt: Optional[Union[str, List[str]]] = None,
 
 
 
 
445
  output_type: Optional[str] = "pil",
446
  return_dict: bool = True,
447
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
 
454
  edit_cooldown_steps: Optional[Union[int, List[int]]] = None,
455
  edit_threshold: Optional[Union[float, List[float]]] = 0.9,
456
  user_mask: Optional[torch.FloatTensor] = None,
 
457
  edit_weights: Optional[List[float]] = None,
458
  sem_guidance: Optional[List[torch.Tensor]] = None,
459
  verbose=True,
 
464
  use_intersect_mask: bool = False,
465
  init_latents = None,
466
  zs = None,
467
+
468
  ):
469
  r"""
470
  Function invoked when calling the pipeline for generation.
 
559
  second element is a list of `bool`s denoting whether the corresponding generated image likely represents
560
  "not-safe-for-work" (nsfw) content, according to the `safety_checker`.
561
  """
562
+ eta = 1.0
563
  num_images_per_prompt = 1
564
  # latents = self.init_latents
565
  latents = init_latents
 
574
  if use_cross_attn_mask:
575
  self.smoothing = GaussianSmoothing(self.device)
576
 
577
+ org_prompt = ""
 
 
 
 
 
 
 
 
 
 
 
578
 
579
  # 2. Define call parameters
580
  batch_size = self.batch_size
 
591
  self.enabled_editing_prompts = 0
592
  enable_edit_guidance = False
593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
  if enable_edit_guidance:
595
  # get safety text embeddings
596
  if editing_prompt_embeddings is None:
 
633
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
634
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
635
  # corresponds to doing no classifier free guidance.
 
636
  # get unconditional embeddings for classifier free guidance
637
 
638
+
639
+ uncond_tokens: List[str]
640
+ if negative_prompt is None:
641
+ uncond_tokens = [""]
642
+ elif isinstance(negative_prompt, str):
643
+ uncond_tokens = [negative_prompt]
644
+ elif batch_size != len(negative_prompt):
645
+ raise ValueError(
646
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
647
+ f" has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
648
+ " the batch size of `prompt`."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
649
  )
650
+ else:
651
+ uncond_tokens = negative_prompt
652
 
653
+ max_length = self.tokenizer.model_max_length
654
+ uncond_input = self.tokenizer(
655
+ uncond_tokens,
656
+ padding="max_length",
657
+ max_length=max_length,
658
+ truncation=True,
659
+ return_tensors="pt",
660
+ )
661
+ uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
662
 
663
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
664
+ seq_len = uncond_embeddings.shape[1]
665
+ uncond_embeddings = uncond_embeddings.repeat(batch_size, num_images_per_prompt, 1)
666
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len, -1)
667
+
668
+ # For classifier free guidance, we need to do two forward passes.
669
+ # Here we concatenate the unconditional and text embeddings into a single batch
670
+ # to avoid doing two forward passes
671
+ if enable_edit_guidance:
672
+ text_embeddings = torch.cat([uncond_embeddings, edit_concepts])
673
+ self.text_cross_attention_maps = \
674
+ ([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt)
675
+ else:
676
+ text_embeddings = torch.cat([uncond_embeddings])
677
 
678
  # 4. Prepare timesteps
679
  #self.scheduler.set_timesteps(num_inference_steps, device=self.device)
 
691
  latents = self.prepare_latents(
692
  batch_size * num_images_per_prompt,
693
  num_channels_latents,
694
+ None,
695
+ None,
696
  text_embeddings.dtype,
697
  self.device,
698
  latents,
 
701
  # 6. Prepare extra step kwargs.
702
  extra_step_kwargs = self.prepare_extra_step_kwargs(eta)
703
 
 
704
  self.uncond_estimates = None
 
705
  self.edit_estimates = None
706
  self.sem_guidance = None
707
  self.activation_mask = None
708
 
709
  for i, t in enumerate(self.progress_bar(timesteps, verbose=verbose)):
 
 
 
710
  # expand the latents if we are doing classifier free guidance
711
 
712
+ if enable_edit_guidance:
713
+ latent_model_input = torch.cat([latents] * (1 + self.enabled_editing_prompts))
714
  else:
715
  latent_model_input = latents
716
 
 
721
  # predict the noise residual
722
  noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embed_input).sample
723
 
 
 
724
 
725
+ noise_pred_out = noise_pred.chunk(1 + self.enabled_editing_prompts) # [b,4, 64, 64]
726
+ noise_pred_uncond = noise_pred_out[0]
727
+ noise_pred_edit_concepts = noise_pred_out[1:]
728
 
729
+ # default text guidance
730
+ noise_guidance = torch.zeros_like(noise_pred_uncond)
731
 
732
+ if self.uncond_estimates is None:
733
+ self.uncond_estimates = torch.zeros((len(timesteps), *noise_pred_uncond.shape))
734
+ self.uncond_estimates[i] = noise_pred_uncond.detach().cpu()
735
 
736
+ if sem_guidance is not None and len(sem_guidance) > i:
737
+ edit_guidance = sem_guidance[i].to(self.device)
738
+ noise_guidance = noise_guidance + edit_guidance
739
 
740
+ elif enable_edit_guidance:
741
+ if self.activation_mask is None:
742
+ self.activation_mask = torch.zeros(
743
+ (len(timesteps), len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape)
744
+ )
745
+ if self.edit_estimates is None and enable_edit_guidance:
746
+ self.edit_estimates = torch.zeros(
747
+ (len(timesteps), len(noise_pred_edit_concepts), *noise_pred_edit_concepts[0].shape)
748
+ )
749
 
750
+ if self.sem_guidance is None:
751
+ self.sem_guidance = torch.zeros((len(timesteps), *noise_pred_uncond.shape))
 
752
 
753
+ concept_weights = torch.zeros(
754
+ (len(noise_pred_edit_concepts), noise_guidance.shape[0]),
755
+ device=self.device,
756
+ dtype=noise_guidance.dtype,
757
+ )
758
+ noise_guidance_edit = torch.zeros(
759
+ (len(noise_pred_edit_concepts), *noise_guidance.shape),
760
+ device=self.device,
761
+ dtype=noise_guidance.dtype,
762
+ )
763
+ warmup_inds = []
764
+ # noise_guidance_edit = torch.zeros_like(noise_guidance)
765
+ for c, noise_pred_edit_concept in enumerate(noise_pred_edit_concepts):
766
+ self.edit_estimates[i, c] = noise_pred_edit_concept
767
+ if isinstance(edit_warmup_steps, list):
768
+ edit_warmup_steps_c = edit_warmup_steps[c]
769
+ else:
770
+ edit_warmup_steps_c = edit_warmup_steps
771
+ if i >= edit_warmup_steps_c:
772
+ warmup_inds.append(c)
773
+ else:
774
+ continue
775
+
776
+ if isinstance(edit_guidance_scale, list):
777
+ edit_guidance_scale_c = edit_guidance_scale[c]
778
+ else:
779
+ edit_guidance_scale_c = edit_guidance_scale
780
+
781
+ if isinstance(edit_threshold, list):
782
+ edit_threshold_c = edit_threshold[c]
783
+ else:
784
+ edit_threshold_c = edit_threshold
785
+ if isinstance(reverse_editing_direction, list):
786
+ reverse_editing_direction_c = reverse_editing_direction[c]
787
+ else:
788
+ reverse_editing_direction_c = reverse_editing_direction
789
+ if edit_weights:
790
+ edit_weight_c = edit_weights[c]
791
+ else:
792
+ edit_weight_c = 1.0
793
+
794
+ if isinstance(edit_cooldown_steps, list):
795
+ edit_cooldown_steps_c = edit_cooldown_steps[c]
796
+ elif edit_cooldown_steps is None:
797
+ edit_cooldown_steps_c = i + 1
798
+ else:
799
+ edit_cooldown_steps_c = edit_cooldown_steps
800
+
801
+ if i >= edit_cooldown_steps_c:
802
+ noise_guidance_edit[c, :, :, :, :] = torch.zeros_like(noise_pred_edit_concept)
803
+ continue
804
+
805
+ noise_guidance_edit_tmp = noise_pred_edit_concept - noise_pred_uncond
806
+ # tmp_weights = (noise_pred_text - noise_pred_edit_concept).sum(dim=(1, 2, 3))
807
+ tmp_weights = (noise_guidance - noise_pred_edit_concept).sum(dim=(1, 2, 3))
808
+
809
+ tmp_weights = torch.full_like(tmp_weights, edit_weight_c) # * (1 / enabled_editing_prompts)
810
+ if reverse_editing_direction_c:
811
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * -1
812
+ concept_weights[c, :] = tmp_weights
813
+
814
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * edit_guidance_scale_c
815
+
816
+ if user_mask is not None:
817
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask
818
+
819
+ if use_cross_attn_mask:
820
+ out = self.attention_store.aggregate_attention(
821
+ attention_maps=self.attention_store.step_store,
822
+ prompts=self.text_cross_attention_maps,
823
+ res=16,
824
+ from_where=["up", "down"],
825
+ is_cross=True,
826
+ select=self.text_cross_attention_maps.index(editing_prompt[c]),
827
  )
828
+ attn_map = out[:, :, :, 1:1 + num_edit_tokens[c]] # 0 -> startoftext
829
 
830
+ # average over all tokens
831
+ assert (attn_map.shape[3] == num_edit_tokens[c])
832
+ attn_map = torch.sum(attn_map, dim=3)
833
 
834
+ # gaussian_smoothing
835
+ attn_map = F.pad(attn_map.unsqueeze(1), (1, 1, 1, 1), mode="reflect")
836
+ attn_map = self.smoothing(attn_map).squeeze(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
837
 
838
+ # create binary mask
839
+ if attn_map.dtype == torch.float32:
840
+ tmp = torch.quantile(attn_map.flatten(start_dim=1), edit_threshold_c, dim=1)
 
841
  else:
842
+ tmp = torch.quantile(attn_map.flatten(start_dim=1).to(torch.float32), edit_threshold_c, dim=1).to(attn_map.dtype)
843
+ attn_mask = torch.where(attn_map >= tmp.unsqueeze(1).unsqueeze(1).repeat(1,16,16), 1.0, 0.0)
844
+
845
+ # resolution must match latent space dimension
846
+ attn_mask = F.interpolate(
847
+ attn_mask.unsqueeze(1),
848
+ noise_guidance_edit_tmp.shape[-2:] # 64,64
849
+ ).repeat(1, 4, 1, 1)
850
+ self.activation_mask[i, c] = attn_mask.detach().cpu()
851
+ if not use_intersect_mask:
852
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * attn_mask
853
+
854
+ if use_intersect_mask:
855
+ noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
856
+ noise_guidance_edit_tmp_quantile = torch.sum(noise_guidance_edit_tmp_quantile, dim=1,
857
+ keepdim=True)
858
+ noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1)
859
+
860
+ # torch.quantile function expects float32
861
+ if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
862
+ tmp = torch.quantile(
863
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
864
+ edit_threshold_c,
865
+ dim=2,
866
+ keepdim=False,
 
 
 
 
867
  )
868
+ else:
869
+ tmp = torch.quantile(
870
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
871
+ edit_threshold_c,
872
+ dim=2,
873
+ keepdim=False,
874
+ ).to(noise_guidance_edit_tmp_quantile.dtype)
875
+
876
+ intersect_mask = torch.where(
877
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
878
+ torch.ones_like(noise_guidance_edit_tmp),
879
+ torch.zeros_like(noise_guidance_edit_tmp),
880
+ ) * attn_mask
881
+
882
+ self.activation_mask[i, c] = intersect_mask.detach().cpu()
883
+
884
+ noise_guidance_edit_tmp = noise_guidance_edit_tmp * intersect_mask
885
+
886
+ elif not use_cross_attn_mask:
887
+ # calculate quantile
888
+ noise_guidance_edit_tmp_quantile = torch.abs(noise_guidance_edit_tmp)
889
+ noise_guidance_edit_tmp_quantile = torch.sum(noise_guidance_edit_tmp_quantile, dim=1,
890
+ keepdim=True)
891
+ noise_guidance_edit_tmp_quantile = noise_guidance_edit_tmp_quantile.repeat(1, 4, 1, 1)
892
+
893
+ # torch.quantile function expects float32
894
+ if noise_guidance_edit_tmp_quantile.dtype == torch.float32:
895
+ tmp = torch.quantile(
896
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2),
897
+ edit_threshold_c,
898
+ dim=2,
899
+ keepdim=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
900
  )
901
+ else:
902
+ tmp = torch.quantile(
903
+ noise_guidance_edit_tmp_quantile.flatten(start_dim=2).to(torch.float32),
904
+ edit_threshold_c,
905
+ dim=2,
906
+ keepdim=False,
907
+ ).to(noise_guidance_edit_tmp_quantile.dtype)
908
+
909
+ self.activation_mask[i, c] = torch.where(
910
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
911
+ torch.ones_like(noise_guidance_edit_tmp),
912
+ torch.zeros_like(noise_guidance_edit_tmp),
913
+ ).detach().cpu()
914
+
915
+ noise_guidance_edit_tmp = torch.where(
916
+ noise_guidance_edit_tmp_quantile >= tmp[:, :, None, None],
917
+ noise_guidance_edit_tmp,
918
+ torch.zeros_like(noise_guidance_edit_tmp),
 
 
919
  )
 
 
920
 
921
+ noise_guidance_edit[c, :, :, :, :] = noise_guidance_edit_tmp
922
 
923
+ warmup_inds = torch.tensor(warmup_inds).to(self.device)
924
+ concept_weights = torch.index_select(concept_weights, 0, warmup_inds)
925
+ concept_weights = torch.where(
926
+ concept_weights < 0, torch.zeros_like(concept_weights), concept_weights
927
+ )
 
 
 
 
 
 
 
928
 
929
+ concept_weights = torch.nan_to_num(concept_weights)
930
 
931
+ noise_guidance_edit = torch.einsum("cb,cbijk->bijk", concept_weights, noise_guidance_edit)
932
 
933
+ noise_guidance = noise_guidance + noise_guidance_edit
934
+ self.sem_guidance[i] = noise_guidance_edit.detach().cpu()
 
935
 
936
+ noise_pred = noise_pred_uncond + noise_guidance
937
 
938
  # compute the previous noisy sample x_t -> x_t-1
939
  if use_ddpm:
 
941
  latents = self.scheduler.step(noise_pred, t, latents, variance_noise=zs[idx],
942
  **extra_step_kwargs).prev_sample
943
 
944
+ else: # if not use_ddpm:
945
  latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
946
 
947
  # step callback
 
1001
  source_prompt: str = "",
1002
  source_guidance_scale=3.5,
1003
  num_inversion_steps: int = 30,
1004
+ skip: int = 15,
1005
  eta: float = 1.0,
1006
  generator: Optional[torch.Generator] = None,
1007
  verbose=True,
 
1018
  # self.eta = eta
1019
  # assert (self.eta > 0)
1020
  skip = skip/100
1021
+
1022
  train_steps = self.scheduler.config.num_train_timesteps
1023
  timesteps = torch.from_numpy(
1024
  np.linspace(train_steps - skip * train_steps - 1, 1, num_inversion_steps).astype(np.int64)).to(self.device)
 
1027
  self.num_inversion_steps = timesteps.shape[0]
1028
  self.scheduler.num_inference_steps = timesteps.shape[0]
1029
  self.scheduler.timesteps = timesteps
1030
+
 
 
 
1031
 
1032
  # 1. get embeddings
1033
 
 
1043
  # autoencoder reconstruction
1044
  # image_rec = self.vae.decode(x0 / self.vae.config.scaling_factor, return_dict=False)[0]
1045
  # image_rec = self.image_processor.postprocess(image_rec, output_type="pil")
1046
+
1047
  # 3. find zs and xts
1048
  variance_noise_shape = (
1049
  self.num_inversion_steps,
 
1093
  # self.zs = zs
1094
 
1095
 
1096
+
1097
  return zs, xts
 
1098
 
1099
  @torch.no_grad()
1100
  def encode_image(self, image_path, dtype=None):