Linoy Tsaban commited on
Commit
d2ec8a3
1 Parent(s): af2b22f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -2
app.py CHANGED
@@ -31,7 +31,7 @@ def caption_image(input_image):
31
 
32
  generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
33
  generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
34
- return generated_caption
35
 
36
 
37
 
@@ -123,6 +123,7 @@ def load_and_invert(
123
  def edit(input_image,
124
  wts, zs,
125
  tar_prompt,
 
126
  steps,
127
  skip,
128
  tar_cfg_scale,
@@ -162,6 +163,8 @@ def edit(input_image,
162
  eta=1,)
163
 
164
  latnets = wts.value[skip].expand(1, -1, -1, -1)
 
 
165
  sega_out = sem_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale,
166
  num_images_per_prompt=1,
167
  num_inference_steps=steps,
@@ -426,6 +429,9 @@ with gr.Blocks(css="style.css") as demo:
426
  do_reconstruction = True
427
  return do_reconstruction
428
 
 
 
 
429
  def update_inversion_progress_visibility(input_image, do_inversion):
430
  if do_inversion and not input_image is None:
431
  return inversion_progress.update(visible=True)
@@ -446,6 +452,7 @@ with gr.Blocks(css="style.css") as demo:
446
  do_inversion = gr.State(value=True)
447
  do_reconstruction = gr.State(value=True)
448
  sega_concepts_counter = gr.State(0)
 
449
 
450
 
451
 
@@ -659,6 +666,7 @@ with gr.Blocks(css="style.css") as demo:
659
  inputs=[input_image,
660
  wts, zs,
661
  tar_prompt,
 
662
  steps,
663
  skip,
664
  tar_cfg_scale,
@@ -689,7 +697,7 @@ with gr.Blocks(css="style.css") as demo:
689
  outputs = [do_inversion],
690
  queue = False).then(fn = caption_image,
691
  inputs = [input_image],
692
- outputs = [tar_prompt]).then(fn = update_inversion_progress_visibility, inputs =[input_image,do_inversion],
693
  outputs=[inversion_progress],queue=False).then(
694
  fn=load_and_invert,
695
  inputs=[input_image,
 
31
 
32
  generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
33
  generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
34
+ return generated_caption, generated_caption
35
 
36
 
37
 
 
123
  def edit(input_image,
124
  wts, zs,
125
  tar_prompt,
126
+ image_caption,
127
  steps,
128
  skip,
129
  tar_cfg_scale,
 
163
  eta=1,)
164
 
165
  latnets = wts.value[skip].expand(1, -1, -1, -1)
166
+ if image_caption == tar_prompt:
167
+ tar_prompt = ""
168
  sega_out = sem_pipe(prompt=tar_prompt, latents=latnets, guidance_scale = tar_cfg_scale,
169
  num_images_per_prompt=1,
170
  num_inference_steps=steps,
 
429
  do_reconstruction = True
430
  return do_reconstruction
431
 
432
+ def reset_image_caption():
433
+ return ""
434
+
435
  def update_inversion_progress_visibility(input_image, do_inversion):
436
  if do_inversion and not input_image is None:
437
  return inversion_progress.update(visible=True)
 
452
  do_inversion = gr.State(value=True)
453
  do_reconstruction = gr.State(value=True)
454
  sega_concepts_counter = gr.State(0)
455
+ image_caption = gr.State(value="")
456
 
457
 
458
 
 
666
  inputs=[input_image,
667
  wts, zs,
668
  tar_prompt,
669
+ image_caption,
670
  steps,
671
  skip,
672
  tar_cfg_scale,
 
697
  outputs = [do_inversion],
698
  queue = False).then(fn = caption_image,
699
  inputs = [input_image],
700
+ outputs = [tar_prompt, image_caption]).then(fn = update_inversion_progress_visibility, inputs =[input_image,do_inversion],
701
  outputs=[inversion_progress],queue=False).then(
702
  fn=load_and_invert,
703
  inputs=[input_image,