liuyizhang commited on
Commit
2a71ebd
β€’
1 Parent(s): 5c28041

add time cost by step (ms)

Browse files
Files changed (3) hide show
  1. app.py +40 -12
  2. kosmos_utils.py +1 -1
  3. requirements.txt +1 -1
app.py CHANGED
@@ -519,24 +519,42 @@ def relate_anything(input_image, k):
519
  mask_source_draw = "draw a mask on input image"
520
  mask_source_segment = "type what to detect below"
521
 
 
 
 
 
 
 
 
 
 
 
 
522
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
523
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
 
 
 
 
 
524
  if (task_type == 'Kosmos-2'):
525
  global kosmos_model, kosmos_processor
526
  kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(input_image, kosmos_input, kosmos_model, kosmos_processor)
527
- return None, None, kosmos_image, kosmos_text, kosmos_entities
 
528
 
529
  if (task_type == 'relate anything'):
530
  output_images = relate_anything(input_image['image'], num_relation)
531
- return output_images, gr.Gallery.update(label='relate images'), None, None, None
 
532
 
533
  text_prompt = text_prompt.strip()
534
  if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
535
  if text_prompt == '':
536
- return [], gr.Gallery.update(label='Detection prompt is not found!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), None, None, None
537
 
538
  if input_image is None:
539
- return [], gr.Gallery.update(label='Please upload a image!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), None, None, None
540
 
541
  file_temp = int(time.time())
542
  logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
@@ -552,10 +570,12 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
552
  image_pil, image = load_image(input_image['image'].convert("RGB"))
553
  input_img = input_image['image']
554
  output_images.append(input_image['image'])
 
555
  else:
556
  image_pil, image = load_image(input_image.convert("RGB"))
557
  input_img = input_image
558
  output_images.append(input_image)
 
559
 
560
  size = image_pil.size
561
 
@@ -576,7 +596,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
576
  )
577
  if boxes_filt.size(0) == 0:
578
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
579
- return [], gr.Gallery.update(label='No objects detected, please try others.πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), None, None, None
580
  boxes_filt_ori = copy.deepcopy(boxes_filt)
581
 
582
  pred_dict = {
@@ -587,6 +607,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
587
 
588
  image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
589
  output_images.append(image_with_box)
 
590
 
591
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
592
  if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
@@ -622,12 +643,13 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
622
  plt.savefig(image_path, bbox_inches="tight")
623
  segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
624
  os.remove(image_path)
625
- output_images.append(segment_image_result)
 
626
 
627
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
628
  if task_type == 'detection' or task_type == 'segment':
629
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
630
- return output_images, gr.Gallery.update(label='result images'), None, None, None
631
  elif task_type == 'inpainting' or task_type == 'remove':
632
  if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
633
  task_type = 'remove'
@@ -644,6 +666,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
644
  mask = masks[0][0].cpu().numpy()
645
  mask_pil = Image.fromarray(mask)
646
  output_images.append(mask_pil.convert("RGB"))
 
647
 
648
  if task_type == 'inpainting':
649
  # inpainting pipeline
@@ -682,21 +705,24 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
682
  extend_pixels=remove_mask_extend, useRectangle=useRectangle)
683
  mask_imgs.append(mask_pil_exp)
684
  mask_pil = mix_masks(mask_imgs)
685
- output_images.append(mask_pil.convert("RGB"))
 
686
 
687
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
688
  image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
689
  # output_images.append(image_inpainting)
 
690
 
691
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
692
  image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
693
  output_images.append(image_inpainting)
 
694
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
695
- return output_images, gr.Gallery.update(label='result images'), None, None, None
696
  else:
697
  logger.info(f"task_type:{task_type} error!")
698
  logger.info(f'run_anything_task_[{file_temp}]_9_9_')
699
- return output_images, gr.Gallery.update(label='result images'), None, None, None
700
 
701
  def change_radio_display(task_type, mask_source_radio):
702
  text_prompt_visible = True
@@ -828,7 +854,9 @@ if __name__ == "__main__":
828
 
829
  with gr.Column():
830
  image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
831
- ).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
 
 
832
  kosmos_output = gr.Image(type="pil", label="result images", visible=False)
833
  kosmos_text_output = gr.HighlightedText(
834
  label="Generated Description",
@@ -860,7 +888,7 @@ if __name__ == "__main__":
860
  run_button.click(fn=run_anything_task, inputs=[
861
  input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
862
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
863
- outputs=[image_gallery, image_gallery, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
864
 
865
  mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
866
  outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
 
519
  mask_source_draw = "draw a mask on input image"
520
  mask_source_segment = "type what to detect below"
521
 
522
+ def get_time_cost(run_task_time, time_cost_str):
523
+ now_time = int(time.time()*1000)
524
+ if run_task_time == 0:
525
+ time_cost_str = 'start'
526
+ else:
527
+ if time_cost_str != '':
528
+ time_cost_str += f'-->'
529
+ time_cost_str += f'{now_time - run_task_time}'
530
+ run_task_time = now_time
531
+ return run_task_time, time_cost_str
532
+
533
  def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
534
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input, cleaner_size_limit=1080):
535
+
536
+ run_task_time = 0
537
+ time_cost_str = ''
538
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
539
+
540
  if (task_type == 'Kosmos-2'):
541
  global kosmos_model, kosmos_processor
542
  kosmos_image, kosmos_text, kosmos_entities = kosmos_generate_predictions(input_image, kosmos_input, kosmos_model, kosmos_processor)
543
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
544
+ return None, None, time_cost_str, kosmos_image, gr.Textbox.update(visible=(time_cost_str !='')), kosmos_text, kosmos_entities
545
 
546
  if (task_type == 'relate anything'):
547
  output_images = relate_anything(input_image['image'], num_relation)
548
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
549
+ return output_images, gr.Gallery.update(label='relate images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
550
 
551
  text_prompt = text_prompt.strip()
552
  if not ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_draw):
553
  if text_prompt == '':
554
+ return [], gr.Gallery.update(label='Detection prompt is not found!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
555
 
556
  if input_image is None:
557
+ return [], gr.Gallery.update(label='Please upload a image!πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
558
 
559
  file_temp = int(time.time())
560
  logger.info(f'run_anything_task_002/{device}_[{file_temp}]_{task_type}/{inpaint_mode}/[{mask_source_radio}]/{remove_mode}/{remove_mask_extend}_[{text_prompt}]/[{inpaint_prompt}]___1_')
 
570
  image_pil, image = load_image(input_image['image'].convert("RGB"))
571
  input_img = input_image['image']
572
  output_images.append(input_image['image'])
573
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
574
  else:
575
  image_pil, image = load_image(input_image.convert("RGB"))
576
  input_img = input_image
577
  output_images.append(input_image)
578
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
579
 
580
  size = image_pil.size
581
 
 
596
  )
597
  if boxes_filt.size(0) == 0:
598
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_[{text_prompt}]_1___{groundingdino_device}/[No objects detected, please try others.]_')
599
+ return [], gr.Gallery.update(label='No objects detected, please try others.πŸ˜‚πŸ˜‚πŸ˜‚πŸ˜‚'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
600
  boxes_filt_ori = copy.deepcopy(boxes_filt)
601
 
602
  pred_dict = {
 
607
 
608
  image_with_box = plot_boxes_to_image(copy.deepcopy(image_pil), pred_dict)[0]
609
  output_images.append(image_with_box)
610
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
611
 
612
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_')
613
  if task_type == 'segment' or ((task_type == 'inpainting' or task_type == 'remove') and mask_source_radio == mask_source_segment):
 
643
  plt.savefig(image_path, bbox_inches="tight")
644
  segment_image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
645
  os.remove(image_path)
646
+ output_images.append(segment_image_result)
647
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
648
 
649
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_3_')
650
  if task_type == 'detection' or task_type == 'segment':
651
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
652
+ return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
653
  elif task_type == 'inpainting' or task_type == 'remove':
654
  if inpaint_prompt.strip() == '' and mask_source_radio == mask_source_segment:
655
  task_type = 'remove'
 
666
  mask = masks[0][0].cpu().numpy()
667
  mask_pil = Image.fromarray(mask)
668
  output_images.append(mask_pil.convert("RGB"))
669
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
670
 
671
  if task_type == 'inpainting':
672
  # inpainting pipeline
 
705
  extend_pixels=remove_mask_extend, useRectangle=useRectangle)
706
  mask_imgs.append(mask_pil_exp)
707
  mask_pil = mix_masks(mask_imgs)
708
+ output_images.append(mask_pil.convert("RGB"))
709
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
710
 
711
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_6_')
712
  image_inpainting = lama_cleaner_process(np.array(image_pil), np.array(mask_pil.convert("L")), cleaner_size_limit)
713
  # output_images.append(image_inpainting)
714
+ # run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
715
 
716
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_7_')
717
  image_inpainting = image_inpainting.resize((image_pil.size[0], image_pil.size[1]))
718
  output_images.append(image_inpainting)
719
+ run_task_time, time_cost_str = get_time_cost(run_task_time, time_cost_str)
720
  logger.info(f'run_anything_task_[{file_temp}]_{task_type}_9_')
721
+ return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
722
  else:
723
  logger.info(f"task_type:{task_type} error!")
724
  logger.info(f'run_anything_task_[{file_temp}]_9_9_')
725
+ return output_images, gr.Gallery.update(label='result images'), time_cost_str, gr.Textbox.update(visible=(time_cost_str !='')), None, None, None
726
 
727
  def change_radio_display(task_type, mask_source_radio):
728
  text_prompt_visible = True
 
854
 
855
  with gr.Column():
856
  image_gallery = gr.Gallery(label="result images", show_label=True, elem_id="gallery", visible=True
857
+ ).style(preview=True, columns=[5], object_fit="scale-down", height="auto")
858
+ time_cost = gr.Textbox(label="Time cost by step (ms):", visible=False, interactive=False)
859
+
860
  kosmos_output = gr.Image(type="pil", label="result images", visible=False)
861
  kosmos_text_output = gr.HighlightedText(
862
  label="Generated Description",
 
888
  run_button.click(fn=run_anything_task, inputs=[
889
  input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold,
890
  iou_threshold, inpaint_mode, mask_source_radio, remove_mode, remove_mask_extend, num_relation, kosmos_input],
891
+ outputs=[image_gallery, image_gallery, time_cost, time_cost, kosmos_output, kosmos_text_output, entity_output], show_progress=True, queue=True)
892
 
893
  mask_source_radio.change(fn=change_radio_display, inputs=[task_type, mask_source_radio],
894
  outputs=[text_prompt, inpaint_prompt, mask_source_radio, num_relation])
kosmos_utils.py CHANGED
@@ -230,4 +230,4 @@ def kosmos_generate_predictions(image_input, text_input, kosmos_model, kosmos_pr
230
  if end < len(processed_text):
231
  colored_text.append((processed_text[end:len(processed_text)], None))
232
 
233
- return annotated_image, colored_text, str(filtered_entities)
 
230
  if end < len(processed_text):
231
  colored_text.append((processed_text[end:len(processed_text)], None))
232
 
233
+ return annotated_image, colored_text, str(filtered_entities)
requirements.txt CHANGED
@@ -17,7 +17,7 @@ termcolor
17
  timm
18
  torch
19
  torchvision
20
- transformers
21
  yapf
22
  numba
23
  scipy
 
17
  timm
18
  torch
19
  torchvision
20
+ transformers==4.27.4
21
  yapf
22
  numba
23
  scipy