AAAAAAyq commited on
Commit
901ea42
1 Parent(s): 2f10180

Fix the everything mode bug and add point mode

Browse files
Files changed (3) hide show
  1. __pycache__/tools.cpython-39.pyc +0 -0
  2. app.py +196 -72
  3. tools.py +1 -35
__pycache__/tools.cpython-39.pyc CHANGED
Binary files a/__pycache__/tools.cpython-39.pyc and b/__pycache__/tools.cpython-39.pyc differ
 
app.py CHANGED
@@ -1,27 +1,49 @@
1
  from ultralytics import YOLO
2
  import gradio as gr
3
  import torch
4
- from tools import fast_process
 
 
5
 
6
  # Load the pre-trained model
7
  model = YOLO('checkpoints/FastSAM.pt')
8
 
 
 
9
  # Description
10
  title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
11
 
12
- news = """ # News
13
 
14
- 🔥 Add the 'Advanced options" in Everything mode to get a more detailed adjustment.
 
 
 
15
  """
16
 
17
 
18
- # 🔥 Support the points mode and box mode, text mode will come soon.
19
 
20
- description = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM).
21
 
22
  🎯 Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
23
 
24
- ⌛️ It takes about 4~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
27
 
@@ -41,17 +63,14 @@ default_example = examples[0]
41
  css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
42
 
43
 
44
- def segment_image(
45
  input,
46
  input_size=1024,
47
  iou_threshold=0.7,
48
  conf_threshold=0.25,
49
  better_quality=False,
50
- mask_random_color=True,
51
  withContours=True,
52
- points=None,
53
- bbox=None,
54
- point_label=None,
55
  use_retina=True,
56
  ):
57
  input_size = int(input_size) # 确保 imgsz 是整数
@@ -69,19 +88,80 @@ def segment_image(
69
  iou=iou_threshold,
70
  conf=conf_threshold,
71
  imgsz=input_size,)
 
72
  fig = fast_process(annotations=results[0].masks.data,
73
  image=input,
74
  device=device,
75
  scale=(1024 // input_size),
76
  better_quality=better_quality,
77
  mask_random_color=mask_random_color,
78
- points=points,
79
- bbox=bbox,
80
- point_label=point_label,
81
  use_retina=use_retina,
82
  withContours=withContours,)
83
  return fig
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
  # input_size=1024
87
  # high_quality_visual=True
@@ -93,75 +173,119 @@ def segment_image(
93
  # pil_image = fast_process(annotations=results[0].masks.data,
94
  # image=input, high_quality=high_quality_visual, device=device)
95
 
96
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
97
 
98
- cond_img = gr.Image(label="Input", value=default_example[0], type='pil')
 
99
 
100
- segm_img = gr.Image(label="Segmented Image", interactive=False, type='pil')
 
101
 
102
  input_size_slider = gr.components.Slider(minimum=512,
103
  maximum=1024,
104
  value=1024,
105
  step=64,
106
- label='Input_size (Our model was trained on a size of 1024)')
 
107
 
108
  with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
109
  with gr.Row():
110
- with gr.Column(scale=1):
111
- # Title
112
- gr.Markdown(title)
113
-
114
- with gr.Column(scale=1):
115
- # News
116
- gr.Markdown(news)
117
-
118
- # Images
119
- with gr.Row(variant="panel"):
120
- with gr.Column(scale=1):
121
- cond_img.render()
122
-
123
- with gr.Column(scale=1):
124
- segm_img.render()
125
-
126
- # Submit & Clear
127
- with gr.Row():
128
- with gr.Column():
129
- input_size_slider.render()
130
-
131
- with gr.Row():
132
- contour_check = gr.Checkbox(value=True, label='withContours')
133
-
134
- with gr.Column():
135
- segment_btn = gr.Button("Segment Anything", variant='primary')
136
-
137
- # with gr.Column():
138
- # clear_btn = gr.Button("Clear", variant="primary")
139
-
140
- gr.Markdown("Try some of the examples below ⬇️")
141
- gr.Examples(examples=examples,
142
- inputs=[cond_img],
143
- outputs=segm_img,
144
- fn=segment_image,
145
- cache_examples=True,
146
- examples_per_page=4)
147
-
148
- with gr.Column():
149
- with gr.Accordion("Advanced options", open=False):
150
- iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou_threshold')
151
- conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf_threshold')
152
- mor_check = gr.Checkbox(value=False, label='better_visual_quality')
153
 
154
- # Description
155
- gr.Markdown(description)
156
-
157
- segment_btn.click(segment_image,
158
- inputs=[cond_img, input_size_slider, iou_threshold, conf_threshold, mor_check, contour_check],
159
- outputs=segm_img)
160
-
161
- # def clear():
162
- # return None, None
163
-
164
- # clear_btn.click(fn=clear, inputs=None, outputs=None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  demo.queue()
167
  demo.launch()
 
1
  from ultralytics import YOLO
2
  import gradio as gr
3
  import torch
4
+ from tools import fast_process, format_results, box_prompt, point_prompt
5
+ from PIL import ImageDraw
6
+ import numpy as np
7
 
8
  # Load the pre-trained model
9
  model = YOLO('checkpoints/FastSAM.pt')
10
 
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+
13
  # Description
14
  title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
15
 
16
+ news = """ # 📖 News
17
 
18
+ 🔥 2023/06/24: Add the 'Advanced options" in Everything mode to get a more detailed adjustment.
19
+
20
+ 🔥 2023/06/26: Support the points mode. (Better and faster interaction will come soon!)
21
+
22
  """
23
 
24
 
 
25
 
26
+ description_e = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM).
27
 
28
  🎯 Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
29
 
30
+ ⌛️ It takes about 6~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
31
+
32
+ 🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
33
+
34
+ 📣 You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)
35
+
36
+ 😚 A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant.
37
+
38
+ 🏠 Check out our [Model Card 🏃](https://huggingface.co/An-619/FastSAM)
39
+
40
+ """
41
+
42
+ description_p = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM).
43
+
44
+ 🎯 Upload an Image, add points and segment it with Fast Segment Anything (Points mode).
45
+
46
+ ⌛️ It takes about 6~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
47
 
48
  🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
49
 
 
63
  css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
64
 
65
 
66
+ def segment_everything(
67
  input,
68
  input_size=1024,
69
  iou_threshold=0.7,
70
  conf_threshold=0.25,
71
  better_quality=False,
 
72
  withContours=True,
73
+ mask_random_color=True,
 
 
74
  use_retina=True,
75
  ):
76
  input_size = int(input_size) # 确保 imgsz 是整数
 
88
  iou=iou_threshold,
89
  conf=conf_threshold,
90
  imgsz=input_size,)
91
+
92
  fig = fast_process(annotations=results[0].masks.data,
93
  image=input,
94
  device=device,
95
  scale=(1024 // input_size),
96
  better_quality=better_quality,
97
  mask_random_color=mask_random_color,
98
+ bbox=None,
 
 
99
  use_retina=use_retina,
100
  withContours=withContours,)
101
  return fig
102
 
103
+ def segment_with_points(
104
+ input,
105
+ input_size=1024,
106
+ iou_threshold=0.7,
107
+ conf_threshold=0.25,
108
+ better_quality=False,
109
+ withContours=True,
110
+ mask_random_color=True,
111
+ use_retina=True,
112
+ ):
113
+ global global_points
114
+ global global_point_label
115
+
116
+ input_size = int(input_size) # 确保 imgsz 是整数
117
+ # Thanks for the suggestion by hysts in HuggingFace.
118
+ w, h = input.size
119
+ scale = input_size / max(w, h)
120
+ new_w = int(w * scale)
121
+ new_h = int(h * scale)
122
+ input = input.resize((new_w, new_h))
123
+
124
+ scaled_points = [[int(x * scale) for x in point] for point in global_points]
125
+
126
+ results = model(input,
127
+ device=device,
128
+ retina_masks=True,
129
+ iou=iou_threshold,
130
+ conf=conf_threshold,
131
+ imgsz=input_size,)
132
+
133
+ results = format_results(results[0], 0)
134
+
135
+ annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
136
+ annotations = np.array([annotations])
137
+
138
+ fig = fast_process(annotations=annotations,
139
+ image=input,
140
+ device=device,
141
+ scale=(1024 // input_size),
142
+ better_quality=better_quality,
143
+ mask_random_color=mask_random_color,
144
+ bbox=None,
145
+ use_retina=use_retina,
146
+ withContours=withContours,)
147
+ global_points = []
148
+ global_point_label = []
149
+ return fig, None
150
+
151
+ def get_points_with_draw(image, label, evt: gr.SelectData):
152
+ x, y = evt.index[0], evt.index[1]
153
+ point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
154
+ global global_points
155
+ global global_point_label
156
+ print((x, y))
157
+ global_points.append([x, y])
158
+ global_point_label.append(1 if label == 'Add Mask' else 0)
159
+
160
+ # 创建一个可以在图像上绘图的对象
161
+ draw = ImageDraw.Draw(image)
162
+ draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
163
+ return image
164
+
165
 
166
  # input_size=1024
167
  # high_quality_visual=True
 
173
  # pil_image = fast_process(annotations=results[0].masks.data,
174
  # image=input, high_quality=high_quality_visual, device=device)
175
 
176
+ cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
177
+ cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil')
178
 
179
+ segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')
180
+ segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil')
181
 
182
+ global_points = []
183
+ global_point_label = [] # TODO:Clear points each image
184
 
185
  input_size_slider = gr.components.Slider(minimum=512,
186
  maximum=1024,
187
  value=1024,
188
  step=64,
189
+ label='Input_size',
190
+ info='Our model was trained on a size of 1024')
191
 
192
  with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
193
  with gr.Row():
194
+ with gr.Column(scale=1):
195
+ # Title
196
+ gr.Markdown(title)
197
+
198
+ with gr.Column(scale=1):
199
+ # News
200
+ gr.Markdown(news)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ with gr.Tab("Everything mode"):
203
+ # Images
204
+ with gr.Row(variant="panel"):
205
+ with gr.Column(scale=1):
206
+ cond_img_e.render()
207
+
208
+ with gr.Column(scale=1):
209
+ segm_img_e.render()
210
+
211
+ # Submit & Clear
212
+ with gr.Row():
213
+ with gr.Column():
214
+ input_size_slider.render()
215
+
216
+ with gr.Row():
217
+ contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
218
+
219
+ with gr.Column():
220
+ segment_btn_e = gr.Button("Segment Everything", variant='primary')
221
+ clear_btn_e = gr.Button("Clear", variant="secondary")
222
+
223
+ gr.Markdown("Try some of the examples below ⬇️")
224
+ gr.Examples(examples=examples,
225
+ inputs=[cond_img_e],
226
+ outputs=segm_img_e,
227
+ fn=segment_everything,
228
+ cache_examples=True,
229
+ examples_per_page=4)
230
+
231
+ with gr.Column():
232
+ with gr.Accordion("Advanced options", open=False):
233
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
234
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
235
+ with gr.Row():
236
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
237
+ with gr.Column():
238
+ retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
239
+
240
+ # Description
241
+ gr.Markdown(description_e)
242
+
243
+ with gr.Tab("Points mode"):
244
+ # Images
245
+ with gr.Row(variant="panel"):
246
+ with gr.Column(scale=1):
247
+ cond_img_p.render()
248
+
249
+ with gr.Column(scale=1):
250
+ segm_img_p.render()
251
+
252
+ # Submit & Clear
253
+ with gr.Row():
254
+ with gr.Column():
255
+ with gr.Row():
256
+ add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
257
+
258
+ with gr.Column():
259
+ segment_btn_p = gr.Button("Segment with points prompt", variant='primary')
260
+ clear_btn_p = gr.Button("Clear points", variant='secondary')
261
+
262
+ gr.Markdown("Try some of the examples below ⬇️")
263
+ gr.Examples(examples=examples,
264
+ inputs=[cond_img_p],
265
+ outputs=segm_img_p,
266
+ fn=segment_with_points,
267
+ # cache_examples=True,
268
+ examples_per_page=4)
269
+
270
+ with gr.Column():
271
+ # Description
272
+ gr.Markdown(description_p)
273
+
274
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
275
+
276
+ segment_btn_e.click(segment_everything,
277
+ inputs=[cond_img_e, input_size_slider, iou_threshold, conf_threshold, mor_check, contour_check, retina_check],
278
+ outputs=segm_img_e)
279
+
280
+ segment_btn_p.click(segment_with_points,
281
+ inputs=[cond_img_p],
282
+ outputs=[segm_img_p, cond_img_p])
283
+
284
+ def clear():
285
+ return None, None
286
+
287
+ clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
288
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
289
 
290
  demo.queue()
291
  demo.launch()
tools.py CHANGED
@@ -93,9 +93,7 @@ def fast_process(
93
  scale,
94
  better_quality=False,
95
  mask_random_color=True,
96
- points=None,
97
  bbox=None,
98
- point_label=None,
99
  use_retina=True,
100
  withContours=True,
101
  ):
@@ -117,8 +115,6 @@ def fast_process(
117
  plt.gca(),
118
  random_color=mask_random_color,
119
  bbox=bbox,
120
- points=points,
121
- pointlabel=point_label,
122
  retinamask=use_retina,
123
  target_height=original_h,
124
  target_width=original_w,
@@ -131,8 +127,6 @@ def fast_process(
131
  plt.gca(),
132
  random_color=mask_random_color,
133
  bbox=bbox,
134
- points=points,
135
- pointlabel=point_label,
136
  retinamask=use_retina,
137
  target_height=original_h,
138
  target_width=original_w,
@@ -159,7 +153,7 @@ def fast_process(
159
  cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
160
  color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
161
  contour_mask = temp / 255 * color.reshape(1, 1, -1)
162
- i
163
  image = image.convert('RGBA')
164
  overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
165
  image.paste(overlay_inner, (0, 0), overlay_inner)
@@ -177,8 +171,6 @@ def fast_show_mask(
177
  ax,
178
  random_color=False,
179
  bbox=None,
180
- points=None,
181
- pointlabel=None,
182
  retinamask=True,
183
  target_height=960,
184
  target_width=960,
@@ -209,16 +201,6 @@ def fast_show_mask(
209
  if bbox is not None:
210
  x1, y1, x2, y2 = bbox
211
  ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
212
- # draw point
213
- if points is not None:
214
- plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
215
- [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
216
- s=20,
217
- c='y')
218
- plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
219
- [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
220
- s=20,
221
- c='m')
222
 
223
  if retinamask == False:
224
  mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
@@ -231,8 +213,6 @@ def fast_show_mask_gpu(
231
  ax,
232
  random_color=False,
233
  bbox=None,
234
- points=None,
235
- pointlabel=None,
236
  retinamask=True,
237
  target_height=960,
238
  target_width=960,
@@ -269,20 +249,6 @@ def fast_show_mask_gpu(
269
  (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
270
  )
271
  )
272
- # draw point
273
- if points is not None:
274
- plt.scatter(
275
- [point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
276
- [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
277
- s=20,
278
- c="y",
279
- )
280
- plt.scatter(
281
- [point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
282
- [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
283
- s=20,
284
- c="m",
285
- )
286
  if retinamask == False:
287
  mask_cpu = cv2.resize(
288
  mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
 
93
  scale,
94
  better_quality=False,
95
  mask_random_color=True,
 
96
  bbox=None,
 
97
  use_retina=True,
98
  withContours=True,
99
  ):
 
115
  plt.gca(),
116
  random_color=mask_random_color,
117
  bbox=bbox,
 
 
118
  retinamask=use_retina,
119
  target_height=original_h,
120
  target_width=original_w,
 
127
  plt.gca(),
128
  random_color=mask_random_color,
129
  bbox=bbox,
 
 
130
  retinamask=use_retina,
131
  target_height=original_h,
132
  target_width=original_w,
 
153
  cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
154
  color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
155
  contour_mask = temp / 255 * color.reshape(1, 1, -1)
156
+
157
  image = image.convert('RGBA')
158
  overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
159
  image.paste(overlay_inner, (0, 0), overlay_inner)
 
171
  ax,
172
  random_color=False,
173
  bbox=None,
 
 
174
  retinamask=True,
175
  target_height=960,
176
  target_width=960,
 
201
  if bbox is not None:
202
  x1, y1, x2, y2 = bbox
203
  ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
 
 
 
 
 
 
 
 
 
 
204
 
205
  if retinamask == False:
206
  mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
 
213
  ax,
214
  random_color=False,
215
  bbox=None,
 
 
216
  retinamask=True,
217
  target_height=960,
218
  target_width=960,
 
249
  (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
250
  )
251
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  if retinamask == False:
253
  mask_cpu = cv2.resize(
254
  mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST