jianyouli commited on
Commit
b759a29
1 Parent(s): 6d443fe

Add application file1

Browse files
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (141 Bytes). View file
 
utils/__pycache__/tools.cpython-38.pyc ADDED
Binary file (10.9 kB). View file
 
utils/__pycache__/tools_gradio.cpython-38.pyc ADDED
Binary file (4.19 kB). View file
 
utils/test5.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # import matplotlib
3
+ # matplotlib.use('Qt5Agg')
4
+ import matplotlib.pyplot as plt
5
+ import gradio as gr
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from mobile_sam import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
10
+ from PIL import ImageDraw,Image
11
+ from utils.tools import box_prompt, format_results, point_prompt
12
+ from utils.tools_gradio import fast_process
13
+
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+ # Load the pre-trained model
16
+ sam_checkpoint = r"F:\zht\code\MobileSAM-master\weights\mobile_sam.pt"
17
+ model_type = "vit_t"
18
+ mobile_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
19
+ mobile_sam = mobile_sam.to(device=device)
20
+ mobile_sam.eval()
21
+
22
+ mask_generator = SamAutomaticMaskGenerator(mobile_sam)
23
+ predictor = SamPredictor(mobile_sam)
24
+
25
+ # default_example = examples[0]
26
+
27
+ @torch.no_grad()
28
+ def segment_with_boxs(
29
+ image,
30
+ input_size=1024,
31
+ better_quality=False,
32
+ withContours=True,
33
+ use_retina=True,
34
+ mask_random_color=True,
35
+ ):
36
+ global global_points
37
+ global global_point_label
38
+
39
+ input_size = int(input_size)
40
+ w, h = image.size
41
+ scale = input_size / max(w, h)
42
+ new_w = int(w * scale)
43
+ new_h = int(h * scale)
44
+
45
+ image = image.resize((new_w, new_h))
46
+ #################
47
+ scaled_points = np.array(
48
+ [[int(x * scale) for x in point] for point in global_points]
49
+ )
50
+ print("nnnnnnnnnnnnnnnnnnnnnnnnnnnnn00nnnnn",scaled_points)
51
+ scaled_point_label = np.array(global_point_label)
52
+
53
+ nd_image = np.array(image)
54
+ print("mmmmmmm0mmmm",nd_image.shape) #(685, 1024, 3)
55
+ predictor.set_image(nd_image) #改变形状
56
+ masks, scores, logits = predictor.predict(
57
+ point_coords=scaled_points,
58
+ point_labels=scaled_point_label,
59
+ multimask_output=True,
60
+ )
61
+
62
+ results = format_results(masks, scores, logits, 0)
63
+ print("mmmmmmmmmmmmmmmm2222m",len(results)) # [530 437]
64
+ annotations, _ = point_prompt(
65
+ results, scaled_points, scaled_point_label, new_h, new_w
66
+ )
67
+ annotations = np.array([annotations])
68
+ # 显示图像
69
+ plt.imshow(annotations[0], cmap='viridis') # 使用 'viridis' 颜色映射
70
+ plt.colorbar() # 显示颜色条
71
+ plt.savefig(r'F:\zht\code\2.png')
72
+ plt.show()
73
+
74
+ fig = fast_process(
75
+ annotations=annotations,
76
+ image=image,
77
+ device=device,
78
+ scale=(1024 // input_size),
79
+ better_quality=better_quality,
80
+ mask_random_color=mask_random_color,
81
+ bbox=None,
82
+ use_retina=use_retina,
83
+ withContours=withContours,
84
+ )
85
+ global_points = []
86
+ global_point_label = []
87
+ return fig, image
88
+
89
+ #################################################
90
+ if __name__ == "__main__":
91
+ path = r"F:\zht\code\MobileSAM-master\app\assets\05.jpg"
92
+ image1 = Image.open(path)
93
+ # image = cv2.imread(path)
94
+ print(image1.size)
95
+ # global_points = [[1069,928]]
96
+ global_points = [[324,740,1448,1192]]
97
+ global_point_label = [1]
98
+ segment_with_boxs(
99
+ image1,
100
+ input_size=1024,
101
+ better_quality=False,
102
+ withContours=True,
103
+ use_retina=True,
104
+ mask_random_color=True,
105
+ )
utils/tools.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import cv2
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image
9
+
10
+
11
+ def convert_box_xywh_to_xyxy(box):
12
+ x1 = box[0]
13
+ y1 = box[1]
14
+ x2 = box[0] + box[2]
15
+ y2 = box[1] + box[3]
16
+ return [x1, y1, x2, y2]
17
+
18
+
19
+ def segment_image(image, bbox):
20
+ image_array = np.array(image)
21
+ segmented_image_array = np.zeros_like(image_array)
22
+ x1, y1, x2, y2 = bbox
23
+ segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
24
+ segmented_image = Image.fromarray(segmented_image_array)
25
+ black_image = Image.new("RGB", image.size, (255, 255, 255))
26
+ # transparency_mask = np.zeros_like((), dtype=np.uint8)
27
+ transparency_mask = np.zeros(
28
+ (image_array.shape[0], image_array.shape[1]), dtype=np.uint8
29
+ )
30
+ transparency_mask[y1:y2, x1:x2] = 255
31
+ transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
32
+ black_image.paste(segmented_image, mask=transparency_mask_image)
33
+ return black_image
34
+
35
+
36
+ def format_results(masks, scores, logits, filter=0):
37
+ annotations = []
38
+ n = len(scores)
39
+ for i in range(n):
40
+ annotation = {}
41
+
42
+ mask = masks[i]
43
+ tmp = np.where(mask != 0) #返回坐标索引
44
+ # print("lllllllllllllllllllllllllll",tmp)
45
+ if np.sum(mask) < filter:
46
+ continue
47
+ annotation["id"] = i
48
+ annotation["segmentation"] = mask
49
+ annotation["bbox"] = [
50
+ np.min(tmp[0]),
51
+ np.min(tmp[1]),
52
+ np.max(tmp[1]),
53
+ np.max(tmp[0]),
54
+ ]
55
+ annotation["score"] = scores[i]
56
+ annotation["area"] = annotation["segmentation"].sum() #返回的是所有数值之和
57
+ # print("00000000000000000000000",annotation["area"])
58
+ annotations.append(annotation)
59
+ return annotations
60
+
61
+
62
+ def filter_masks(annotations): # filter the overlap mask
63
+ annotations.sort(key=lambda x: x["area"], reverse=True)
64
+ to_remove = set()
65
+ for i in range(0, len(annotations)):
66
+ a = annotations[i]
67
+ for j in range(i + 1, len(annotations)):
68
+ b = annotations[j]
69
+ if i != j and j not in to_remove:
70
+ # check if
71
+ if b["area"] < a["area"]:
72
+ if (a["segmentation"] & b["segmentation"]).sum() / b[
73
+ "segmentation"
74
+ ].sum() > 0.8:
75
+ to_remove.add(j)
76
+
77
+ return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
78
+
79
+
80
+ def get_bbox_from_mask(mask):
81
+ mask = mask.astype(np.uint8)
82
+ contours, hierarchy = cv2.findContours(
83
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
84
+ )
85
+ x1, y1, w, h = cv2.boundingRect(contours[0])
86
+ x2, y2 = x1 + w, y1 + h
87
+ if len(contours) > 1:
88
+ for b in contours:
89
+ x_t, y_t, w_t, h_t = cv2.boundingRect(b)
90
+ # 将多个bbox合并成一个
91
+ x1 = min(x1, x_t)
92
+ y1 = min(y1, y_t)
93
+ x2 = max(x2, x_t + w_t)
94
+ y2 = max(y2, y_t + h_t)
95
+ h = y2 - y1
96
+ w = x2 - x1
97
+ return [x1, y1, x2, y2]
98
+
99
+
100
+ def fast_process(
101
+ annotations, args, mask_random_color, bbox=None, points=None, edges=False
102
+ ):
103
+ if isinstance(annotations[0], dict):
104
+ annotations = [annotation["segmentation"] for annotation in annotations]
105
+ result_name = os.path.basename(args.img_path)
106
+ image = cv2.imread(args.img_path)
107
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
108
+ original_h = image.shape[0]
109
+ original_w = image.shape[1]
110
+ if sys.platform == "darwin":
111
+ plt.switch_backend("TkAgg")
112
+ plt.figure(figsize=(original_w / 100, original_h / 100))
113
+ # Add subplot with no margin.
114
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
115
+ plt.margins(0, 0)
116
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
117
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
118
+ plt.imshow(image)
119
+ if args.better_quality == True:
120
+ if isinstance(annotations[0], torch.Tensor):
121
+ annotations = np.array(annotations.cpu())
122
+ for i, mask in enumerate(annotations):
123
+ mask = cv2.morphologyEx(
124
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
125
+ )
126
+ annotations[i] = cv2.morphologyEx(
127
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
128
+ )
129
+ if args.device == "cpu":
130
+ annotations = np.array(annotations)
131
+ fast_show_mask(
132
+ annotations,
133
+ plt.gca(),
134
+ random_color=mask_random_color,
135
+ bbox=bbox,
136
+ points=points,
137
+ point_label=args.point_label,
138
+ retinamask=args.retina,
139
+ target_height=original_h,
140
+ target_width=original_w,
141
+ )
142
+ else:
143
+ if isinstance(annotations[0], np.ndarray):
144
+ annotations = torch.from_numpy(annotations)
145
+ fast_show_mask_gpu(
146
+ annotations,
147
+ plt.gca(),
148
+ random_color=args.randomcolor,
149
+ bbox=bbox,
150
+ points=points,
151
+ point_label=args.point_label,
152
+ retinamask=args.retina,
153
+ target_height=original_h,
154
+ target_width=original_w,
155
+ )
156
+ if isinstance(annotations, torch.Tensor):
157
+ annotations = annotations.cpu().numpy()
158
+ if args.withContours == True:
159
+ contour_all = []
160
+ temp = np.zeros((original_h, original_w, 1))
161
+ for i, mask in enumerate(annotations):
162
+ if type(mask) == dict:
163
+ mask = mask["segmentation"]
164
+ annotation = mask.astype(np.uint8)
165
+ if args.retina == False:
166
+ annotation = cv2.resize(
167
+ annotation,
168
+ (original_w, original_h),
169
+ interpolation=cv2.INTER_NEAREST,
170
+ )
171
+ contours, hierarchy = cv2.findContours(
172
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
173
+ )
174
+ for contour in contours:
175
+ contour_all.append(contour)
176
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
177
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
178
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
179
+ plt.imshow(contour_mask)
180
+
181
+ save_path = args.output
182
+ if not os.path.exists(save_path):
183
+ os.makedirs(save_path)
184
+ plt.axis("off")
185
+ fig = plt.gcf()
186
+ plt.draw()
187
+
188
+ try:
189
+ buf = fig.canvas.tostring_rgb()
190
+ except AttributeError:
191
+ fig.canvas.draw()
192
+ buf = fig.canvas.tostring_rgb()
193
+
194
+ cols, rows = fig.canvas.get_width_height()
195
+ img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
196
+ cv2.imwrite(
197
+ os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
198
+ )
199
+
200
+
201
+ # CPU post process
202
+ def fast_show_mask(
203
+ annotation,
204
+ ax,
205
+ random_color=False,
206
+ bbox=None,
207
+ points=None,
208
+ point_label=None,
209
+ retinamask=True,
210
+ target_height=960,
211
+ target_width=960,
212
+ ):
213
+ msak_sum = annotation.shape[0]
214
+ height = annotation.shape[1]
215
+ weight = annotation.shape[2]
216
+ # 将annotation 按照面积 排序
217
+ areas = np.sum(annotation, axis=(1, 2))
218
+ sorted_indices = np.argsort(areas)
219
+ annotation = annotation[sorted_indices]
220
+
221
+ index = (annotation != 0).argmax(axis=0)
222
+ if random_color == True:
223
+ color = np.random.random((msak_sum, 1, 1, 3))
224
+ else:
225
+ color = np.ones((msak_sum, 1, 1, 3)) * np.array(
226
+ [30 / 255, 144 / 255, 255 / 255]
227
+ )
228
+ transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
229
+ visual = np.concatenate([color, transparency], axis=-1)
230
+ mask_image = np.expand_dims(annotation, -1) * visual
231
+
232
+ show = np.zeros((height, weight, 4))
233
+ h_indices, w_indices = np.meshgrid(
234
+ np.arange(height), np.arange(weight), indexing="ij"
235
+ )
236
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
237
+ # 使用向量化索引更新show的值
238
+ show[h_indices, w_indices, :] = mask_image[indices]
239
+ if bbox is not None:
240
+ x1, y1, x2, y2 = bbox
241
+ ax.add_patch(
242
+ plt.Rectangle(
243
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
244
+ )
245
+ )
246
+ # draw point
247
+ if points is not None:
248
+ plt.scatter(
249
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
250
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
251
+ s=20,
252
+ c="y",
253
+ )
254
+ plt.scatter(
255
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
256
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
257
+ s=20,
258
+ c="m",
259
+ )
260
+
261
+ if retinamask == False:
262
+ show = cv2.resize(
263
+ show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
264
+ )
265
+ ax.imshow(show)
266
+
267
+
268
+ def fast_show_mask_gpu(
269
+ annotation,
270
+ ax,
271
+ random_color=False,
272
+ bbox=None,
273
+ points=None,
274
+ point_label=None,
275
+ retinamask=True,
276
+ target_height=960,
277
+ target_width=960,
278
+ ):
279
+ msak_sum = annotation.shape[0]
280
+ height = annotation.shape[1]
281
+ weight = annotation.shape[2]
282
+ areas = torch.sum(annotation, dim=(1, 2))
283
+ sorted_indices = torch.argsort(areas, descending=False)
284
+ annotation = annotation[sorted_indices]
285
+ # 找每个位置第一个非零值下标
286
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
287
+ if random_color == True:
288
+ color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
289
+ else:
290
+ color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor(
291
+ [30 / 255, 144 / 255, 255 / 255]
292
+ ).to(annotation.device)
293
+ transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
294
+ visual = torch.cat([color, transparency], dim=-1)
295
+ mask_image = torch.unsqueeze(annotation, -1) * visual
296
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
297
+ show = torch.zeros((height, weight, 4)).to(annotation.device)
298
+ h_indices, w_indices = torch.meshgrid(
299
+ torch.arange(height), torch.arange(weight), indexing="ij"
300
+ )
301
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
302
+ # 使用向量化索引更新show的值
303
+ show[h_indices, w_indices, :] = mask_image[indices]
304
+ show_cpu = show.cpu().numpy()
305
+ if bbox is not None:
306
+ x1, y1, x2, y2 = bbox
307
+ ax.add_patch(
308
+ plt.Rectangle(
309
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
310
+ )
311
+ )
312
+ # draw point
313
+ if points is not None:
314
+ plt.scatter(
315
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
316
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
317
+ s=20,
318
+ c="y",
319
+ )
320
+ plt.scatter(
321
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
322
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
323
+ s=20,
324
+ c="m",
325
+ )
326
+ if retinamask == False:
327
+ show_cpu = cv2.resize(
328
+ show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
329
+ )
330
+ ax.imshow(show_cpu)
331
+
332
+
333
+ def crop_image(annotations, image_like):
334
+ if isinstance(image_like, str):
335
+ image = Image.open(image_like)
336
+ else:
337
+ image = image_like
338
+ ori_w, ori_h = image.size
339
+ mask_h, mask_w = annotations[0]["segmentation"].shape
340
+ if ori_w != mask_w or ori_h != mask_h:
341
+ image = image.resize((mask_w, mask_h))
342
+ cropped_boxes = []
343
+ cropped_images = []
344
+ not_crop = []
345
+ filter_id = []
346
+ # annotations, _ = filter_masks(annotations)
347
+ # filter_id = list(_)
348
+ for _, mask in enumerate(annotations):
349
+ if np.sum(mask["segmentation"]) <= 100:
350
+ filter_id.append(_)
351
+ continue
352
+ bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
353
+ cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
354
+ # cropped_boxes.append(segment_image(image,mask["segmentation"]))
355
+ cropped_images.append(bbox) # 保存裁剪的图片的bbox
356
+
357
+ return cropped_boxes, cropped_images, not_crop, filter_id, annotations
358
+
359
+
360
+ def box_prompt(masks, bbox, target_height, target_width):
361
+ h = masks.shape[1]
362
+ w = masks.shape[2]
363
+ if h != target_height or w != target_width:
364
+ bbox = [
365
+ int(bbox[0] * w / target_width),
366
+ int(bbox[1] * h / target_height),
367
+ int(bbox[2] * w / target_width),
368
+ int(bbox[3] * h / target_height),
369
+ ]
370
+ bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
371
+ bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
372
+ bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
373
+ bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
374
+
375
+ # IoUs = torch.zeros(len(masks), dtype=torch.float32)
376
+ bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
377
+
378
+ masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
379
+ orig_masks_area = torch.sum(masks, dim=(1, 2))
380
+
381
+ union = bbox_area + orig_masks_area - masks_area
382
+ IoUs = masks_area / union
383
+ max_iou_index = torch.argmax(IoUs)
384
+
385
+ return masks[max_iou_index].cpu().numpy(), max_iou_index
386
+
387
+
388
+ def point_prompt(masks, points, point_label, target_height, target_width): # numpy 处理
389
+ h = masks[0]["segmentation"].shape[0]
390
+ w = masks[0]["segmentation"].shape[1]
391
+ if h != target_height or w != target_width:
392
+ print(".................................................")
393
+ points = [
394
+ [int(point[0] * w / target_width), int(point[1] * h / target_height)]
395
+ for point in points
396
+ ]
397
+ onemask = np.zeros((h, w)) # (685, 1024)
398
+ for i, annotation in enumerate(masks):
399
+ if type(annotation) == dict:
400
+ # print("22222222222222222222222222222") 3
401
+ mask = annotation["segmentation"]
402
+ else:
403
+ # print("333333333333333333333333333333")
404
+ mask = annotation
405
+ for i, point in enumerate(points):
406
+ # print("llllllllllllllllllllllllllllllllllllll",i,point_label)
407
+ if mask[point[1], point[0]] == 1 and point_label[i] == 1:
408
+ onemask += mask
409
+ if mask[point[1], point[0]] == 1 and point_label[i] == 0:
410
+ onemask -= mask
411
+ onemask = onemask >= 1
412
+ return onemask, 0
utils/tools_gradio.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+
7
+
8
+ def fast_process(
9
+ annotations,
10
+ image,
11
+ device,
12
+ scale,
13
+ better_quality=False,
14
+ mask_random_color=True,
15
+ bbox=None,
16
+ use_retina=True,
17
+ withContours=True,
18
+ ):
19
+ if isinstance(annotations[0], dict):
20
+ annotations = [annotation["segmentation"] for annotation in annotations]
21
+
22
+ original_h = image.height
23
+ original_w = image.width
24
+ if better_quality:
25
+ if isinstance(annotations[0], torch.Tensor):
26
+ annotations = np.array(annotations.cpu())
27
+ for i, mask in enumerate(annotations):
28
+ mask = cv2.morphologyEx(
29
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
30
+ )
31
+ annotations[i] = cv2.morphologyEx(
32
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
33
+ )
34
+ if device == "cpu":
35
+ annotations = np.array(annotations)
36
+ inner_mask = fast_show_mask(
37
+ annotations,
38
+ plt.gca(),
39
+ random_color=mask_random_color,
40
+ bbox=bbox,
41
+ retinamask=use_retina,
42
+ target_height=original_h,
43
+ target_width=original_w,
44
+ )
45
+ else:
46
+ if isinstance(annotations[0], np.ndarray):
47
+ annotations = np.array(annotations)
48
+ annotations = torch.from_numpy(annotations)
49
+ inner_mask = fast_show_mask_gpu(
50
+ annotations,
51
+ plt.gca(),
52
+ random_color=mask_random_color,
53
+ bbox=bbox,
54
+ retinamask=use_retina,
55
+ target_height=original_h,
56
+ target_width=original_w,
57
+ )
58
+ if isinstance(annotations, torch.Tensor):
59
+ annotations = annotations.cpu().numpy()
60
+
61
+ if withContours:
62
+ contour_all = []
63
+ temp = np.zeros((original_h, original_w, 1))
64
+ for i, mask in enumerate(annotations):
65
+ if type(mask) == dict:
66
+ mask = mask["segmentation"]
67
+ annotation = mask.astype(np.uint8)
68
+ if use_retina == False:
69
+ annotation = cv2.resize(
70
+ annotation,
71
+ (original_w, original_h),
72
+ interpolation=cv2.INTER_NEAREST,
73
+ )
74
+ contours, _ = cv2.findContours(
75
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
76
+ )
77
+ for contour in contours:
78
+ contour_all.append(contour)
79
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
80
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
81
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
82
+
83
+ image = image.convert("RGBA")
84
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), "RGBA")
85
+ image.paste(overlay_inner, (0, 0), overlay_inner)
86
+
87
+ if withContours:
88
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), "RGBA")
89
+ image.paste(overlay_contour, (0, 0), overlay_contour)
90
+
91
+ return image
92
+
93
+
94
+ # CPU post process
95
+ def fast_show_mask(
96
+ annotation,
97
+ ax,
98
+ random_color=False,
99
+ bbox=None,
100
+ retinamask=True,
101
+ target_height=960,
102
+ target_width=960,
103
+ ):
104
+ mask_sum = annotation.shape[0]
105
+ height = annotation.shape[1]
106
+ weight = annotation.shape[2]
107
+ # 将annotation 按照面积 排序
108
+ areas = np.sum(annotation, axis=(1, 2))
109
+ sorted_indices = np.argsort(areas)[::1]
110
+ annotation = annotation[sorted_indices]
111
+
112
+ index = (annotation != 0).argmax(axis=0)
113
+ if random_color == True:
114
+ color = np.random.random((mask_sum, 1, 1, 3))
115
+ else:
116
+ color = np.ones((mask_sum, 1, 1, 3)) * np.array(
117
+ [30 / 255, 144 / 255, 255 / 255]
118
+ )
119
+ transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
120
+ visual = np.concatenate([color, transparency], axis=-1)
121
+ mask_image = np.expand_dims(annotation, -1) * visual
122
+
123
+ mask = np.zeros((height, weight, 4))
124
+
125
+ h_indices, w_indices = np.meshgrid(
126
+ np.arange(height), np.arange(weight), indexing="ij"
127
+ )
128
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
129
+
130
+ mask[h_indices, w_indices, :] = mask_image[indices]
131
+ if bbox is not None:
132
+ x1, y1, x2, y2 = bbox
133
+ ax.add_patch(
134
+ plt.Rectangle(
135
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
136
+ )
137
+ )
138
+
139
+ if retinamask == False:
140
+ mask = cv2.resize(
141
+ mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST
142
+ )
143
+
144
+ return mask
145
+
146
+
147
+ def fast_show_mask_gpu(
148
+ annotation,
149
+ ax,
150
+ random_color=False,
151
+ bbox=None,
152
+ retinamask=True,
153
+ target_height=960,
154
+ target_width=960,
155
+ ):
156
+ device = annotation.device
157
+ mask_sum = annotation.shape[0]
158
+ height = annotation.shape[1]
159
+ weight = annotation.shape[2]
160
+ areas = torch.sum(annotation, dim=(1, 2))
161
+ sorted_indices = torch.argsort(areas, descending=False)
162
+ annotation = annotation[sorted_indices]
163
+ # 找每个位置第一个非零值下标
164
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
165
+ if random_color == True:
166
+ color = torch.rand((mask_sum, 1, 1, 3)).to(device)
167
+ else:
168
+ color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
169
+ [30 / 255, 144 / 255, 255 / 255]
170
+ ).to(device)
171
+ transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
172
+ visual = torch.cat([color, transparency], dim=-1)
173
+ mask_image = torch.unsqueeze(annotation, -1) * visual
174
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
175
+ mask = torch.zeros((height, weight, 4)).to(device)
176
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
177
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
178
+ # 使用向量化索引更新show的值
179
+ mask[h_indices, w_indices, :] = mask_image[indices]
180
+ mask_cpu = mask.cpu().numpy()
181
+ if bbox is not None:
182
+ x1, y1, x2, y2 = bbox
183
+ ax.add_patch(
184
+ plt.Rectangle(
185
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
186
+ )
187
+ )
188
+ if retinamask == False:
189
+ mask_cpu = cv2.resize(
190
+ mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
191
+ )
192
+ return mask_cpu