liuyizhang commited on
Commit
63e6e86
1 Parent(s): b269211

update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -49
app.py CHANGED
@@ -434,8 +434,7 @@ def concatenate_images_vertical(image1, image2):
434
 
435
  return new_image
436
 
437
- def relate_anything(input_image_mask, k):
438
- input_image = input_image_mask['image']
439
  logger.info(f'relate_anything_1_{input_image.size}_')
440
  w, h = input_image.size
441
  max_edge = 1500
@@ -478,15 +477,17 @@ def relate_anything(input_image_mask, k):
478
  concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
479
  pil_image_list.append(concate_pil_image)
480
 
481
- logger.info(f'relate_anything_5_')
482
- yield pil_image_list
483
-
484
 
485
  mask_source_draw = "draw a mask on input image"
486
  mask_source_segment = "type what to detect below"
487
 
488
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
489
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation):
 
 
 
 
490
  text_prompt = text_prompt.strip()
491
  if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
492
  if text_prompt == '':
@@ -510,7 +511,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
510
  size = image_pil.size
511
 
512
  output_images = []
513
- # output_images.append(input_image['image'])
514
  # run grounding dino model
515
  if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
516
  pass
@@ -538,11 +539,12 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
538
  "labels": pred_phrases,
539
  }
540
  image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
541
- image_path = os.path.join(output_dir, f"grounding_dino_output_{file_temp}.jpg")
542
- image_with_box.save(image_path)
543
- detection_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
544
- os.remove(image_path)
545
- output_images.append(detection_image_result)
 
546
 
547
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
548
  if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
@@ -600,13 +602,12 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
600
  mask = masks[0][0].cpu().numpy()
601
  mask_pil = Image.fromarray(mask)
602
 
603
- image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
604
- # if reverse_mask:
605
- # mask_pil = mask_pil.point(lambda _: 255-_)
606
- mask_pil.convert("RGB").save(image_path)
607
- image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
608
- os.remove(image_path)
609
- output_images.append(image_result)
610
 
611
  if task_type == 'inpainting':
612
  # inpainting pipeline
@@ -645,24 +646,23 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
645
  mask_imgs.append(mask_pil_exp)
646
  mask_pil = mix_masks(mask_imgs)
647
 
648
- image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
649
- # if reverse_mask:
650
- # mask_pil = mask_pil.point(lambda _: 255-_)
651
- mask_pil.convert("RGB").save(image_path)
652
- image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
653
- os.remove(image_path)
654
- output_images.append(image_result)
655
  image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")))
656
 
657
  image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
658
 
659
- image_path = os.path.join(output_dir, f"grounded_sam_inpainting_output_{file_temp}.jpg")
660
- image_inpainting.save(image_path)
661
- image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
662
- os.remove(image_path)
663
- logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9__{type(image_inpainting)} / {type(image_result)}')
664
- output_images.append(image_inpainting)
665
  # output_images.append(image_result)
 
666
  return output_images, gr.Gallery.update(label='result images')
667
  else:
668
  logger.info(f"task_type:{task_type} error!")
@@ -674,10 +674,10 @@ def change_radio_display(task_type, mask_source_radio):
674
  inpaint_prompt_visible = False
675
  mask_source_radio_visible = False
676
  num_relation_visible = False
677
- run_button_visible = True
678
- relate_all_button_visible = False
679
- gsa_gallery_visible = True
680
- ram_gallery_visible = False
681
  if task_type == "inpainting":
682
  inpaint_prompt_visible = True
683
  if task_type == "inpainting" or task_type == "remove":
@@ -687,11 +687,12 @@ def change_radio_display(task_type, mask_source_radio):
687
  if task_type == "relate anything":
688
  text_prompt_visible = False
689
  num_relation_visible = True
690
- run_button_visible = False
691
- relate_all_button_visible = True
692
- gsa_gallery_visible = False
693
- ram_gallery_visible = True
694
- return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible), gr.Button.update(visible=run_button_visible), gr.Button.update(visible=relate_all_button_visible), gr.Gallery.update(visible=gsa_gallery_visible), gr.Gallery.update(visible=ram_gallery_visible)
 
695
 
696
  if __name__ == "__main__":
697
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
@@ -715,7 +716,7 @@ if __name__ == "__main__":
715
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
716
  num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
717
  run_button = gr.Button(label="Run", visible=True)
718
- relate_all_button = gr.Button(label="Run", visible=False)
719
  with gr.Accordion("Advanced options", open=False) as advanced_options:
720
  box_threshold = gr.Slider(
721
  label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
@@ -734,17 +735,19 @@ if __name__ == "__main__":
734
  remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
735
 
736
  with gr.Column():
737
- gsa_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gsa_allery", visible=True
738
- ).style(preview=True, grid=[2], full_width=True, full_height=True)
739
- ram_gallery = gr.Gallery(label="Your Result", show_label=True, elem_id="ram_gallery", visible=False
740
- ).style(preview=True, columns=5, object_fit="scale-down")
 
 
741
 
742
  run_button.click(fn=run_anything_task, inputs=[
743
- input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation], outputs=[gsa_gallery, gsa_gallery], show_progress=True, queue=True)
744
- relate_all_button.click(fn=relate_anything, inputs=[input_image, num_relation], outputs=[ram_gallery], show_progress=True, queue=True)
745
 
746
- task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation, run_button, relate_all_button, gsa_gallery, ram_gallery])
747
- mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation, run_button, relate_all_button, gsa_gallery, ram_gallery])
748
 
749
  DESCRIPTION = '### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
750
  DESCRIPTION += 'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything). <br>'
 
434
 
435
  return new_image
436
 
437
+ def relate_anything(input_image, k):
 
438
  logger.info(f'relate_anything_1_{input_image.size}_')
439
  w, h = input_image.size
440
  max_edge = 1500
 
477
  concate_pil_image = concatenate_images_vertical(current_pil_image, title_image)
478
  pil_image_list.append(concate_pil_image)
479
 
480
+ return pil_image_list
 
 
481
 
482
  mask_source_draw = "draw a mask on input image"
483
  mask_source_segment = "type what to detect below"
484
 
485
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
486
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation):
487
+ if (task_type == 'relate anything'):
488
+ output_images = relate_anything(input_image['image'], num_relation)
489
+ return output_images, gr.Gallery.update(label='relate images')
490
+
491
  text_prompt = text_prompt.strip()
492
  if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
493
  if text_prompt == '':
 
511
  size = image_pil.size
512
 
513
  output_images = []
514
+ output_images.append(input_image['image'])
515
  # run grounding dino model
516
  if (task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw:
517
  pass
 
539
  "labels": pred_phrases,
540
  }
541
  image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
542
+ # image_path = os.path.join(output_dir, f"grounding_dino_output_{file_temp}.jpg")
543
+ # image_with_box.save(image_path)
544
+ # detection_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
545
+ # os.remove(image_path)
546
+ # output_images.append(detection_image_result)
547
+ output_images.append(image_with_box)
548
 
549
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
550
  if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
 
602
  mask = masks[0][0].cpu().numpy()
603
  mask_pil = Image.fromarray(mask)
604
 
605
+ # image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
606
+ # mask_pil.convert("RGB").save(image_path)
607
+ # image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
608
+ # os.remove(image_path)
609
+ # output_images.append(image_result)
610
+ output_images.append(mask_pil.convert("RGB"))
 
611
 
612
  if task_type == 'inpainting':
613
  # inpainting pipeline
 
646
  mask_imgs.append(mask_pil_exp)
647
  mask_pil = mix_masks(mask_imgs)
648
 
649
+ # image_path = os.path.join(output_dir, f"image_mask_{file_temp}.jpg")
650
+ # mask_pil.convert("RGB").save(image_path)
651
+ # image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
652
+ # os.remove(image_path)
653
+ # output_images.append(image_result)
654
+ output_images.append(mask_pil.convert("RGB"))
 
655
  image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")))
656
 
657
  image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
658
 
659
+ # image_path = os.path.join(output_dir, f"grounded_sam_inpainting_output_{file_temp}.jpg")
660
+ # image_inpainting.save(image_path)
661
+ # image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
662
+ # os.remove(image_path)
663
+ # logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
 
664
  # output_images.append(image_result)
665
+ output_images.append(image_inpainting)
666
  return output_images, gr.Gallery.update(label='result images')
667
  else:
668
  logger.info(f"task_type:{task_type} error!")
 
674
  inpaint_prompt_visible = False
675
  mask_source_radio_visible = False
676
  num_relation_visible = False
677
+ # run_button_visible = True
678
+ # relate_all_button_visible = False
679
+ # gsa_gallery_visible = True
680
+ # ram_gallery_visible = False
681
  if task_type == "inpainting":
682
  inpaint_prompt_visible = True
683
  if task_type == "inpainting" or task_type == "remove":
 
687
  if task_type == "relate anything":
688
  text_prompt_visible = False
689
  num_relation_visible = True
690
+ # run_button_visible = False
691
+ # relate_all_button_visible = True
692
+ # gsa_gallery_visible = False
693
+ # ram_gallery_visible = True
694
+ return gr.Textbox.update(visible=text_prompt_visible), gr.Textbox.update(visible=inpaint_prompt_visible), gr.Radio.update(visible=mask_source_radio_visible), gr.Slider.update(visible=num_relation_visible)
695
+ #, gr.Button.update(visible=run_button_visible), gr.Button.update(visible=relate_all_button_visible), gr.Gallery.update(visible=gsa_gallery_visible), gr.Gallery.update(visible=ram_gallery_visible)
696
 
697
  if __name__ == "__main__":
698
  parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
 
716
  inpaint_prompt = gr.Textbox(label="Inpaint Prompt (if this is empty, then remove)", visible=False)
717
  num_relation = gr.Slider(label="How many relations do you want to see", minimum=1, maximum=20, value=5, step=1, visible=False)
718
  run_button = gr.Button(label="Run", visible=True)
719
+ # relate_all_button = gr.Button(label="Run", visible=False)
720
  with gr.Accordion("Advanced options", open=False) as advanced_options:
721
  box_threshold = gr.Slider(
722
  label="Box Threshold", minimum=0.0, maximum=1.0, value=0.3, step=0.001
 
735
  remove_mask_extend = gr.Textbox(label="remove_mask_extend", value='10')
736
 
737
  with gr.Column():
738
+ image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gsa_allery", visible=True
739
+ ).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
740
+ # gsa_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gsa_allery", visible=True
741
+ # ).style(preview=True, grid=[2], full_width=True, full_height=True)
742
+ # ram_gallery = gr.Gallery(label="Your Result", show_label=True, elem_id="ram_gallery", visible=False
743
+ # ).style(preview=True, columns=5, object_fit="scale-down")
744
 
745
  run_button.click(fn=run_anything_task, inputs=[
746
+ input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation], outputs=[image_gallery, image_gallery], show_progress=True, queue=True)
747
+ # relate_all_button.click(fn=relate_anything, inputs=[input_image, num_relation], outputs=[ram_gallery], show_progress=True, queue=True)
748
 
749
+ task_type.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
750
+ mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio], outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
751
 
752
  DESCRIPTION = '### This demo from [Grounded-Segment-Anything](https://github.com/IDEA-Research/Grounded-Segment-Anything). <br>'
753
  DESCRIPTION += 'RAM from [RelateAnything](https://github.com/Luodian/RelateAnything). <br>'