AAAAAAAyq commited on
Commit
4b45202
1 Parent(s): e03ed2b

Better points mode & Fix the Contours button bug

Browse files
Files changed (2) hide show
  1. app_gradio.py +4 -4
  2. utils/tools.py +4 -3
app_gradio.py CHANGED
@@ -221,7 +221,7 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
221
  input_size_slider.render()
222
 
223
  with gr.Row():
224
- contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
225
 
226
  with gr.Column():
227
  segment_btn_e = gr.Button("Segment Everything", variant='primary')
@@ -298,7 +298,7 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
298
  info='Our model was trained on a size of 1024')
299
  with gr.Row():
300
  with gr.Column():
301
- contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
302
  text_box = gr.Textbox(label="text prompt", value="a black dog")
303
 
304
  with gr.Column():
@@ -334,7 +334,7 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
334
  iou_threshold,
335
  conf_threshold,
336
  mor_check,
337
- contour_check,
338
  retina_check,
339
  ],
340
  outputs=segm_img_e)
@@ -350,7 +350,7 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
350
  iou_threshold,
351
  conf_threshold,
352
  mor_check,
353
- contour_check,
354
  retina_check,
355
  text_box,
356
  ],
 
221
  input_size_slider.render()
222
 
223
  with gr.Row():
224
+ contour_check_e = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
225
 
226
  with gr.Column():
227
  segment_btn_e = gr.Button("Segment Everything", variant='primary')
 
298
  info='Our model was trained on a size of 1024')
299
  with gr.Row():
300
  with gr.Column():
301
+ contour_check_t = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
302
  text_box = gr.Textbox(label="text prompt", value="a black dog")
303
 
304
  with gr.Column():
 
334
  iou_threshold,
335
  conf_threshold,
336
  mor_check,
337
+ contour_check_e,
338
  retina_check,
339
  ],
340
  outputs=segm_img_e)
 
350
  iou_threshold,
351
  conf_threshold,
352
  mor_check,
353
+ contour_check_t,
354
  retina_check,
355
  text_box,
356
  ],
utils/tools.py CHANGED
@@ -400,16 +400,17 @@ def point_prompt(masks, points, point_label, target_height, target_width): # nu
400
  for point in points
401
  ]
402
  onemask = np.zeros((h, w))
 
403
  for i, annotation in enumerate(masks):
404
  if type(annotation) == dict:
405
- mask = annotation["segmentation"]
406
  else:
407
  mask = annotation
408
  for i, point in enumerate(points):
409
  if mask[point[1], point[0]] == 1 and point_label[i] == 1:
410
- onemask += mask
411
  if mask[point[1], point[0]] == 1 and point_label[i] == 0:
412
- onemask -= mask
413
  onemask = onemask >= 1
414
  return onemask, 0
415
 
 
400
  for point in points
401
  ]
402
  onemask = np.zeros((h, w))
403
+ masks = sorted(masks, key=lambda x: x['area'], reverse=True)
404
  for i, annotation in enumerate(masks):
405
  if type(annotation) == dict:
406
+ mask = annotation['segmentation']
407
  else:
408
  mask = annotation
409
  for i, point in enumerate(points):
410
  if mask[point[1], point[0]] == 1 and point_label[i] == 1:
411
+ onemask[mask] = 1
412
  if mask[point[1], point[0]] == 1 and point_label[i] == 0:
413
+ onemask[mask] = 0
414
  onemask = onemask >= 1
415
  return onemask, 0
416