AAAAAAAyq commited on
Commit
9951234
·
1 Parent(s): 766f95f

Add text mode

Browse files
.gitignore ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ *.pyc
2
+ *.pyo
3
+ *.pyd
4
+ .DS_Store
5
+ .idea
6
+ weights
7
+ gradio_cached_examples
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: pink
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 3.35.2
8
- app_file: app.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
 
5
  colorTo: indigo
6
  sdk: gradio
7
  sdk_version: 3.35.2
8
+ app_file: app_gradio.py
9
  pinned: false
10
  license: apache-2.0
11
  ---
__pycache__/tools.cpython-39.pyc DELETED
Binary file (8.4 kB)
 
app.py → app_gradio.py RENAMED
@@ -1,24 +1,31 @@
1
  from ultralytics import YOLO
2
  import gradio as gr
3
  import torch
4
- from tools import fast_process, format_results, box_prompt, point_prompt
 
5
  from PIL import ImageDraw
6
  import numpy as np
7
 
8
  # Load the pre-trained model
9
- model = YOLO('checkpoints/FastSAM.pt')
10
 
11
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
 
12
 
13
  # Description
14
  title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
15
 
16
  news = """ # 📖 News
 
17
 
18
- 🔥 2023/06/24: Add the 'Advanced options" in Everything mode to get a more detailed adjustment.
19
-
20
  🔥 2023/06/26: Support the points mode. (Better and faster interaction will come soon!)
21
-
 
22
  """
23
 
24
  description_e = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it.
@@ -52,8 +59,8 @@ description_p = """ # 🎯 Instructions for points mode
52
 
53
  """
54
 
55
- examples = [["assets/sa_8776.jpg"], ["assets/sa_414.jpg"], ["assets/sa_1309.jpg"], ["assets/sa_11025.jpg"],
56
- ["assets/sa_561.jpg"], ["assets/sa_192.jpg"], ["assets/sa_10039.jpg"], ["assets/sa_862.jpg"]]
57
 
58
  default_example = examples[0]
59
 
@@ -68,10 +75,10 @@ def segment_everything(
68
  better_quality=False,
69
  withContours=True,
70
  use_retina=True,
 
71
  mask_random_color=True,
72
- ):
73
  input_size = int(input_size) # 确保 imgsz 是整数
74
-
75
  # Thanks for the suggestion by hysts in HuggingFace.
76
  w, h = input.size
77
  scale = input_size / max(w, h)
@@ -85,18 +92,26 @@ def segment_everything(
85
  iou=iou_threshold,
86
  conf=conf_threshold,
87
  imgsz=input_size,)
 
 
 
 
 
 
 
88
 
89
- fig = fast_process(annotations=results[0].masks.data,
90
- image=input,
91
- device=device,
92
- scale=(1024 // input_size),
93
- better_quality=better_quality,
94
- mask_random_color=mask_random_color,
95
- bbox=None,
96
- use_retina=use_retina,
97
- withContours=withContours,)
98
  return fig
99
 
 
100
  def segment_with_points(
101
  input,
102
  input_size=1024,
@@ -104,9 +119,9 @@ def segment_with_points(
104
  conf_threshold=0.25,
105
  better_quality=False,
106
  withContours=True,
107
- mask_random_color=True,
108
  use_retina=True,
109
- ):
 
110
  global global_points
111
  global global_point_label
112
 
@@ -128,56 +143,51 @@ def segment_with_points(
128
  imgsz=input_size,)
129
 
130
  results = format_results(results[0], 0)
131
-
132
  annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
133
  annotations = np.array([annotations])
134
-
135
  fig = fast_process(annotations=annotations,
136
- image=input,
137
- device=device,
138
- scale=(1024 // input_size),
139
- better_quality=better_quality,
140
- mask_random_color=mask_random_color,
141
- bbox=None,
142
- use_retina=use_retina,
143
- withContours=withContours,)
 
144
  global_points = []
145
  global_point_label = []
146
  return fig, None
147
 
 
148
  def get_points_with_draw(image, label, evt: gr.SelectData):
149
- x, y = evt.index[0], evt.index[1]
150
- point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
151
  global global_points
152
  global global_point_label
153
- print((x, y))
 
 
154
  global_points.append([x, y])
155
  global_point_label.append(1 if label == 'Add Mask' else 0)
156
 
 
 
157
  # 创建一个可以在图像上绘图的对象
158
  draw = ImageDraw.Draw(image)
159
  draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
160
  return image
161
-
162
 
163
- # input_size=1024
164
- # high_quality_visual=True
165
- # inp = 'assets/sa_192.jpg'
166
- # input = Image.open(inp)
167
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
168
- # input_size = int(input_size) # 确保 imgsz 是整数
169
- # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
170
- # pil_image = fast_process(annotations=results[0].masks.data,
171
- # image=input, high_quality=high_quality_visual, device=device)
172
 
173
  cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
174
  cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil')
 
175
 
176
  segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')
177
  segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil')
 
178
 
179
  global_points = []
180
- global_point_label = [] # TODO:Clear points each image
181
 
182
  input_size_slider = gr.components.Slider(minimum=512,
183
  maximum=1024,
@@ -188,14 +198,14 @@ input_size_slider = gr.components.Slider(minimum=512,
188
 
189
  with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
190
  with gr.Row():
191
- with gr.Column(scale=1):
192
- # Title
193
- gr.Markdown(title)
194
-
195
- with gr.Column(scale=1):
196
- # News
197
- gr.Markdown(news)
198
-
199
  with gr.Tab("Everything mode"):
200
  # Images
201
  with gr.Row(variant="panel"):
@@ -227,13 +237,13 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
227
 
228
  with gr.Column():
229
  with gr.Accordion("Advanced options", open=False):
 
230
  iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
231
  conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
232
  with gr.Row():
233
  mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
234
  with gr.Column():
235
  retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
236
-
237
  # Description
238
  gr.Markdown(description_e)
239
 
@@ -259,30 +269,102 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
259
  gr.Markdown("Try some of the examples below ⬇️")
260
  gr.Examples(examples=examples,
261
  inputs=[cond_img_p],
262
- outputs=segm_img_p,
263
- fn=segment_with_points,
264
  # cache_examples=True,
265
  examples_per_page=4)
266
 
267
  with gr.Column():
268
  # Description
269
  gr.Markdown(description_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
272
 
273
  segment_btn_e.click(segment_everything,
274
- inputs=[cond_img_e, input_size_slider, iou_threshold, conf_threshold, mor_check, contour_check, retina_check],
275
- outputs=segm_img_e)
276
-
 
 
 
 
 
 
 
 
277
  segment_btn_p.click(segment_with_points,
278
- inputs=[cond_img_p],
279
- outputs=[segm_img_p, cond_img_p])
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  def clear():
282
  return None, None
283
 
 
 
 
284
  clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
285
  clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
 
286
 
287
  demo.queue()
288
  demo.launch()
 
1
  from ultralytics import YOLO
2
  import gradio as gr
3
  import torch
4
+ from utils.tools_gradio import fast_process
5
+ from utils.tools import format_results, box_prompt, point_prompt, text_prompt
6
  from PIL import ImageDraw
7
  import numpy as np
8
 
9
  # Load the pre-trained model
10
+ model = YOLO('./weights/FastSAM.pt')
11
 
12
+ device = torch.device(
13
+ "cuda"
14
+ if torch.cuda.is_available()
15
+ else "mps"
16
+ if torch.backends.mps.is_available()
17
+ else "cpu"
18
+ )
19
 
20
  # Description
21
  title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
22
 
23
  news = """ # 📖 News
24
+ 🔥 2023/06/29: Support the text mode (Thanks for [gaoxinge](https://github.com/CASIA-IVA-Lab/FastSAM/pull/47)).
25
 
 
 
26
  🔥 2023/06/26: Support the points mode. (Better and faster interaction will come soon!)
27
+
28
+ 🔥 2023/06/24: Add the 'Advanced options" in Everything mode to get a more detailed adjustment.
29
  """
30
 
31
  description_e = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM). Welcome to give a star ⭐️ to it.
 
59
 
60
  """
61
 
62
+ examples = [["examples/sa_8776.jpg"], ["examples/sa_414.jpg"], ["examples/sa_1309.jpg"], ["examples/sa_11025.jpg"],
63
+ ["examples/sa_561.jpg"], ["examples/sa_192.jpg"], ["examples/sa_10039.jpg"], ["examples/sa_862.jpg"]]
64
 
65
  default_example = examples[0]
66
 
 
75
  better_quality=False,
76
  withContours=True,
77
  use_retina=True,
78
+ text="",
79
  mask_random_color=True,
80
+ ):
81
  input_size = int(input_size) # 确保 imgsz 是整数
 
82
  # Thanks for the suggestion by hysts in HuggingFace.
83
  w, h = input.size
84
  scale = input_size / max(w, h)
 
92
  iou=iou_threshold,
93
  conf=conf_threshold,
94
  imgsz=input_size,)
95
+
96
+ if len(text) > 0:
97
+ results = format_results(results[0], 0)
98
+ annotations, _ = text_prompt(results, text, input, device=device)
99
+ annotations = np.array([annotations])
100
+ else:
101
+ annotations = results[0].masks.data
102
 
103
+ fig = fast_process(annotations=annotations,
104
+ image=input,
105
+ device=device,
106
+ scale=(1024 // input_size),
107
+ better_quality=better_quality,
108
+ mask_random_color=mask_random_color,
109
+ bbox=None,
110
+ use_retina=use_retina,
111
+ withContours=withContours,)
112
  return fig
113
 
114
+
115
  def segment_with_points(
116
  input,
117
  input_size=1024,
 
119
  conf_threshold=0.25,
120
  better_quality=False,
121
  withContours=True,
 
122
  use_retina=True,
123
+ mask_random_color=True,
124
+ ):
125
  global global_points
126
  global global_point_label
127
 
 
143
  imgsz=input_size,)
144
 
145
  results = format_results(results[0], 0)
 
146
  annotations, _ = point_prompt(results, scaled_points, global_point_label, new_h, new_w)
147
  annotations = np.array([annotations])
148
+
149
  fig = fast_process(annotations=annotations,
150
+ image=input,
151
+ device=device,
152
+ scale=(1024 // input_size),
153
+ better_quality=better_quality,
154
+ mask_random_color=mask_random_color,
155
+ bbox=None,
156
+ use_retina=use_retina,
157
+ withContours=withContours,)
158
+
159
  global_points = []
160
  global_point_label = []
161
  return fig, None
162
 
163
+
164
  def get_points_with_draw(image, label, evt: gr.SelectData):
 
 
165
  global global_points
166
  global global_point_label
167
+
168
+ x, y = evt.index[0], evt.index[1]
169
+ point_radius, point_color = 15, (255, 255, 0) if label == 'Add Mask' else (255, 0, 255)
170
  global_points.append([x, y])
171
  global_point_label.append(1 if label == 'Add Mask' else 0)
172
 
173
+ print(x, y, label == 'Add Mask')
174
+
175
  # 创建一个可以在图像上绘图的对象
176
  draw = ImageDraw.Draw(image)
177
  draw.ellipse([(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)], fill=point_color)
178
  return image
 
179
 
 
 
 
 
 
 
 
 
 
180
 
181
  cond_img_e = gr.Image(label="Input", value=default_example[0], type='pil')
182
  cond_img_p = gr.Image(label="Input with points", value=default_example[0], type='pil')
183
+ cond_img_t = gr.Image(label="Input with text", value="examples/dogs.jpg", type='pil')
184
 
185
  segm_img_e = gr.Image(label="Segmented Image", interactive=False, type='pil')
186
  segm_img_p = gr.Image(label="Segmented Image with points", interactive=False, type='pil')
187
+ segm_img_t = gr.Image(label="Segmented Image with text", interactive=False, type='pil')
188
 
189
  global_points = []
190
+ global_point_label = []
191
 
192
  input_size_slider = gr.components.Slider(minimum=512,
193
  maximum=1024,
 
198
 
199
  with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
200
  with gr.Row():
201
+ with gr.Column(scale=1):
202
+ # Title
203
+ gr.Markdown(title)
204
+
205
+ with gr.Column(scale=1):
206
+ # News
207
+ gr.Markdown(news)
208
+
209
  with gr.Tab("Everything mode"):
210
  # Images
211
  with gr.Row(variant="panel"):
 
237
 
238
  with gr.Column():
239
  with gr.Accordion("Advanced options", open=False):
240
+ # text_box = gr.Textbox(label="text prompt")
241
  iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
242
  conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
243
  with gr.Row():
244
  mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
245
  with gr.Column():
246
  retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
 
247
  # Description
248
  gr.Markdown(description_e)
249
 
 
269
  gr.Markdown("Try some of the examples below ⬇️")
270
  gr.Examples(examples=examples,
271
  inputs=[cond_img_p],
272
+ # outputs=segm_img_p,
273
+ # fn=segment_with_points,
274
  # cache_examples=True,
275
  examples_per_page=4)
276
 
277
  with gr.Column():
278
  # Description
279
  gr.Markdown(description_p)
280
+
281
+ with gr.Tab("Text mode"):
282
+ # Images
283
+ with gr.Row(variant="panel"):
284
+ with gr.Column(scale=1):
285
+ cond_img_t.render()
286
+
287
+ with gr.Column(scale=1):
288
+ segm_img_t.render()
289
+
290
+ # Submit & Clear
291
+ with gr.Row():
292
+ with gr.Column():
293
+ input_size_slider_t = gr.components.Slider(minimum=512,
294
+ maximum=1024,
295
+ value=1024,
296
+ step=64,
297
+ label='Input_size',
298
+ info='Our model was trained on a size of 1024')
299
+ with gr.Row():
300
+ with gr.Column():
301
+ contour_check = gr.Checkbox(value=True, label='withContours', info='draw the edges of the masks')
302
+ text_box = gr.Textbox(label="text prompt", value="a black dog")
303
+
304
+ with gr.Column():
305
+ segment_btn_t = gr.Button("Segment with text", variant='primary')
306
+ clear_btn_t = gr.Button("Clear", variant="secondary")
307
+
308
+ gr.Markdown("Try some of the examples below ⬇️")
309
+ gr.Examples(examples=["examples/dogs.jpg"],
310
+ inputs=[cond_img_e],
311
+ # outputs=segm_img_e,
312
+ # fn=segment_everything,
313
+ # cache_examples=True,
314
+ examples_per_page=4)
315
+
316
+ with gr.Column():
317
+ with gr.Accordion("Advanced options", open=False):
318
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou', info='iou threshold for filtering the annotations')
319
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf', info='object confidence threshold')
320
+ with gr.Row():
321
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality', info='better quality using morphologyEx')
322
+ with gr.Column():
323
+ retina_check = gr.Checkbox(value=True, label='use_retina', info='draw high-resolution segmentation masks')
324
+
325
+ # Description
326
+ gr.Markdown(description_e)
327
 
328
  cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
329
 
330
  segment_btn_e.click(segment_everything,
331
+ inputs=[
332
+ cond_img_e,
333
+ input_size_slider,
334
+ iou_threshold,
335
+ conf_threshold,
336
+ mor_check,
337
+ contour_check,
338
+ retina_check,
339
+ ],
340
+ outputs=segm_img_e)
341
+
342
  segment_btn_p.click(segment_with_points,
343
+ inputs=[cond_img_p],
344
+ outputs=[segm_img_p, cond_img_p])
345
 
346
+ segment_btn_t.click(segment_everything,
347
+ inputs=[
348
+ cond_img_t,
349
+ input_size_slider_t,
350
+ iou_threshold,
351
+ conf_threshold,
352
+ mor_check,
353
+ contour_check,
354
+ retina_check,
355
+ text_box,
356
+ ],
357
+ outputs=segm_img_t)
358
+
359
  def clear():
360
  return None, None
361
 
362
+ def clear_text():
363
+ return None, None, None
364
+
365
  clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
366
  clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
367
+ clear_btn_t.click(clear_text, outputs=[cond_img_p, segm_img_p, text_box])
368
 
369
  demo.queue()
370
  demo.launch()
checkpoints/FastSAM.pt → examples/dogs.jpg RENAMED
File without changes
{assets → examples}/sa_10039.jpg RENAMED
File without changes
{assets → examples}/sa_11025.jpg RENAMED
File without changes
{assets → examples}/sa_1309.jpg RENAMED
File without changes
{assets → examples}/sa_192.jpg RENAMED
File without changes
{assets → examples}/sa_414.jpg RENAMED
File without changes
{assets → examples}/sa_561.jpg RENAMED
File without changes
{assets → examples}/sa_862.jpg RENAMED
File without changes
{assets → examples}/sa_8776.jpg RENAMED
File without changes
requirements.txt CHANGED
@@ -2,6 +2,7 @@
2
  matplotlib==3.2.2
3
  numpy
4
  opencv-python
 
5
  # Pillow>=7.1.2
6
  # PyYAML>=5.3.1
7
  # requests>=2.23.0
 
2
  matplotlib==3.2.2
3
  numpy
4
  opencv-python
5
+ clip>=0.2.0
6
  # Pillow>=7.1.2
7
  # PyYAML>=5.3.1
8
  # requests>=2.23.0
utils/__init__.py ADDED
File without changes
tools.py → utils/tools.py RENAMED
@@ -3,7 +3,9 @@ from PIL import Image
3
  import matplotlib.pyplot as plt
4
  import cv2
5
  import torch
6
- # import clip
 
 
7
 
8
 
9
  def convert_box_xywh_to_xyxy(box):
@@ -49,7 +51,7 @@ def format_results(result, filter=0):
49
  return annotations
50
 
51
 
52
- def filter_masks(annotations): # filte the overlap mask
53
  annotations.sort(key=lambda x: x["area"], reverse=True)
54
  to_remove = set()
55
  for i in range(0, len(annotations)):
@@ -86,126 +88,171 @@ def get_bbox_from_mask(mask):
86
  w = x2 - x1
87
  return [x1, y1, x2, y2]
88
 
 
89
  def fast_process(
90
- annotations,
91
- image,
92
- device,
93
- scale,
94
- better_quality=False,
95
- mask_random_color=True,
96
- bbox=None,
97
- use_retina=True,
98
- withContours=True,
99
- ):
100
  if isinstance(annotations[0], dict):
101
- annotations = [annotation['segmentation'] for annotation in annotations]
102
-
103
- original_h = image.height
104
- original_w = image.width
105
- if better_quality:
 
 
 
 
 
 
 
 
 
 
 
106
  if isinstance(annotations[0], torch.Tensor):
107
  annotations = np.array(annotations.cpu())
108
  for i, mask in enumerate(annotations):
109
- mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
110
- annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
111
- if device == 'cpu':
 
 
 
 
112
  annotations = np.array(annotations)
113
- inner_mask = fast_show_mask(
114
  annotations,
115
  plt.gca(),
116
  random_color=mask_random_color,
117
  bbox=bbox,
118
- retinamask=use_retina,
 
 
119
  target_height=original_h,
120
  target_width=original_w,
121
  )
122
  else:
123
  if isinstance(annotations[0], np.ndarray):
124
  annotations = torch.from_numpy(annotations)
125
- inner_mask = fast_show_mask_gpu(
126
  annotations,
127
  plt.gca(),
128
- random_color=mask_random_color,
129
  bbox=bbox,
130
- retinamask=use_retina,
 
 
131
  target_height=original_h,
132
  target_width=original_w,
133
  )
134
  if isinstance(annotations, torch.Tensor):
135
  annotations = annotations.cpu().numpy()
136
-
137
- if withContours:
138
  contour_all = []
139
  temp = np.zeros((original_h, original_w, 1))
140
  for i, mask in enumerate(annotations):
141
  if type(mask) == dict:
142
- mask = mask['segmentation']
143
  annotation = mask.astype(np.uint8)
144
- if use_retina == False:
145
  annotation = cv2.resize(
146
  annotation,
147
  (original_w, original_h),
148
  interpolation=cv2.INTER_NEAREST,
149
  )
150
- contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
 
 
151
  for contour in contours:
152
  contour_all.append(contour)
153
- cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
154
- color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
155
  contour_mask = temp / 255 * color.reshape(1, 1, -1)
156
-
157
- image = image.convert('RGBA')
158
- overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
159
- image.paste(overlay_inner, (0, 0), overlay_inner)
160
-
161
- if withContours:
162
- overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
163
- image.paste(overlay_contour, (0, 0), overlay_contour)
164
-
165
- return image
166
-
167
-
168
- # CPU post process
 
 
 
 
 
 
 
 
169
  def fast_show_mask(
170
  annotation,
171
  ax,
172
  random_color=False,
173
  bbox=None,
 
 
174
  retinamask=True,
175
  target_height=960,
176
  target_width=960,
177
  ):
178
- mask_sum = annotation.shape[0]
179
  height = annotation.shape[1]
180
  weight = annotation.shape[2]
181
  # 将annotation 按照面积 排序
182
  areas = np.sum(annotation, axis=(1, 2))
183
- sorted_indices = np.argsort(areas)[::1]
184
  annotation = annotation[sorted_indices]
185
 
186
  index = (annotation != 0).argmax(axis=0)
187
  if random_color == True:
188
- color = np.random.random((mask_sum, 1, 1, 3))
189
  else:
190
- color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
191
- transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
 
 
192
  visual = np.concatenate([color, transparency], axis=-1)
193
  mask_image = np.expand_dims(annotation, -1) * visual
194
 
195
- mask = np.zeros((height, weight, 4))
196
-
197
- h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
 
198
  indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
199
-
200
- mask[h_indices, w_indices, :] = mask_image[indices]
201
  if bbox is not None:
202
  x1, y1, x2, y2 = bbox
203
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
205
  if retinamask == False:
206
- mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
207
-
208
- return mask
 
209
 
210
 
211
  def fast_show_mask_gpu(
@@ -213,12 +260,13 @@ def fast_show_mask_gpu(
213
  ax,
214
  random_color=False,
215
  bbox=None,
 
 
216
  retinamask=True,
217
  target_height=960,
218
  target_width=960,
219
  ):
220
- device = annotation.device
221
- mask_sum = annotation.shape[0]
222
  height = annotation.shape[1]
223
  weight = annotation.shape[2]
224
  areas = torch.sum(annotation, dim=(1, 2))
@@ -227,21 +275,23 @@ def fast_show_mask_gpu(
227
  # 找每个位置第一个非零值下标
228
  index = (annotation != 0).to(torch.long).argmax(dim=0)
229
  if random_color == True:
230
- color = torch.rand((mask_sum, 1, 1, 3)).to(device)
231
  else:
232
- color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
233
  [30 / 255, 144 / 255, 255 / 255]
234
- ).to(device)
235
- transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
236
  visual = torch.cat([color, transparency], dim=-1)
237
  mask_image = torch.unsqueeze(annotation, -1) * visual
238
  # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
239
- mask = torch.zeros((height, weight, 4)).to(device)
240
- h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
 
 
241
  indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
242
  # 使用向量化索引更新show的值
243
- mask[h_indices, w_indices, :] = mask_image[indices]
244
- mask_cpu = mask.cpu().numpy()
245
  if bbox is not None:
246
  x1, y1, x2, y2 = bbox
247
  ax.add_patch(
@@ -249,31 +299,48 @@ def fast_show_mask_gpu(
249
  (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
250
  )
251
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  if retinamask == False:
253
- mask_cpu = cv2.resize(
254
- mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
255
  )
256
- return mask_cpu
257
-
258
-
259
- # # clip
260
- # @torch.no_grad()
261
- # def retriev(
262
- # model, preprocess, elements, search_text: str, device
263
- # ) -> int:
264
- # preprocessed_images = [preprocess(image).to(device) for image in elements]
265
- # tokenized_text = clip.tokenize([search_text]).to(device)
266
- # stacked_images = torch.stack(preprocessed_images)
267
- # image_features = model.encode_image(stacked_images)
268
- # text_features = model.encode_text(tokenized_text)
269
- # image_features /= image_features.norm(dim=-1, keepdim=True)
270
- # text_features /= text_features.norm(dim=-1, keepdim=True)
271
- # probs = 100.0 * image_features @ text_features.T
272
- # return probs[:, 0].softmax(dim=0)
273
-
274
-
275
- def crop_image(annotations, image_path):
276
- image = Image.open(image_path)
 
 
 
277
  ori_w, ori_h = image.size
278
  mask_h, mask_w = annotations[0]["segmentation"].shape
279
  if ori_w != mask_w or ori_h != mask_h:
@@ -324,7 +391,7 @@ def box_prompt(masks, bbox, target_height, target_width):
324
  return masks[max_iou_index].cpu().numpy(), max_iou_index
325
 
326
 
327
- def point_prompt(masks, points, pointlabel, target_height, target_width): # numpy 处理
328
  h = masks[0]["segmentation"].shape[0]
329
  w = masks[0]["segmentation"].shape[1]
330
  if h != target_height or w != target_width:
@@ -339,23 +406,23 @@ def point_prompt(masks, points, pointlabel, target_height, target_width): # num
339
  else:
340
  mask = annotation
341
  for i, point in enumerate(points):
342
- if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
343
  onemask += mask
344
- if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
345
  onemask -= mask
346
  onemask = onemask >= 1
347
  return onemask, 0
348
 
349
 
350
- # def text_prompt(annotations, args):
351
- # cropped_boxes, cropped_images, not_crop, filter_id, annotaions = crop_image(
352
- # annotations, args.img_path
353
- # )
354
- # clip_model, preprocess = clip.load("ViT-B/32", device=args.device)
355
- # scores = retriev(
356
- # clip_model, preprocess, cropped_boxes, args.text_prompt, device=args.device
357
- # )
358
- # max_idx = scores.argsort()
359
- # max_idx = max_idx[-1]
360
- # max_idx += sum(np.array(filter_id) <= int(max_idx))
361
- # return annotaions[max_idx]["segmentation"], max_idx
 
3
  import matplotlib.pyplot as plt
4
  import cv2
5
  import torch
6
+ import os
7
+ import sys
8
+ import clip
9
 
10
 
11
  def convert_box_xywh_to_xyxy(box):
 
51
  return annotations
52
 
53
 
54
+ def filter_masks(annotations): # filter the overlap mask
55
  annotations.sort(key=lambda x: x["area"], reverse=True)
56
  to_remove = set()
57
  for i in range(0, len(annotations)):
 
88
  w = x2 - x1
89
  return [x1, y1, x2, y2]
90
 
91
+
92
  def fast_process(
93
+ annotations, args, mask_random_color, bbox=None, points=None, edges=False
94
+ ):
 
 
 
 
 
 
 
 
95
  if isinstance(annotations[0], dict):
96
+ annotations = [annotation["segmentation"] for annotation in annotations]
97
+ result_name = os.path.basename(args.img_path)
98
+ image = cv2.imread(args.img_path)
99
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
100
+ original_h = image.shape[0]
101
+ original_w = image.shape[1]
102
+ if sys.platform == "darwin":
103
+ plt.switch_backend("TkAgg")
104
+ plt.figure(figsize=(original_w/100, original_h/100))
105
+ # Add subplot with no margin.
106
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
107
+ plt.margins(0, 0)
108
+ plt.gca().xaxis.set_major_locator(plt.NullLocator())
109
+ plt.gca().yaxis.set_major_locator(plt.NullLocator())
110
+ plt.imshow(image)
111
+ if args.better_quality == True:
112
  if isinstance(annotations[0], torch.Tensor):
113
  annotations = np.array(annotations.cpu())
114
  for i, mask in enumerate(annotations):
115
+ mask = cv2.morphologyEx(
116
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
117
+ )
118
+ annotations[i] = cv2.morphologyEx(
119
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
120
+ )
121
+ if args.device == "cpu":
122
  annotations = np.array(annotations)
123
+ fast_show_mask(
124
  annotations,
125
  plt.gca(),
126
  random_color=mask_random_color,
127
  bbox=bbox,
128
+ points=points,
129
+ point_label=args.point_label,
130
+ retinamask=args.retina,
131
  target_height=original_h,
132
  target_width=original_w,
133
  )
134
  else:
135
  if isinstance(annotations[0], np.ndarray):
136
  annotations = torch.from_numpy(annotations)
137
+ fast_show_mask_gpu(
138
  annotations,
139
  plt.gca(),
140
+ random_color=args.randomcolor,
141
  bbox=bbox,
142
+ points=points,
143
+ point_label=args.point_label,
144
+ retinamask=args.retina,
145
  target_height=original_h,
146
  target_width=original_w,
147
  )
148
  if isinstance(annotations, torch.Tensor):
149
  annotations = annotations.cpu().numpy()
150
+ if args.withContours == True:
 
151
  contour_all = []
152
  temp = np.zeros((original_h, original_w, 1))
153
  for i, mask in enumerate(annotations):
154
  if type(mask) == dict:
155
+ mask = mask["segmentation"]
156
  annotation = mask.astype(np.uint8)
157
+ if args.retina == False:
158
  annotation = cv2.resize(
159
  annotation,
160
  (original_w, original_h),
161
  interpolation=cv2.INTER_NEAREST,
162
  )
163
+ contours, hierarchy = cv2.findContours(
164
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
165
+ )
166
  for contour in contours:
167
  contour_all.append(contour)
168
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
169
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
170
  contour_mask = temp / 255 * color.reshape(1, 1, -1)
171
+ plt.imshow(contour_mask)
172
+
173
+ save_path = args.output
174
+ if not os.path.exists(save_path):
175
+ os.makedirs(save_path)
176
+ plt.axis("off")
177
+ fig = plt.gcf()
178
+ plt.draw()
179
+
180
+ try:
181
+ buf = fig.canvas.tostring_rgb()
182
+ except AttributeError:
183
+ fig.canvas.draw()
184
+ buf = fig.canvas.tostring_rgb()
185
+
186
+ cols, rows = fig.canvas.get_width_height()
187
+ img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
188
+ cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
189
+
190
+
191
+ # CPU post process
192
  def fast_show_mask(
193
  annotation,
194
  ax,
195
  random_color=False,
196
  bbox=None,
197
+ points=None,
198
+ point_label=None,
199
  retinamask=True,
200
  target_height=960,
201
  target_width=960,
202
  ):
203
+ msak_sum = annotation.shape[0]
204
  height = annotation.shape[1]
205
  weight = annotation.shape[2]
206
  # 将annotation 按照面积 排序
207
  areas = np.sum(annotation, axis=(1, 2))
208
+ sorted_indices = np.argsort(areas)
209
  annotation = annotation[sorted_indices]
210
 
211
  index = (annotation != 0).argmax(axis=0)
212
  if random_color == True:
213
+ color = np.random.random((msak_sum, 1, 1, 3))
214
  else:
215
+ color = np.ones((msak_sum, 1, 1, 3)) * np.array(
216
+ [30 / 255, 144 / 255, 255 / 255]
217
+ )
218
+ transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
219
  visual = np.concatenate([color, transparency], axis=-1)
220
  mask_image = np.expand_dims(annotation, -1) * visual
221
 
222
+ show = np.zeros((height, weight, 4))
223
+ h_indices, w_indices = np.meshgrid(
224
+ np.arange(height), np.arange(weight), indexing="ij"
225
+ )
226
  indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
227
+ # 使用向量化索引更新show的值
228
+ show[h_indices, w_indices, :] = mask_image[indices]
229
  if bbox is not None:
230
  x1, y1, x2, y2 = bbox
231
+ ax.add_patch(
232
+ plt.Rectangle(
233
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
234
+ )
235
+ )
236
+ # draw point
237
+ if points is not None:
238
+ plt.scatter(
239
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
240
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
241
+ s=20,
242
+ c="y",
243
+ )
244
+ plt.scatter(
245
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
246
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
247
+ s=20,
248
+ c="m",
249
+ )
250
 
251
  if retinamask == False:
252
+ show = cv2.resize(
253
+ show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
254
+ )
255
+ ax.imshow(show)
256
 
257
 
258
  def fast_show_mask_gpu(
 
260
  ax,
261
  random_color=False,
262
  bbox=None,
263
+ points=None,
264
+ point_label=None,
265
  retinamask=True,
266
  target_height=960,
267
  target_width=960,
268
  ):
269
+ msak_sum = annotation.shape[0]
 
270
  height = annotation.shape[1]
271
  weight = annotation.shape[2]
272
  areas = torch.sum(annotation, dim=(1, 2))
 
275
  # 找每个位置第一个非零值下标
276
  index = (annotation != 0).to(torch.long).argmax(dim=0)
277
  if random_color == True:
278
+ color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
279
  else:
280
+ color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor(
281
  [30 / 255, 144 / 255, 255 / 255]
282
+ ).to(annotation.device)
283
+ transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
284
  visual = torch.cat([color, transparency], dim=-1)
285
  mask_image = torch.unsqueeze(annotation, -1) * visual
286
  # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
287
+ show = torch.zeros((height, weight, 4)).to(annotation.device)
288
+ h_indices, w_indices = torch.meshgrid(
289
+ torch.arange(height), torch.arange(weight), indexing="ij"
290
+ )
291
  indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
292
  # 使用向量化索引更新show的值
293
+ show[h_indices, w_indices, :] = mask_image[indices]
294
+ show_cpu = show.cpu().numpy()
295
  if bbox is not None:
296
  x1, y1, x2, y2 = bbox
297
  ax.add_patch(
 
299
  (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
300
  )
301
  )
302
+ # draw point
303
+ if points is not None:
304
+ plt.scatter(
305
+ [point[0] for i, point in enumerate(points) if point_label[i] == 1],
306
+ [point[1] for i, point in enumerate(points) if point_label[i] == 1],
307
+ s=20,
308
+ c="y",
309
+ )
310
+ plt.scatter(
311
+ [point[0] for i, point in enumerate(points) if point_label[i] == 0],
312
+ [point[1] for i, point in enumerate(points) if point_label[i] == 0],
313
+ s=20,
314
+ c="m",
315
+ )
316
  if retinamask == False:
317
+ show_cpu = cv2.resize(
318
+ show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
319
  )
320
+ ax.imshow(show_cpu)
321
+
322
+
323
+ # clip
324
+ @torch.no_grad()
325
+ def retriev(
326
+ model, preprocess, elements, search_text: str, device
327
+ ) -> int:
328
+ preprocessed_images = [preprocess(image).to(device) for image in elements]
329
+ tokenized_text = clip.tokenize([search_text]).to(device)
330
+ stacked_images = torch.stack(preprocessed_images)
331
+ image_features = model.encode_image(stacked_images)
332
+ text_features = model.encode_text(tokenized_text)
333
+ image_features /= image_features.norm(dim=-1, keepdim=True)
334
+ text_features /= text_features.norm(dim=-1, keepdim=True)
335
+ probs = 100.0 * image_features @ text_features.T
336
+ return probs[:, 0].softmax(dim=0)
337
+
338
+
339
+ def crop_image(annotations, image_like):
340
+ if isinstance(image_like, str):
341
+ image = Image.open(image_like)
342
+ else:
343
+ image = image_like
344
  ori_w, ori_h = image.size
345
  mask_h, mask_w = annotations[0]["segmentation"].shape
346
  if ori_w != mask_w or ori_h != mask_h:
 
391
  return masks[max_iou_index].cpu().numpy(), max_iou_index
392
 
393
 
394
+ def point_prompt(masks, points, point_label, target_height, target_width): # numpy 处理
395
  h = masks[0]["segmentation"].shape[0]
396
  w = masks[0]["segmentation"].shape[1]
397
  if h != target_height or w != target_width:
 
406
  else:
407
  mask = annotation
408
  for i, point in enumerate(points):
409
+ if mask[point[1], point[0]] == 1 and point_label[i] == 1:
410
  onemask += mask
411
+ if mask[point[1], point[0]] == 1 and point_label[i] == 0:
412
  onemask -= mask
413
  onemask = onemask >= 1
414
  return onemask, 0
415
 
416
 
417
+ def text_prompt(annotations, text, img_path, device):
418
+ cropped_boxes, cropped_images, not_crop, filter_id, annotations_ = crop_image(
419
+ annotations, img_path
420
+ )
421
+ clip_model, preprocess = clip.load("./weights/CLIP_ViT_B_32.pt", device=device)
422
+ scores = retriev(
423
+ clip_model, preprocess, cropped_boxes, text, device=device
424
+ )
425
+ max_idx = scores.argsort()
426
+ max_idx = max_idx[-1]
427
+ max_idx += sum(np.array(filter_id) <= int(max_idx))
428
+ return annotations_[max_idx]["segmentation"], max_idx
utils/tools_gradio.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+ import torch
6
+
7
+
8
+ def fast_process(
9
+ annotations,
10
+ image,
11
+ device,
12
+ scale,
13
+ better_quality=False,
14
+ mask_random_color=True,
15
+ bbox=None,
16
+ use_retina=True,
17
+ withContours=True,
18
+ ):
19
+ if isinstance(annotations[0], dict):
20
+ annotations = [annotation['segmentation'] for annotation in annotations]
21
+
22
+ original_h = image.height
23
+ original_w = image.width
24
+ if better_quality:
25
+ if isinstance(annotations[0], torch.Tensor):
26
+ annotations = np.array(annotations.cpu())
27
+ for i, mask in enumerate(annotations):
28
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
29
+ annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
30
+ if device == 'cpu':
31
+ annotations = np.array(annotations)
32
+ inner_mask = fast_show_mask(
33
+ annotations,
34
+ plt.gca(),
35
+ random_color=mask_random_color,
36
+ bbox=bbox,
37
+ retinamask=use_retina,
38
+ target_height=original_h,
39
+ target_width=original_w,
40
+ )
41
+ else:
42
+ if isinstance(annotations[0], np.ndarray):
43
+ annotations = torch.from_numpy(annotations)
44
+ inner_mask = fast_show_mask_gpu(
45
+ annotations,
46
+ plt.gca(),
47
+ random_color=mask_random_color,
48
+ bbox=bbox,
49
+ retinamask=use_retina,
50
+ target_height=original_h,
51
+ target_width=original_w,
52
+ )
53
+ if isinstance(annotations, torch.Tensor):
54
+ annotations = annotations.cpu().numpy()
55
+
56
+ if withContours:
57
+ contour_all = []
58
+ temp = np.zeros((original_h, original_w, 1))
59
+ for i, mask in enumerate(annotations):
60
+ if type(mask) == dict:
61
+ mask = mask['segmentation']
62
+ annotation = mask.astype(np.uint8)
63
+ if use_retina == False:
64
+ annotation = cv2.resize(
65
+ annotation,
66
+ (original_w, original_h),
67
+ interpolation=cv2.INTER_NEAREST,
68
+ )
69
+ contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
70
+ for contour in contours:
71
+ contour_all.append(contour)
72
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
73
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
74
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
75
+
76
+ image = image.convert('RGBA')
77
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
78
+ image.paste(overlay_inner, (0, 0), overlay_inner)
79
+
80
+ if withContours:
81
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
82
+ image.paste(overlay_contour, (0, 0), overlay_contour)
83
+
84
+ return image
85
+
86
+
87
+ # CPU post process
88
+ def fast_show_mask(
89
+ annotation,
90
+ ax,
91
+ random_color=False,
92
+ bbox=None,
93
+ retinamask=True,
94
+ target_height=960,
95
+ target_width=960,
96
+ ):
97
+ mask_sum = annotation.shape[0]
98
+ height = annotation.shape[1]
99
+ weight = annotation.shape[2]
100
+ # 将annotation 按照面积 排序
101
+ areas = np.sum(annotation, axis=(1, 2))
102
+ sorted_indices = np.argsort(areas)[::1]
103
+ annotation = annotation[sorted_indices]
104
+
105
+ index = (annotation != 0).argmax(axis=0)
106
+ if random_color == True:
107
+ color = np.random.random((mask_sum, 1, 1, 3))
108
+ else:
109
+ color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
110
+ transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
111
+ visual = np.concatenate([color, transparency], axis=-1)
112
+ mask_image = np.expand_dims(annotation, -1) * visual
113
+
114
+ mask = np.zeros((height, weight, 4))
115
+
116
+ h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
117
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
118
+
119
+ mask[h_indices, w_indices, :] = mask_image[indices]
120
+ if bbox is not None:
121
+ x1, y1, x2, y2 = bbox
122
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
123
+
124
+ if retinamask == False:
125
+ mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
126
+
127
+ return mask
128
+
129
+
130
+ def fast_show_mask_gpu(
131
+ annotation,
132
+ ax,
133
+ random_color=False,
134
+ bbox=None,
135
+ retinamask=True,
136
+ target_height=960,
137
+ target_width=960,
138
+ ):
139
+ device = annotation.device
140
+ mask_sum = annotation.shape[0]
141
+ height = annotation.shape[1]
142
+ weight = annotation.shape[2]
143
+ areas = torch.sum(annotation, dim=(1, 2))
144
+ sorted_indices = torch.argsort(areas, descending=False)
145
+ annotation = annotation[sorted_indices]
146
+ # 找每个位置第一个非���值下标
147
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
148
+ if random_color == True:
149
+ color = torch.rand((mask_sum, 1, 1, 3)).to(device)
150
+ else:
151
+ color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
152
+ [30 / 255, 144 / 255, 255 / 255]
153
+ ).to(device)
154
+ transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
155
+ visual = torch.cat([color, transparency], dim=-1)
156
+ mask_image = torch.unsqueeze(annotation, -1) * visual
157
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
158
+ mask = torch.zeros((height, weight, 4)).to(device)
159
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
160
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
161
+ # 使用向量化索引更新show的值
162
+ mask[h_indices, w_indices, :] = mask_image[indices]
163
+ mask_cpu = mask.cpu().numpy()
164
+ if bbox is not None:
165
+ x1, y1, x2, y2 = bbox
166
+ ax.add_patch(
167
+ plt.Rectangle(
168
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
169
+ )
170
+ )
171
+ if retinamask == False:
172
+ mask_cpu = cv2.resize(
173
+ mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
174
+ )
175
+ return mask_cpu