jianyouli commited on
Commit
6d443fe
·
1 Parent(s): 327fcfc

Add application file0

Browse files
Files changed (2) hide show
  1. app.py +329 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ import numpy as np
5
+ import torch
6
+ from mobile_sam import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
7
+ from PIL import ImageDraw
8
+ from utils.tools import box_prompt, format_results, point_prompt
9
+ from utils.tools_gradio import fast_process
10
+
11
+ # Most of our demo code is from [FastSAM Demo](https://huggingface.co/spaces/An-619/FastSAM). Huge thanks for AN-619.
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ # Load the pre-trained model
16
+ sam_checkpoint = r"F:\zht\code\MobileSAM-master\weights\mobile_sam.pt"
17
+ model_type = "vit_t"
18
+
19
+ mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
20
+ mobile_sam = mobile_sam.to(device=device)
21
+ mobile_sam.eval()
22
+
23
+ mask_generator = SamAutomaticMaskGenerator(mobile_sam)
24
+ predictor = SamPredictor(mobile_sam)
25
+
26
+ # Description
27
+ title = "<center><strong><font size='8'>Faster Segment Anything(MobileSAM)<font></strong></center>"
28
+
29
+ description_e = """This is a demo of [Faster Segment Anything(MobileSAM) Model](https://github.com/ChaoningZhang/MobileSAM).
30
+
31
+ We will provide box mode soon.
32
+
33
+ Enjoy!
34
+
35
+ """
36
+
37
+ description_p = """ # Instructions for point mode
38
+
39
+ 0. Restart by click the Restart button
40
+ 1. Select a point with Add Mask for the foreground (Must)
41
+ 2. Select a point with Remove Area for the background (Optional)
42
+ 3. Click the Start Segmenting.
43
+
44
+ """
45
+
46
+ examples = [
47
+ ["assets/picture3.jpg"],
48
+ ["assets/picture4.jpg"],
49
+ ["assets/picture5.jpg"],
50
+ ["assets/picture6.jpg"],
51
+ ["assets/picture1.jpg"],
52
+ ["assets/picture2.jpg"],
53
+ ]
54
+
55
+ default_example = examples[0]
56
+
57
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
58
+
59
+
60
+ @torch.no_grad()
61
+ def segment_everything(
62
+ image,
63
+ input_size=1024,
64
+ better_quality=False,
65
+ withContours=True,
66
+ use_retina=True,
67
+ mask_random_color=True,
68
+ ):
69
+ global mask_generator
70
+
71
+ input_size = int(input_size)
72
+ w, h = image.size
73
+ scale = input_size / max(w, h)
74
+ new_w = int(w * scale)
75
+ new_h = int(h * scale)
76
+ image = image.resize((new_w, new_h))
77
+
78
+ nd_image = np.array(image)
79
+ annotations = mask_generator.generate(nd_image)
80
+
81
+ fig = fast_process(
82
+ annotations=annotations,
83
+ image=image,
84
+ device=device,
85
+ scale=(1024 // input_size),
86
+ better_quality=better_quality,
87
+ mask_random_color=mask_random_color,
88
+ bbox=None,
89
+ use_retina=use_retina,
90
+ withContours=withContours,
91
+ )
92
+ return fig
93
+
94
+
95
+ def segment_with_points(
96
+ image,
97
+ input_size=1024,
98
+ better_quality=False,
99
+ withContours=True,
100
+ use_retina=True,
101
+ mask_random_color=True,
102
+ ):
103
+ global global_points
104
+ global global_point_label
105
+
106
+ input_size = int(input_size)
107
+ w, h = image.size
108
+ scale = input_size / max(w, h)
109
+ new_w = int(w * scale)
110
+ new_h = int(h * scale)
111
+ image = image.resize((new_w, new_h))
112
+
113
+ scaled_points = np.array(
114
+ [[int(x * scale) for x in point] for point in global_points]
115
+ )
116
+ scaled_point_label = np.array(global_point_label)
117
+
118
+ if scaled_points.size == 0 and scaled_point_label.size == 0:
119
+ print("No points selected")
120
+ return image, image
121
+
122
+ print(scaled_points, scaled_points is not None)
123
+ print(scaled_point_label, scaled_point_label is not None)
124
+
125
+ nd_image = np.array(image)
126
+ predictor.set_image(nd_image)
127
+ masks, scores, logits = predictor.predict(
128
+ point_coords=scaled_points,
129
+ point_labels=scaled_point_label,
130
+ multimask_output=True,
131
+ )
132
+
133
+ results = format_results(masks, scores, logits, 0)
134
+
135
+ annotations, _ = point_prompt(
136
+ results, scaled_points, scaled_point_label, new_h, new_w
137
+ )
138
+ annotations = np.array([annotations])
139
+
140
+ fig = fast_process(
141
+ annotations=annotations,
142
+ image=image,
143
+ device=device,
144
+ scale=(1024 // input_size),
145
+ better_quality=better_quality,
146
+ mask_random_color=mask_random_color,
147
+ bbox=None,
148
+ use_retina=use_retina,
149
+ withContours=withContours,
150
+ )
151
+
152
+ global_points = []
153
+ global_point_label = []
154
+ # return fig, None
155
+ return fig, image
156
+
157
+
158
+ def get_points_with_draw(image, label, evt: gr.SelectData):
159
+ global global_points
160
+ global global_point_label
161
+
162
+ x, y = evt.index[0], evt.index[1]
163
+ point_radius, point_color = 15, (255, 255, 0) if label == "Add Mask" else (
164
+ 255,
165
+ 0,
166
+ 255,
167
+ )
168
+ global_points.append([x, y])
169
+ global_point_label.append(1 if label == "Add Mask" else 0)
170
+
171
+ print(x, y, label == "Add Mask")
172
+
173
+ # 创建一个可以在图像上绘图的对象
174
+ draw = ImageDraw.Draw(image)
175
+ draw.ellipse(
176
+ [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
177
+ fill=point_color,
178
+ )
179
+ return image
180
+
181
+
182
+ cond_img_e = gr.Image(label="Input", value=default_example[0], type="pil")
183
+ cond_img_p = gr.Image(label="Input with points", value=default_example[0], type="pil")
184
+
185
+ segm_img_e = gr.Image(label="Segmented Image", interactive=False, type="pil")
186
+ segm_img_p = gr.Image(
187
+ label="Segmented Image with points", interactive=False, type="pil"
188
+ )
189
+
190
+ global_points = []
191
+ global_point_label = []
192
+
193
+ input_size_slider = gr.components.Slider(
194
+ minimum=512,
195
+ maximum=1024,
196
+ value=1024,
197
+ step=64,
198
+ label="Input_size",
199
+ info="Our model was trained on a size of 1024",
200
+ )
201
+
202
+ with gr.Blocks(css=css, title="Faster Segment Anything(MobileSAM)") as demo:
203
+ with gr.Row():
204
+ with gr.Column(scale=1):
205
+ # Title
206
+ gr.Markdown(title)
207
+
208
+ # with gr.Tab("Everything mode"):
209
+ # # Images
210
+ # with gr.Row(variant="panel"):
211
+ # with gr.Column(scale=1):
212
+ # cond_img_e.render()
213
+ #
214
+ # with gr.Column(scale=1):
215
+ # segm_img_e.render()
216
+ #
217
+ # # Submit & Clear
218
+ # with gr.Row():
219
+ # with gr.Column():
220
+ # input_size_slider.render()
221
+ #
222
+ # with gr.Row():
223
+ # contour_check = gr.Checkbox(
224
+ # value=True,
225
+ # label="withContours",
226
+ # info="draw the edges of the masks",
227
+ # )
228
+ #
229
+ # with gr.Column():
230
+ # segment_btn_e = gr.Button(
231
+ # "Segment Everything", variant="primary"
232
+ # )
233
+ # clear_btn_e = gr.Button("Clear", variant="secondary")
234
+ #
235
+ # gr.Markdown("Try some of the examples below ⬇️")
236
+ # gr.Examples(
237
+ # examples=examples,
238
+ # inputs=[cond_img_e],
239
+ # outputs=segm_img_e,
240
+ # fn=segment_everything,
241
+ # cache_examples=True,
242
+ # examples_per_page=4,
243
+ # )
244
+ #
245
+ # with gr.Column():
246
+ # with gr.Accordion("Advanced options", open=False):
247
+ # # text_box = gr.Textbox(label="text prompt")
248
+ # with gr.Row():
249
+ # mor_check = gr.Checkbox(
250
+ # value=False,
251
+ # label="better_visual_quality",
252
+ # info="better quality using morphologyEx",
253
+ # )
254
+ # with gr.Column():
255
+ # retina_check = gr.Checkbox(
256
+ # value=True,
257
+ # label="use_retina",
258
+ # info="draw high-resolution segmentation masks",
259
+ # )
260
+ # # Description
261
+ # gr.Markdown(description_e)
262
+ #
263
+ with gr.Tab("Point mode"):
264
+ # Images
265
+ with gr.Row(variant="panel"):
266
+ with gr.Column(scale=1):
267
+ cond_img_p.render()
268
+
269
+ with gr.Column(scale=1):
270
+ segm_img_p.render()
271
+
272
+ # Submit & Clear
273
+ with gr.Row():
274
+ with gr.Column():
275
+ with gr.Row():
276
+ add_or_remove = gr.Radio(
277
+ ["Add Mask", "Remove Area"],
278
+ value="Add Mask",
279
+ )
280
+
281
+ with gr.Column():
282
+ segment_btn_p = gr.Button(
283
+ "Start segmenting!", variant="primary"
284
+ )
285
+ clear_btn_p = gr.Button("Restart", variant="secondary")
286
+
287
+ gr.Markdown("Try some of the examples below ⬇️")
288
+ gr.Examples(
289
+ examples=examples,
290
+ inputs=[cond_img_p],
291
+ # outputs=segm_img_p,
292
+ # fn=segment_with_points,
293
+ # cache_examples=True,
294
+ examples_per_page=4,
295
+ )
296
+
297
+ with gr.Column():
298
+ # Description
299
+ gr.Markdown(description_p)
300
+
301
+ cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
302
+
303
+ # segment_btn_e.click(
304
+ # segment_everything,
305
+ # inputs=[
306
+ # cond_img_e,
307
+ # input_size_slider,
308
+ # mor_check,
309
+ # contour_check,
310
+ # retina_check,
311
+ # ],
312
+ # outputs=segm_img_e,
313
+ # )
314
+
315
+ segment_btn_p.click(
316
+ segment_with_points, inputs=[cond_img_p], outputs=[segm_img_p, cond_img_p]
317
+ )
318
+
319
+ def clear():
320
+ return None, None
321
+
322
+ def clear_text():
323
+ return None, None, None
324
+
325
+ # clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
326
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
327
+
328
+ demo.queue()
329
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ opencv-python
5
+ git+https://github.com/dhkim2810/MobileSAM.git