microhan commited on
Commit
565cb88
1 Parent(s): 8dddb51

Upload 11 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ efficientsam_ti_cpu.jit filter=lfs diff=lfs merge=lfs -text
37
+ efficientsam_ti_gpu.jit filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os # noqa
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ from PIL import ImageDraw
8
+ from torchvision.transforms import ToTensor
9
+
10
+ from utils.tools import format_results, point_prompt
11
+ from utils.tools_gradio import fast_process
12
+
13
+ # Most of our demo code is from [FastSAM Demo](https://huggingface.co/spaces/An-619/FastSAM). Thanks for AN-619.
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ gpu_checkpoint_path = "efficientsam_s_gpu.jit"
18
+ cpu_checkpoint_path = "efficientsam_s_cpu.jit"
19
+
20
+ if torch.cuda.is_available():
21
+ model = torch.jit.load(gpu_checkpoint_path)
22
+ else:
23
+ model = torch.jit.load(cpu_checkpoint_path)
24
+ model.eval()
25
+
26
+ # Description
27
+ title = "<center><strong><font size='8'>Efficient Segment Anything(EfficientSAM)<font></strong></center>"
28
+
29
+ description_e = """This is a demo of [Efficient Segment Anything(EfficientSAM) Model](https://github.com/yformer/EfficientSAM).
30
+ """
31
+
32
+ description_p = """# Interactive Instance Segmentation
33
+ - Point-prompt instruction
34
+ <ol>
35
+ <li> Click on the left image (point input), visualizing the point on the right image </li>
36
+ <li> Click the button of Segment with Point Prompt </li>
37
+ </ol>
38
+ - Box-prompt instruction
39
+ <ol>
40
+ <li> Click on the left image (one point input), visualizing the point on the right image </li>
41
+ <li> Click on the left image (another point input), visualizing the point and the box on the right image</li>
42
+ <li> Click the button of Segment with Box Prompt </li>
43
+ </ol>
44
+ - Github [link](https://github.com/yformer/EfficientSAM)
45
+ """
46
+
47
+ # examples
48
+ examples = [
49
+ ["examples/image1.jpg"],
50
+ ["examples/image2.jpg"],
51
+ ["examples/image3.jpg"],
52
+ ["examples/image4.jpg"],
53
+ ["examples/image5.jpg"],
54
+ ["examples/image6.jpg"],
55
+ ["examples/image7.jpg"],
56
+ ["examples/image8.jpg"],
57
+ ["examples/image9.jpg"],
58
+ ["examples/image10.jpg"],
59
+ ["examples/image11.jpg"],
60
+ ["examples/image12.jpg"],
61
+ ["examples/image13.jpg"],
62
+ ["examples/image14.jpg"],
63
+ ]
64
+
65
+ default_example = examples[0]
66
+
67
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
68
+
69
+
70
+ def segment_with_boxs(
71
+ image,
72
+ seg_image,
73
+ global_points,
74
+ global_point_label,
75
+ input_size=1024,
76
+ better_quality=False,
77
+ withContours=True,
78
+ use_retina=True,
79
+ mask_random_color=True,
80
+ ):
81
+ if len(global_points) < 2:
82
+ return seg_image, global_points, global_point_label
83
+ print("Original Image : ", image.size)
84
+
85
+ input_size = int(input_size)
86
+ w, h = image.size
87
+ scale = input_size / max(w, h)
88
+ new_w = int(w * scale)
89
+ new_h = int(h * scale)
90
+ image = image.resize((new_w, new_h))
91
+
92
+ print("Scaled Image : ", image.size)
93
+ print("Scale : ", scale)
94
+
95
+ scaled_points = np.array(
96
+ [[int(x * scale) for x in point] for point in global_points]
97
+ )
98
+ scaled_points = scaled_points[:2]
99
+ scaled_point_label = np.array(global_point_label)[:2]
100
+
101
+ print(scaled_points, scaled_points is not None)
102
+ print(scaled_point_label, scaled_point_label is not None)
103
+
104
+ if scaled_points.size == 0 and scaled_point_label.size == 0:
105
+ print("No points selected")
106
+ return image, global_points, global_point_label
107
+
108
+ nd_image = np.array(image)
109
+ img_tensor = ToTensor()(nd_image)
110
+
111
+ print(img_tensor.shape)
112
+ pts_sampled = torch.reshape(torch.tensor(scaled_points), [1, 1, -1, 2])
113
+ pts_sampled = pts_sampled[:, :, :2, :]
114
+ pts_labels = torch.reshape(torch.tensor([2, 3]), [1, 1, 2])
115
+
116
+ predicted_logits, predicted_iou = model(
117
+ img_tensor[None, ...].to(device),
118
+ pts_sampled.to(device),
119
+ pts_labels.to(device),
120
+ )
121
+ predicted_logits = predicted_logits.cpu()
122
+ all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
123
+ predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
124
+
125
+
126
+ max_predicted_iou = -1
127
+ selected_mask_using_predicted_iou = None
128
+ selected_predicted_iou = None
129
+
130
+ for m in range(all_masks.shape[0]):
131
+ curr_predicted_iou = predicted_iou[m]
132
+ if (
133
+ curr_predicted_iou > max_predicted_iou
134
+ or selected_mask_using_predicted_iou is None
135
+ ):
136
+ max_predicted_iou = curr_predicted_iou
137
+ selected_mask_using_predicted_iou = all_masks[m:m+1]
138
+ selected_predicted_iou = predicted_iou[m:m+1]
139
+
140
+ results = format_results(selected_mask_using_predicted_iou, selected_predicted_iou, predicted_logits, 0)
141
+
142
+ annotations = results[0]["segmentation"]
143
+ annotations = np.array([annotations])
144
+ print(scaled_points.shape)
145
+ fig = fast_process(
146
+ annotations=annotations,
147
+ image=image,
148
+ device=device,
149
+ scale=(1024 // input_size),
150
+ better_quality=better_quality,
151
+ mask_random_color=mask_random_color,
152
+ use_retina=use_retina,
153
+ bbox = scaled_points.reshape([4]),
154
+ withContours=withContours,
155
+ )
156
+
157
+ global_points = []
158
+ global_point_label = []
159
+ # return fig, None
160
+ return fig, global_points, global_point_label
161
+
162
+
163
+ def segment_with_points(
164
+ image,
165
+ global_points,
166
+ global_point_label,
167
+ input_size=1024,
168
+ better_quality=False,
169
+ withContours=True,
170
+ use_retina=True,
171
+ mask_random_color=True,
172
+ ):
173
+ print("Original Image : ", image.size)
174
+
175
+ input_size = int(input_size)
176
+ w, h = image.size
177
+ scale = input_size / max(w, h)
178
+ new_w = int(w * scale)
179
+ new_h = int(h * scale)
180
+ image = image.resize((new_w, new_h))
181
+
182
+ print("Scaled Image : ", image.size)
183
+ print("Scale : ", scale)
184
+
185
+ if global_points is None:
186
+ return image, global_points, global_point_label
187
+ if len(global_points) < 1:
188
+ return image, global_points, global_point_label
189
+ scaled_points = np.array(
190
+ [[int(x * scale) for x in point] for point in global_points]
191
+ )
192
+ scaled_point_label = np.array(global_point_label)
193
+
194
+ print(scaled_points, scaled_points is not None)
195
+ print(scaled_point_label, scaled_point_label is not None)
196
+
197
+ if scaled_points.size == 0 and scaled_point_label.size == 0:
198
+ print("No points selected")
199
+ return image, global_points, global_point_label
200
+
201
+ nd_image = np.array(image)
202
+ img_tensor = ToTensor()(nd_image)
203
+
204
+ print(img_tensor.shape)
205
+ pts_sampled = torch.reshape(torch.tensor(scaled_points), [1, 1, -1, 2])
206
+ pts_labels = torch.reshape(torch.tensor(global_point_label), [1, 1, -1])
207
+
208
+ predicted_logits, predicted_iou = model(
209
+ img_tensor[None, ...].to(device),
210
+ pts_sampled.to(device),
211
+ pts_labels.to(device),
212
+ )
213
+ predicted_logits = predicted_logits.cpu()
214
+ all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
215
+ predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
216
+
217
+ results = format_results(all_masks, predicted_iou, predicted_logits, 0)
218
+
219
+ annotations, _ = point_prompt(
220
+ results, scaled_points, scaled_point_label, new_h, new_w
221
+ )
222
+ annotations = np.array([annotations])
223
+
224
+ fig = fast_process(
225
+ annotations=annotations,
226
+ image=image,
227
+ device=device,
228
+ scale=(1024 // input_size),
229
+ better_quality=better_quality,
230
+ mask_random_color=mask_random_color,
231
+ points = scaled_points,
232
+ bbox=None,
233
+ use_retina=use_retina,
234
+ withContours=withContours,
235
+ )
236
+
237
+ global_points = []
238
+ global_point_label = []
239
+ # return fig, None
240
+ return fig, global_points, global_point_label
241
+
242
+
243
+ def get_points_with_draw(image, cond_image, global_points, global_point_label, evt: gr.SelectData):
244
+ print("Starting functioning")
245
+ if len(global_points) == 0:
246
+ image = copy.deepcopy(cond_image)
247
+ x, y = evt.index[0], evt.index[1]
248
+ label = "Add Mask"
249
+ point_radius, point_color = 15, (255, 255, 0) if label == "Add Mask" else (
250
+ 255,
251
+ 0,
252
+ 255,
253
+ )
254
+ global_points.append([x, y])
255
+ global_point_label.append(1 if label == "Add Mask" else 0)
256
+
257
+ print(x, y, label == "Add Mask")
258
+
259
+ if image is not None:
260
+ draw = ImageDraw.Draw(image)
261
+
262
+ draw.ellipse(
263
+ [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
264
+ fill=point_color,
265
+ )
266
+
267
+ return image, global_points, global_point_label
268
+
269
+ def get_points_with_draw_(image, cond_image, global_points, global_point_label, evt: gr.SelectData):
270
+
271
+ if len(global_points) == 0:
272
+ image = copy.deepcopy(cond_image)
273
+ if len(global_points) > 2:
274
+ return image, global_points, global_point_label
275
+ x, y = evt.index[0], evt.index[1]
276
+ label = "Add Mask"
277
+ point_radius, point_color = 15, (255, 255, 0) if label == "Add Mask" else (
278
+ 255,
279
+ 0,
280
+ 255,
281
+ )
282
+ global_points.append([x, y])
283
+ global_point_label.append(1 if label == "Add Mask" else 0)
284
+
285
+ print(x, y, label == "Add Mask")
286
+
287
+ if image is not None:
288
+ draw = ImageDraw.Draw(image)
289
+
290
+ draw.ellipse(
291
+ [(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
292
+ fill=point_color,
293
+ )
294
+
295
+ if len(global_points) == 2:
296
+ x1, y1 = global_points[0]
297
+ x2, y2 = global_points[1]
298
+ if x1 < x2 and y1 < y2:
299
+ draw.rectangle([x1, y1, x2, y2], outline="red", width=5)
300
+ elif x1 < x2 and y1 >= y2:
301
+ draw.rectangle([x1, y2, x2, y1], outline="red", width=5)
302
+ global_points[0][0] = x1
303
+ global_points[0][1] = y2
304
+ global_points[1][0] = x2
305
+ global_points[1][1] = y1
306
+ elif x1 >= x2 and y1 < y2:
307
+ draw.rectangle([x2, y1, x1, y2], outline="red", width=5)
308
+ global_points[0][0] = x2
309
+ global_points[0][1] = y1
310
+ global_points[1][0] = x1
311
+ global_points[1][1] = y2
312
+ elif x1 >= x2 and y1 >= y2:
313
+ draw.rectangle([x2, y2, x1, y1], outline="red", width=5)
314
+ global_points[0][0] = x2
315
+ global_points[0][1] = y2
316
+ global_points[1][0] = x1
317
+ global_points[1][1] = y1
318
+
319
+ return image, global_points, global_point_label
320
+
321
+
322
+ cond_img_p = gr.Image(label="Input with Point", value=default_example[0], type="pil")
323
+ cond_img_b = gr.Image(label="Input with Box", value=default_example[0], type="pil")
324
+
325
+ segm_img_p = gr.Image(
326
+ label="Segmented Image with Point-Prompt", interactive=False, type="pil"
327
+ )
328
+ segm_img_b = gr.Image(
329
+ label="Segmented Image with Box-Prompt", interactive=False, type="pil"
330
+ )
331
+
332
+ input_size_slider = gr.components.Slider(
333
+ minimum=512,
334
+ maximum=1024,
335
+ value=1024,
336
+ step=64,
337
+ label="Input_size",
338
+ info="Our model was trained on a size of 1024",
339
+ )
340
+
341
+ with gr.Blocks(css=css, title="Efficient SAM") as demo:
342
+ global_points = gr.State([])
343
+ global_point_label = gr.State([])
344
+ with gr.Row():
345
+ with gr.Column(scale=1):
346
+ # Title
347
+ gr.Markdown(title)
348
+
349
+ with gr.Tab("Point mode"):
350
+ # Images
351
+ with gr.Row(variant="panel"):
352
+ with gr.Column(scale=1):
353
+ cond_img_p.render()
354
+
355
+ with gr.Column(scale=1):
356
+ segm_img_p.render()
357
+
358
+ # Submit & Clear
359
+ # ###
360
+ with gr.Row():
361
+ with gr.Column():
362
+
363
+ with gr.Column():
364
+ segment_btn_p = gr.Button(
365
+ "Segment with Point Prompt", variant="primary"
366
+ )
367
+ clear_btn_p = gr.Button("Clear", variant="secondary")
368
+
369
+ gr.Markdown("Try some of the examples below ⬇️")
370
+ gr.Examples(
371
+ examples=examples,
372
+ inputs=[cond_img_p],
373
+ examples_per_page=4,
374
+ )
375
+
376
+ with gr.Column():
377
+ # Description
378
+ gr.Markdown(description_p)
379
+
380
+ with gr.Tab("Box mode"):
381
+ # Images
382
+ with gr.Row(variant="panel"):
383
+ with gr.Column(scale=1):
384
+ cond_img_b.render()
385
+
386
+ with gr.Column(scale=1):
387
+ segm_img_b.render()
388
+
389
+ # Submit & Clear
390
+ with gr.Row():
391
+ with gr.Column():
392
+
393
+ with gr.Column():
394
+ segment_btn_b = gr.Button(
395
+ "Segment with Box Prompt", variant="primary"
396
+ )
397
+ clear_btn_b = gr.Button("Clear", variant="secondary")
398
+
399
+ gr.Markdown("Try some of the examples below ⬇️")
400
+ gr.Examples(
401
+ examples=examples,
402
+ inputs=[cond_img_b],
403
+
404
+ examples_per_page=4,
405
+ )
406
+
407
+ with gr.Column():
408
+ # Description
409
+ gr.Markdown(description_p)
410
+
411
+ cond_img_p.select(get_points_with_draw, inputs = [segm_img_p, cond_img_p, global_points, global_point_label], outputs = [segm_img_p, global_points, global_point_label])
412
+
413
+ cond_img_b.select(get_points_with_draw_, [segm_img_b, cond_img_b, global_points, global_point_label], [segm_img_b, global_points, global_point_label])
414
+
415
+ segment_btn_p.click(
416
+ segment_with_points, inputs=[cond_img_p, global_points, global_point_label], outputs=[segm_img_p, global_points, global_point_label]
417
+ )
418
+
419
+ segment_btn_b.click(
420
+ segment_with_boxs, inputs=[cond_img_b, segm_img_b, global_points, global_point_label], outputs=[segm_img_b,global_points, global_point_label]
421
+ )
422
+
423
+ def clear():
424
+ return None, None, [], []
425
+
426
+ clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p, global_points, global_point_label])
427
+ clear_btn_b.click(clear, outputs=[cond_img_b, segm_img_b, global_points, global_point_label])
428
+
429
+ demo.queue()
430
+ demo.launch()
efficientsam_ti.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:143c3198a7b2a15f23c21cdb723432fb3fbcdbabbdad3483cf3babd8b95c1397
3
+ size 41365520
efficientsam_ti_cpu.jit ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2369ab027799ba26c8828834a00708aa92c66937d1d211ad43346934b0d5171c
3
+ size 41247427
efficientsam_ti_decoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a62f8fa5ea080447c0689418d69e58f1e83e0b7adf9c142e2bd9bcc8045c0b11
3
+ size 16565728
efficientsam_ti_encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:84ed466ffcc5c1f8d08409bc34a23bb364ab2c15e402cb12d4335a42be0e0951
3
+ size 24799761
efficientsam_ti_gpu.jit ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8beeac933a4b99ca118e545ff3b0abb5c433e2f1fa861ad0ed9f2d378d29004a
3
+ size 41247427
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ gradio
4
+ opencv-python
5
+ pandas
6
+ matplotlib
utils/__init__.py ADDED
File without changes
utils/tools gradio.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+
7
+
8
+ def fast_process(
9
+ annotations,
10
+ image,
11
+ device,
12
+ scale,
13
+ better_quality=False,
14
+ mask_random_color=True,
15
+ bbox=None,
16
+ points=None,
17
+ use_retina=True,
18
+ withContours=True,
19
+ ):
20
+ if isinstance(annotations[0], dict):
21
+ annotations = [annotation["segmentation"] for annotation in annotations]
22
+
23
+ original_h = image.height
24
+ original_w = image.width
25
+ if better_quality:
26
+ if isinstance(annotations[0], torch.Tensor):
27
+ annotations = np.array(annotations.cpu())
28
+ for i, mask in enumerate(annotations):
29
+ mask = cv2.morphologyEx(
30
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
31
+ )
32
+ annotations[i] = cv2.morphologyEx(
33
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
34
+ )
35
+ if device == "cpu":
36
+ annotations = np.array(annotations)
37
+ inner_mask = fast_show_mask(
38
+ annotations,
39
+ plt.gca(),
40
+ random_color=mask_random_color,
41
+ bbox=bbox,
42
+ retinamask=use_retina,
43
+ target_height=original_h,
44
+ target_width=original_w,
45
+ )
46
+ else:
47
+ if isinstance(annotations[0], np.ndarray):
48
+ annotations = np.array(annotations)
49
+ annotations = torch.from_numpy(annotations)
50
+ inner_mask = fast_show_mask_gpu(
51
+ annotations,
52
+ plt.gca(),
53
+ random_color=mask_random_color,
54
+ bbox=bbox,
55
+ retinamask=use_retina,
56
+ target_height=original_h,
57
+ target_width=original_w,
58
+ )
59
+ if isinstance(annotations, torch.Tensor):
60
+ annotations = annotations.cpu().numpy()
61
+
62
+ if withContours:
63
+ contour_all = []
64
+ temp = np.zeros((original_h, original_w, 1))
65
+ for i, mask in enumerate(annotations):
66
+ if type(mask) == dict:
67
+ mask = mask["segmentation"]
68
+ annotation = mask.astype(np.uint8)
69
+ if use_retina == False:
70
+ annotation = cv2.resize(
71
+ annotation,
72
+ (original_w, original_h),
73
+ interpolation=cv2.INTER_NEAREST,
74
+ )
75
+ contours, _ = cv2.findContours(
76
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
77
+ )
78
+ for contour in contours:
79
+ contour_all.append(contour)
80
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
81
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
82
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
83
+
84
+ image = image.convert("RGBA")
85
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), "RGBA")
86
+ image.paste(overlay_inner, (0, 0), overlay_inner)
87
+
88
+ if withContours:
89
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), "RGBA")
90
+ image.paste(overlay_contour, (0, 0), overlay_contour)
91
+
92
+ return image
93
+
94
+
95
+ # CPU post process
96
+ def fast_show_mask(
97
+ annotation,
98
+ ax,
99
+ random_color=False,
100
+ bbox=None,
101
+ retinamask=True,
102
+ target_height=960,
103
+ target_width=960,
104
+ ):
105
+ mask_sum = annotation.shape[0]
106
+ height = annotation.shape[1]
107
+ weight = annotation.shape[2]
108
+ # annotation is sorted by area
109
+ areas = np.sum(annotation, axis=(1, 2))
110
+ sorted_indices = np.argsort(areas)[::1]
111
+ annotation = annotation[sorted_indices]
112
+
113
+ index = (annotation != 0).argmax(axis=0)
114
+ if random_color == True:
115
+ color = np.random.random((mask_sum, 1, 1, 3))
116
+ else:
117
+ color = np.ones((mask_sum, 1, 1, 3)) * np.array(
118
+ [30 / 255, 144 / 255, 255 / 255]
119
+ )
120
+ transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
121
+ visual = np.concatenate([color, transparency], axis=-1)
122
+ mask_image = np.expand_dims(annotation, -1) * visual
123
+
124
+ mask = np.zeros((height, weight, 4))
125
+
126
+ h_indices, w_indices = np.meshgrid(
127
+ np.arange(height), np.arange(weight), indexing="ij"
128
+ )
129
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
130
+
131
+ mask[h_indices, w_indices, :] = mask_image[indices]
132
+ if bbox is not None:
133
+ x1, y1, x2, y2 = bbox
134
+ ax.add_patch(
135
+ plt.Rectangle(
136
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
137
+ )
138
+ )
139
+
140
+ if retinamask == False:
141
+ mask = cv2.resize(
142
+ mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST
143
+ )
144
+
145
+ return mask
146
+
147
+
148
+ def fast_show_mask_gpu(
149
+ annotation,
150
+ ax,
151
+ random_color=False,
152
+ bbox=None,
153
+ retinamask=True,
154
+ target_height=960,
155
+ target_width=960,
156
+ ):
157
+ device = annotation.device
158
+ mask_sum = annotation.shape[0]
159
+ height = annotation.shape[1]
160
+ weight = annotation.shape[2]
161
+ areas = torch.sum(annotation, dim=(1, 2))
162
+ sorted_indices = torch.argsort(areas, descending=False)
163
+ annotation = annotation[sorted_indices]
164
+ # find the first non-zero subscript for each position
165
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
166
+ if random_color == True:
167
+ color = torch.rand((mask_sum, 1, 1, 3)).to(device)
168
+ else:
169
+ color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
170
+ [30 / 255, 144 / 255, 255 / 255]
171
+ ).to(device)
172
+ transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
173
+ visual = torch.cat([color, transparency], dim=-1)
174
+ mask_image = torch.unsqueeze(annotation, -1) * visual
175
+ # index
176
+ mask = torch.zeros((height, weight, 4)).to(device)
177
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
178
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
179
+ # make updates based on indices
180
+ mask[h_indices, w_indices, :] = mask_image[indices]
181
+ mask_cpu = mask.cpu().numpy()
182
+ if bbox is not None:
183
+ x1, y1, x2, y2 = bbox
184
+ ax.add_patch(
185
+ plt.Rectangle(
186
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
187
+ )
188
+ )
189
+ if retinamask == False:
190
+ mask_cpu = cv2.resize(
191
+ mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
192
+ )
193
+ return mask_cpu
utils/tools.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import cv2
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+
10
+
11
+ def convert_box_xywh_to_xyxy(box):
12
+ x1 = box[0]
13
+ y1 = box[1]
14
+ x2 = box[0] + box[2]
15
+ y2 = box[1] + box[3]
16
+ return [x1, y1, x2, y2]
17
+
18
+
19
+ def segment_image(image, bbox):
20
+ image_array = np.array(image)
21
+ segmented_image_array = np.zeros_like(image_array)
22
+ x1, y1, x2, y2 = bbox
23
+ segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
24
+ segmented_image = Image.fromarray(segmented_image_array)
25
+ black_image = Image.new("RGB", image.size, (255, 255, 255))
26
+ # transparency_mask = np.zeros_like((), dtype=np.uint8)
27
+ transparency_mask = np.zeros(
28
+ (image_array.shape[0], image_array.shape[1]), dtype=np.uint8
29
+ )
30
+ transparency_mask[y1:y2, x1:x2] = 255
31
+ transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
32
+ black_image.paste(segmented_image, mask=transparency_mask_image)
33
+ return black_image
34
+
35
+
36
+ def format_results(masks, scores, logits, filter=0):
37
+ annotations = []
38
+ n = len(scores)
39
+ for i in range(n):
40
+ annotation = {}
41
+
42
+ mask = masks[i]
43
+ tmp = np.where(mask != 0)
44
+ if np.sum(mask) < filter:
45
+ continue
46
+ annotation["id"] = i
47
+ annotation["segmentation"] = mask
48
+ annotation["bbox"] = [
49
+ np.min(tmp[0]),
50
+ np.min(tmp[1]),
51
+ np.max(tmp[1]),
52
+ np.max(tmp[0]),
53
+ ]
54
+ annotation["score"] = scores[i]
55
+ annotation["area"] = annotation["segmentation"].sum()
56
+ annotations.append(annotation)
57
+ return annotations
58
+
59
+
60
+ def filter_masks(annotations): # filter the overlap mask
61
+ annotations.sort(key=lambda x: x["area"], reverse=True)
62
+ to_remove = set()
63
+ for i in range(0, len(annotations)):
64
+ a = annotations[i]
65
+ for j in range(i + 1, len(annotations)):
66
+ b = annotations[j]
67
+ if i != j and j not in to_remove:
68
+ # check if
69
+ if b["area"] < a["area"]:
70
+ if (a["segmentation"] & b["segmentation"]).sum() / b[
71
+ "segmentation"
72
+ ].sum() > 0.8:
73
+ to_remove.add(j)
74
+
75
+ return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
76
+
77
+
78
+ def get_bbox_from_mask(mask):
79
+ mask = mask.astype(np.uint8)
80
+ contours, hierarchy = cv2.findContours(
81
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
82
+ )
83
+ x1, y1, w, h = cv2.boundingRect(contours[0])
84
+ x2, y2 = x1 + w, y1 + h
85
+ if len(contours) > 1:
86
+ for b in contours:
87
+ x_t, y_t, w_t, h_t = cv2.boundingRect(b)
88
+ # 将多个bbox合并成一个
89
+ x1 = min(x1, x_t)
90
+ y1 = min(y1, y_t)
91
+ x2 = max(x2, x_t + w_t)
92
+ y2 = max(y2, y_t + h_t)
93
+ h = y2 - y1
94
+ w = x2 - x1
95
+ return [x1, y1, x2, y2]
96
+
97
+
98
+ def fast_process(
99
+ annotations, args, mask_random_color, bbox=None, points=None, edges=False
100
+ ):
101
+ if isinstance(annotations[0], dict):
102
+ annotations = [annotation["segmentation"] for annotation in annotations]
103
+ result_name = os.path.basename(args.img_path)
104
+ image = cv2.imread(args.img_path)
105
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
106
+ original_h = image.shape[0]
107
+ original_w = image.shape[1]
108
+ if sys.platform == "darwin":
109
+ plt.switch_backend("TkAgg")
110
+ plt.figure(figsize=(original_w / 100, original_h / 100))
111
+ # Add subplot with no margin.
112
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
113
+ plt.margins(0, 0)
114
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
115
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
116
+ plt.imshow(image)
117
+ if args.better_quality == True:
118
+ if isinstance(annotations[0], torch.Tensor):
119
+ annotations = np.array(annotations.cpu())
120
+ for i, mask in enumerate(annotations):
121
+ mask = cv2.morphologyEx(
122
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
123
+ )
124
+ annotations[i] = cv2.morphologyEx(
125
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
126
+ )
127
+ if args.device == "cpu":
128
+ annotations = np.array(annotations)
129
+ fast_show_mask(
130
+ annotations,
131
+ plt.gca(),
132
+ random_color=mask_random_color,
133
+ bbox=bbox,
134
+ points=points,
135
+ point_label=args.point_label,
136
+ retinamask=args.retina,
137
+ target_height=original_h,
138
+ target_width=original_w,
139
+ )
140
+ else:
141
+ if isinstance(annotations[0], np.ndarray):
142
+ annotations = torch.from_numpy(annotations)
143
+ fast_show_mask_gpu(
144
+ annotations,
145
+ plt.gca(),
146
+ random_color=args.randomcolor,
147
+ bbox=bbox,
148
+ points=points,
149
+ point_label=args.point_label,
150
+ retinamask=args.retina,
151
+ target_height=original_h,
152
+ target_width=original_w,
153
+ )
154
+ if isinstance(annotations, torch.Tensor):
155
+ annotations = annotations.cpu().numpy()
156
+ if args.withContours == True:
157
+ contour_all = []
158
+ temp = np.zeros((original_h, original_w, 1))
159
+ for i, mask in enumerate(annotations):
160
+ if type(mask) == dict:
161
+ mask = mask["segmentation"]
162
+ annotation = mask.astype(np.uint8)
163
+ if args.retina == False:
164
+ annotation = cv2.resize(
165
+ annotation,
166
+ (original_w, original_h),
167
+ interpolation=cv2.INTER_NEAREST,
168
+ )
169
+ contours, hierarchy = cv2.findContours(
170
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
171
+ )
172
+ for contour in contours:
173
+ contour_all.append(contour)
174
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
175
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
176
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
177
+ plt.imshow(contour_mask)
178
+
179
+ save_path = args.output
180
+ if not os.path.exists(save_path):
181
+ os.makedirs(save_path)
182
+ plt.axis("off")
183
+ fig = plt.gcf()
184
+ plt.draw()
185
+
186
+ try:
187
+ buf = fig.canvas.tostring_rgb()
188
+ except AttributeError:
189
+ fig.canvas.draw()
190
+ buf = fig.canvas.tostring_rgb()
191
+
192
+ cols, rows = fig.canvas.get_width_height()
193
+ img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
194
+ cv2.imwrite(
195
+ os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
196
+ )
197
+
198
+
199
+ # CPU post process
200
+ def fast_show_mask(
201
+ annotation,
202
+ ax,
203
+ random_color=False,
204
+ bbox=None,
205
+ points=None,
206
+ point_label=None,
207
+ retinamask=True,
208
+ target_height=960,
209
+ target_width=960,
210
+ ):
211
+ msak_sum = annotation.shape[0]
212
+ height = annotation.shape[1]
213
+ weight = annotation.shape[2]
214
+ # annotation is sorted by area
215
+ areas = np.sum(annotation, axis=(1, 2))
216
+ sorted_indices = np.argsort(areas)
217
+ annotation = annotation[sorted_indices]
218
+
219
+ index = (annotation != 0).argmax(axis=0)
220
+ if random_color == True:
221
+ color = np.random.random((msak_sum, 1, 1, 3))
222
+ else:
223
+ color = np.ones((msak_sum, 1, 1, 3)) * np.array(
224
+ [30 / 255, 144 / 255, 255 / 255]
225
+ )
226
+ transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
227
+ visual = np.concatenate([color, transparency], axis=-1)
228
+ mask_image = np.expand_dims(annotation, -1) * visual
229
+
230
+ show = np.zeros((height, weight, 4))
231
+ h_indices, w_indices = np.meshgrid(
232
+ np.arange(height), np.arange(weight), indexing="ij"
233
+ )
234
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
235
+ # make updates
236
+ show[h_indices, w_indices, :] = mask_image[indices]
237
+ if bbox is not None:
238
+ x1, y1, x2, y2 = bbox
239
+ ax.add_patch(
240
+ plt.Rectangle(
241
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
242
+ )
243
+ )
244
+ # draw point
245
+ if points is not None:
246
+ plt.scatter(
247
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
248
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
249
+ s=20,
250
+ c="y",
251
+ )
252
+ plt.scatter(
253
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
254
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
255
+ s=20,
256
+ c="m",
257
+ )
258
+
259
+ if retinamask == False:
260
+ show = cv2.resize(
261
+ show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
262
+ )
263
+ ax.imshow(show)
264
+
265
+
266
+ def fast_show_mask_gpu(
267
+ annotation,
268
+ ax,
269
+ random_color=False,
270
+ bbox=None,
271
+ points=None,
272
+ point_label=None,
273
+ retinamask=True,
274
+ target_height=960,
275
+ target_width=960,
276
+ ):
277
+ msak_sum = annotation.shape[0]
278
+ height = annotation.shape[1]
279
+ weight = annotation.shape[2]
280
+ areas = torch.sum(annotation, dim=(1, 2))
281
+ sorted_indices = torch.argsort(areas, descending=False)
282
+ annotation = annotation[sorted_indices]
283
+ # find the first non-zero subscript for each position
284
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
285
+ if random_color == True:
286
+ color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
287
+ else:
288
+ color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor(
289
+ [30 / 255, 144 / 255, 255 / 255]
290
+ ).to(annotation.device)
291
+ transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
292
+ visual = torch.cat([color, transparency], dim=-1)
293
+ mask_image = torch.unsqueeze(annotation, -1) * visual
294
+ # index
295
+ show = torch.zeros((height, weight, 4)).to(annotation.device)
296
+ h_indices, w_indices = torch.meshgrid(
297
+ torch.arange(height), torch.arange(weight), indexing="ij"
298
+ )
299
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
300
+ # make updates based on indices
301
+ show[h_indices, w_indices, :] = mask_image[indices]
302
+ show_cpu = show.cpu().numpy()
303
+ if bbox is not None:
304
+ x1, y1, x2, y2 = bbox
305
+ ax.add_patch(
306
+ plt.Rectangle(
307
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
308
+ )
309
+ )
310
+ # draw point
311
+ if points is not None:
312
+ plt.scatter(
313
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
314
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
315
+ s=20,
316
+ c="y",
317
+ )
318
+ plt.scatter(
319
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
320
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
321
+ s=20,
322
+ c="m",
323
+ )
324
+ if retinamask == False:
325
+ show_cpu = cv2.resize(
326
+ show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
327
+ )
328
+ ax.imshow(show_cpu)
329
+
330
+
331
+ def crop_image(annotations, image_like):
332
+ if isinstance(image_like, str):
333
+ image = Image.open(image_like)
334
+ else:
335
+ image = image_like
336
+ ori_w, ori_h = image.size
337
+ mask_h, mask_w = annotations[0]["segmentation"].shape
338
+ if ori_w != mask_w or ori_h != mask_h:
339
+ image = image.resize((mask_w, mask_h))
340
+ cropped_boxes = []
341
+ cropped_images = []
342
+ not_crop = []
343
+ filter_id = []
344
+ # annotations, _ = filter_masks(annotations)
345
+ # filter_id = list(_)
346
+ for _, mask in enumerate(annotations):
347
+ if np.sum(mask["segmentation"]) <= 100:
348
+ filter_id.append(_)
349
+ continue
350
+ bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
351
+ cropped_boxes.append(segment_image(image, bbox))
352
+ # cropped_boxes.append(segment_image(image,mask["segmentation"]))
353
+ cropped_images.append(bbox)
354
+
355
+ return cropped_boxes, cropped_images, not_crop, filter_id, annotations
356
+
357
+
358
+ def box_prompt(masks, bbox, target_height, target_width):
359
+ h = masks[0]["segmentation"].shape[1]
360
+ w = masks[0]["segmentation"].shape[2]
361
+ masks = masks[0]["segmentation"]
362
+ bbox = bbox.reshape([4])
363
+ if h != target_height or w != target_width:
364
+ bbox = [
365
+ int(bbox[0] * w / target_width),
366
+ int(bbox[1] * h / target_height),
367
+ int(bbox[2] * w / target_width),
368
+ int(bbox[3] * h / target_height),
369
+ ]
370
+ bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
371
+ bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
372
+ bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
373
+ bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
374
+
375
+ # IoUs = torch.zeros(len(masks), dtype=torch.float32)
376
+ bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
377
+
378
+ masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
379
+ orig_masks_area = torch.sum(masks, dim=(1, 2))
380
+
381
+ union = bbox_area + orig_masks_area - masks_area
382
+ IoUs = masks_area / union
383
+ max_iou_index = torch.argmax(IoUs)
384
+
385
+ return masks[max_iou_index].cpu().numpy(), max_iou_index
386
+
387
+
388
+ def point_prompt(masks, points, point_label, target_height, target_width): # numpy
389
+ h = masks[0]["segmentation"].shape[0]
390
+ w = masks[0]["segmentation"].shape[1]
391
+ if h != target_height or w != target_width:
392
+ points = [
393
+ [int(point[0] * w / target_width), int(point[1] * h / target_height)]
394
+ for point in points
395
+ ]
396
+ onemask = np.zeros((h, w))
397
+ for i, annotation in enumerate(masks):
398
+ if type(annotation) == dict:
399
+ mask = annotation["segmentation"]
400
+ else:
401
+ mask = annotation
402
+ for i, point in enumerate(points):
403
+ if point[1] < mask.shape[0] and point[0] < mask.shape[1]:
404
+ if mask[point[1], point[0]] == 1 and point_label[i] == 1:
405
+ onemask += mask
406
+ if mask[point[1], point[0]] == 1 and point_label[i] == 0:
407
+ onemask -= mask
408
+ onemask = onemask >= 1
409
+ return onemask, 0
utils/tools_gradio.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+
7
+
8
+ def fast_process(
9
+ annotations,
10
+ image,
11
+ device,
12
+ scale,
13
+ better_quality=False,
14
+ mask_random_color=True,
15
+ bbox=None,
16
+ points=None,
17
+ use_retina=True,
18
+ withContours=True,
19
+ ):
20
+ if isinstance(annotations[0], dict):
21
+ annotations = [annotation["segmentation"] for annotation in annotations]
22
+
23
+ original_h = image.height
24
+ original_w = image.width
25
+ if better_quality:
26
+ if isinstance(annotations[0], torch.Tensor):
27
+ annotations = np.array(annotations.cpu())
28
+ for i, mask in enumerate(annotations):
29
+ mask = cv2.morphologyEx(
30
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
31
+ )
32
+ annotations[i] = cv2.morphologyEx(
33
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
34
+ )
35
+ if device == "cpu":
36
+ annotations = np.array(annotations)
37
+ inner_mask = fast_show_mask(
38
+ annotations,
39
+ plt.gca(),
40
+ random_color=mask_random_color,
41
+ bbox=bbox,
42
+ retinamask=use_retina,
43
+ target_height=original_h,
44
+ target_width=original_w,
45
+ )
46
+ else:
47
+ if isinstance(annotations[0], np.ndarray):
48
+ annotations = np.array(annotations)
49
+ annotations = torch.from_numpy(annotations)
50
+ inner_mask = fast_show_mask_gpu(
51
+ annotations,
52
+ plt.gca(),
53
+ random_color=mask_random_color,
54
+ bbox=bbox,
55
+ retinamask=use_retina,
56
+ target_height=original_h,
57
+ target_width=original_w,
58
+ )
59
+ if isinstance(annotations, torch.Tensor):
60
+ annotations = annotations.cpu().numpy()
61
+
62
+ if withContours:
63
+ contour_all = []
64
+ temp = np.zeros((original_h, original_w, 1))
65
+ for i, mask in enumerate(annotations):
66
+ if type(mask) == dict:
67
+ mask = mask["segmentation"]
68
+ annotation = mask.astype(np.uint8)
69
+ if use_retina == False:
70
+ annotation = cv2.resize(
71
+ annotation,
72
+ (original_w, original_h),
73
+ interpolation=cv2.INTER_NEAREST,
74
+ )
75
+ contours, _ = cv2.findContours(
76
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
77
+ )
78
+ for contour in contours:
79
+ contour_all.append(contour)
80
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
81
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
82
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
83
+
84
+ image = image.convert("RGBA")
85
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), "RGBA")
86
+ image.paste(overlay_inner, (0, 0), overlay_inner)
87
+
88
+ if withContours:
89
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), "RGBA")
90
+ image.paste(overlay_contour, (0, 0), overlay_contour)
91
+
92
+ return image
93
+
94
+
95
+ # CPU post process
96
+ def fast_show_mask(
97
+ annotation,
98
+ ax,
99
+ random_color=False,
100
+ bbox=None,
101
+ retinamask=True,
102
+ target_height=960,
103
+ target_width=960,
104
+ ):
105
+ mask_sum = annotation.shape[0]
106
+ height = annotation.shape[1]
107
+ weight = annotation.shape[2]
108
+ # annotation is sorted by area
109
+ areas = np.sum(annotation, axis=(1, 2))
110
+ sorted_indices = np.argsort(areas)[::1]
111
+ annotation = annotation[sorted_indices]
112
+
113
+ index = (annotation != 0).argmax(axis=0)
114
+ if random_color == True:
115
+ color = np.random.random((mask_sum, 1, 1, 3))
116
+ else:
117
+ color = np.ones((mask_sum, 1, 1, 3)) * np.array(
118
+ [30 / 255, 144 / 255, 255 / 255]
119
+ )
120
+ transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
121
+ visual = np.concatenate([color, transparency], axis=-1)
122
+ mask_image = np.expand_dims(annotation, -1) * visual
123
+
124
+ mask = np.zeros((height, weight, 4))
125
+
126
+ h_indices, w_indices = np.meshgrid(
127
+ np.arange(height), np.arange(weight), indexing="ij"
128
+ )
129
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
130
+
131
+ mask[h_indices, w_indices, :] = mask_image[indices]
132
+ if bbox is not None:
133
+ x1, y1, x2, y2 = bbox
134
+ ax.add_patch(
135
+ plt.Rectangle(
136
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
137
+ )
138
+ )
139
+
140
+ if retinamask == False:
141
+ mask = cv2.resize(
142
+ mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST
143
+ )
144
+
145
+ return mask
146
+
147
+
148
+ def fast_show_mask_gpu(
149
+ annotation,
150
+ ax,
151
+ random_color=False,
152
+ bbox=None,
153
+ retinamask=True,
154
+ target_height=960,
155
+ target_width=960,
156
+ ):
157
+ device = annotation.device
158
+ mask_sum = annotation.shape[0]
159
+ height = annotation.shape[1]
160
+ weight = annotation.shape[2]
161
+ areas = torch.sum(annotation, dim=(1, 2))
162
+ sorted_indices = torch.argsort(areas, descending=False)
163
+ annotation = annotation[sorted_indices]
164
+ # find the first non-zero subscript for each position
165
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
166
+ if random_color == True:
167
+ color = torch.rand((mask_sum, 1, 1, 3)).to(device)
168
+ else:
169
+ color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
170
+ [30 / 255, 144 / 255, 255 / 255]
171
+ ).to(device)
172
+ transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
173
+ visual = torch.cat([color, transparency], dim=-1)
174
+ mask_image = torch.unsqueeze(annotation, -1) * visual
175
+ # index
176
+ mask = torch.zeros((height, weight, 4)).to(device)
177
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
178
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
179
+ # make updates based on indices
180
+ mask[h_indices, w_indices, :] = mask_image[indices]
181
+ mask_cpu = mask.cpu().numpy()
182
+ if bbox is not None:
183
+ x1, y1, x2, y2 = bbox
184
+ ax.add_patch(
185
+ plt.Rectangle(
186
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
187
+ )
188
+ )
189
+ if retinamask == False:
190
+ mask_cpu = cv2.resize(
191
+ mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
192
+ )
193
+ return mask_cpu