AAAAAAAyq commited on
Commit
e4f0b1d
1 Parent(s): 5778432

Add a wider result

Browse files
Files changed (3) hide show
  1. app_gradio.py +50 -46
  2. utils/tools.py +29 -16
  3. utils/tools_gradio.py +4 -4
app_gradio.py CHANGED
@@ -21,6 +21,8 @@ device = torch.device(
21
  title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
22
 
23
  news = """ # 📖 News
 
 
24
  🔥 2023/06/29: Support the text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/47)).
25
 
26
  🔥 2023/06/26: Support the points mode. (Better and faster interaction will come soon!)
@@ -76,6 +78,7 @@ def segment_everything(
76
  withContours=True,
77
  use_retina=True,
78
  text="",
 
79
  mask_random_color=True,
80
  ):
81
  input_size = int(input_size) # 确保 imgsz 是整数
@@ -95,7 +98,7 @@ def segment_everything(
95
 
96
  if len(text) > 0:
97
  results = format_results(results[0], 0)
98
- annotations, _ = text_prompt(results, text, input, device=device)
99
  annotations = np.array([annotations])
100
  else:
101
  annotations = results[0].masks.data
@@ -189,7 +192,7 @@ segm_img_t = gr.Image(label="Segmented Image with text", interactive=False, type
189
  global_points = []
190
  global_point_label = []
191
 
192
- input_size_slider_e = gr.components.Slider(minimum=512,
193
  maximum=1024,
194
  value=1024,
195
  step=64,
@@ -218,10 +221,10 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
218
  # Submit & Clear
219
  with gr.Row():
220
  with gr.Column():
221
- input_size_slider_e.render()
222
 
223
  with gr.Row():
224
- contour_check_e = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
225
 
226
  with gr.Column():
227
  segment_btn_e = gr.Button("Segment Everything", variant='primary')
@@ -237,16 +240,28 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
237
 
238
  with gr.Column():
239
  with gr.Accordion("Advanced options", open=False):
240
- # text_box = gr.Textbox(label="text prompt")
241
- iou_threshold_e = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
242
- conf_threshold_e = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
243
  with gr.Row():
244
- mor_check_e = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
245
  with gr.Column():
246
- retina_check_e = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
 
247
  # Description
248
  gr.Markdown(description_e)
249
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  with gr.Tab("Points mode"):
251
  # Images
252
  with gr.Row(variant="panel"):
@@ -277,7 +292,13 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
277
  with gr.Column():
278
  # Description
279
  gr.Markdown(description_p)
280
-
 
 
 
 
 
 
281
  with gr.Tab("Text mode"):
282
  # Images
283
  with gr.Row(variant="panel"):
@@ -291,14 +312,14 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
291
  with gr.Row():
292
  with gr.Column():
293
  input_size_slider_t = gr.components.Slider(minimum=512,
294
- maximum=1024,
295
- value=1024,
296
- step=64,
297
- label='Input_size',
298
- info='Our model was trained on a size of 1024')
299
  with gr.Row():
300
  with gr.Column():
301
- contour_check_t = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
302
  text_box = gr.Textbox(label="text prompt", value="a black dog")
303
 
304
  with gr.Column():
@@ -306,7 +327,7 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
306
  clear_btn_t = gr.Button("Clear", variant="secondary")
307
 
308
  gr.Markdown("Try some of the examples below ⬇️")
309
- gr.Examples(examples=["examples/dogs.jpg"],
310
  inputs=[cond_img_e],
311
  # outputs=segm_img_e,
312
  # fn=segment_everything,
@@ -315,44 +336,27 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
315
 
316
  with gr.Column():
317
  with gr.Accordion("Advanced options", open=False):
318
- iou_threshold_t = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
319
- conf_threshold_t = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
320
  with gr.Row():
321
- mor_check_t = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
322
- with gr.Column():
323
- retina_check_t = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
324
 
325
  # Description
326
  gr.Markdown(description_e)
327
-
328
- cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
329
-
330
- segment_btn_e.click(segment_everything,
331
- inputs=[
332
- cond_img_e,
333
- input_size_slider_e,
334
- iou_threshold_e,
335
- conf_threshold_e,
336
- mor_check_e,
337
- contour_check_e,
338
- retina_check_e,
339
- ],
340
- outputs=segm_img_e)
341
-
342
- segment_btn_p.click(segment_with_points,
343
- inputs=[cond_img_p],
344
- outputs=[segm_img_p, cond_img_p])
345
 
346
  segment_btn_t.click(segment_everything,
347
  inputs=[
348
  cond_img_t,
349
  input_size_slider_t,
350
- iou_threshold_t,
351
- conf_threshold_t,
352
- mor_check_t,
353
- contour_check_t,
354
- retina_check_t,
355
  text_box,
 
356
  ],
357
  outputs=segm_img_t)
358
 
@@ -361,7 +365,7 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
361
 
362
  def clear_text():
363
  return None, None, None
364
-
365
  clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
366
  clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
367
  clear_btn_t.click(clear_text, outputs=[cond_img_p, segm_img_p, text_box])
 
21
  title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
22
 
23
  news = """ # 📖 News
24
+ 🔥 2023/07/14: Add a "wider result" button in text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/95)).
25
+
26
  🔥 2023/06/29: Support the text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/47)).
27
 
28
  🔥 2023/06/26: Support the points mode. (Better and faster interaction will come soon!)
 
78
  withContours=True,
79
  use_retina=True,
80
  text="",
81
+ wider=False,
82
  mask_random_color=True,
83
  ):
84
  input_size = int(input_size) # 确保 imgsz 是整数
 
98
 
99
  if len(text) > 0:
100
  results = format_results(results[0], 0)
101
+ annotations, _ = text_prompt(results, text, input, device=device, wider=wider)
102
  annotations = np.array([annotations])
103
  else:
104
  annotations = results[0].masks.data
 
192
  global_points = []
193
  global_point_label = []
194
 
195
+ input_size_slider = gr.components.Slider(minimum=512,
196
  maximum=1024,
197
  value=1024,
198
  step=64,
 
221
  # Submit & Clear
222
  with gr.Row():
223
  with gr.Column():
224
+ input_size_slider.render()
225
 
226
  with gr.Row():
227
+ contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
228
 
229
  with gr.Column():
230
  segment_btn_e = gr.Button("Segment Everything", variant='primary')
 
240
 
241
  with gr.Column():
242
  with gr.Accordion("Advanced options", open=False):
243
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
244
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
 
245
  with gr.Row():
246
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
247
  with gr.Column():
248
+ retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
249
+
250
  # Description
251
  gr.Markdown(description_e)
252
 
253
+ segment_btn_e.click(segment_everything,
254
+ inputs=[
255
+ cond_img_e,
256
+ input_size_slider,
257
+ iou_threshold,
258
+ conf_threshold,
259
+ mor_check,
260
+ contour_check,
261
+ retina_check,
262
+ ],
263
+ outputs=segm_img_e)
264
+
265
  with gr.Tab("Points mode"):
266
  # Images
267
  with gr.Row(variant="panel"):
 
292
  with gr.Column():
293
  # Description
294
  gr.Markdown(description_p)
295
+
296
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
297
+
298
+ segment_btn_p.click(segment_with_points,
299
+ inputs=[cond_img_p],
300
+ outputs=[segm_img_p, cond_img_p])
301
+
302
  with gr.Tab("Text mode"):
303
  # Images
304
  with gr.Row(variant="panel"):
 
312
  with gr.Row():
313
  with gr.Column():
314
  input_size_slider_t = gr.components.Slider(minimum=512,
315
+ maximum=1024,
316
+ value=1024,
317
+ step=64,
318
+ label='Input_size',
319
+ info='Our model was trained on a size of 1024')
320
  with gr.Row():
321
  with gr.Column():
322
+ contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
323
  text_box = gr.Textbox(label="text prompt", value="a black dog")
324
 
325
  with gr.Column():
 
327
  clear_btn_t = gr.Button("Clear", variant="secondary")
328
 
329
  gr.Markdown("Try some of the examples below ⬇️")
330
+ gr.Examples(examples=[["examples/dogs.jpg"]] + examples,
331
  inputs=[cond_img_e],
332
  # outputs=segm_img_e,
333
  # fn=segment_everything,
 
336
 
337
  with gr.Column():
338
  with gr.Accordion("Advanced options", open=False):
339
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
340
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
341
  with gr.Row():
342
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
343
+ retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
344
+ wider_check = gr.Checkbox(value=False, label='wider', info='wider result')
345
 
346
  # Description
347
  gr.Markdown(description_e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
  segment_btn_t.click(segment_everything,
350
  inputs=[
351
  cond_img_t,
352
  input_size_slider_t,
353
+ iou_threshold,
354
+ conf_threshold,
355
+ mor_check,
356
+ contour_check,
357
+ retina_check,
358
  text_box,
359
+ wider_check,
360
  ],
361
  outputs=segm_img_t)
362
 
 
365
 
366
  def clear_text():
367
  return None, None, None
368
+
369
  clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
370
  clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
371
  clear_btn_t.click(clear_text, outputs=[cond_img_p, segm_img_p, text_box])
utils/tools.py CHANGED
@@ -9,11 +9,14 @@ import clip
9
 
10
 
11
  def convert_box_xywh_to_xyxy(box):
12
- x1 = box[0]
13
- y1 = box[1]
14
- x2 = box[0] + box[2]
15
- y2 = box[1] + box[3]
16
- return [x1, y1, x2, y2]
 
 
 
17
 
18
 
19
  def segment_image(image, bbox):
@@ -323,8 +326,8 @@ def fast_show_mask_gpu(
323
  # clip
324
  @torch.no_grad()
325
  def retriev(
326
- model, preprocess, elements, search_text: str, device
327
- ) -> int:
328
  preprocessed_images = [preprocess(image).to(device) for image in elements]
329
  tokenized_text = clip.tokenize([search_text]).to(device)
330
  stacked_images = torch.stack(preprocessed_images)
@@ -348,19 +351,16 @@ def crop_image(annotations, image_like):
348
  cropped_boxes = []
349
  cropped_images = []
350
  not_crop = []
351
- filter_id = []
352
- # annotations, _ = filter_masks(annotations)
353
- # filter_id = list(_)
354
  for _, mask in enumerate(annotations):
355
  if np.sum(mask["segmentation"]) <= 100:
356
- filter_id.append(_)
357
  continue
 
358
  bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
359
  cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
360
  # cropped_boxes.append(segment_image(image,mask["segmentation"]))
361
  cropped_images.append(bbox) # 保存裁剪的图片的bbox
362
-
363
- return cropped_boxes, cropped_images, not_crop, filter_id, annotations
364
 
365
 
366
  def box_prompt(masks, bbox, target_height, target_width):
@@ -415,8 +415,8 @@ def point_prompt(masks, points, point_label, target_height, target_width): # nu
415
  return onemask, 0
416
 
417
 
418
- def text_prompt(annotations, text, img_path, device):
419
- cropped_boxes, cropped_images, not_crop, filter_id, annotations_ = crop_image(
420
  annotations, img_path
421
  )
422
  clip_model, preprocess = clip.load("./weights/CLIP_ViT_B_32.pt", device=device)
@@ -425,5 +425,18 @@ def text_prompt(annotations, text, img_path, device):
425
  )
426
  max_idx = scores.argsort()
427
  max_idx = max_idx[-1]
428
- max_idx += sum(np.array(filter_id) <= int(max_idx))
 
 
 
 
 
 
 
 
 
 
 
 
 
429
  return annotations_[max_idx]["segmentation"], max_idx
 
9
 
10
 
11
  def convert_box_xywh_to_xyxy(box):
12
+ if len(box) == 4:
13
+ return [box[0], box[1], box[0] + box[2], box[1] + box[3]]
14
+ else:
15
+ result = []
16
+ for b in box:
17
+ b = convert_box_xywh_to_xyxy(b)
18
+ result.append(b)
19
+ return result
20
 
21
 
22
  def segment_image(image, bbox):
 
326
  # clip
327
  @torch.no_grad()
328
  def retriev(
329
+ model, preprocess, elements: [Image.Image], search_text: str, device
330
+ ):
331
  preprocessed_images = [preprocess(image).to(device) for image in elements]
332
  tokenized_text = clip.tokenize([search_text]).to(device)
333
  stacked_images = torch.stack(preprocessed_images)
 
351
  cropped_boxes = []
352
  cropped_images = []
353
  not_crop = []
354
+ origin_id = []
 
 
355
  for _, mask in enumerate(annotations):
356
  if np.sum(mask["segmentation"]) <= 100:
 
357
  continue
358
+ origin_id.append(_)
359
  bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
360
  cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
361
  # cropped_boxes.append(segment_image(image,mask["segmentation"]))
362
  cropped_images.append(bbox) # 保存裁剪的图片的bbox
363
+ return cropped_boxes, cropped_images, not_crop, origin_id, annotations
 
364
 
365
 
366
  def box_prompt(masks, bbox, target_height, target_width):
 
415
  return onemask, 0
416
 
417
 
418
+ def text_prompt(annotations, text, img_path, device, wider=False, threshold=0.9):
419
+ cropped_boxes, cropped_images, not_crop, origin_id, annotations_ = crop_image(
420
  annotations, img_path
421
  )
422
  clip_model, preprocess = clip.load("./weights/CLIP_ViT_B_32.pt", device=device)
 
425
  )
426
  max_idx = scores.argsort()
427
  max_idx = max_idx[-1]
428
+ max_idx = origin_id[int(max_idx)]
429
+
430
+ # find the biggest mask which contains the mask with max score
431
+ if wider:
432
+ mask0 = annotations_[max_idx]["segmentation"]
433
+ area0 = np.sum(mask0)
434
+ areas = [(i, np.sum(mask["segmentation"])) for i, mask in enumerate(annotations_) if i in origin_id]
435
+ areas = sorted(areas, key=lambda area: area[1], reverse=True)
436
+ indices = [area[0] for area in areas]
437
+ for index in indices:
438
+ if index == max_idx or np.sum(annotations_[index]["segmentation"] & mask0) / area0 > threshold:
439
+ max_idx = index
440
+ break
441
+
442
  return annotations_[max_idx]["segmentation"], max_idx
utils/tools_gradio.py CHANGED
@@ -103,7 +103,7 @@ def fast_show_mask(
103
  annotation = annotation[sorted_indices]
104
 
105
  index = (annotation != 0).argmax(axis=0)
106
- if random_color == True:
107
  color = np.random.random((mask_sum, 1, 1, 3))
108
  else:
109
  color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
@@ -121,7 +121,7 @@ def fast_show_mask(
121
  x1, y1, x2, y2 = bbox
122
  ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
123
 
124
- if retinamask == False:
125
  mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
126
 
127
  return mask
@@ -145,7 +145,7 @@ def fast_show_mask_gpu(
145
  annotation = annotation[sorted_indices]
146
  # 找每个位置第一个非零值下标
147
  index = (annotation != 0).to(torch.long).argmax(dim=0)
148
- if random_color == True:
149
  color = torch.rand((mask_sum, 1, 1, 3)).to(device)
150
  else:
151
  color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
@@ -168,7 +168,7 @@ def fast_show_mask_gpu(
168
  (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
169
  )
170
  )
171
- if retinamask == False:
172
  mask_cpu = cv2.resize(
173
  mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
174
  )
 
103
  annotation = annotation[sorted_indices]
104
 
105
  index = (annotation != 0).argmax(axis=0)
106
+ if random_color:
107
  color = np.random.random((mask_sum, 1, 1, 3))
108
  else:
109
  color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
 
121
  x1, y1, x2, y2 = bbox
122
  ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
123
 
124
+ if not retinamask:
125
  mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
126
 
127
  return mask
 
145
  annotation = annotation[sorted_indices]
146
  # 找每个位置第一个非零值下标
147
  index = (annotation != 0).to(torch.long).argmax(dim=0)
148
+ if random_color:
149
  color = torch.rand((mask_sum, 1, 1, 3)).to(device)
150
  else:
151
  color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
 
168
  (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
169
  )
170
  )
171
+ if not retinamask:
172
  mask_cpu = cv2.resize(
173
  mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
174
  )