Update pipeline_semantic_stable_diffusion_img2img_solver.py

#8
by linoyts HF staff - opened
app.py CHANGED
@@ -35,9 +35,9 @@ def caption_image(input_image):
35
  generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
36
  return generated_caption, generated_caption
37
 
38
- def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
39
  latents = wts[-1].expand(1, -1, -1, -1)
40
- img = pipe(
41
  prompt=prompt_tar,
42
  init_latents=latents,
43
  guidance_scale=cfg_scale_tar,
@@ -45,9 +45,10 @@ def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
45
  # num_inference_steps=steps,
46
  # use_ddpm=True,
47
  # wts=wts.value,
 
48
  zs=zs,
49
- ).images[0]
50
- return img
51
 
52
 
53
  def reconstruct(
@@ -57,6 +58,7 @@ def reconstruct(
57
  skip,
58
  wts,
59
  zs,
 
60
  do_reconstruction,
61
  reconstruction,
62
  reconstruct_button,
@@ -77,8 +79,8 @@ def reconstruct(
77
  ): # if image caption was not changed, run actual reconstruction
78
  tar_prompt = ""
79
  latents = wts[-1].expand(1, -1, -1, -1)
80
- reconstruction = sample(
81
- zs, wts, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale
82
  )
83
  do_reconstruction = False
84
  return (
@@ -128,7 +130,7 @@ def load_and_invert(
128
  ## SEGA ##
129
 
130
  def edit(input_image,
131
- wts, zs,
132
  tar_prompt,
133
  image_caption,
134
  steps,
@@ -195,27 +197,27 @@ def edit(input_image,
195
  )
196
 
197
  latnets = wts[-1].expand(1, -1, -1, -1)
198
- sega_out = pipe(prompt=tar_prompt,
199
  init_latents=latnets,
200
  guidance_scale = tar_cfg_scale,
201
  # num_images_per_prompt=1,
202
  # num_inference_steps=steps,
203
  # use_ddpm=True,
204
  # wts=wts.value,
205
- zs=zs, **editing_args)
206
 
207
- return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
208
 
209
 
210
  else: # if sega concepts were not added, performs regular ddpm sampling
211
 
212
  if do_reconstruction: # if ddpm sampling wasn't computed
213
- pure_ddpm_img = sample(zs, wts, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
214
  reconstruction = pure_ddpm_img
215
  do_reconstruction = False
216
- return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
217
 
218
- return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
219
 
220
 
221
  def randomize_seed_fn(seed, is_random):
@@ -458,6 +460,7 @@ with gr.Blocks(css="style.css") as demo:
458
  gr.HTML(intro)
459
  wts = gr.State()
460
  zs = gr.State()
 
461
  reconstruction = gr.State()
462
  do_inversion = gr.State(value=True)
463
  do_reconstruction = gr.State(value=True)
@@ -693,7 +696,7 @@ with gr.Blocks(css="style.css") as demo:
693
  run_button.click(
694
  fn=edit,
695
  inputs=[input_image,
696
- wts, zs,
697
  tar_prompt,
698
  image_caption,
699
  steps,
@@ -713,7 +716,7 @@ with gr.Blocks(css="style.css") as demo:
713
 
714
 
715
  ],
716
- outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction, wts, zs, do_inversion, share_btn_container])
717
  # .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
718
 
719
 
 
35
  generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
36
  return generated_caption, generated_caption
37
 
38
+ def sample(zs, wts, attention_store, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
39
  latents = wts[-1].expand(1, -1, -1, -1)
40
+ img, attention_store = pipe(
41
  prompt=prompt_tar,
42
  init_latents=latents,
43
  guidance_scale=cfg_scale_tar,
 
45
  # num_inference_steps=steps,
46
  # use_ddpm=True,
47
  # wts=wts.value,
48
+ attention_store = attention_store,
49
  zs=zs,
50
+ )
51
+ return img.images[0], attention_store
52
 
53
 
54
  def reconstruct(
 
58
  skip,
59
  wts,
60
  zs,
61
+ attention_store,
62
  do_reconstruction,
63
  reconstruction,
64
  reconstruct_button,
 
79
  ): # if image caption was not changed, run actual reconstruction
80
  tar_prompt = ""
81
  latents = wts[-1].expand(1, -1, -1, -1)
82
+ reconstruction, attention_store = sample(
83
+ zs, wts, attention_store=attention_store, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale
84
  )
85
  do_reconstruction = False
86
  return (
 
130
  ## SEGA ##
131
 
132
  def edit(input_image,
133
+ wts, zs, attention_store,
134
  tar_prompt,
135
  image_caption,
136
  steps,
 
197
  )
198
 
199
  latnets = wts[-1].expand(1, -1, -1, -1)
200
+ sega_out, attention_store = pipe(prompt=tar_prompt,
201
  init_latents=latnets,
202
  guidance_scale = tar_cfg_scale,
203
  # num_images_per_prompt=1,
204
  # num_inference_steps=steps,
205
  # use_ddpm=True,
206
  # wts=wts.value,
207
+ zs=zs, attention_store=attention_store, **editing_args)
208
 
209
+ return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
210
 
211
 
212
  else: # if sega concepts were not added, performs regular ddpm sampling
213
 
214
  if do_reconstruction: # if ddpm sampling wasn't computed
215
+ pure_ddpm_img, attention_store = sample(zs, wts, attention_store=attention_store, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
216
  reconstruction = pure_ddpm_img
217
  do_reconstruction = False
218
+ return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
219
 
220
+ return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
221
 
222
 
223
  def randomize_seed_fn(seed, is_random):
 
460
  gr.HTML(intro)
461
  wts = gr.State()
462
  zs = gr.State()
463
+ attention_store=gr.State()
464
  reconstruction = gr.State()
465
  do_inversion = gr.State(value=True)
466
  do_reconstruction = gr.State(value=True)
 
696
  run_button.click(
697
  fn=edit,
698
  inputs=[input_image,
699
+ wts, zs, attention_store,
700
  tar_prompt,
701
  image_caption,
702
  steps,
 
716
 
717
 
718
  ],
719
+ outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction, wts, zs,attention_store, do_inversion, share_btn_container])
720
  # .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
721
 
722
 
pipeline_semantic_stable_diffusion_img2img_solver.py CHANGED
@@ -499,6 +499,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
499
  verbose=True,
500
  use_cross_attn_mask: bool = False,
501
  # Attention store (just for visualization purposes)
 
502
  attn_store_steps: Optional[List[int]] = [],
503
  store_averaged_over_steps: bool = True,
504
  use_intersect_mask: bool = False,
@@ -771,8 +772,8 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
771
  timesteps = timesteps[-zs.shape[0]:]
772
 
773
  if use_cross_attn_mask:
774
- self.attention_store = AttentionStore(average=store_averaged_over_steps, batch_size=batch_size)
775
- self.prepare_unet(self.attention_store, PnP=False)
776
  # 5. Prepare latent variables
777
  num_channels_latents = self.unet.config.in_channels
778
  latents = self.prepare_latents(
@@ -917,8 +918,8 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
917
  noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask
918
 
919
  if use_cross_attn_mask:
920
- out = self.attention_store.aggregate_attention(
921
- attention_maps=self.attention_store.step_store,
922
  prompts=self.text_cross_attention_maps,
923
  res=16,
924
  from_where=["up", "down"],
@@ -1080,7 +1081,7 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
1080
  store_step = i in attn_store_steps
1081
  if store_step:
1082
  print(f"storing attention for step {i}")
1083
- self.attention_store.between_steps(store_step)
1084
 
1085
  # call the callback, if provided
1086
  if callback is not None and i % callback_steps == 0:
@@ -1102,9 +1103,9 @@ class SemanticStableDiffusionImg2ImgPipeline_DPMSolver(DiffusionPipeline):
1102
  image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1103
 
1104
  if not return_dict:
1105
- return (image, has_nsfw_concept)
1106
 
1107
- return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
1108
 
1109
  def encode_text(self, prompts):
1110
  text_inputs = self.tokenizer(
 
499
  verbose=True,
500
  use_cross_attn_mask: bool = False,
501
  # Attention store (just for visualization purposes)
502
+ attention_store = None,
503
  attn_store_steps: Optional[List[int]] = [],
504
  store_averaged_over_steps: bool = True,
505
  use_intersect_mask: bool = False,
 
772
  timesteps = timesteps[-zs.shape[0]:]
773
 
774
  if use_cross_attn_mask:
775
+ attention_store = AttentionStore(average=store_averaged_over_steps, batch_size=batch_size)
776
+ self.prepare_unet(attention_store, PnP=False)
777
  # 5. Prepare latent variables
778
  num_channels_latents = self.unet.config.in_channels
779
  latents = self.prepare_latents(
 
918
  noise_guidance_edit_tmp = noise_guidance_edit_tmp * user_mask
919
 
920
  if use_cross_attn_mask:
921
+ out = attention_store.aggregate_attention(
922
+ attention_maps=attention_store.step_store,
923
  prompts=self.text_cross_attention_maps,
924
  res=16,
925
  from_where=["up", "down"],
 
1081
  store_step = i in attn_store_steps
1082
  if store_step:
1083
  print(f"storing attention for step {i}")
1084
+ attention_store.between_steps(store_step)
1085
 
1086
  # call the callback, if provided
1087
  if callback is not None and i % callback_steps == 0:
 
1103
  image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
1104
 
1105
  if not return_dict:
1106
+ return (image, has_nsfw_concept), attention_store
1107
 
1108
+ return SemanticStableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept), attention_store
1109
 
1110
  def encode_text(self, prompts):
1111
  text_inputs = self.tokenizer(