Linoy Tsaban commited on
Commit
5c86655
1 Parent(s): f58d8b8

Update app.py

Browse files

fix auto-inversion and auto-caption when removing an image

Files changed (1) hide show
  1. app.py +12 -7
app.py CHANGED
@@ -27,12 +27,14 @@ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image
27
  ## IMAGE CPATIONING ##
28
  def caption_image(input_image):
29
 
30
- inputs = blip_processor(images=input_image, return_tensors="pt").to(device)
31
- pixel_values = inputs.pixel_values
 
32
 
33
- generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
34
- generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
35
- return generated_caption
 
36
 
37
 
38
  ## DDPM INVERSION AND SAMPLING ##
@@ -361,8 +363,10 @@ with gr.Blocks(css="style.css") as demo:
361
 
362
 
363
  def reset_do_inversion():
364
- do_inversion = True
365
- return do_inversion
 
 
366
 
367
  def reset_do_reconstruction():
368
  do_reconstruction = True
@@ -597,6 +601,7 @@ with gr.Blocks(css="style.css") as demo:
597
  # Automatically start inverting upon input_image change
598
  input_image.change(
599
  fn = reset_do_inversion,
 
600
  outputs = [do_inversion],
601
  queue = False).then(fn = caption_image,
602
  inputs = [input_image],
 
27
  ## IMAGE CPATIONING ##
28
  def caption_image(input_image):
29
 
30
+ if not input_image is None:
31
+ inputs = blip_processor(images=input_image, return_tensors="pt").to(device)
32
+ pixel_values = inputs.pixel_values
33
 
34
+ generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
35
+ generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
36
+ return generated_caption
37
+ return ""
38
 
39
 
40
  ## DDPM INVERSION AND SAMPLING ##
 
363
 
364
 
365
  def reset_do_inversion():
366
+ if not input_image is None:
367
+ return True
368
+ else:
369
+ return False
370
 
371
  def reset_do_reconstruction():
372
  do_reconstruction = True
 
601
  # Automatically start inverting upon input_image change
602
  input_image.change(
603
  fn = reset_do_inversion,
604
+ inputs = [input_image],
605
  outputs = [do_inversion],
606
  queue = False).then(fn = caption_image,
607
  inputs = [input_image],