AAAAAAAyq commited on
Commit
eb724c1
โ€ข
2 Parent(s): 99c2e65 a5288e6

Merge branch 'main' of https://huggingface.co/spaces/An-619/FastSAM into main

Browse files
Files changed (2) hide show
  1. app.py +288 -288
  2. app_debug.py +0 -287
app.py CHANGED
@@ -1,288 +1,288 @@
1
- from ultralytics import YOLO
2
- import gradio as gr
3
- import torch
4
- from tools import fast_process, format_results, box_prompt, point_prompt
5
- from PIL import ImageDraw
6
- import numpy as np
7
-
8
- # Load the pre-trained model
9
- model = YOLO('checkpoints/FastSAM.pt')
10
-
11
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
-
13
- # Description
14
- title = "<center><strong><font size='8'>๐Ÿƒ Fast Segment Anything ๐Ÿค—</font></strong></center>"
15
-
16
- news = """ # ๐Ÿ“– News
17
-
18
- ๐Ÿ”ฅ 2023/06/24: Add the 'Advanced options" in Everything mode to get a more detailed adjustment.
19
-
20
- ๐Ÿ”ฅ 2023/06/26: Support the points mode. (Better and faster interaction will come soon!)
21
-
22
- """
23
-
24
- description_e = """This is a demo on Github project ๐Ÿƒ [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star โญ๏ธ to it.
25
-
26
- ๐ŸŽฏ Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
27
-
28
- โŒ›๏ธ It takes about 6~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
29
-
30
- ๐Ÿš€ To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
31
-
32
- ๐Ÿ“ฃ You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)
33
-
34
- ๐Ÿ˜š A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant.
35
-
36
- ๐Ÿ  Check out our [Model Card ๐Ÿƒ](https://huggingface.co/An-619/FastSAM)
37
-
38
- """
39
-
40
- description_p = """ # ๐ŸŽฏ Instructions for points mode
41
- This is a demo on Github project ๐Ÿƒ [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star โญ๏ธ to it.
42
-
43
- 1. Upload an image or choose an example.
44
-
45
- 2. Choose the point label ('Add mask' means a positive point. 'Remove' Area means a negative point that is not segmented).
46
-
47
- 3. Add points one by one on the image.
48
-
49
- 4. Click the 'Segemnt with points prompt' button to get the segmentation results.
50
-
51
- **5. If you get Error, click the 'Clear points' button and try again may help.**
52
-
53
- """
54
-
55
- examples = [["assets/sa_8776.jpg"], ["assets/sa_414.jpg"], ["assets/sa_1309.jpg"], ["assets/sa_11025.jpg"],
56
- ["assets/sa_561.jpg"], ["assets/sa_192.jpg"], ["assets/sa_10039.jpg"], ["assets/sa_862.jpg"]]
57
-
58
- default_example = examples[0]
59
-
60
- css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
61
-
62
-
63
- def segment_everything(
64
- input,
65
- input_size=1024,
66
- iou_threshold=0.7,
67
- conf_threshold=0.25,
68
- better_quality=False,
69
- withContours=True,
70
- use_retina=True,
71
- mask_random_color=True,
72
- ):
73
- input_size = int(input_size) # ็กฎไฟ imgsz ๆ˜ฏๆ•ดๆ•ฐ
74
-
75
- # Thanks for the suggestion by hysts in HuggingFace.
76
- w, h = input.size
77
- scale = input_size / max(w, h)
78
- new_w = int(w * scale)
79
- new_h = int(h * scale)
80
- input = input.resize((new_w, new_h))
81
-
82
- results = model(input,
83
- device=device,
84
- retina_masks=True,
85
- iou=iou_threshold,
86
- conf=conf_threshold,
87
- imgsz=input_size,)
88
-
89
- fig = fast_process(annotations=results[0].masks.data,
90
- image=input,
91
- device=device,
92
- scale=(1024 // input_size),
93
- better_quality=better_quality,
94
- mask_random_color=mask_random_color,
95
- bbox=None,
96
- use_retina=use_retina,
97
- withContours=withContours,)
98
- return fig
99
-
100
- def segment_with_points(
101
- input,
102
- input_size=1024,
103
- iou_threshold=0.7,
104
- conf_threshold=0.25,
105
- better_quality=False,
106
- withContours=True,
107
- mask_random_color=True,
108
- use_retina=True,
109
- ):
110
- global global_points
111
- global global_point_label
112
-
113
- input_size = int(input_size) # ็กฎไฟ imgsz ๆ˜ฏๆ•ดๆ•ฐ
114
- # Thanks for the suggestion by hysts in HuggingFace.
115
- w, h = input.size
116
- scale = input_size / max(w, h)
117
- new_w = int(w * scale)
118
- new_h = int(h * scale)
119
- input = input.resize((new_w, new_h))
120
-
121
- scaled_points = [[int(x * scale) for x in point] for point in global_points]
122
-
123
- results = model(input,
124
- device=device,
125
- retina_masks=True,
126
- iou=iou_threshold,
127
- conf=conf_threshold,
128
- imgsz=input_size,)
129
-
130
- results = format_results(results[0], 0)
131
-
132
- annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
133
- annotations = np.array([annotations])
134
-
135
- fig = fast_process(annotations=annotations,
136
- image=input,
137
- device=device,
138
- scale=(1024 // input_size),
139
- better_quality=better_quality,
140
- mask_random_color=mask_random_color,
141
- bbox=None,
142
- use_retina=use_retina,
143
- withContours=withContours,)
144
- global_points = []
145
- global_point_label = []
146
- return fig, None
147
-
148
- def get_points_with_draw(image, label, evt: gr.SelectData):
149
- x, y = evt.index[0], evt.index[1]
150
- point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
151
- global global_points
152
- global global_point_label
153
- print((x, y))
154
- global_points.append([x, y])
155
- global_point_label.append(1 if label == 'Add Mask' else 0)
156
-
157
- # ๅˆ›ๅปบไธ€ไธชๅฏไปฅๅœจๅ›พๅƒไธŠ็ป˜ๅ›พ็š„ๅฏน่ฑก
158
- draw = ImageDraw.Draw(image)
159
- draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
160
- return image
161
-
162
-
163
- # input_size=1024
164
- # high_quality_visual=True
165
- # inp = 'assets/sa_192.jpg'
166
- # input = Image.open(inp)
167
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
168
- # input_size = int(input_size) # ็กฎไฟ imgsz ๆ˜ฏๆ•ดๆ•ฐ
169
- # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
170
- # pil_image = fast_process(annotations=results[0].masks.data,
171
- # image=input, high_quality=high_quality_visual, device=device)
172
-
173
- cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
174
- cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil')
175
-
176
- segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')
177
- segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil')
178
-
179
- global_points = []
180
- global_point_label = [] # TODO:Clear points each image
181
-
182
- input_size_slider = gr.components.Slider(minimum=512,
183
- maximum=1024,
184
- value=1024,
185
- step=64,
186
- label='Input_size',
187
- info='Our model was trained on a size of 1024')
188
-
189
- with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
190
- with gr.Row():
191
- with gr.Column(scale=1):
192
- # Title
193
- gr.Markdown(title)
194
-
195
- with gr.Column(scale=1):
196
- # News
197
- gr.Markdown(news)
198
-
199
- with gr.Tab("Everything mode"):
200
- # Images
201
- with gr.Row(variant="panel"):
202
- with gr.Column(scale=1):
203
- cond_img_e.render()
204
-
205
- with gr.Column(scale=1):
206
- segm_img_e.render()
207
-
208
- # Submit & Clear
209
- with gr.Row():
210
- with gr.Column():
211
- input_size_slider.render()
212
-
213
- with gr.Row():
214
- contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
215
-
216
- with gr.Column():
217
- segment_btn_e = gr.Button("Segment Everything", variant='primary')
218
- clear_btn_e = gr.Button("Clear", variant="secondary")
219
-
220
- gr.Markdown("Try some of the examples below โฌ‡๏ธ")
221
- gr.Examples(examples=examples,
222
- inputs=[cond_img_e],
223
- outputs=segm_img_e,
224
- fn=segment_everything,
225
- cache_examples=True,
226
- examples_per_page=4)
227
-
228
- with gr.Column():
229
- with gr.Accordion("Advanced options", open=False):
230
- iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
231
- conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
232
- with gr.Row():
233
- mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
234
- with gr.Column():
235
- retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
236
-
237
- # Description
238
- gr.Markdown(description_e)
239
-
240
- with gr.Tab("Points mode"):
241
- # Images
242
- with gr.Row(variant="panel"):
243
- with gr.Column(scale=1):
244
- cond_img_p.render()
245
-
246
- with gr.Column(scale=1):
247
- segm_img_p.render()
248
-
249
- # Submit & Clear
250
- with gr.Row():
251
- with gr.Column():
252
- with gr.Row():
253
- add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
254
-
255
- with gr.Column():
256
- segment_btn_p = gr.Button("Segment with points prompt", variant='primary')
257
- clear_btn_p = gr.Button("Clear points", variant='secondary')
258
-
259
- gr.Markdown("Try some of the examples below โฌ‡๏ธ")
260
- gr.Examples(examples=examples,
261
- inputs=[cond_img_p],
262
- outputs=segm_img_p,
263
- fn=segment_with_points,
264
- # cache_examples=True,
265
- examples_per_page=4)
266
-
267
- with gr.Column():
268
- # Description
269
- gr.Markdown(description_p)
270
-
271
- cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
272
-
273
- segment_btn_e.click(segment_everything,
274
- inputs=[cond_img_e, input_size_slider, iou_threshold, conf_threshold, mor_check, contour_check, retina_check],
275
- outputs=segm_img_e)
276
-
277
- segment_btn_p.click(segment_with_points,
278
- inputs=[cond_img_p],
279
- outputs=[segm_img_p, cond_img_p])
280
-
281
- def clear():
282
- return None, None
283
-
284
- clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
285
- clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
286
-
287
- demo.queue()
288
- demo.launch()
 
1
+ from ultralytics import YOLO
2
+ import gradio as gr
3
+ import torch
4
+ from tools import fast_process, format_results, box_prompt, point_prompt
5
+ from PIL import ImageDraw
6
+ import numpy as np
7
+
8
+ # Load the pre-trained model
9
+ model = YOLO('checkpoints/FastSAM.pt')
10
+
11
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
12
+
13
+ # Description
14
+ title = "<center><strong><font size='8'>๐Ÿƒ Fast Segment Anything ๐Ÿค—</font></strong></center>"
15
+
16
+ news = """ # ๐Ÿ“– News
17
+
18
+ ๐Ÿ”ฅ 2023/06/24: Add the 'Advanced options" in Everything mode to get a more detailed adjustment.
19
+
20
+ ๐Ÿ”ฅ 2023/06/26: Support the points mode. (Better and faster interaction will come soon!)
21
+
22
+ """
23
+
24
+ description_e = """This is a demo on Github project ๐Ÿƒ [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star โญ๏ธ to it.
25
+
26
+ ๐ŸŽฏ Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
27
+
28
+ โŒ›๏ธ It takes about 6~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
29
+
30
+ ๐Ÿš€ To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
31
+
32
+ ๐Ÿ“ฃ You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)
33
+
34
+ ๐Ÿ˜š A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant.
35
+
36
+ ๐Ÿ  Check out our [Model Card ๐Ÿƒ](https://huggingface.co/An-619/FastSAM)
37
+
38
+ """
39
+
40
+ description_p = """ # ๐ŸŽฏ Instructions for points mode
41
+ This is a demo on Github project ๐Ÿƒ [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star โญ๏ธ to it.
42
+
43
+ 1. Upload an image or choose an example.
44
+
45
+ 2. Choose the point label ('Add mask' means a positive point. 'Remove' Area means a negative point that is not segmented).
46
+
47
+ 3. Add points one by one on the image.
48
+
49
+ 4. Click the 'Segemnt with points prompt' button to get the segmentation results.
50
+
51
+ **5. If you get Error, click the 'Clear points' button and try again may help.**
52
+
53
+ """
54
+
55
+ examples = [["assets/sa_8776.jpg"], ["assets/sa_414.jpg"], ["assets/sa_1309.jpg"], ["assets/sa_11025.jpg"],
56
+ ["assets/sa_561.jpg"], ["assets/sa_192.jpg"], ["assets/sa_10039.jpg"], ["assets/sa_862.jpg"]]
57
+
58
+ default_example = examples[0]
59
+
60
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
61
+
62
+
63
+ def segment_everything(
64
+ input,
65
+ input_size=1024,
66
+ iou_threshold=0.7,
67
+ conf_threshold=0.25,
68
+ better_quality=False,
69
+ withContours=True,
70
+ use_retina=True,
71
+ mask_random_color=True,
72
+ ):
73
+ input_size = int(input_size) # ็กฎไฟ imgsz ๆ˜ฏๆ•ดๆ•ฐ
74
+
75
+ # Thanks for the suggestion by hysts in HuggingFace.
76
+ w, h = input.size
77
+ scale = input_size / max(w, h)
78
+ new_w = int(w * scale)
79
+ new_h = int(h * scale)
80
+ input = input.resize((new_w, new_h))
81
+
82
+ results = model(input,
83
+ device=device,
84
+ retina_masks=True,
85
+ iou=iou_threshold,
86
+ conf=conf_threshold,
87
+ imgsz=input_size,)
88
+
89
+ fig = fast_process(annotations=results[0].masks.data,
90
+ image=input,
91
+ device=device,
92
+ scale=(1024 // input_size),
93
+ better_quality=better_quality,
94
+ mask_random_color=mask_random_color,
95
+ bbox=None,
96
+ use_retina=use_retina,
97
+ withContours=withContours,)
98
+ return fig
99
+
100
+ def segment_with_points(
101
+ input,
102
+ input_size=1024,
103
+ iou_threshold=0.7,
104
+ conf_threshold=0.25,
105
+ better_quality=False,
106
+ withContours=True,
107
+ mask_random_color=True,
108
+ use_retina=True,
109
+ ):
110
+ global global_points
111
+ global global_point_label
112
+
113
+ input_size = int(input_size) # ็กฎไฟ imgsz ๆ˜ฏๆ•ดๆ•ฐ
114
+ # Thanks for the suggestion by hysts in HuggingFace.
115
+ w, h = input.size
116
+ scale = input_size / max(w, h)
117
+ new_w = int(w * scale)
118
+ new_h = int(h * scale)
119
+ input = input.resize((new_w, new_h))
120
+
121
+ scaled_points = [[int(x * scale) for x in point] for point in global_points]
122
+
123
+ results = model(input,
124
+ device=device,
125
+ retina_masks=True,
126
+ iou=iou_threshold,
127
+ conf=conf_threshold,
128
+ imgsz=input_size,)
129
+
130
+ results = format_results(results[0], 0)
131
+
132
+ annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
133
+ annotations = np.array([annotations])
134
+
135
+ fig = fast_process(annotations=annotations,
136
+ image=input,
137
+ device=device,
138
+ scale=(1024 // input_size),
139
+ better_quality=better_quality,
140
+ mask_random_color=mask_random_color,
141
+ bbox=None,
142
+ use_retina=use_retina,
143
+ withContours=withContours,)
144
+ global_points = []
145
+ global_point_label = []
146
+ return fig, None
147
+
148
+ def get_points_with_draw(image, label, evt: gr.SelectData):
149
+ x, y = evt.index[0], evt.index[1]
150
+ point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
151
+ global global_points
152
+ global global_point_label
153
+ print((x, y))
154
+ global_points.append([x, y])
155
+ global_point_label.append(1 if label == 'Add Mask' else 0)
156
+
157
+ # ๅˆ›ๅปบไธ€ไธชๅฏไปฅๅœจๅ›พๅƒไธŠ็ป˜ๅ›พ็š„ๅฏน่ฑก
158
+ draw = ImageDraw.Draw(image)
159
+ draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
160
+ return image
161
+
162
+
163
+ # input_size=1024
164
+ # high_quality_visual=True
165
+ # inp = 'assets/sa_192.jpg'
166
+ # input = Image.open(inp)
167
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
168
+ # input_size = int(input_size) # ็กฎไฟ imgsz ๆ˜ฏๆ•ดๆ•ฐ
169
+ # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
170
+ # pil_image = fast_process(annotations=results[0].masks.data,
171
+ # image=input, high_quality=high_quality_visual, device=device)
172
+
173
+ cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
174
+ cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil')
175
+
176
+ segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')
177
+ segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil')
178
+
179
+ global_points = []
180
+ global_point_label = [] # TODO:Clear points each image
181
+
182
+ input_size_slider = gr.components.Slider(minimum=512,
183
+ maximum=1024,
184
+ value=1024,
185
+ step=64,
186
+ label='Input_size',
187
+ info='Our model was trained on a size of 1024')
188
+
189
+ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
190
+ with gr.Row():
191
+ with gr.Column(scale=1):
192
+ # Title
193
+ gr.Markdown(title)
194
+
195
+ with gr.Column(scale=1):
196
+ # News
197
+ gr.Markdown(news)
198
+
199
+ with gr.Tab("Everything mode"):
200
+ # Images
201
+ with gr.Row(variant="panel"):
202
+ with gr.Column(scale=1):
203
+ cond_img_e.render()
204
+
205
+ with gr.Column(scale=1):
206
+ segm_img_e.render()
207
+
208
+ # Submit & Clear
209
+ with gr.Row():
210
+ with gr.Column():
211
+ input_size_slider.render()
212
+
213
+ with gr.Row():
214
+ contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
215
+
216
+ with gr.Column():
217
+ segment_btn_e = gr.Button("Segment Everything", variant='primary')
218
+ clear_btn_e = gr.Button("Clear", variant="secondary")
219
+
220
+ gr.Markdown("Try some of the examples below โฌ‡๏ธ")
221
+ gr.Examples(examples=examples,
222
+ inputs=[cond_img_e],
223
+ outputs=segm_img_e,
224
+ fn=segment_everything,
225
+ cache_examples=True,
226
+ examples_per_page=4)
227
+
228
+ with gr.Column():
229
+ with gr.Accordion("Advanced options", open=False):
230
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
231
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
232
+ with gr.Row():
233
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
234
+ with gr.Column():
235
+ retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
236
+
237
+ # Description
238
+ gr.Markdown(description_e)
239
+
240
+ with gr.Tab("Points mode"):
241
+ # Images
242
+ with gr.Row(variant="panel"):
243
+ with gr.Column(scale=1):
244
+ cond_img_p.render()
245
+
246
+ with gr.Column(scale=1):
247
+ segm_img_p.render()
248
+
249
+ # Submit & Clear
250
+ with gr.Row():
251
+ with gr.Column():
252
+ with gr.Row():
253
+ add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
254
+
255
+ with gr.Column():
256
+ segment_btn_p = gr.Button("Segment with points prompt", variant='primary')
257
+ clear_btn_p = gr.Button("Clear points", variant='secondary')
258
+
259
+ gr.Markdown("Try some of the examples below โฌ‡๏ธ")
260
+ gr.Examples(examples=examples,
261
+ inputs=[cond_img_p],
262
+ outputs=segm_img_p,
263
+ fn=segment_with_points,
264
+ # cache_examples=True,
265
+ examples_per_page=4)
266
+
267
+ with gr.Column():
268
+ # Description
269
+ gr.Markdown(description_p)
270
+
271
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
272
+
273
+ segment_btn_e.click(segment_everything,
274
+ inputs=[cond_img_e, input_size_slider, iou_threshold, conf_threshold, mor_check, contour_check, retina_check],
275
+ outputs=segm_img_e)
276
+
277
+ segment_btn_p.click(segment_with_points,
278
+ inputs=[cond_img_p],
279
+ outputs=[segm_img_p, cond_img_p])
280
+
281
+ def clear():
282
+ return None, None
283
+
284
+ clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
285
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
286
+
287
+ demo.queue()
288
+ demo.launch()
app_debug.py DELETED
@@ -1,287 +0,0 @@
1
- from ultralytics import YOLO
2
- import numpy as np
3
- import matplotlib.pyplot as plt
4
- import gradio as gr
5
- import cv2
6
- import torch
7
- from PIL import Image
8
-
9
- # Load the pre-trained model
10
- model = YOLO('checkpoints/FastSAM.pt')
11
-
12
- # Description
13
- title = "<center><strong><font size='8'>๐Ÿƒ Fast Segment Anything ๐Ÿค—</font></strong></center>"
14
-
15
- description = """This is a demo on Github project ๐Ÿƒ [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM).
16
-
17
- ๐ŸŽฏ Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
18
-
19
- โŒ›๏ธ It takes about 4~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
20
-
21
- ๐Ÿš€ To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
22
-
23
- ๐Ÿ“ฃ You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)
24
-
25
- ๐Ÿ˜š A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant.
26
-
27
- ๐Ÿ  Check out our [Model Card ๐Ÿƒ](https://huggingface.co/An-619/FastSAM)
28
-
29
- """
30
-
31
- examples = [["assets/sa_8776.jpg"], ["assets/sa_414.jpg"],
32
- ["assets/sa_1309.jpg"], ["assets/sa_11025.jpg"],
33
- ["assets/sa_561.jpg"], ["assets/sa_192.jpg"],
34
- ["assets/sa_10039.jpg"], ["assets/sa_862.jpg"]]
35
-
36
- default_example = examples[0]
37
-
38
- css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
39
-
40
- def fast_process(annotations, image, high_quality, device, scale):
41
- if isinstance(annotations[0],dict):
42
- annotations = [annotation['segmentation'] for annotation in annotations]
43
-
44
- original_h = image.height
45
- original_w = image.width
46
- if high_quality == True:
47
- if isinstance(annotations[0],torch.Tensor):
48
- annotations = np.array(annotations.cpu())
49
- for i, mask in enumerate(annotations):
50
- mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
51
- annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
52
- if device == 'cpu':
53
- annotations = np.array(annotations)
54
- inner_mask = fast_show_mask(annotations,
55
- plt.gca(),
56
- bbox=None,
57
- points=None,
58
- pointlabel=None,
59
- retinamask=True,
60
- target_height=original_h,
61
- target_width=original_w)
62
- else:
63
- if isinstance(annotations[0],np.ndarray):
64
- annotations = torch.from_numpy(annotations)
65
- inner_mask = fast_show_mask_gpu(annotations,
66
- plt.gca(),
67
- bbox=None,
68
- points=None,
69
- pointlabel=None)
70
- if isinstance(annotations, torch.Tensor):
71
- annotations = annotations.cpu().numpy()
72
-
73
- if high_quality:
74
- contour_all = []
75
- temp = np.zeros((original_h, original_w,1))
76
- for i, mask in enumerate(annotations):
77
- if type(mask) == dict:
78
- mask = mask['segmentation']
79
- annotation = mask.astype(np.uint8)
80
- contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
81
- for contour in contours:
82
- contour_all.append(contour)
83
- cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
84
- color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
85
- contour_mask = temp / 255 * color.reshape(1, 1, -1)
86
- image = image.convert('RGBA')
87
-
88
- overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
89
- image.paste(overlay_inner, (0, 0), overlay_inner)
90
-
91
- if high_quality:
92
- overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
93
- image.paste(overlay_contour, (0, 0), overlay_contour)
94
-
95
- return image
96
-
97
- # CPU post process
98
- def fast_show_mask(annotation, ax, bbox=None,
99
- points=None, pointlabel=None,
100
- retinamask=True, target_height=960,
101
- target_width=960):
102
- msak_sum = annotation.shape[0]
103
- height = annotation.shape[1]
104
- weight = annotation.shape[2]
105
- # ๅฐ†annotation ๆŒ‰็…ง้ข็งฏ ๆŽ’ๅบ
106
- areas = np.sum(annotation, axis=(1, 2))
107
- sorted_indices = np.argsort(areas)[::1]
108
- annotation = annotation[sorted_indices]
109
-
110
- index = (annotation != 0).argmax(axis=0)
111
- color = np.random.random((msak_sum,1,1,3))
112
- transparency = np.ones((msak_sum,1,1,1)) * 0.6
113
- visual = np.concatenate([color,transparency],axis=-1)
114
- mask_image = np.expand_dims(annotation,-1) * visual
115
-
116
- mask = np.zeros((height,weight,4))
117
-
118
- h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
119
- indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
120
- # ไฝฟ็”จๅ‘้‡ๅŒ–็ดขๅผ•ๆ›ดๆ–ฐshow็š„ๅ€ผ
121
- mask[h_indices, w_indices, :] = mask_image[indices]
122
- if bbox is not None:
123
- x1, y1, x2, y2 = bbox
124
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
125
- # draw point
126
- if points is not None:
127
- plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
128
- plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
129
-
130
- if retinamask==False:
131
- mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
132
-
133
- return mask
134
-
135
-
136
- def fast_show_mask_gpu(annotation, ax,
137
- bbox=None, points=None,
138
- pointlabel=None):
139
- msak_sum = annotation.shape[0]
140
- height = annotation.shape[1]
141
- weight = annotation.shape[2]
142
- areas = torch.sum(annotation, dim=(1, 2))
143
- sorted_indices = torch.argsort(areas, descending=False)
144
- annotation = annotation[sorted_indices]
145
- # ๆ‰พๆฏไธชไฝ็ฝฎ็ฌฌไธ€ไธช้ž้›ถๅ€ผไธ‹ๆ ‡
146
- index = (annotation != 0).to(torch.long).argmax(dim=0)
147
- color = torch.rand((msak_sum,1,1,3)).to(annotation.device)
148
- transparency = torch.ones((msak_sum,1,1,1)).to(annotation.device) * 0.6
149
- visual = torch.cat([color,transparency],dim=-1)
150
- mask_image = torch.unsqueeze(annotation,-1) * visual
151
- # ๆŒ‰indexๅ–ๆ•ฐ๏ผŒindexๆŒ‡ๆฏไธชไฝ็ฝฎ้€‰ๅ“ชไธชbatch็š„ๆ•ฐ๏ผŒๆŠŠmask_image่ฝฌๆˆไธ€ไธชbatch็š„ๅฝขๅผ
152
- mask = torch.zeros((height,weight,4)).to(annotation.device)
153
- h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
154
- indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
155
- # ไฝฟ็”จๅ‘้‡ๅŒ–็ดขๅผ•ๆ›ดๆ–ฐshow็š„ๅ€ผ
156
- mask[h_indices, w_indices, :] = mask_image[indices]
157
- mask_cpu = mask.cpu().numpy()
158
- if bbox is not None:
159
- x1, y1, x2, y2 = bbox
160
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
161
- # draw point
162
- if points is not None:
163
- plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
164
- plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
165
- return mask_cpu
166
-
167
-
168
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
169
-
170
- def segment_image(input, evt: gr.SelectData=None, input_size=1024, high_visual_quality=True, iou_threshold=0.7, conf_threshold=0.25):
171
- point = (evt.index[0],evt.index[1])
172
- input_size = int(input_size) # ็กฎไฟ imgsz ๆ˜ฏๆ•ดๆ•ฐ
173
-
174
- # Thanks for the suggestion by hysts in HuggingFace.
175
- w, h = input.size
176
- scale = input_size / max(w, h)
177
- new_w = int(w * scale)
178
- new_h = int(h * scale)
179
- input = input.resize((new_w, new_h))
180
-
181
- results = model(input, device=device, retina_masks=True, iou=iou_threshold, conf=conf_threshold, imgsz=input_size)
182
- fig = fast_process(annotations=results[0].masks.data,
183
- image=input, high_quality=high_visual_quality,
184
- device=device, scale=(1024 // input_size),
185
- points=)
186
- return fig
187
-
188
- # input_size=1024
189
- # high_quality_visual=True
190
- # inp = 'assets/sa_192.jpg'
191
- # input = Image.open(inp)
192
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
193
- # input_size = int(input_size) # ็กฎไฟ imgsz ๆ˜ฏๆ•ดๆ•ฐ
194
- # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
195
- # pil_image = fast_process(annotations=results[0].masks.data,
196
- # image=input, high_quality=high_quality_visual, device=device)
197
-
198
- cond_img = gr.Image(label="Input", value=default_example[0], type='pil')
199
-
200
- segm_img = gr.Image(label="Segmented Image", interactive=False, type='pil')
201
-
202
- input_size_slider = gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='Input_size (Our model was trained on a size of 1024)')
203
-
204
- with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
205
- with gr.Row():
206
- # Title
207
- gr.Markdown(title)
208
- # # # Description
209
- # # gr.Markdown(description)
210
-
211
- # Images
212
- with gr.Row(variant="panel"):
213
- with gr.Column(scale=1):
214
- cond_img.render()
215
-
216
- with gr.Column(scale=1):
217
- segm_img.render()
218
-
219
- # Submit & Clear
220
- with gr.Row():
221
- with gr.Column():
222
- input_size_slider.render()
223
-
224
- with gr.Row():
225
- vis_check = gr.Checkbox(value=True, label='high_visual_quality')
226
-
227
- with gr.Column():
228
- segment_btn = gr.Button("Segment Anything", variant='primary')
229
-
230
- # with gr.Column():
231
- # clear_btn = gr.Button("Clear", variant="primary")
232
-
233
- gr.Markdown("Try some of the examples below โฌ‡๏ธ")
234
- gr.Examples(examples=examples,
235
- inputs=[cond_img],
236
- outputs=segm_img,
237
- fn=segment_image,
238
- cache_examples=True,
239
- examples_per_page=4)
240
- # gr.Markdown("Try some of the examples below โฌ‡๏ธ")
241
- # gr.Examples(examples=examples,
242
- # inputs=[cond_img, input_size_slider, vis_check, iou_threshold, conf_threshold],
243
- # outputs=output,
244
- # fn=segment_image,
245
- # examples_per_page=4)
246
-
247
- with gr.Column():
248
- with gr.Accordion("Advanced options", open=False):
249
- iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou_threshold')
250
- conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf_threshold')
251
-
252
- # Description
253
- gr.Markdown(description)
254
-
255
- cond_img.select(segment_image, [], input_img)
256
-
257
- segment_btn.click(segment_image,
258
- inputs=[cond_img, input_size_slider, vis_check, iou_threshold, conf_threshold],
259
- outputs=segm_img)
260
-
261
- # def clear():
262
- # return None, None
263
-
264
- # clear_btn.click(fn=clear, inputs=None, outputs=None)
265
-
266
- demo.queue()
267
- demo.launch()
268
-
269
- # app_interface = gr.Interface(fn=predict,
270
- # inputs=[gr.Image(type='pil'),
271
- # gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='input_size'),
272
- # gr.components.Checkbox(value=True, label='high_visual_quality')],
273
- # # outputs=['plot'],
274
- # outputs=gr.Image(type='pil'),
275
- # # examples=[["assets/sa_8776.jpg"]],
276
- # # # ["assets/sa_1309.jpg", 1024]],
277
- # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
278
- # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
279
- # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
280
- # ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
281
- # cache_examples=True,
282
- # title="Fast Segment Anything (Everything mode)"
283
- # )
284
-
285
-
286
- # app_interface.queue(concurrency_count=1, max_size=20)
287
- # app_interface.launch()