TheFrenchLeaf commited on
Commit
7342297
1 Parent(s): eb724c1
Files changed (1) hide show
  1. app.py +288 -288
app.py CHANGED
@@ -1,288 +1,288 @@
1
- from ultralytics import YOLO
2
- import gradio as gr
3
- import torch
4
- from tools import fast_process, format_results, box_prompt, point_prompt
5
- from PIL import ImageDraw
6
- import numpy as np
7
-
8
- # Load the pre-trained model
9
- model = YOLO('checkpoints/FastSAM.pt')
10
-
11
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
-
13
- # Description
14
- title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
15
-
16
- news = """ # 📖 News
17
-
18
- 🔥 2023/06/24: Add the 'Advanced options" in Everything mode to get a more detailed adjustment.
19
-
20
- 🔥 2023/06/26: Support the points mode. (Better and faster interaction will come soon!)
21
-
22
- """
23
-
24
- description_e = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it.
25
-
26
- 🎯 Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
27
-
28
- ⌛️ It takes about 6~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
29
-
30
- 🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
31
-
32
- 📣 You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)
33
-
34
- 😚 A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant.
35
-
36
- 🏠 Check out our [Model Card 🏃](https://huggingface.co/An-619/FastSAM)
37
-
38
- """
39
-
40
- description_p = """ # 🎯 Instructions for points mode
41
- This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it.
42
-
43
- 1. Upload an image or choose an example.
44
-
45
- 2. Choose the point label ('Add mask' means a positive point. 'Remove' Area means a negative point that is not segmented).
46
-
47
- 3. Add points one by one on the image.
48
-
49
- 4. Click the 'Segemnt with points prompt' button to get the segmentation results.
50
-
51
- **5. If you get Error, click the 'Clear points' button and try again may help.**
52
-
53
- """
54
-
55
- examples = [["assets/sa_8776.jpg"], ["assets/sa_414.jpg"], ["assets/sa_1309.jpg"], ["assets/sa_11025.jpg"],
56
- ["assets/sa_561.jpg"], ["assets/sa_192.jpg"], ["assets/sa_10039.jpg"], ["assets/sa_862.jpg"]]
57
-
58
- default_example = examples[0]
59
-
60
- css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
61
-
62
-
63
- def segment_everything(
64
- input,
65
- input_size=1024,
66
- iou_threshold=0.7,
67
- conf_threshold=0.25,
68
- better_quality=False,
69
- withContours=True,
70
- use_retina=True,
71
- mask_random_color=True,
72
- ):
73
- input_size = int(input_size) # 确保 imgsz 是整数
74
-
75
- # Thanks for the suggestion by hysts in HuggingFace.
76
- w, h = input.size
77
- scale = input_size / max(w, h)
78
- new_w = int(w * scale)
79
- new_h = int(h * scale)
80
- input = input.resize((new_w, new_h))
81
-
82
- results = model(input,
83
- device=device,
84
- retina_masks=True,
85
- iou=iou_threshold,
86
- conf=conf_threshold,
87
- imgsz=input_size,)
88
-
89
- fig = fast_process(annotations=results[0].masks.data,
90
- image=input,
91
- device=device,
92
- scale=(1024 // input_size),
93
- better_quality=better_quality,
94
- mask_random_color=mask_random_color,
95
- bbox=None,
96
- use_retina=use_retina,
97
- withContours=withContours,)
98
- return fig
99
-
100
- def segment_with_points(
101
- input,
102
- input_size=1024,
103
- iou_threshold=0.7,
104
- conf_threshold=0.25,
105
- better_quality=False,
106
- withContours=True,
107
- mask_random_color=True,
108
- use_retina=True,
109
- ):
110
- global global_points
111
- global global_point_label
112
-
113
- input_size = int(input_size) # 确保 imgsz 是整数
114
- # Thanks for the suggestion by hysts in HuggingFace.
115
- w, h = input.size
116
- scale = input_size / max(w, h)
117
- new_w = int(w * scale)
118
- new_h = int(h * scale)
119
- input = input.resize((new_w, new_h))
120
-
121
- scaled_points = [[int(x * scale) for x in point] for point in global_points]
122
-
123
- results = model(input,
124
- device=device,
125
- retina_masks=True,
126
- iou=iou_threshold,
127
- conf=conf_threshold,
128
- imgsz=input_size,)
129
-
130
- results = format_results(results[0], 0)
131
-
132
- annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
133
- annotations = np.array([annotations])
134
-
135
- fig = fast_process(annotations=annotations,
136
- image=input,
137
- device=device,
138
- scale=(1024 // input_size),
139
- better_quality=better_quality,
140
- mask_random_color=mask_random_color,
141
- bbox=None,
142
- use_retina=use_retina,
143
- withContours=withContours,)
144
- global_points = []
145
- global_point_label = []
146
- return fig, None
147
-
148
- def get_points_with_draw(image, label, evt: gr.SelectData):
149
- x, y = evt.index[0], evt.index[1]
150
- point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
151
- global global_points
152
- global global_point_label
153
- print((x, y))
154
- global_points.append([x, y])
155
- global_point_label.append(1 if label == 'Add Mask' else 0)
156
-
157
- # 创建一个可以在图像上绘图的对象
158
- draw = ImageDraw.Draw(image)
159
- draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
160
- return image
161
-
162
-
163
- # input_size=1024
164
- # high_quality_visual=True
165
- # inp = 'assets/sa_192.jpg'
166
- # input = Image.open(inp)
167
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
168
- # input_size = int(input_size) # 确保 imgsz 是整数
169
- # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
170
- # pil_image = fast_process(annotations=results[0].masks.data,
171
- # image=input, high_quality=high_quality_visual, device=device)
172
-
173
- cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
174
- cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil')
175
-
176
- segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')
177
- segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil')
178
-
179
- global_points = []
180
- global_point_label = [] # TODO:Clear points each image
181
-
182
- input_size_slider = gr.components.Slider(minimum=512,
183
- maximum=1024,
184
- value=1024,
185
- step=64,
186
- label='Input_size',
187
- info='Our model was trained on a size of 1024')
188
-
189
- with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
190
- with gr.Row():
191
- with gr.Column(scale=1):
192
- # Title
193
- gr.Markdown(title)
194
-
195
- with gr.Column(scale=1):
196
- # News
197
- gr.Markdown(news)
198
-
199
- with gr.Tab("Everything mode"):
200
- # Images
201
- with gr.Row(variant="panel"):
202
- with gr.Column(scale=1):
203
- cond_img_e.render()
204
-
205
- with gr.Column(scale=1):
206
- segm_img_e.render()
207
-
208
- # Submit & Clear
209
- with gr.Row():
210
- with gr.Column():
211
- input_size_slider.render()
212
-
213
- with gr.Row():
214
- contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
215
-
216
- with gr.Column():
217
- segment_btn_e = gr.Button("Segment Everything", variant='primary')
218
- clear_btn_e = gr.Button("Clear", variant="secondary")
219
-
220
- gr.Markdown("Try some of the examples below ⬇️")
221
- gr.Examples(examples=examples,
222
- inputs=[cond_img_e],
223
- outputs=segm_img_e,
224
- fn=segment_everything,
225
- cache_examples=True,
226
- examples_per_page=4)
227
-
228
- with gr.Column():
229
- with gr.Accordion("Advanced options", open=False):
230
- iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
231
- conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
232
- with gr.Row():
233
- mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
234
- with gr.Column():
235
- retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
236
-
237
- # Description
238
- gr.Markdown(description_e)
239
-
240
- with gr.Tab("Points mode"):
241
- # Images
242
- with gr.Row(variant="panel"):
243
- with gr.Column(scale=1):
244
- cond_img_p.render()
245
-
246
- with gr.Column(scale=1):
247
- segm_img_p.render()
248
-
249
- # Submit & Clear
250
- with gr.Row():
251
- with gr.Column():
252
- with gr.Row():
253
- add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
254
-
255
- with gr.Column():
256
- segment_btn_p = gr.Button("Segment with points prompt", variant='primary')
257
- clear_btn_p = gr.Button("Clear points", variant='secondary')
258
-
259
- gr.Markdown("Try some of the examples below ⬇️")
260
- gr.Examples(examples=examples,
261
- inputs=[cond_img_p],
262
- outputs=segm_img_p,
263
- fn=segment_with_points,
264
- # cache_examples=True,
265
- examples_per_page=4)
266
-
267
- with gr.Column():
268
- # Description
269
- gr.Markdown(description_p)
270
-
271
- cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
272
-
273
- segment_btn_e.click(segment_everything,
274
- inputs=[cond_img_e, input_size_slider, iou_threshold, conf_threshold, mor_check, contour_check, retina_check],
275
- outputs=segm_img_e)
276
-
277
- segment_btn_p.click(segment_with_points,
278
- inputs=[cond_img_p],
279
- outputs=[segm_img_p, cond_img_p])
280
-
281
- def clear():
282
- return None, None
283
-
284
- clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
285
- clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
286
-
287
- demo.queue()
288
- demo.launch()
 
1
+ from ultralytics import YOLO
2
+ import gradio as gr
3
+ import torch
4
+ from tools import fast_process, format_results, box_prompt, point_prompt
5
+ from PIL import ImageDraw
6
+ import numpy as np
7
+
8
+ # Load the pre-trained model
9
+ model = YOLO('checkpoints/FastSAM.pt')
10
+
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+
13
+ # Description
14
+ title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
15
+
16
+ news = """ # 📖 News
17
+
18
+ 🔥 2023/06/24: Add the 'Advanced options" in Everything mode to get a more detailed adjustment.
19
+
20
+ 🔥 2023/06/26: Support the points mode. (Better and faster interaction will come soon!)
21
+
22
+ """
23
+
24
+ description_e = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it.
25
+
26
+ 🎯 Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
27
+
28
+ ⌛️ It takes about 6~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
29
+
30
+ 🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
31
+
32
+ 📣 You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)
33
+
34
+ 😚 A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant.
35
+
36
+ 🏠 Check out our [Model Card 🏃](https://huggingface.co/An-619/FastSAM)
37
+
38
+ """
39
+
40
+ description_p = """ # 🎯 Instructions for points mode
41
+ This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it.
42
+
43
+ 1. Upload an image or choose an example.
44
+
45
+ 2. Choose the point label ('Add mask' means a positive point. 'Remove' Area means a negative point that is not segmented).
46
+
47
+ 3. Add points one by one on the image.
48
+
49
+ 4. Click the 'Segment with points prompt' button to get the segmentation results.
50
+
51
+ **5. If you get Error, click the 'Clear points' button and try again may help.**
52
+
53
+ """
54
+
55
+ examples = [["assets/sa_8776.jpg"], ["assets/sa_414.jpg"], ["assets/sa_1309.jpg"], ["assets/sa_11025.jpg"],
56
+ ["assets/sa_561.jpg"], ["assets/sa_192.jpg"], ["assets/sa_10039.jpg"], ["assets/sa_862.jpg"]]
57
+
58
+ default_example = examples[0]
59
+
60
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
61
+
62
+
63
+ def segment_everything(
64
+ input,
65
+ input_size=1024,
66
+ iou_threshold=0.7,
67
+ conf_threshold=0.25,
68
+ better_quality=False,
69
+ withContours=True,
70
+ use_retina=True,
71
+ mask_random_color=True,
72
+ ):
73
+ input_size = int(input_size) # 确保 imgsz 是整数
74
+
75
+ # Thanks for the suggestion by hysts in HuggingFace.
76
+ w, h = input.size
77
+ scale = input_size / max(w, h)
78
+ new_w = int(w * scale)
79
+ new_h = int(h * scale)
80
+ input = input.resize((new_w, new_h))
81
+
82
+ results = model(input,
83
+ device=device,
84
+ retina_masks=True,
85
+ iou=iou_threshold,
86
+ conf=conf_threshold,
87
+ imgsz=input_size,)
88
+
89
+ fig = fast_process(annotations=results[0].masks.data,
90
+ image=input,
91
+ device=device,
92
+ scale=(1024 // input_size),
93
+ better_quality=better_quality,
94
+ mask_random_color=mask_random_color,
95
+ bbox=None,
96
+ use_retina=use_retina,
97
+ withContours=withContours,)
98
+ return fig
99
+
100
+ def segment_with_points(
101
+ input,
102
+ input_size=1024,
103
+ iou_threshold=0.7,
104
+ conf_threshold=0.25,
105
+ better_quality=False,
106
+ withContours=True,
107
+ mask_random_color=True,
108
+ use_retina=True,
109
+ ):
110
+ global global_points
111
+ global global_point_label
112
+
113
+ input_size = int(input_size) # 确保 imgsz 是整数
114
+ # Thanks for the suggestion by hysts in HuggingFace.
115
+ w, h = input.size
116
+ scale = input_size / max(w, h)
117
+ new_w = int(w * scale)
118
+ new_h = int(h * scale)
119
+ input = input.resize((new_w, new_h))
120
+
121
+ scaled_points = [[int(x * scale) for x in point] for point in global_points]
122
+
123
+ results = model(input,
124
+ device=device,
125
+ retina_masks=True,
126
+ iou=iou_threshold,
127
+ conf=conf_threshold,
128
+ imgsz=input_size,)
129
+
130
+ results = format_results(results[0], 0)
131
+
132
+ annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
133
+ annotations = np.array([annotations])
134
+
135
+ fig = fast_process(annotations=annotations,
136
+ image=input,
137
+ device=device,
138
+ scale=(1024 // input_size),
139
+ better_quality=better_quality,
140
+ mask_random_color=mask_random_color,
141
+ bbox=None,
142
+ use_retina=use_retina,
143
+ withContours=withContours,)
144
+ global_points = []
145
+ global_point_label = []
146
+ return fig, None
147
+
148
+ def get_points_with_draw(image, label, evt: gr.SelectData):
149
+ x, y = evt.index[0], evt.index[1]
150
+ point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
151
+ global global_points
152
+ global global_point_label
153
+ print((x, y))
154
+ global_points.append([x, y])
155
+ global_point_label.append(1 if label == 'Add Mask' else 0)
156
+
157
+ # 创建一个可以在图像上绘图的对象
158
+ draw = ImageDraw.Draw(image)
159
+ draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
160
+ return image
161
+
162
+
163
+ # input_size=1024
164
+ # high_quality_visual=True
165
+ # inp = 'assets/sa_192.jpg'
166
+ # input = Image.open(inp)
167
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
168
+ # input_size = int(input_size) # 确保 imgsz 是整数
169
+ # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
170
+ # pil_image = fast_process(annotations=results[0].masks.data,
171
+ # image=input, high_quality=high_quality_visual, device=device)
172
+
173
+ cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
174
+ cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil')
175
+
176
+ segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')
177
+ segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil')
178
+
179
+ global_points = []
180
+ global_point_label = [] # TODO:Clear points each image
181
+
182
+ input_size_slider = gr.components.Slider(minimum=512,
183
+ maximum=1024,
184
+ value=1024,
185
+ step=64,
186
+ label='Input_size',
187
+ info='Our model was trained on a size of 1024')
188
+
189
+ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
190
+ with gr.Row():
191
+ with gr.Column(scale=1):
192
+ # Title
193
+ gr.Markdown(title)
194
+
195
+ with gr.Column(scale=1):
196
+ # News
197
+ gr.Markdown(news)
198
+
199
+ with gr.Tab("Everything mode"):
200
+ # Images
201
+ with gr.Row(variant="panel"):
202
+ with gr.Column(scale=1):
203
+ cond_img_e.render()
204
+
205
+ with gr.Column(scale=1):
206
+ segm_img_e.render()
207
+
208
+ # Submit & Clear
209
+ with gr.Row():
210
+ with gr.Column():
211
+ input_size_slider.render()
212
+
213
+ with gr.Row():
214
+ contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
215
+
216
+ with gr.Column():
217
+ segment_btn_e = gr.Button("Segment Everything", variant='primary')
218
+ clear_btn_e = gr.Button("Clear", variant="secondary")
219
+
220
+ gr.Markdown("Try some of the examples below ⬇️")
221
+ gr.Examples(examples=examples,
222
+ inputs=[cond_img_e],
223
+ outputs=segm_img_e,
224
+ fn=segment_everything,
225
+ cache_examples=True,
226
+ examples_per_page=4)
227
+
228
+ with gr.Column():
229
+ with gr.Accordion("Advanced options", open=False):
230
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
231
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
232
+ with gr.Row():
233
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
234
+ with gr.Column():
235
+ retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
236
+
237
+ # Description
238
+ gr.Markdown(description_e)
239
+
240
+ with gr.Tab("Points mode"):
241
+ # Images
242
+ with gr.Row(variant="panel"):
243
+ with gr.Column(scale=1):
244
+ cond_img_p.render()
245
+
246
+ with gr.Column(scale=1):
247
+ segm_img_p.render()
248
+
249
+ # Submit & Clear
250
+ with gr.Row():
251
+ with gr.Column():
252
+ with gr.Row():
253
+ add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
254
+
255
+ with gr.Column():
256
+ segment_btn_p = gr.Button("Segment with points prompt", variant='primary')
257
+ clear_btn_p = gr.Button("Clear points", variant='secondary')
258
+
259
+ gr.Markdown("Try some of the examples below ⬇️")
260
+ gr.Examples(examples=examples,
261
+ inputs=[cond_img_p],
262
+ outputs=segm_img_p,
263
+ fn=segment_with_points,
264
+ # cache_examples=True,
265
+ examples_per_page=4)
266
+
267
+ with gr.Column():
268
+ # Description
269
+ gr.Markdown(description_p)
270
+
271
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
272
+
273
+ segment_btn_e.click(segment_everything,
274
+ inputs=[cond_img_e, input_size_slider, iou_threshold, conf_threshold, mor_check, contour_check, retina_check],
275
+ outputs=segm_img_e)
276
+
277
+ segment_btn_p.click(segment_with_points,
278
+ inputs=[cond_img_p],
279
+ outputs=[segm_img_p, cond_img_p])
280
+
281
+ def clear():
282
+ return None, None
283
+
284
+ clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
285
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
286
+
287
+ demo.queue()
288
+ demo.launch()