Linoy Tsaban commited on
Commit
21466b8
1 Parent(s): ec15161

Update app.py

Browse files

attention store

Files changed (1) hide show
  1. app.py +18 -15
app.py CHANGED
@@ -35,9 +35,9 @@ def caption_image(input_image):
35
  generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
36
  return generated_caption, generated_caption
37
 
38
- def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
39
  latents = wts[-1].expand(1, -1, -1, -1)
40
- img = pipe(
41
  prompt=prompt_tar,
42
  init_latents=latents,
43
  guidance_scale=cfg_scale_tar,
@@ -45,9 +45,10 @@ def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
45
  # num_inference_steps=steps,
46
  # use_ddpm=True,
47
  # wts=wts.value,
 
48
  zs=zs,
49
- ).images[0]
50
- return img
51
 
52
 
53
  def reconstruct(
@@ -57,6 +58,7 @@ def reconstruct(
57
  skip,
58
  wts,
59
  zs,
 
60
  do_reconstruction,
61
  reconstruction,
62
  reconstruct_button,
@@ -77,8 +79,8 @@ def reconstruct(
77
  ): # if image caption was not changed, run actual reconstruction
78
  tar_prompt = ""
79
  latents = wts[-1].expand(1, -1, -1, -1)
80
- reconstruction = sample(
81
- zs, wts, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale
82
  )
83
  do_reconstruction = False
84
  return (
@@ -128,7 +130,7 @@ def load_and_invert(
128
  ## SEGA ##
129
 
130
  def edit(input_image,
131
- wts, zs,
132
  tar_prompt,
133
  image_caption,
134
  steps,
@@ -195,27 +197,27 @@ def edit(input_image,
195
  )
196
 
197
  latnets = wts[-1].expand(1, -1, -1, -1)
198
- sega_out = pipe(prompt=tar_prompt,
199
  init_latents=latnets,
200
  guidance_scale = tar_cfg_scale,
201
  # num_images_per_prompt=1,
202
  # num_inference_steps=steps,
203
  # use_ddpm=True,
204
  # wts=wts.value,
205
- zs=zs, **editing_args)
206
 
207
- return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
208
 
209
 
210
  else: # if sega concepts were not added, performs regular ddpm sampling
211
 
212
  if do_reconstruction: # if ddpm sampling wasn't computed
213
- pure_ddpm_img = sample(zs, wts, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
214
  reconstruction = pure_ddpm_img
215
  do_reconstruction = False
216
- return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
217
 
218
- return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, do_inversion, show_share_button
219
 
220
 
221
  def randomize_seed_fn(seed, is_random):
@@ -458,6 +460,7 @@ with gr.Blocks(css="style.css") as demo:
458
  gr.HTML(intro)
459
  wts = gr.State()
460
  zs = gr.State()
 
461
  reconstruction = gr.State()
462
  do_inversion = gr.State(value=True)
463
  do_reconstruction = gr.State(value=True)
@@ -693,7 +696,7 @@ with gr.Blocks(css="style.css") as demo:
693
  run_button.click(
694
  fn=edit,
695
  inputs=[input_image,
696
- wts, zs,
697
  tar_prompt,
698
  image_caption,
699
  steps,
@@ -713,7 +716,7 @@ with gr.Blocks(css="style.css") as demo:
713
 
714
 
715
  ],
716
- outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction, wts, zs, do_inversion, share_btn_container])
717
  # .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
718
 
719
 
 
35
  generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
36
  return generated_caption, generated_caption
37
 
38
+ def sample(zs, wts, attention_store, prompt_tar="", cfg_scale_tar=15, skip=36, eta=1):
39
  latents = wts[-1].expand(1, -1, -1, -1)
40
+ img, attention_store = pipe(
41
  prompt=prompt_tar,
42
  init_latents=latents,
43
  guidance_scale=cfg_scale_tar,
 
45
  # num_inference_steps=steps,
46
  # use_ddpm=True,
47
  # wts=wts.value,
48
+ attention_store = attention_store,
49
  zs=zs,
50
+ )
51
+ return img.images[0], attention_store
52
 
53
 
54
  def reconstruct(
 
58
  skip,
59
  wts,
60
  zs,
61
+ attention_store,
62
  do_reconstruction,
63
  reconstruction,
64
  reconstruct_button,
 
79
  ): # if image caption was not changed, run actual reconstruction
80
  tar_prompt = ""
81
  latents = wts[-1].expand(1, -1, -1, -1)
82
+ reconstruction, attention_store = sample(
83
+ zs, wts, attention_store=attention_store, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale
84
  )
85
  do_reconstruction = False
86
  return (
 
130
  ## SEGA ##
131
 
132
  def edit(input_image,
133
+ wts, zs, attention_store,
134
  tar_prompt,
135
  image_caption,
136
  steps,
 
197
  )
198
 
199
  latnets = wts[-1].expand(1, -1, -1, -1)
200
+ sega_out, attention_store = pipe(prompt=tar_prompt,
201
  init_latents=latnets,
202
  guidance_scale = tar_cfg_scale,
203
  # num_images_per_prompt=1,
204
  # num_inference_steps=steps,
205
  # use_ddpm=True,
206
  # wts=wts.value,
207
+ zs=zs, attention_store=attention_store, **editing_args)
208
 
209
+ return sega_out.images[0], gr.update(visible=True), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
210
 
211
 
212
  else: # if sega concepts were not added, performs regular ddpm sampling
213
 
214
  if do_reconstruction: # if ddpm sampling wasn't computed
215
+ pure_ddpm_img, attention_store = sample(zs, wts, attention_store=attention_store, prompt_tar=tar_prompt, skip=skip, cfg_scale_tar=tar_cfg_scale)
216
  reconstruction = pure_ddpm_img
217
  do_reconstruction = False
218
+ return pure_ddpm_img, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
219
 
220
+ return reconstruction, gr.update(visible=False), do_reconstruction, reconstruction, wts, zs, attention_store, do_inversion, show_share_button
221
 
222
 
223
  def randomize_seed_fn(seed, is_random):
 
460
  gr.HTML(intro)
461
  wts = gr.State()
462
  zs = gr.State()
463
+ attention_store=gr.State()
464
  reconstruction = gr.State()
465
  do_inversion = gr.State(value=True)
466
  do_reconstruction = gr.State(value=True)
 
696
  run_button.click(
697
  fn=edit,
698
  inputs=[input_image,
699
+ wts, zs, attention_store,
700
  tar_prompt,
701
  image_caption,
702
  steps,
 
716
 
717
 
718
  ],
719
+ outputs=[sega_edited_image, reconstruct_button, do_reconstruction, reconstruction, wts, zs,attention_store, do_inversion, share_btn_container])
720
  # .success(fn=update_gallery_display, inputs= [prev_output_image, sega_edited_image], outputs = [gallery, gallery, prev_output_image])
721
 
722