xinghaochen commited on
Commit
7d0a798
1 Parent(s): 507e5e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +310 -213
app.py CHANGED
@@ -1,14 +1,18 @@
1
- # Code credit: [EdgeSAM Demo](https://huggingface.co/spaces/chongzhou/EdgeSAM).
 
 
 
2
 
3
- import torch
4
  import gradio as gr
5
- from huggingface_hub import snapshot_download
6
  import numpy as np
7
- from tinysam import sam_model_registry, SamPredictor
8
  from PIL import ImageDraw
 
 
 
9
  from utils.tools_gradio import fast_process
10
- import copy
11
- import argparse
12
 
13
  snapshot_download("merve/tinysam", local_dir="tinysam")
14
 
@@ -19,6 +23,28 @@ sam.to(device=device)
19
  sam.eval()
20
  predictor = SamPredictor(sam)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  examples = [
23
  ["assets/1.jpg"],
24
  ["assets/2.jpg"],
@@ -28,67 +54,126 @@ examples = [
28
  ["assets/6.jpeg"]
29
  ]
30
 
31
- # Description
32
- title = "<center><strong><font size='8'>TinySAM<font></strong> <a href='https://github.com/xinghaochen/TinySAM'><font size='6'>[GitHub]</font></a> </center>"
33
-
34
- description_p = """ # Instructions for point mode
35
 
36
- 1. Upload an image or click one of the provided examples.
37
- 2. Select the point type.
38
- 3. Click once or multiple times on the image to indicate the object of interest.
39
- 4. The Clear button clears all the points.
40
- 5. The Reset button resets both points and the image.
41
-
42
- """
43
 
44
- description_b = """ # Instructions for box mode
45
 
46
- 1. Upload an image or click one of the provided examples.
47
- 2. Click twice on the image (diagonal points of the box).
48
- 3. The Clear button clears the box.
49
- 4. The Reset button resets both the box and the image.
 
 
 
 
 
 
 
 
 
 
50
 
51
- """
 
 
 
 
 
52
 
53
- css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
 
54
 
 
 
 
 
 
55
 
56
- def reset(session_state):
57
- session_state['coord_list'] = []
58
- session_state['label_list'] = []
59
- session_state['box_list'] = []
60
- session_state['ori_image'] = None
61
- session_state['image_with_prompt'] = None
62
- session_state['feature'] = None
63
- return None, session_state
64
 
 
 
 
65
 
66
- def reset_all(session_state):
67
- session_state['coord_list'] = []
68
- session_state['label_list'] = []
69
- session_state['box_list'] = []
70
- session_state['ori_image'] = None
71
- session_state['image_with_prompt'] = None
72
- session_state['feature'] = None
73
- return None, None, session_state
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
- def clear(session_state):
77
- session_state['coord_list'] = []
78
- session_state['label_list'] = []
79
- session_state['box_list'] = []
80
- session_state['image_with_prompt'] = copy.deepcopy(session_state['ori_image'])
81
- return session_state['ori_image'], session_state
82
 
83
 
84
- def on_image_upload(
85
  image,
86
- session_state,
87
- input_size=1024
 
 
 
 
 
88
  ):
89
- session_state['coord_list'] = []
90
- session_state['label_list'] = []
91
- session_state['box_list'] = []
92
 
93
  input_size = int(input_size)
94
  w, h = image.size
@@ -96,232 +181,244 @@ def on_image_upload(
96
  new_w = int(w * scale)
97
  new_h = int(h * scale)
98
  image = image.resize((new_w, new_h))
99
- session_state['ori_image'] = copy.deepcopy(image)
100
- session_state['image_with_prompt'] = copy.deepcopy(image)
101
- print("Image changed")
102
- nd_image = np.array(image)
103
- session_state['feature'] = predictor.set_image(nd_image)
104
 
105
- return image, session_state
 
106
 
 
 
 
 
 
 
 
 
107
 
108
- def convert_box(xyxy):
109
- min_x = min(xyxy[0][0], xyxy[1][0])
110
- max_x = max(xyxy[0][0], xyxy[1][0])
111
- min_y = min(xyxy[0][1], xyxy[1][1])
112
- max_y = max(xyxy[0][1], xyxy[1][1])
113
- xyxy[0][0] = min_x
114
- xyxy[1][0] = max_x
115
- xyxy[0][1] = min_y
116
- xyxy[1][1] = max_y
117
- return xyxy
118
 
 
 
 
119
 
120
- def segment_with_points(
121
- label,
122
- session_state,
123
- evt: gr.SelectData,
124
- input_size=1024,
125
- better_quality=False,
126
- withContours=True,
127
- use_retina=True,
128
- mask_random_color=False,
129
- ):
130
- x, y = evt.index[0], evt.index[1]
131
- print(f"x y: {x,y}")
132
- point_radius, point_color = 5, (97, 217, 54) if label == "Positive" else (237, 34, 13)
133
- session_state['coord_list'].append([x, y])
134
- session_state['label_list'].append(1 if label == "Positive" else 0)
135
-
136
- print(f"coord_list: {session_state['coord_list']}")
137
- print(f"label_list: {session_state['label_list']}")
138
-
139
- draw = ImageDraw.Draw(session_state['image_with_prompt'])
140
- draw.ellipse(
141
- [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
142
- fill=point_color,
143
- )
144
- image = session_state['image_with_prompt']
145
-
146
- coord_np = np.array(session_state['coord_list'])
147
- label_np = np.array(session_state['label_list'])
148
  masks, scores, logits = predictor.predict(
149
- point_coords=coord_np,
150
- point_labels=label_np,
151
  )
152
  print(f'scores: {scores}')
153
  area = masks.sum(axis=(1, 2))
154
  print(f'area: {area}')
155
-
156
  annotations = np.expand_dims(masks[scores.argmax()], axis=0)
157
 
158
- seg = fast_process(
159
  annotations=annotations,
160
  image=image,
161
  device=device,
162
  scale=(1024 // input_size),
163
  better_quality=better_quality,
164
  mask_random_color=mask_random_color,
 
165
  bbox=None,
166
  use_retina=use_retina,
167
  withContours=withContours,
168
  )
169
 
170
- return seg, session_state
 
 
 
171
 
172
 
173
- def segment_with_box(
174
- session_state,
175
- evt: gr.SelectData,
176
- input_size=1024,
177
- better_quality=False,
178
- withContours=True,
179
- use_retina=True,
180
- mask_random_color=False,
181
- ):
182
  x, y = evt.index[0], evt.index[1]
183
- point_radius, point_color, box_outline = 5, (97, 217, 54), 5
184
- box_color = (0, 255, 0)
185
-
186
- if len(session_state['box_list']) == 0:
187
- session_state['box_list'].append([x, y])
188
- elif len(session_state['box_list']) == 1:
189
- session_state['box_list'].append([x, y])
190
- elif len(session_state['box_list']) == 2:
191
- session_state['image_with_prompt'] = copy.deepcopy(session_state['ori_image'])
192
- session_state['box_list'] = [[x, y]]
193
-
194
- print(f"box_list: {session_state['box_list']}")
195
-
196
- draw = ImageDraw.Draw(session_state['image_with_prompt'])
197
- draw.ellipse(
198
- [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
199
- fill=point_color,
200
  )
201
- image = session_state['image_with_prompt']
202
-
203
- if len(session_state['box_list']) == 2:
204
- box = convert_box(session_state['box_list'])
205
- xy = (box[0][0], box[0][1], box[1][0], box[1][1])
206
- draw.rectangle(
207
- xy,
208
- outline=box_color,
209
- width=box_outline
210
- )
211
 
212
- box_np = np.array(xy)
213
- masks, scores, _ = predictor.predict(
214
- point_coords=None,
215
- point_labels=None,
216
- box=box_np[None, :],
217
- )
218
- annotations = np.expand_dims(masks[scores.argmax()], axis=0)
219
-
220
- seg = fast_process(
221
- annotations=annotations,
222
- image=image,
223
- device=device,
224
- scale=(1024 // input_size),
225
- better_quality=better_quality,
226
- mask_random_color=mask_random_color,
227
- bbox=None,
228
- use_retina=use_retina,
229
- withContours=withContours,
230
  )
231
- return seg, session_state
232
- return image, session_state
233
 
 
234
 
235
- img_p = gr.Image(label="Input with points", type="pil", interactive=True)
236
- img_b = gr.Image(label="Input with box", type="pil", interactive=True)
237
 
238
- with gr.Blocks(css=css, title="TinySAM") as demo:
239
- session_state = gr.State({
240
- 'coord_list': [],
241
- 'label_list': [],
242
- 'box_list': [],
243
- 'ori_image': None,
244
- 'image_with_prompt': None,
245
- 'feature': None
246
- })
 
 
 
 
 
 
 
 
 
247
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  with gr.Row():
249
  with gr.Column(scale=1):
250
  # Title
251
  gr.Markdown(title)
252
 
253
- with gr.Tab("Point mode") as tab_p:
254
  # Images
255
  with gr.Row(variant="panel"):
256
  with gr.Column(scale=1):
257
- img_p.render()
258
- with gr.Column(scale=1):
259
- with gr.Row():
260
- add_or_remove = gr.Radio(
261
- ["Positive", "Negative"],
262
- value="Positive",
263
- label="Point Type"
264
- )
265
 
266
- with gr.Column():
267
- clear_btn_p = gr.Button("Clear", variant="secondary")
268
- reset_btn_p = gr.Button("Reset", variant="secondary")
269
- with gr.Row():
270
- gr.Markdown(description_p)
271
 
 
 
272
  with gr.Row():
273
  with gr.Column():
 
 
 
 
 
 
 
274
  gr.Markdown("Try some of the examples below ⬇️")
275
  gr.Examples(
276
  examples=examples,
277
- inputs=[img_p, session_state],
278
- outputs=[img_p, session_state],
279
- examples_per_page=8,
280
- fn=on_image_upload,
281
- run_on_click=True
282
  )
283
 
284
- with gr.Tab("Box mode") as tab_b:
 
 
 
 
285
  # Images
286
  with gr.Row(variant="panel"):
287
  with gr.Column(scale=1):
288
- img_b.render()
289
- with gr.Row():
290
- with gr.Column():
291
- clear_btn_b = gr.Button("Clear", variant="secondary")
292
- reset_btn_b = gr.Button("Reset", variant="secondary")
293
- gr.Markdown(description_b)
294
 
 
 
 
 
295
  with gr.Row():
296
  with gr.Column():
 
 
 
 
 
 
 
297
  gr.Markdown("Try some of the examples below ⬇️")
298
  gr.Examples(
299
  examples=examples,
300
- inputs=[img_b, session_state],
301
- outputs=[img_b, session_state],
302
- examples_per_page=8,
303
- fn=on_image_upload,
304
- run_on_click=True
305
  )
306
 
307
- with gr.Row():
308
- with gr.Column(scale=1):
309
- gr.Markdown(
310
- "<center><img src='https://visitor-badge.laobi.icu/badge?page_id=xinghaochen.tinysam' alt='visitors'></center>")
 
311
 
312
- img_p.upload(on_image_upload, [img_p, session_state], [img_p, session_state])
313
- img_p.select(segment_with_points, [add_or_remove, session_state], [img_p, session_state])
314
 
315
- clear_btn_p.click(clear, [session_state], [img_p, session_state])
316
- reset_btn_p.click(reset, [session_state], [img_p, session_state])
317
- tab_p.select(fn=reset_all, inputs=[session_state], outputs=[img_p, img_b, session_state])
 
 
 
 
318
 
319
- img_b.upload(on_image_upload, [img_b, session_state], [img_b, session_state])
320
- img_b.select(segment_with_box, [session_state], [img_b, session_state])
321
 
322
- clear_btn_b.click(clear, [session_state], [img_b, session_state])
323
- reset_btn_b.click(reset, [session_state], [img_b, session_state])
324
- tab_b.select(fn=reset_all, inputs=[session_state], outputs=[img_p, img_b, session_state])
325
 
326
  demo.queue()
327
  demo.launch()
 
1
+ # Code credit: [EdgeSAM Demo](https://huggingface.co/spaces/yunyangx/EfficientSAM).
2
+
3
+ import copy
4
+ import os # noqa
5
 
 
6
  import gradio as gr
 
7
  import numpy as np
8
+ import torch
9
  from PIL import ImageDraw
10
+ from torchvision.transforms import ToTensor
11
+
12
+ from utils.tools import format_results, point_prompt
13
  from utils.tools_gradio import fast_process
14
+ from tinysam import sam_model_registry, SamPredictor
15
+
16
 
17
  snapshot_download("merve/tinysam", local_dir="tinysam")
18
 
 
23
  sam.eval()
24
  predictor = SamPredictor(sam)
25
 
26
+ # Description
27
+ title = "<center><strong><font size='8'>TinySAM<font></strong> <a href='https://github.com/xinghaochen/TinySAM'><font size='6'>[GitHub]</font></a> </center>"
28
+
29
+ description_e = """This is a demo of TinySAM Model](https://github.com/xinghaochen/TinySAM).
30
+ """
31
+
32
+ description_p = """# Interactive Instance Segmentation
33
+ - Point-prompt instruction
34
+ <ol>
35
+ <li> Click on the left image (point input), visualizing the point on the right image </li>
36
+ <li> Click the button of Segment with Point Prompt </li>
37
+ </ol>
38
+ - Box-prompt instruction
39
+ <ol>
40
+ <li> Click on the left image (one point input), visualizing the point on the right image </li>
41
+ <li> Click on the left image (another point input), visualizing the point and the box on the right image</li>
42
+ <li> Click the button of Segment with Box Prompt </li>
43
+ </ol>
44
+ - Github [link](https://github.com/xinghaochen/TinySAM)
45
+ """
46
+
47
+ # examples
48
  examples = [
49
  ["assets/1.jpg"],
50
  ["assets/2.jpg"],
 
54
  ["assets/6.jpeg"]
55
  ]
56
 
57
+ default_example = examples[0]
 
 
 
58
 
59
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
 
 
 
 
 
 
60
 
 
61
 
62
+ def segment_with_boxs(
63
+ image,
64
+ seg_image,
65
+ global_points,
66
+ global_point_label,
67
+ input_size=1024,
68
+ better_quality=False,
69
+ withContours=True,
70
+ use_retina=True,
71
+ mask_random_color=True,
72
+ ):
73
+ if len(global_points) < 2:
74
+ return seg_image, global_points, global_point_label
75
+ print("Original Image : ", image.size)
76
 
77
+ input_size = int(input_size)
78
+ w, h = image.size
79
+ scale = input_size / max(w, h)
80
+ new_w = int(w * scale)
81
+ new_h = int(h * scale)
82
+ image = image.resize((new_w, new_h))
83
 
84
+ print("Scaled Image : ", image.size)
85
+ print("Scale : ", scale)
86
 
87
+ scaled_points = np.array(
88
+ [[int(x * scale) for x in point] for point in global_points]
89
+ )
90
+ scaled_points = scaled_points[:2]
91
+ scaled_point_label = np.array(global_point_label)[:2]
92
 
93
+ print(scaled_points, scaled_points is not None)
94
+ print(scaled_point_label, scaled_point_label is not None)
 
 
 
 
 
 
95
 
96
+ if scaled_points.size == 0 and scaled_point_label.size == 0:
97
+ print("No points selected")
98
+ return image, global_points, global_point_label
99
 
100
+ nd_image = np.array(image)
101
+ img_tensor = ToTensor()(nd_image)
 
 
 
 
 
 
102
 
103
+ coord_np = np.array(session_state['coord_list'])
104
+ label_np = np.array(session_state['label_list'])
105
+ masks, scores, logits = predictor.predict(
106
+ point_coords=coord_np,
107
+ point_labels=label_np,
108
+ )
109
+ print(f'scores: {scores}')
110
+ area = masks.sum(axis=(1, 2))
111
+ print(f'area: {area}')
112
+ annotations = np.expand_dims(masks[scores.argmax()], axis=0)
113
+
114
+ print(img_tensor.shape)
115
+ pts_sampled = torch.reshape(torch.tensor(scaled_points), [1, 1, -1, 2])
116
+ pts_sampled = pts_sampled[:, :, :2, :]
117
+ pts_labels = torch.reshape(torch.tensor([2, 3]), [1, 1, 2])
118
+
119
+ predicted_logits, predicted_iou = model(
120
+ img_tensor[None, ...].to(device),
121
+ pts_sampled.to(device),
122
+ pts_labels.to(device),
123
+ )
124
+ predicted_logits = predicted_logits.cpu()
125
+ all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
126
+ predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
127
+
128
+
129
+ max_predicted_iou = -1
130
+ selected_mask_using_predicted_iou = None
131
+ selected_predicted_iou = None
132
+
133
+ for m in range(all_masks.shape[0]):
134
+ curr_predicted_iou = predicted_iou[m]
135
+ if (
136
+ curr_predicted_iou > max_predicted_iou
137
+ or selected_mask_using_predicted_iou is None
138
+ ):
139
+ max_predicted_iou = curr_predicted_iou
140
+ selected_mask_using_predicted_iou = all_masks[m:m+1]
141
+ selected_predicted_iou = predicted_iou[m:m+1]
142
+
143
+ results = format_results(selected_mask_using_predicted_iou, selected_predicted_iou, predicted_logits, 0)
144
+
145
+ annotations = results[0]["segmentation"]
146
+ annotations = np.array([annotations])
147
+ print(scaled_points.shape)
148
+ fig = fast_process(
149
+ annotations=annotations,
150
+ image=image,
151
+ device=device,
152
+ scale=(1024 // input_size),
153
+ better_quality=better_quality,
154
+ mask_random_color=mask_random_color,
155
+ use_retina=use_retina,
156
+ bbox = scaled_points.reshape([4]),
157
+ withContours=withContours,
158
+ )
159
 
160
+ global_points = []
161
+ global_point_label = []
162
+ # return fig, None
163
+ return fig, global_points, global_point_label
 
 
164
 
165
 
166
+ def segment_with_points(
167
  image,
168
+ global_points,
169
+ global_point_label,
170
+ input_size=1024,
171
+ better_quality=False,
172
+ withContours=True,
173
+ use_retina=True,
174
+ mask_random_color=True,
175
  ):
176
+ print("Original Image : ", image.size)
 
 
177
 
178
  input_size = int(input_size)
179
  w, h = image.size
 
181
  new_w = int(w * scale)
182
  new_h = int(h * scale)
183
  image = image.resize((new_w, new_h))
 
 
 
 
 
184
 
185
+ print("Scaled Image : ", image.size)
186
+ print("Scale : ", scale)
187
 
188
+ if global_points is None:
189
+ return image, global_points, global_point_label
190
+ if len(global_points) < 1:
191
+ return image, global_points, global_point_label
192
+ scaled_points = np.array(
193
+ [[int(x * scale) for x in point] for point in global_points]
194
+ )
195
+ scaled_point_label = np.array(global_point_label)
196
 
197
+ print(scaled_points, scaled_points is not None)
198
+ print(scaled_point_label, scaled_point_label is not None)
 
 
 
 
 
 
 
 
199
 
200
+ if scaled_points.size == 0 and scaled_point_label.size == 0:
201
+ print("No points selected")
202
+ return image, global_points, global_point_label
203
 
204
+ nd_image = np.array(image)
205
+ img_tensor = ToTensor()(nd_image)
206
+ print(img_tensor.shape)
207
+
208
+ predictor.set_image(nd_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  masks, scores, logits = predictor.predict(
210
+ point_coords=scaled_points,
211
+ point_labels=global_point_label,
212
  )
213
  print(f'scores: {scores}')
214
  area = masks.sum(axis=(1, 2))
215
  print(f'area: {area}')
 
216
  annotations = np.expand_dims(masks[scores.argmax()], axis=0)
217
 
218
+ fig = fast_process(
219
  annotations=annotations,
220
  image=image,
221
  device=device,
222
  scale=(1024 // input_size),
223
  better_quality=better_quality,
224
  mask_random_color=mask_random_color,
225
+ points = scaled_points,
226
  bbox=None,
227
  use_retina=use_retina,
228
  withContours=withContours,
229
  )
230
 
231
+ global_points = []
232
+ global_point_label = []
233
+ # return fig, None
234
+ return fig, global_points, global_point_label
235
 
236
 
237
+ def get_points_with_draw(image, cond_image, global_points, global_point_label, evt: gr.SelectData):
238
+ print("Starting functioning")
239
+ if len(global_points) == 0:
240
+ image = copy.deepcopy(cond_image)
 
 
 
 
 
241
  x, y = evt.index[0], evt.index[1]
242
+ label = "Add Mask"
243
+ point_radius, point_color = 15, (255, 255, 0) if label == "Add Mask" else (
244
+ 255,
245
+ 0,
246
+ 255,
 
 
 
 
 
 
 
 
 
 
 
 
247
  )
248
+ global_points.append([x, y])
249
+ global_point_label.append(1 if label == "Add Mask" else 0)
 
 
 
 
 
 
 
 
250
 
251
+ print(x, y, label == "Add Mask")
252
+
253
+ if image is not None:
254
+ draw = ImageDraw.Draw(image)
255
+
256
+ draw.ellipse(
257
+ [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
258
+ fill=point_color,
 
 
 
 
 
 
 
 
 
 
259
  )
 
 
260
 
261
+ return image, global_points, global_point_label
262
 
263
+ def get_points_with_draw_(image, cond_image, global_points, global_point_label, evt: gr.SelectData):
 
264
 
265
+ if len(global_points) == 0:
266
+ image = copy.deepcopy(cond_image)
267
+ if len(global_points) > 2:
268
+ return image, global_points, global_point_label
269
+ x, y = evt.index[0], evt.index[1]
270
+ label = "Add Mask"
271
+ point_radius, point_color = 15, (255, 255, 0) if label == "Add Mask" else (
272
+ 255,
273
+ 0,
274
+ 255,
275
+ )
276
+ global_points.append([x, y])
277
+ global_point_label.append(1 if label == "Add Mask" else 0)
278
+
279
+ print(x, y, label == "Add Mask")
280
+
281
+ if image is not None:
282
+ draw = ImageDraw.Draw(image)
283
 
284
+ draw.ellipse(
285
+ [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
286
+ fill=point_color,
287
+ )
288
+
289
+ if len(global_points) == 2:
290
+ x1, y1 = global_points[0]
291
+ x2, y2 = global_points[1]
292
+ if x1 < x2 and y1 < y2:
293
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=5)
294
+ elif x1 < x2 and y1 >= y2:
295
+ draw.rectangle([x1, y2, x2, y1], outline="red", width=5)
296
+ global_points[0][0] = x1
297
+ global_points[0][1] = y2
298
+ global_points[1][0] = x2
299
+ global_points[1][1] = y1
300
+ elif x1 >= x2 and y1 < y2:
301
+ draw.rectangle([x2, y1, x1, y2], outline="red", width=5)
302
+ global_points[0][0] = x2
303
+ global_points[0][1] = y1
304
+ global_points[1][0] = x1
305
+ global_points[1][1] = y2
306
+ elif x1 >= x2 and y1 >= y2:
307
+ draw.rectangle([x2, y2, x1, y1], outline="red", width=5)
308
+ global_points[0][0] = x2
309
+ global_points[0][1] = y2
310
+ global_points[1][0] = x1
311
+ global_points[1][1] = y1
312
+
313
+ return image, global_points, global_point_label
314
+
315
+
316
+ cond_img_p = gr.Image(label="Input with Point", value=default_example[0], type="pil")
317
+ cond_img_b = gr.Image(label="Input with Box", value=default_example[0], type="pil")
318
+
319
+ segm_img_p = gr.Image(
320
+ label="Segmented Image with Point-Prompt", interactive=False, type="pil"
321
+ )
322
+ segm_img_b = gr.Image(
323
+ label="Segmented Image with Box-Prompt", interactive=False, type="pil"
324
+ )
325
+
326
+ input_size_slider = gr.components.Slider(
327
+ minimum=512,
328
+ maximum=1024,
329
+ value=1024,
330
+ step=64,
331
+ label="Input_size",
332
+ info="Our model was trained on a size of 1024",
333
+ )
334
+
335
+ with gr.Blocks(css=css, title="TinySAM") as demo:
336
+ global_points = gr.State([])
337
+ global_point_label = gr.State([])
338
  with gr.Row():
339
  with gr.Column(scale=1):
340
  # Title
341
  gr.Markdown(title)
342
 
343
+ with gr.Tab("Point mode"):
344
  # Images
345
  with gr.Row(variant="panel"):
346
  with gr.Column(scale=1):
347
+ cond_img_p.render()
 
 
 
 
 
 
 
348
 
349
+ with gr.Column(scale=1):
350
+ segm_img_p.render()
 
 
 
351
 
352
+ # Submit & Clear
353
+ # ###
354
  with gr.Row():
355
  with gr.Column():
356
+
357
+ with gr.Column():
358
+ segment_btn_p = gr.Button(
359
+ "Segment with Point Prompt", variant="primary"
360
+ )
361
+ clear_btn_p = gr.Button("Clear", variant="secondary")
362
+
363
  gr.Markdown("Try some of the examples below ⬇️")
364
  gr.Examples(
365
  examples=examples,
366
+ inputs=[cond_img_p],
367
+ examples_per_page=4,
 
 
 
368
  )
369
 
370
+ with gr.Column():
371
+ # Description
372
+ gr.Markdown(description_p)
373
+
374
+ with gr.Tab("Box mode"):
375
  # Images
376
  with gr.Row(variant="panel"):
377
  with gr.Column(scale=1):
378
+ cond_img_b.render()
 
 
 
 
 
379
 
380
+ with gr.Column(scale=1):
381
+ segm_img_b.render()
382
+
383
+ # Submit & Clear
384
  with gr.Row():
385
  with gr.Column():
386
+
387
+ with gr.Column():
388
+ segment_btn_b = gr.Button(
389
+ "Segment with Box Prompt", variant="primary"
390
+ )
391
+ clear_btn_b = gr.Button("Clear", variant="secondary")
392
+
393
  gr.Markdown("Try some of the examples below ⬇️")
394
  gr.Examples(
395
  examples=examples,
396
+ inputs=[cond_img_b],
397
+
398
+ examples_per_page=4,
 
 
399
  )
400
 
401
+ with gr.Column():
402
+ # Description
403
+ gr.Markdown(description_p)
404
+
405
+ cond_img_p.select(get_points_with_draw, inputs = [segm_img_p, cond_img_p, global_points, global_point_label], outputs = [segm_img_p, global_points, global_point_label])
406
 
407
+ cond_img_b.select(get_points_with_draw_, [segm_img_b, cond_img_b, global_points, global_point_label], [segm_img_b, global_points, global_point_label])
 
408
 
409
+ segment_btn_p.click(
410
+ segment_with_points, inputs=[cond_img_p, global_points, global_point_label], outputs=[segm_img_p, global_points, global_point_label]
411
+ )
412
+
413
+ segment_btn_b.click(
414
+ segment_with_boxs, inputs=[cond_img_b, segm_img_b, global_points, global_point_label], outputs=[segm_img_b,global_points, global_point_label]
415
+ )
416
 
417
+ def clear():
418
+ return None, None, [], []
419
 
420
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p, global_points, global_point_label])
421
+ clear_btn_b.click(clear, outputs=[cond_img_b, segm_img_b, global_points, global_point_label])
 
422
 
423
  demo.queue()
424
  demo.launch()