Linoy Tsaban commited on
Commit
4065064
1 Parent(s): 45e73ca

Update pipeline_semantic_stable_diffusion_img2img_solver.py

Browse files
pipeline_semantic_stable_diffusion_img2img_solver.py CHANGED
@@ -500,6 +500,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
500
  use_cross_attn_mask: bool = False,
501
  # Attention store (just for visualization purposes)
502
  attention_store = None,
 
503
  attn_store_steps: Optional[List[int]] = [],
504
  store_averaged_over_steps: bool = True,
505
  use_intersect_mask: bool = False,
@@ -755,10 +756,10 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
755
  # For classifier free guidance, we need to do two forward passes.
756
  # Here we concatenate the unconditional and text embeddings into a single batch
757
  # to avoid doing two forward passes
758
- self.text_cross_attention_maps = [org_prompt] if isinstance(org_prompt, str) else org_prompt
759
  if enable_edit_guidance:
760
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
761
- self.text_cross_attention_maps += \
762
  ([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt)
763
  else:
764
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
@@ -920,11 +921,11 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
920
  if use_cross_attn_mask:
921
  out = attention_store.aggregate_attention(
922
  attention_maps=attention_store.step_store,
923
- prompts=self.text_cross_attention_maps,
924
  res=16,
925
  from_where=["up", "down"],
926
  is_cross=True,
927
- select=self.text_cross_attention_maps.index(editing_prompt[c]),
928
  )
929
  attn_map = out[:, :, :, 1:1 + num_edit_tokens[c]] # 0 -> startoftext
930
 
@@ -1105,7 +1106,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
1105
  if not return_dict:
1106
  return (image, has_nsfw_concept), attention_store
1107
 
1108
- return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), attention_store
1109
 
1110
  def encode_text(self, prompts):
1111
  text_inputs = self.tokenizer(
 
500
  use_cross_attn_mask: bool = False,
501
  # Attention store (just for visualization purposes)
502
  attention_store = None,
503
+ text_cross_attention_maps = None,
504
  attn_store_steps: Optional[List[int]] = [],
505
  store_averaged_over_steps: bool = True,
506
  use_intersect_mask: bool = False,
 
756
  # For classifier free guidance, we need to do two forward passes.
757
  # Here we concatenate the unconditional and text embeddings into a single batch
758
  # to avoid doing two forward passes
759
+ text_cross_attention_maps = [org_prompt] if isinstance(org_prompt, str) else org_prompt
760
  if enable_edit_guidance:
761
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings, edit_concepts])
762
+ text_cross_attention_maps += \
763
  ([editing_prompt] if isinstance(editing_prompt, str) else editing_prompt)
764
  else:
765
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
 
921
  if use_cross_attn_mask:
922
  out = attention_store.aggregate_attention(
923
  attention_maps=attention_store.step_store,
924
+ prompts=text_cross_attention_maps,
925
  res=16,
926
  from_where=["up", "down"],
927
  is_cross=True,
928
+ select=text_cross_attention_maps.index(editing_prompt[c]),
929
  )
930
  attn_map = out[:, :, :, 1:1 + num_edit_tokens[c]] # 0 -> startoftext
931
 
 
1106
  if not return_dict:
1107
  return (image, has_nsfw_concept), attention_store
1108
 
1109
+ return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), attention_store, text_cross_attention_maps
1110
 
1111
  def encode_text(self, prompts):
1112
  text_inputs = self.tokenizer(