ShilongLiu commited on
Commit
0e8e9e2
1 Parent(s): f45dd9b

add nms by default

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -202,6 +202,14 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
202
 
203
  boxes_filt = boxes_filt.cpu()
204
 
 
 
 
 
 
 
 
 
205
  if task_type == 'seg' or task_type == 'inpainting' or task_type == 'automatic':
206
  if sam_predictor is None:
207
  # initialize SAM
@@ -215,12 +223,6 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
215
 
216
  if task_type == 'automatic':
217
  # use NMS to handle overlapped boxes
218
- print(f"Before NMS: {boxes_filt.shape[0]} boxes")
219
- nms_idx = torchvision.ops.nms(
220
- boxes_filt, scores, iou_threshold).numpy().tolist()
221
- boxes_filt = boxes_filt[nms_idx]
222
- pred_phrases = [pred_phrases[idx] for idx in nms_idx]
223
- print(f"After NMS: {boxes_filt.shape[0]} boxes")
224
  print(f"Revise caption with number: {text_prompt}")
225
 
226
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(
@@ -318,7 +320,7 @@ if __name__ == "__main__":
318
  label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
319
  )
320
  iou_threshold = gr.Slider(
321
- label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.001
322
  )
323
  inpaint_mode = gr.Dropdown(
324
  ["merge", "first"], value="merge", label="inpaint_mode")
 
202
 
203
  boxes_filt = boxes_filt.cpu()
204
 
205
+ # nms
206
+ print(f"Before NMS: {boxes_filt.shape[0]} boxes")
207
+ nms_idx = torchvision.ops.nms(
208
+ boxes_filt, scores, iou_threshold).numpy().tolist()
209
+ boxes_filt = boxes_filt[nms_idx]
210
+ pred_phrases = [pred_phrases[idx] for idx in nms_idx]
211
+ print(f"After NMS: {boxes_filt.shape[0]} boxes")
212
+
213
  if task_type == 'seg' or task_type == 'inpainting' or task_type == 'automatic':
214
  if sam_predictor is None:
215
  # initialize SAM
 
223
 
224
  if task_type == 'automatic':
225
  # use NMS to handle overlapped boxes
 
 
 
 
 
 
226
  print(f"Revise caption with number: {text_prompt}")
227
 
228
  transformed_boxes = sam_predictor.transform.apply_boxes_torch(
 
320
  label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
321
  )
322
  iou_threshold = gr.Slider(
323
+ label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.8, step=0.001
324
  )
325
  inpaint_mode = gr.Dropdown(
326
  ["merge", "first"], value="merge", label="inpaint_mode")