xinghaochen commited on
Commit
5521308
1 Parent(s): d86ac0d

update new app

Browse files
Files changed (7) hide show
  1. app.py +326 -0
  2. assets/1.jpg +0 -0
  3. assets/2.jpg +0 -0
  4. assets/3.jpg +0 -0
  5. assets/4.jpeg +0 -0
  6. assets/5.jpg +0 -0
  7. assets/6.jpeg +0 -0
app.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Code credit: [EdgeSAM Demo](https://huggingface.co/spaces/chongzhou/EdgeSAM).
2
+
3
+ import torch
4
+ import gradio as gr
5
+ import numpy as np
6
+ from tinysam import sam_model_registry, SamPredictor
7
+ from PIL import ImageDraw
8
+ from utils.tools_gradio import fast_process
9
+ import copy
10
+ import argparse
11
+
12
+ snapshot_download("merve/tinysam", local_dir="tinysam")
13
+
14
+ model_type = "vit_t"
15
+ sam = sam_model_registry[model_type](checkpoint="./tinysam/tinysam.pth")
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ sam.to(device=device)
18
+ sam.eval()
19
+ predictor = SamPredictor(sam)
20
+
21
+ examples = [
22
+ ["assets/1.jpg"],
23
+ ["assets/2.jpg"],
24
+ ["assets/3.jpg"],
25
+ ["assets/4.jpeg"],
26
+ ["assets/5.jpg"],
27
+ ["assets/6.jpeg"]
28
+ ]
29
+
30
+ # Description
31
+ title = "<center><strong><font size='8'>TinySAM<font></strong> <a href='https://github.com/xinghaochen/TinySAM'><font size='6'>[GitHub]</font></a> </center>"
32
+
33
+ description_p = """ # Instructions for point mode
34
+
35
+ 1. Upload an image or click one of the provided examples.
36
+ 2. Select the point type.
37
+ 3. Click once or multiple times on the image to indicate the object of interest.
38
+ 4. The Clear button clears all the points.
39
+ 5. The Reset button resets both points and the image.
40
+
41
+ """
42
+
43
+ description_b = """ # Instructions for box mode
44
+
45
+ 1. Upload an image or click one of the provided examples.
46
+ 2. Click twice on the image (diagonal points of the box).
47
+ 3. The Clear button clears the box.
48
+ 4. The Reset button resets both the box and the image.
49
+
50
+ """
51
+
52
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
53
+
54
+
55
+ def reset(session_state):
56
+ session_state['coord_list'] = []
57
+ session_state['label_list'] = []
58
+ session_state['box_list'] = []
59
+ session_state['ori_image'] = None
60
+ session_state['image_with_prompt'] = None
61
+ session_state['feature'] = None
62
+ return None, session_state
63
+
64
+
65
+ def reset_all(session_state):
66
+ session_state['coord_list'] = []
67
+ session_state['label_list'] = []
68
+ session_state['box_list'] = []
69
+ session_state['ori_image'] = None
70
+ session_state['image_with_prompt'] = None
71
+ session_state['feature'] = None
72
+ return None, None, session_state
73
+
74
+
75
+ def clear(session_state):
76
+ session_state['coord_list'] = []
77
+ session_state['label_list'] = []
78
+ session_state['box_list'] = []
79
+ session_state['image_with_prompt'] = copy.deepcopy(session_state['ori_image'])
80
+ return session_state['ori_image'], session_state
81
+
82
+
83
+ def on_image_upload(
84
+ image,
85
+ session_state,
86
+ input_size=1024
87
+ ):
88
+ session_state['coord_list'] = []
89
+ session_state['label_list'] = []
90
+ session_state['box_list'] = []
91
+
92
+ input_size = int(input_size)
93
+ w, h = image.size
94
+ scale = input_size / max(w, h)
95
+ new_w = int(w * scale)
96
+ new_h = int(h * scale)
97
+ image = image.resize((new_w, new_h))
98
+ session_state['ori_image'] = copy.deepcopy(image)
99
+ session_state['image_with_prompt'] = copy.deepcopy(image)
100
+ print("Image changed")
101
+ nd_image = np.array(image)
102
+ session_state['feature'] = None #predictor.set_image(nd_image)
103
+
104
+ return image, session_state
105
+
106
+
107
+ def convert_box(xyxy):
108
+ min_x = min(xyxy[0][0], xyxy[1][0])
109
+ max_x = max(xyxy[0][0], xyxy[1][0])
110
+ min_y = min(xyxy[0][1], xyxy[1][1])
111
+ max_y = max(xyxy[0][1], xyxy[1][1])
112
+ xyxy[0][0] = min_x
113
+ xyxy[1][0] = max_x
114
+ xyxy[0][1] = min_y
115
+ xyxy[1][1] = max_y
116
+ return xyxy
117
+
118
+
119
+ def segment_with_points(
120
+ label,
121
+ session_state,
122
+ evt: gr.SelectData,
123
+ input_size=1024,
124
+ better_quality=False,
125
+ withContours=True,
126
+ use_retina=True,
127
+ mask_random_color=False,
128
+ ):
129
+ x, y = evt.index[0], evt.index[1]
130
+ point_radius, point_color = 5, (97, 217, 54) if label == "Positive" else (237, 34, 13)
131
+ session_state['coord_list'].append([x, y])
132
+ session_state['label_list'].append(1 if label == "Positive" else 0)
133
+
134
+ print(f"coord_list: {session_state['coord_list']}")
135
+ print(f"label_list: {session_state['label_list']}")
136
+
137
+ draw = ImageDraw.Draw(session_state['image_with_prompt'])
138
+ draw.ellipse(
139
+ [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
140
+ fill=point_color,
141
+ )
142
+ image = session_state['image_with_prompt']
143
+
144
+ coord_np = np.array(session_state['coord_list'])
145
+ label_np = np.array(session_state['label_list'])
146
+ masks, scores, logits = predictor.predict(
147
+ point_coords=coord_np,
148
+ point_labels=label_np,
149
+ )
150
+ print(f'scores: {scores}')
151
+ area = masks.sum(axis=(1, 2))
152
+ print(f'area: {area}')
153
+
154
+ annotations = np.expand_dims(masks[scores.argmax()], axis=0)
155
+
156
+ seg = fast_process(
157
+ annotations=annotations,
158
+ image=image,
159
+ device=device,
160
+ scale=(1024 // input_size),
161
+ better_quality=better_quality,
162
+ mask_random_color=mask_random_color,
163
+ bbox=None,
164
+ use_retina=use_retina,
165
+ withContours=withContours,
166
+ )
167
+
168
+ return seg, session_state
169
+
170
+
171
+ def segment_with_box(
172
+ session_state,
173
+ evt: gr.SelectData,
174
+ input_size=1024,
175
+ better_quality=False,
176
+ withContours=True,
177
+ use_retina=True,
178
+ mask_random_color=False,
179
+ ):
180
+ x, y = evt.index[0], evt.index[1]
181
+ point_radius, point_color, box_outline = 5, (97, 217, 54), 5
182
+ box_color = (0, 255, 0)
183
+
184
+ if len(session_state['box_list']) == 0:
185
+ session_state['box_list'].append([x, y])
186
+ elif len(session_state['box_list']) == 1:
187
+ session_state['box_list'].append([x, y])
188
+ elif len(session_state['box_list']) == 2:
189
+ session_state['image_with_prompt'] = copy.deepcopy(session_state['ori_image'])
190
+ session_state['box_list'] = [[x, y]]
191
+
192
+ print(f"box_list: {session_state['box_list']}")
193
+
194
+ draw = ImageDraw.Draw(session_state['image_with_prompt'])
195
+ draw.ellipse(
196
+ [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
197
+ fill=point_color,
198
+ )
199
+ image = session_state['image_with_prompt']
200
+
201
+ if len(session_state['box_list']) == 2:
202
+ box = convert_box(session_state['box_list'])
203
+ xy = (box[0][0], box[0][1], box[1][0], box[1][1])
204
+ draw.rectangle(
205
+ xy,
206
+ outline=box_color,
207
+ width=box_outline
208
+ )
209
+
210
+ box_np = np.array(xy)
211
+ masks, scores, _ = predictor.predict(
212
+ point_coords=None,
213
+ point_labels=None,
214
+ box=box_np[None, :],
215
+ )
216
+ annotations = np.expand_dims(masks[scores.argmax()], axis=0)
217
+
218
+
219
+ seg = fast_process(
220
+ annotations=annotations,
221
+ image=image,
222
+ device=device,
223
+ scale=(1024 // input_size),
224
+ better_quality=better_quality,
225
+ mask_random_color=mask_random_color,
226
+ bbox=None,
227
+ use_retina=use_retina,
228
+ withContours=withContours,
229
+ )
230
+ return seg, session_state
231
+ return image, session_state
232
+
233
+
234
+ img_p = gr.Image(label="Input with points", type="pil")
235
+ img_b = gr.Image(label="Input with box", type="pil")
236
+
237
+ with gr.Blocks(css=css, title="EdgeSAM") as demo:
238
+ session_state = gr.State({
239
+ 'coord_list': [],
240
+ 'label_list': [],
241
+ 'box_list': [],
242
+ 'ori_image': None,
243
+ 'image_with_prompt': None,
244
+ 'feature': None
245
+ })
246
+
247
+ with gr.Row():
248
+ with gr.Column(scale=1):
249
+ # Title
250
+ gr.Markdown(title)
251
+
252
+ with gr.Tab("Point mode") as tab_p:
253
+ # Images
254
+ with gr.Row(variant="panel"):
255
+ with gr.Column(scale=1):
256
+ img_p.render()
257
+ with gr.Column(scale=1):
258
+ with gr.Row():
259
+ add_or_remove = gr.Radio(
260
+ ["Positive", "Negative"],
261
+ value="Positive",
262
+ label="Point Type"
263
+ )
264
+
265
+ with gr.Column():
266
+ clear_btn_p = gr.Button("Clear", variant="secondary")
267
+ reset_btn_p = gr.Button("Reset", variant="secondary")
268
+ with gr.Row():
269
+ gr.Markdown(description_p)
270
+
271
+ with gr.Row():
272
+ with gr.Column():
273
+ gr.Markdown("Try some of the examples below ⬇️")
274
+ gr.Examples(
275
+ examples=examples,
276
+ inputs=[img_p, session_state],
277
+ outputs=[img_p, session_state],
278
+ examples_per_page=8,
279
+ fn=on_image_upload,
280
+ run_on_click=True
281
+ )
282
+
283
+ with gr.Tab("Box mode") as tab_b:
284
+ # Images
285
+ with gr.Row(variant="panel"):
286
+ with gr.Column(scale=1):
287
+ img_b.render()
288
+ with gr.Row():
289
+ with gr.Column():
290
+ clear_btn_b = gr.Button("Clear", variant="secondary")
291
+ reset_btn_b = gr.Button("Reset", variant="secondary")
292
+ gr.Markdown(description_b)
293
+
294
+ with gr.Row():
295
+ with gr.Column():
296
+ gr.Markdown("Try some of the examples below ⬇️")
297
+ gr.Examples(
298
+ examples=examples,
299
+ inputs=[img_b, session_state],
300
+ outputs=[img_b, session_state],
301
+ examples_per_page=8,
302
+ fn=on_image_upload,
303
+ run_on_click=True
304
+ )
305
+
306
+ with gr.Row():
307
+ with gr.Column(scale=1):
308
+ gr.Markdown(
309
+ "<center><img src='https://visitor-badge.laobi.icu/badge?page_id=chongzhou/edgesam' alt='visitors'></center>")
310
+
311
+ img_p.upload(on_image_upload, [img_p, session_state], [img_p, session_state])
312
+ img_p.select(segment_with_points, [add_or_remove, session_state], [img_p, session_state])
313
+
314
+ clear_btn_p.click(clear, [session_state], [img_p, session_state])
315
+ reset_btn_p.click(reset, [session_state], [img_p, session_state])
316
+ tab_p.select(fn=reset_all, inputs=[session_state], outputs=[img_p, img_b, session_state])
317
+
318
+ img_b.upload(on_image_upload, [img_b, session_state], [img_b, session_state])
319
+ img_b.select(segment_with_box, [session_state], [img_b, session_state])
320
+
321
+ clear_btn_b.click(clear, [session_state], [img_b, session_state])
322
+ reset_btn_b.click(reset, [session_state], [img_b, session_state])
323
+ tab_b.select(fn=reset_all, inputs=[session_state], outputs=[img_p, img_b, session_state])
324
+
325
+ demo.queue()
326
+ demo.launch()
assets/1.jpg ADDED
assets/2.jpg ADDED
assets/3.jpg ADDED
assets/4.jpeg ADDED
assets/5.jpg ADDED
assets/6.jpeg ADDED