yizhangliu commited on
Commit
5128046
1 Parent(s): d829f40

update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -0
app.py CHANGED
@@ -774,10 +774,13 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
774
  use_sam_predictor = True
775
  if task_type == 'segment' or ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_segment):
776
  image = np.array(input_img)
 
777
  if task_type == 'remove' and remove_use_segment == False:
 
778
  use_sam_predictor = False
779
 
780
  if sam_predictor and use_sam_predictor:
 
781
  sam_predictor.set_image(image)
782
 
783
  for i in range(boxes_filt.size(0)):
@@ -786,6 +789,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
786
  boxes_filt[i][2:] += boxes_filt[i][:2]
787
 
788
  if sam_predictor and use_sam_predictor:
 
789
  boxes_filt = boxes_filt.to(sam_device)
790
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
791
 
@@ -798,6 +802,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
798
  # masks: [9, 1, 512, 512]
799
  assert sam_checkpoint, 'sam_checkpoint is not found!'
800
  else:
 
801
  masks = torch.zeros(len(boxes_filt), 1, H, W)
802
  mask_count = 0
803
  for box in boxes_filt:
@@ -806,6 +811,7 @@ def run_anything_task(input_image, text_prompt, task_type, inpaint_prompt, box_t
806
  masks = torch.where(masks > 0, True, False)
807
  run_mode = "rectangle"
808
 
 
809
  # draw output image
810
  plt.figure(figsize=(10, 10))
811
  plt.imshow(image)
 
774
  use_sam_predictor = True
775
  if task_type == 'segment' or ((task_type in ['inpainting', 'outpainting'] or task_type == 'remove') and mask_source_radio == mask_source_segment):
776
  image = np.array(input_img)
777
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_1_')
778
  if task_type == 'remove' and remove_use_segment == False:
779
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_2_')
780
  use_sam_predictor = False
781
 
782
  if sam_predictor and use_sam_predictor:
783
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_3_')
784
  sam_predictor.set_image(image)
785
 
786
  for i in range(boxes_filt.size(0)):
 
789
  boxes_filt[i][2:] += boxes_filt[i][:2]
790
 
791
  if sam_predictor and use_sam_predictor:
792
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_4_')
793
  boxes_filt = boxes_filt.to(sam_device)
794
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
795
 
 
802
  # masks: [9, 1, 512, 512]
803
  assert sam_checkpoint, 'sam_checkpoint is not found!'
804
  else:
805
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_5_')
806
  masks = torch.zeros(len(boxes_filt), 1, H, W)
807
  mask_count = 0
808
  for box in boxes_filt:
 
811
  masks = torch.where(masks > 0, True, False)
812
  run_mode = "rectangle"
813
 
814
+ logger.info(f'run_anything_task_[{file_temp}]_{task_type}_2_6_')
815
  # draw output image
816
  plt.figure(figsize=(10, 10))
817
  plt.imshow(image)