Xxuann commited on
Commit
1578923
·
verified ·
1 Parent(s): f7fc621

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +367 -3
app.py CHANGED
@@ -1,7 +1,371 @@
 
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  demo.launch()
 
1
+ from ultralytics import YOLO
2
  import gradio as gr
3
+ import torch
4
+ from utils.tools_gradio import fast_process
5
+ from utils.tools import format_results, box_prompt, point_prompt, text_prompt
6
+ from PIL import ImageDraw
7
+ import numpy as np
8
 
9
+ # Load the pre-trained model
10
+ model = YOLO('./weights/FastSAM.pt')
11
 
12
+ device = torch.device(
13
+ "cuda"
14
+ if torch.cuda.is_available()
15
+ else "mps"
16
+ if torch.backends.mps.is_available()
17
+ else "cpu"
18
+ )
19
+
20
+ # Description
21
+ title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
22
+
23
+ news = """ # 📖 News
24
+ 🔥 2023/07/14: Add a "wider result" button in text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/95)).
25
+ 🔥 2023/06/29: Support the text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/47)).
26
+ 🔥 2023/06/26: Support the points mode. (Better and faster interaction will come soon!)
27
+ 🔥 2023/06/24: Add the 'Advanced options" in Everything mode to get a more detailed adjustment.
28
+ """
29
+
30
+ description_e = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it.
31
+
32
+ 🎯 Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
33
+
34
+ ⌛️ It takes about 6~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
35
+
36
+ 🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
37
+
38
+ 📣 You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)
39
+
40
+ 😚 A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant.
41
+
42
+ 🏠 Check out our [Model Card 🏃](https://huggingface.co/An-619/FastSAM)
43
+
44
+ """
45
+
46
+ description_p = """ # 🎯 Instructions for points mode
47
+ This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it.
48
+
49
+ 1. Upload an image or choose an example.
50
+
51
+ 2. Choose the point label ('Add mask' means a positive point. 'Remove' Area means a negative point that is not segmented).
52
+
53
+ 3. Add points one by one on the image.
54
+
55
+ 4. Click the 'Segment with points prompt' button to get the segmentation results.
56
+
57
+ **5. If you get Error, click the 'Clear points' button and try again may help.**
58
+
59
+ """
60
+
61
+ examples = [["examples/sa_8776.jpg"], ["examples/sa_414.jpg"], ["examples/sa_1309.jpg"], ["examples/sa_11025.jpg"],
62
+ ["examples/sa_561.jpg"], ["examples/sa_192.jpg"], ["examples/sa_10039.jpg"], ["examples/sa_862.jpg"]]
63
+
64
+ default_example = examples[0]
65
+
66
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
67
+
68
+
69
+ def segment_everything(
70
+ input,
71
+ input_size=1024,
72
+ iou_threshold=0.7,
73
+ conf_threshold=0.25,
74
+ better_quality=False,
75
+ withContours=True,
76
+ use_retina=True,
77
+ text="",
78
+ wider=False,
79
+ mask_random_color=True,
80
+ ):
81
+ input_size = int(input_size) # 确保 imgsz 是整数
82
+ # Thanks for the suggestion by hysts in HuggingFace.
83
+ w, h = input.size
84
+ scale = input_size / max(w, h)
85
+ new_w = int(w * scale)
86
+ new_h = int(h * scale)
87
+ input = input.resize((new_w, new_h))
88
+
89
+ results = model(input,
90
+ device=device,
91
+ retina_masks=True,
92
+ iou=iou_threshold,
93
+ conf=conf_threshold,
94
+ imgsz=input_size,)
95
+
96
+ if len(text) > 0:
97
+ results = format_results(results[0], 0)
98
+ annotations, _ = text_prompt(results, text, input, device=device, wider=wider)
99
+ annotations = np.array([annotations])
100
+ else:
101
+ annotations = results[0].masks.data
102
+
103
+ fig = fast_process(annotations=annotations,
104
+ image=input,
105
+ device=device,
106
+ scale=(1024 // input_size),
107
+ better_quality=better_quality,
108
+ mask_random_color=mask_random_color,
109
+ bbox=None,
110
+ use_retina=use_retina,
111
+ withContours=withContours,)
112
+ return fig
113
+
114
+
115
+ def segment_with_points(
116
+ input,
117
+ input_size=1024,
118
+ iou_threshold=0.7,
119
+ conf_threshold=0.25,
120
+ better_quality=False,
121
+ withContours=True,
122
+ use_retina=True,
123
+ mask_random_color=True,
124
+ ):
125
+ global global_points
126
+ global global_point_label
127
+
128
+ input_size = int(input_size) # 确保 imgsz 是整数
129
+ # Thanks for the suggestion by hysts in HuggingFace.
130
+ w, h = input.size
131
+ scale = input_size / max(w, h)
132
+ new_w = int(w * scale)
133
+ new_h = int(h * scale)
134
+ input = input.resize((new_w, new_h))
135
+
136
+ scaled_points = [[int(x * scale) for x in point] for point in global_points]
137
+
138
+ results = model(input,
139
+ device=device,
140
+ retina_masks=True,
141
+ iou=iou_threshold,
142
+ conf=conf_threshold,
143
+ imgsz=input_size,)
144
+
145
+ results = format_results(results[0], 0)
146
+ annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
147
+ annotations = np.array([annotations])
148
+
149
+ fig = fast_process(annotations=annotations,
150
+ image=input,
151
+ device=device,
152
+ scale=(1024 // input_size),
153
+ better_quality=better_quality,
154
+ mask_random_color=mask_random_color,
155
+ bbox=None,
156
+ use_retina=use_retina,
157
+ withContours=withContours,)
158
+
159
+ global_points = []
160
+ global_point_label = []
161
+ return fig, None
162
+
163
+
164
+ def get_points_with_draw(image, label, evt: gr.SelectData):
165
+ global global_points
166
+ global global_point_label
167
+
168
+ x, y = evt.index[0], evt.index[1]
169
+ point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
170
+ global_points.append([x, y])
171
+ global_point_label.append(1 if label == 'Add Mask' else 0)
172
+
173
+ print(x, y, label == 'Add Mask')
174
+
175
+ # 创建一个可以在图像上绘图的对象
176
+ draw = ImageDraw.Draw(image)
177
+ draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
178
+ return image
179
+
180
+
181
+ cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
182
+ cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil')
183
+ cond_img_t = gr.Image(label="Input with text", value="examples/dogs.jpg", type='pil')
184
+
185
+ segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')
186
+ segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil')
187
+ segm_img_t = gr.Image(label="Segmented Image with text", interactive=False, type='pil')
188
+
189
+ global_points = []
190
+ global_point_label = []
191
+
192
+ input_size_slider = gr.components.Slider(minimum=512,
193
+ maximum=1024,
194
+ value=1024,
195
+ step=64,
196
+ label='Input_size',
197
+ info='Our model was trained on a size of 1024')
198
+
199
+ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
200
+ with gr.Row():
201
+ with gr.Column(scale=1):
202
+ # Title
203
+ gr.Markdown(title)
204
+
205
+ with gr.Column(scale=1):
206
+ # News
207
+ gr.Markdown(news)
208
+
209
+ with gr.Tab("Everything mode"):
210
+ # Images
211
+ with gr.Row(variant="panel"):
212
+ with gr.Column(scale=1):
213
+ cond_img_e.render()
214
+
215
+ with gr.Column(scale=1):
216
+ segm_img_e.render()
217
+
218
+ # Submit & Clear
219
+ with gr.Row():
220
+ with gr.Column():
221
+ input_size_slider.render()
222
+
223
+ with gr.Row():
224
+ contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
225
+
226
+ with gr.Column():
227
+ segment_btn_e = gr.Button("Segment Everything", variant='primary')
228
+ clear_btn_e = gr.Button("Clear", variant="secondary")
229
+
230
+ gr.Markdown("Try some of the examples below ⬇️")
231
+ gr.Examples(examples=examples,
232
+ inputs=[cond_img_e],
233
+ outputs=segm_img_e,
234
+ fn=segment_everything,
235
+ cache_examples=True,
236
+ examples_per_page=4)
237
+
238
+ with gr.Column():
239
+ with gr.Accordion("Advanced options", open=False):
240
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
241
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
242
+ with gr.Row():
243
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
244
+ with gr.Column():
245
+ retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
246
+
247
+ # Description
248
+ gr.Markdown(description_e)
249
+
250
+ segment_btn_e.click(segment_everything,
251
+ inputs=[
252
+ cond_img_e,
253
+ input_size_slider,
254
+ iou_threshold,
255
+ conf_threshold,
256
+ mor_check,
257
+ contour_check,
258
+ retina_check,
259
+ ],
260
+ outputs=segm_img_e)
261
+
262
+ with gr.Tab("Points mode"):
263
+ # Images
264
+ with gr.Row(variant="panel"):
265
+ with gr.Column(scale=1):
266
+ cond_img_p.render()
267
+
268
+ with gr.Column(scale=1):
269
+ segm_img_p.render()
270
+
271
+ # Submit & Clear
272
+ with gr.Row():
273
+ with gr.Column():
274
+ with gr.Row():
275
+ add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
276
+
277
+ with gr.Column():
278
+ segment_btn_p = gr.Button("Segment with points prompt", variant='primary')
279
+ clear_btn_p = gr.Button("Clear points", variant='secondary')
280
+
281
+ gr.Markdown("Try some of the examples below ⬇️")
282
+ gr.Examples(examples=examples,
283
+ inputs=[cond_img_p],
284
+ # outputs=segm_img_p,
285
+ # fn=segment_with_points,
286
+ # cache_examples=True,
287
+ examples_per_page=4)
288
+
289
+ with gr.Column():
290
+ # Description
291
+ gr.Markdown(description_p)
292
+
293
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
294
+
295
+ segment_btn_p.click(segment_with_points,
296
+ inputs=[cond_img_p],
297
+ outputs=[segm_img_p, cond_img_p])
298
+
299
+ with gr.Tab("Text mode"):
300
+ # Images
301
+ with gr.Row(variant="panel"):
302
+ with gr.Column(scale=1):
303
+ cond_img_t.render()
304
+
305
+ with gr.Column(scale=1):
306
+ segm_img_t.render()
307
+
308
+ # Submit & Clear
309
+ with gr.Row():
310
+ with gr.Column():
311
+ input_size_slider_t = gr.components.Slider(minimum=512,
312
+ maximum=1024,
313
+ value=1024,
314
+ step=64,
315
+ label='Input_size',
316
+ info='Our model was trained on a size of 1024')
317
+ with gr.Row():
318
+ with gr.Column():
319
+ contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
320
+ text_box = gr.Textbox(label="text prompt", value="a black dog")
321
+
322
+ with gr.Column():
323
+ segment_btn_t = gr.Button("Segment with text", variant='primary')
324
+ clear_btn_t = gr.Button("Clear", variant="secondary")
325
+
326
+ gr.Markdown("Try some of the examples below ⬇️")
327
+ gr.Examples(examples=[["examples/dogs.jpg"], ["examples/fruits.jpg"], ["examples/flowers.jpg"]],
328
+ inputs=[cond_img_t],
329
+ # outputs=segm_img_e,
330
+ # fn=segment_everything,
331
+ # cache_examples=True,
332
+ examples_per_page=4)
333
+
334
+ with gr.Column():
335
+ with gr.Accordion("Advanced options", open=False):
336
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
337
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
338
+ with gr.Row():
339
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
340
+ retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
341
+ wider_check = gr.Checkbox(value=False, label='wider', info='wider result')
342
+
343
+ # Description
344
+ gr.Markdown(description_e)
345
+
346
+ segment_btn_t.click(segment_everything,
347
+ inputs=[
348
+ cond_img_t,
349
+ input_size_slider_t,
350
+ iou_threshold,
351
+ conf_threshold,
352
+ mor_check,
353
+ contour_check,
354
+ retina_check,
355
+ text_box,
356
+ wider_check,
357
+ ],
358
+ outputs=segm_img_t)
359
+
360
+ def clear():
361
+ return None, None
362
+
363
+ def clear_text():
364
+ return None, None, None
365
+
366
+ clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
367
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
368
+ clear_btn_t.click(clear_text, outputs=[cond_img_p, segm_img_p, text_box])
369
+
370
+ demo.queue()
371
  demo.launch()