AAAAAAyq commited on
Commit
9724c61
1 Parent(s): 086b1c4

Update the examples

Browse files
Files changed (4) hide show
  1. __pycache__/tools.cpython-39.pyc +0 -0
  2. app.py +71 -188
  3. app_debug.py +126 -44
  4. tools.py +395 -0
__pycache__/tools.cpython-39.pyc ADDED
Binary file (11 kB). View file
 
app.py CHANGED
@@ -1,10 +1,7 @@
1
  from ultralytics import YOLO
2
- import numpy as np
3
- import matplotlib.pyplot as plt
4
  import gradio as gr
5
- import cv2
6
  import torch
7
- from PIL import Image
8
 
9
  # Load the pre-trained model
10
  model = YOLO('checkpoints/FastSAM.pt')
@@ -12,6 +9,13 @@ model = YOLO('checkpoints/FastSAM.pt')
12
  # Description
13
  title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
14
 
 
 
 
 
 
 
 
15
  description = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM).
16
 
17
  🎯 Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
@@ -28,161 +32,56 @@ description = """This is a demo on Github project 🏃 [Fast Segment Anything Mo
28
 
29
  """
30
 
31
- examples = [["assets/sa_8776.jpg"], ["assets/sa_414.jpg"],
32
- ["assets/sa_1309.jpg"], ["assets/sa_11025.jpg"],
33
- ["assets/sa_561.jpg"], ["assets/sa_192.jpg"],
34
- ["assets/sa_10039.jpg"], ["assets/sa_862.jpg"]]
35
 
36
  default_example = examples[0]
37
 
38
  css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
39
 
40
- def fast_process(annotations, image, high_quality, device, scale):
41
- if isinstance(annotations[0],dict):
42
- annotations = [annotation['segmentation'] for annotation in annotations]
43
-
44
- original_h = image.height
45
- original_w = image.width
46
- if high_quality == True:
47
- if isinstance(annotations[0],torch.Tensor):
48
- annotations = np.array(annotations.cpu())
49
- for i, mask in enumerate(annotations):
50
- mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
51
- annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
52
- if device == 'cpu':
53
- annotations = np.array(annotations)
54
- inner_mask = fast_show_mask(annotations,
55
- plt.gca(),
56
- bbox=None,
57
- points=None,
58
- pointlabel=None,
59
- retinamask=True,
60
- target_height=original_h,
61
- target_width=original_w)
62
- else:
63
- if isinstance(annotations[0],np.ndarray):
64
- annotations = torch.from_numpy(annotations)
65
- inner_mask = fast_show_mask_gpu(annotations,
66
- plt.gca(),
67
- bbox=None,
68
- points=None,
69
- pointlabel=None)
70
- if isinstance(annotations, torch.Tensor):
71
- annotations = annotations.cpu().numpy()
72
-
73
- if high_quality:
74
- contour_all = []
75
- temp = np.zeros((original_h, original_w,1))
76
- for i, mask in enumerate(annotations):
77
- if type(mask) == dict:
78
- mask = mask['segmentation']
79
- annotation = mask.astype(np.uint8)
80
- contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
81
- for contour in contours:
82
- contour_all.append(contour)
83
- cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
84
- color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
85
- contour_mask = temp / 255 * color.reshape(1, 1, -1)
86
- image = image.convert('RGBA')
87
-
88
- overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
89
- image.paste(overlay_inner, (0, 0), overlay_inner)
90
-
91
- if high_quality:
92
- overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
93
- image.paste(overlay_contour, (0, 0), overlay_contour)
94
-
95
- return image
96
-
97
- # CPU post process
98
- def fast_show_mask(annotation, ax, bbox=None,
99
- points=None, pointlabel=None,
100
- retinamask=True, target_height=960,
101
- target_width=960):
102
- msak_sum = annotation.shape[0]
103
- height = annotation.shape[1]
104
- weight = annotation.shape[2]
105
- # 将annotation 按照面积 排序
106
- areas = np.sum(annotation, axis=(1, 2))
107
- sorted_indices = np.argsort(areas)[::1]
108
- annotation = annotation[sorted_indices]
109
-
110
- index = (annotation != 0).argmax(axis=0)
111
- color = np.random.random((msak_sum,1,1,3))
112
- transparency = np.ones((msak_sum,1,1,1)) * 0.6
113
- visual = np.concatenate([color,transparency],axis=-1)
114
- mask_image = np.expand_dims(annotation,-1) * visual
115
-
116
- mask = np.zeros((height,weight,4))
117
-
118
- h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
119
- indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
120
- # 使用向量化索引更新show的值
121
- mask[h_indices, w_indices, :] = mask_image[indices]
122
- if bbox is not None:
123
- x1, y1, x2, y2 = bbox
124
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
125
- # draw point
126
- if points is not None:
127
- plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
128
- plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
129
-
130
- if retinamask==False:
131
- mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
132
-
133
- return mask
134
-
135
-
136
- def fast_show_mask_gpu(annotation, ax,
137
- bbox=None, points=None,
138
- pointlabel=None):
139
- msak_sum = annotation.shape[0]
140
- height = annotation.shape[1]
141
- weight = annotation.shape[2]
142
- areas = torch.sum(annotation, dim=(1, 2))
143
- sorted_indices = torch.argsort(areas, descending=False)
144
- annotation = annotation[sorted_indices]
145
- # 找每个位置第一个非零值下标
146
- index = (annotation != 0).to(torch.long).argmax(dim=0)
147
- color = torch.rand((msak_sum,1,1,3)).to(annotation.device)
148
- transparency = torch.ones((msak_sum,1,1,1)).to(annotation.device) * 0.6
149
- visual = torch.cat([color,transparency],dim=-1)
150
- mask_image = torch.unsqueeze(annotation,-1) * visual
151
- # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
152
- mask = torch.zeros((height,weight,4)).to(annotation.device)
153
- h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
154
- indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
155
- # 使用向量化索引更新show的值
156
- mask[h_indices, w_indices, :] = mask_image[indices]
157
- mask_cpu = mask.cpu().numpy()
158
- if bbox is not None:
159
- x1, y1, x2, y2 = bbox
160
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
161
- # draw point
162
- if points is not None:
163
- plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
164
- plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
165
- return mask_cpu
166
-
167
-
168
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
169
 
170
- def segment_image(input, input_size=1024, high_visual_quality=True, iou_threshold=0.7, conf_threshold=0.25):
 
 
 
 
 
 
 
 
 
 
 
 
171
  input_size = int(input_size) # 确保 imgsz 是整数
172
-
173
  # Thanks for the suggestion by hysts in HuggingFace.
174
  w, h = input.size
175
  scale = input_size / max(w, h)
176
  new_w = int(w * scale)
177
  new_h = int(h * scale)
178
  input = input.resize((new_w, new_h))
179
-
180
- results = model(input, device=device, retina_masks=True, iou=iou_threshold, conf=conf_threshold, imgsz=input_size)
 
 
 
 
 
181
  fig = fast_process(annotations=results[0].masks.data,
182
- image=input, high_quality=high_visual_quality,
183
- device=device, scale=(1024 // input_size))
 
 
 
 
 
 
 
 
184
  return fig
185
 
 
186
  # input_size=1024
187
  # high_quality_visual=True
188
  # inp = 'assets/sa_192.jpg'
@@ -193,41 +92,50 @@ def segment_image(input, input_size=1024, high_visual_quality=True, iou_threshol
193
  # pil_image = fast_process(annotations=results[0].masks.data,
194
  # image=input, high_quality=high_quality_visual, device=device)
195
 
 
 
196
  cond_img = gr.Image(label="Input", value=default_example[0], type='pil')
197
 
198
  segm_img = gr.Image(label="Segmented Image", interactive=False, type='pil')
199
 
200
- input_size_slider = gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='Input_size (Our model was trained on a size of 1024)')
 
 
 
 
201
 
202
  with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
203
  with gr.Row():
204
- # Title
205
- gr.Markdown(title)
206
- # # # Description
207
- # # gr.Markdown(description)
208
-
 
 
 
209
  # Images
210
  with gr.Row(variant="panel"):
211
  with gr.Column(scale=1):
212
  cond_img.render()
213
-
214
  with gr.Column(scale=1):
215
  segm_img.render()
216
-
217
  # Submit & Clear
218
  with gr.Row():
219
  with gr.Column():
220
  input_size_slider.render()
221
-
222
  with gr.Row():
223
- vis_check = gr.Checkbox(value=True, label='high_visual_quality')
224
-
225
  with gr.Column():
226
  segment_btn = gr.Button("Segment Anything", variant='primary')
227
-
228
  # with gr.Column():
229
- # clear_btn = gr.Button("Clear", variant="primary")
230
-
231
  gr.Markdown("Try some of the examples below ⬇️")
232
  gr.Examples(examples=examples,
233
  inputs=[cond_img],
@@ -235,49 +143,24 @@ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
235
  fn=segment_image,
236
  cache_examples=True,
237
  examples_per_page=4)
238
- # gr.Markdown("Try some of the examples below ⬇️")
239
- # gr.Examples(examples=examples,
240
- # inputs=[cond_img, input_size_slider, vis_check, iou_threshold, conf_threshold],
241
- # outputs=output,
242
- # fn=segment_image,
243
- # examples_per_page=4)
244
 
245
  with gr.Column():
246
  with gr.Accordion("Advanced options", open=False):
247
  iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou_threshold')
248
  conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf_threshold')
 
249
 
250
  # Description
251
  gr.Markdown(description)
252
-
253
  segment_btn.click(segment_image,
254
- inputs=[cond_img, input_size_slider, vis_check, iou_threshold, conf_threshold],
255
- outputs=segm_img)
256
-
257
  # def clear():
258
- # return None, None
259
 
260
  # clear_btn.click(fn=clear, inputs=None, outputs=None)
261
 
262
  demo.queue()
263
  demo.launch()
264
-
265
- # app_interface = gr.Interface(fn=predict,
266
- # inputs=[gr.Image(type='pil'),
267
- # gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='input_size'),
268
- # gr.components.Checkbox(value=True, label='high_visual_quality')],
269
- # # outputs=['plot'],
270
- # outputs=gr.Image(type='pil'),
271
- # # examples=[["assets/sa_8776.jpg"]],
272
- # # # ["assets/sa_1309.jpg", 1024]],
273
- # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
274
- # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
275
- # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
276
- # ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
277
- # cache_examples=True,
278
- # title="Fast Segment Anything (Everything mode)"
279
- # )
280
-
281
-
282
- # app_interface.queue(concurrency_count=1, max_size=20)
283
- # app_interface.launch()
 
1
  from ultralytics import YOLO
 
 
2
  import gradio as gr
 
3
  import torch
4
+ from tools import fast_process
5
 
6
  # Load the pre-trained model
7
  model = YOLO('checkpoints/FastSAM.pt')
 
9
  # Description
10
  title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
11
 
12
+ news = """ # News
13
+
14
+ 🔥 Add the 'Advanced options" in Everything mode to get a more detailed adjustment.
15
+
16
+ # 🔥 Support the points mode and box mode, text mode will come soon.
17
+ """
18
+
19
  description = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM).
20
 
21
  🎯 Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
 
32
 
33
  """
34
 
35
+ examples = [["assets/sa_8776.jpg"], ["assets/sa_414.jpg"], ["assets/sa_1309.jpg"], ["assets/sa_11025.jpg"],
36
+ ["assets/sa_561.jpg"], ["assets/sa_192.jpg"], ["assets/sa_10039.jpg"], ["assets/sa_862.jpg"]]
 
 
37
 
38
  default_example = examples[0]
39
 
40
  css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
41
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ def segment_image(
44
+ input,
45
+ input_size=1024,
46
+ iou_threshold=0.7,
47
+ conf_threshold=0.25,
48
+ better_quality=False,
49
+ mask_random_color=True,
50
+ withContours=True,
51
+ points=None,
52
+ bbox=None,
53
+ point_label=None,
54
+ use_retina=True,
55
+ ):
56
  input_size = int(input_size) # 确保 imgsz 是整数
57
+
58
  # Thanks for the suggestion by hysts in HuggingFace.
59
  w, h = input.size
60
  scale = input_size / max(w, h)
61
  new_w = int(w * scale)
62
  new_h = int(h * scale)
63
  input = input.resize((new_w, new_h))
64
+
65
+ results = model(input,
66
+ device=device,
67
+ retina_masks=True,
68
+ iou=iou_threshold,
69
+ conf=conf_threshold,
70
+ imgsz=input_size,)
71
  fig = fast_process(annotations=results[0].masks.data,
72
+ image=input,
73
+ device=device,
74
+ scale=(1024 // input_size),
75
+ better_quality=better_quality,
76
+ mask_random_color=mask_random_color,
77
+ points=points,
78
+ bbox=bbox,
79
+ point_label=point_label,
80
+ use_retina=use_retina,
81
+ withContours=withContours,)
82
  return fig
83
 
84
+
85
  # input_size=1024
86
  # high_quality_visual=True
87
  # inp = 'assets/sa_192.jpg'
 
92
  # pil_image = fast_process(annotations=results[0].masks.data,
93
  # image=input, high_quality=high_quality_visual, device=device)
94
 
95
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
96
+
97
  cond_img = gr.Image(label="Input", value=default_example[0], type='pil')
98
 
99
  segm_img = gr.Image(label="Segmented Image", interactive=False, type='pil')
100
 
101
+ input_size_slider = gr.components.Slider(minimum=512,
102
+ maximum=1024,
103
+ value=1024,
104
+ step=64,
105
+ label='Input_size (Our model was trained on a size of 1024)')
106
 
107
  with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
108
  with gr.Row():
109
+ with gr.Column(scale=1):
110
+ # Title
111
+ gr.Markdown(title)
112
+
113
+ with gr.Column(scale=1):
114
+ # News
115
+ gr.Markdown(news)
116
+
117
  # Images
118
  with gr.Row(variant="panel"):
119
  with gr.Column(scale=1):
120
  cond_img.render()
121
+
122
  with gr.Column(scale=1):
123
  segm_img.render()
124
+
125
  # Submit & Clear
126
  with gr.Row():
127
  with gr.Column():
128
  input_size_slider.render()
129
+
130
  with gr.Row():
131
+ contour_check = gr.Checkbox(value=True, label='withContours')
132
+
133
  with gr.Column():
134
  segment_btn = gr.Button("Segment Anything", variant='primary')
135
+
136
  # with gr.Column():
137
+ # clear_btn = gr.Button("Clear", variant="primary")
138
+
139
  gr.Markdown("Try some of the examples below ⬇️")
140
  gr.Examples(examples=examples,
141
  inputs=[cond_img],
 
143
  fn=segment_image,
144
  cache_examples=True,
145
  examples_per_page=4)
 
 
 
 
 
 
146
 
147
  with gr.Column():
148
  with gr.Accordion("Advanced options", open=False):
149
  iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou_threshold')
150
  conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf_threshold')
151
+ mor_check = gr.Checkbox(value=False, label='better_visual_quality')
152
 
153
  # Description
154
  gr.Markdown(description)
155
+
156
  segment_btn.click(segment_image,
157
+ inputs=[cond_img, input_size_slider, iou_threshold, conf_threshold, mor_check, contour_check],
158
+ outputs=segm_img)
159
+
160
  # def clear():
161
+ # return None, None
162
 
163
  # clear_btn.click(fn=clear, inputs=None, outputs=None)
164
 
165
  demo.queue()
166
  demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app_debug.py CHANGED
@@ -4,22 +4,45 @@ import matplotlib.pyplot as plt
4
  import gradio as gr
5
  import cv2
6
  import torch
7
- # import queue
8
- # import threading
9
  from PIL import Image
10
 
 
 
11
 
12
- model = YOLO('checkpoints/FastSAM.pt') # load a custom model
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- def fast_process(annotations, image, high_quality, device):
 
 
 
 
 
 
 
 
 
16
  if isinstance(annotations[0],dict):
17
  annotations = [annotation['segmentation'] for annotation in annotations]
18
 
19
  original_h = image.height
20
  original_w = image.width
21
- # fig = plt.figure(figsize=(10, 10))
22
- # plt.imshow(image)
23
  if high_quality == True:
24
  if isinstance(annotations[0],torch.Tensor):
25
  annotations = np.array(annotations.cpu())
@@ -57,10 +80,9 @@ def fast_process(annotations, image, high_quality, device):
57
  contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
58
  for contour in contours:
59
  contour_all.append(contour)
60
- cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
61
  color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
62
  contour_mask = temp / 255 * color.reshape(1, 1, -1)
63
- # plt.imshow(contour_mask)
64
  image = image.convert('RGBA')
65
 
66
  overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
@@ -71,10 +93,6 @@ def fast_process(annotations, image, high_quality, device):
71
  image.paste(overlay_contour, (0, 0), overlay_contour)
72
 
73
  return image
74
- # plt.axis('off')
75
- # plt.tight_layout()
76
- # return fig
77
-
78
 
79
  # CPU post process
80
  def fast_show_mask(annotation, ax, bbox=None,
@@ -111,7 +129,6 @@ def fast_show_mask(annotation, ax, bbox=None,
111
 
112
  if retinamask==False:
113
  mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
114
- # ax.imshow(mask)
115
 
116
  return mask
117
 
@@ -145,19 +162,13 @@ def fast_show_mask_gpu(annotation, ax,
145
  if points is not None:
146
  plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
147
  plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
148
- # ax.imshow(mask_cpu)
149
  return mask_cpu
150
 
151
 
152
- # # 预测队列
153
- # prediction_queue = queue.Queue(maxsize=5)
154
-
155
- # # 线程锁
156
- # lock = threading.Lock()
157
-
158
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
159
 
160
- def predict(input, input_size=1024, high_visual_quality=True):
 
161
  input_size = int(input_size) # 确保 imgsz 是整数
162
 
163
  # Thanks for the suggestion by hysts in HuggingFace.
@@ -167,13 +178,13 @@ def predict(input, input_size=1024, high_visual_quality=True):
167
  new_h = int(h * scale)
168
  input = input.resize((new_w, new_h))
169
 
170
- results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
171
  fig = fast_process(annotations=results[0].masks.data,
172
- image=input, high_quality=high_visual_quality, device=device)
 
 
173
  return fig
174
 
175
-
176
-
177
  # input_size=1024
178
  # high_quality_visual=True
179
  # inp = 'assets/sa_192.jpg'
@@ -184,22 +195,93 @@ def predict(input, input_size=1024, high_visual_quality=True):
184
  # pil_image = fast_process(annotations=results[0].masks.data,
185
  # image=input, high_quality=high_quality_visual, device=device)
186
 
187
- app_interface = gr.Interface(fn=predict,
188
- inputs=[gr.Image(type='pil'),
189
- gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='input_size'),
190
- gr.components.Checkbox(value=True, label='high_visual_quality')],
191
- # outputs=['plot'],
192
- outputs=gr.Image(type='pil'),
193
- examples=[["assets/sa_8776.jpg"]],
194
- # # ["assets/sa_1309.jpg", 1024]],
195
- # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
196
- # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
197
- # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
198
- # ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
199
- cache_examples=True,
200
- title="Fast Segment Anything (Everything mode)"
201
- )
202
-
203
-
204
- app_interface.queue(concurrency_count=1, max_size=20)
205
- app_interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import gradio as gr
5
  import cv2
6
  import torch
 
 
7
  from PIL import Image
8
 
9
+ # Load the pre-trained model
10
+ model = YOLO('checkpoints/FastSAM.pt')
11
 
12
+ # Description
13
+ title = "<center><strong><font size='8'>🏃 Fast Segment Anything 🤗</font></strong></center>"
14
 
15
+ description = """This is a demo on Github project 🏃 [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM).
16
+
17
+ 🎯 Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
18
+
19
+ ⌛️ It takes about 4~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
20
+
21
+ 🚀 To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
22
+
23
+ 📣 You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)
24
+
25
+ 😚 A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant.
26
+
27
+ 🏠 Check out our [Model Card 🏃](https://huggingface.co/An-619/FastSAM)
28
+
29
+ """
30
 
31
+ examples = [["assets/sa_8776.jpg"], ["assets/sa_414.jpg"],
32
+ ["assets/sa_1309.jpg"], ["assets/sa_11025.jpg"],
33
+ ["assets/sa_561.jpg"], ["assets/sa_192.jpg"],
34
+ ["assets/sa_10039.jpg"], ["assets/sa_862.jpg"]]
35
+
36
+ default_example = examples[0]
37
+
38
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
39
+
40
+ def fast_process(annotations, image, high_quality, device, scale):
41
  if isinstance(annotations[0],dict):
42
  annotations = [annotation['segmentation'] for annotation in annotations]
43
 
44
  original_h = image.height
45
  original_w = image.width
 
 
46
  if high_quality == True:
47
  if isinstance(annotations[0],torch.Tensor):
48
  annotations = np.array(annotations.cpu())
 
80
  contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
81
  for contour in contours:
82
  contour_all.append(contour)
83
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
84
  color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
85
  contour_mask = temp / 255 * color.reshape(1, 1, -1)
 
86
  image = image.convert('RGBA')
87
 
88
  overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
 
93
  image.paste(overlay_contour, (0, 0), overlay_contour)
94
 
95
  return image
 
 
 
 
96
 
97
  # CPU post process
98
  def fast_show_mask(annotation, ax, bbox=None,
 
129
 
130
  if retinamask==False:
131
  mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
 
132
 
133
  return mask
134
 
 
162
  if points is not None:
163
  plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
164
  plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
 
165
  return mask_cpu
166
 
167
 
 
 
 
 
 
 
168
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
169
 
170
+ def segment_image(input, evt: gr.SelectData=None, input_size=1024, high_visual_quality=True, iou_threshold=0.7, conf_threshold=0.25):
171
+ point = (evt.index[0],evt.index[1])
172
  input_size = int(input_size) # 确保 imgsz 是整数
173
 
174
  # Thanks for the suggestion by hysts in HuggingFace.
 
178
  new_h = int(h * scale)
179
  input = input.resize((new_w, new_h))
180
 
181
+ results = model(input, device=device, retina_masks=True, iou=iou_threshold, conf=conf_threshold, imgsz=input_size)
182
  fig = fast_process(annotations=results[0].masks.data,
183
+ image=input, high_quality=high_visual_quality,
184
+ device=device, scale=(1024 // input_size),
185
+ points=)
186
  return fig
187
 
 
 
188
  # input_size=1024
189
  # high_quality_visual=True
190
  # inp = 'assets/sa_192.jpg'
 
195
  # pil_image = fast_process(annotations=results[0].masks.data,
196
  # image=input, high_quality=high_quality_visual, device=device)
197
 
198
+ cond_img = gr.Image(label="Input", value=default_example[0], type='pil')
199
+
200
+ segm_img = gr.Image(label="Segmented Image", interactive=False, type='pil')
201
+
202
+ input_size_slider = gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='Input_size (Our model was trained on a size of 1024)')
203
+
204
+ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
205
+ with gr.Row():
206
+ # Title
207
+ gr.Markdown(title)
208
+ # # # Description
209
+ # # gr.Markdown(description)
210
+
211
+ # Images
212
+ with gr.Row(variant="panel"):
213
+ with gr.Column(scale=1):
214
+ cond_img.render()
215
+
216
+ with gr.Column(scale=1):
217
+ segm_img.render()
218
+
219
+ # Submit & Clear
220
+ with gr.Row():
221
+ with gr.Column():
222
+ input_size_slider.render()
223
+
224
+ with gr.Row():
225
+ vis_check = gr.Checkbox(value=True, label='high_visual_quality')
226
+
227
+ with gr.Column():
228
+ segment_btn = gr.Button("Segment Anything", variant='primary')
229
+
230
+ # with gr.Column():
231
+ # clear_btn = gr.Button("Clear", variant="primary")
232
+
233
+ gr.Markdown("Try some of the examples below ⬇️")
234
+ gr.Examples(examples=examples,
235
+ inputs=[cond_img],
236
+ outputs=segm_img,
237
+ fn=segment_image,
238
+ cache_examples=True,
239
+ examples_per_page=4)
240
+ # gr.Markdown("Try some of the examples below ⬇️")
241
+ # gr.Examples(examples=examples,
242
+ # inputs=[cond_img, input_size_slider, vis_check, iou_threshold, conf_threshold],
243
+ # outputs=output,
244
+ # fn=segment_image,
245
+ # examples_per_page=4)
246
+
247
+ with gr.Column():
248
+ with gr.Accordion("Advanced options", open=False):
249
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou_threshold')
250
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf_threshold')
251
+
252
+ # Description
253
+ gr.Markdown(description)
254
+
255
+ cond_img.select(segment_image, [], input_img)
256
+
257
+ segment_btn.click(segment_image,
258
+ inputs=[cond_img, input_size_slider, vis_check, iou_threshold, conf_threshold],
259
+ outputs=segm_img)
260
+
261
+ # def clear():
262
+ # return None, None
263
+
264
+ # clear_btn.click(fn=clear, inputs=None, outputs=None)
265
+
266
+ demo.queue()
267
+ demo.launch()
268
+
269
+ # app_interface = gr.Interface(fn=predict,
270
+ # inputs=[gr.Image(type='pil'),
271
+ # gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='input_size'),
272
+ # gr.components.Checkbox(value=True, label='high_visual_quality')],
273
+ # # outputs=['plot'],
274
+ # outputs=gr.Image(type='pil'),
275
+ # # examples=[["assets/sa_8776.jpg"]],
276
+ # # # ["assets/sa_1309.jpg", 1024]],
277
+ # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
278
+ # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
279
+ # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
280
+ # ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
281
+ # cache_examples=True,
282
+ # title="Fast Segment Anything (Everything mode)"
283
+ # )
284
+
285
+
286
+ # app_interface.queue(concurrency_count=1, max_size=20)
287
+ # app_interface.launch()
tools.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ import cv2
5
+ import torch
6
+ import clip
7
+
8
+
9
+ def convert_box_xywh_to_xyxy(box):
10
+ x1 = box[0]
11
+ y1 = box[1]
12
+ x2 = box[0] + box[2]
13
+ y2 = box[1] + box[3]
14
+ return [x1, y1, x2, y2]
15
+
16
+
17
+ def segment_image(image, bbox):
18
+ image_array = np.array(image)
19
+ segmented_image_array = np.zeros_like(image_array)
20
+ x1, y1, x2, y2 = bbox
21
+ segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
22
+ segmented_image = Image.fromarray(segmented_image_array)
23
+ black_image = Image.new("RGB", image.size, (255, 255, 255))
24
+ # transparency_mask = np.zeros_like((), dtype=np.uint8)
25
+ transparency_mask = np.zeros(
26
+ (image_array.shape[0], image_array.shape[1]), dtype=np.uint8
27
+ )
28
+ transparency_mask[y1:y2, x1:x2] = 255
29
+ transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
30
+ black_image.paste(segmented_image, mask=transparency_mask_image)
31
+ return black_image
32
+
33
+
34
+ def format_results(result, filter=0):
35
+ annotations = []
36
+ n = len(result.masks.data)
37
+ for i in range(n):
38
+ annotation = {}
39
+ mask = result.masks.data[i] == 1.0
40
+
41
+ if torch.sum(mask) < filter:
42
+ continue
43
+ annotation["id"] = i
44
+ annotation["segmentation"] = mask.cpu().numpy()
45
+ annotation["bbox"] = result.boxes.data[i]
46
+ annotation["score"] = result.boxes.conf[i]
47
+ annotation["area"] = annotation["segmentation"].sum()
48
+ annotations.append(annotation)
49
+ return annotations
50
+
51
+
52
+ def filter_masks(annotations): # filte the overlap mask
53
+ annotations.sort(key=lambda x: x["area"], reverse=True)
54
+ to_remove = set()
55
+ for i in range(0, len(annotations)):
56
+ a = annotations[i]
57
+ for j in range(i + 1, len(annotations)):
58
+ b = annotations[j]
59
+ if i != j and j not in to_remove:
60
+ # check if
61
+ if b["area"] < a["area"]:
62
+ if (a["segmentation"] & b["segmentation"]).sum() / b[
63
+ "segmentation"
64
+ ].sum() > 0.8:
65
+ to_remove.add(j)
66
+
67
+ return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
68
+
69
+
70
+ def get_bbox_from_mask(mask):
71
+ mask = mask.astype(np.uint8)
72
+ contours, hierarchy = cv2.findContours(
73
+ mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
74
+ )
75
+ x1, y1, w, h = cv2.boundingRect(contours[0])
76
+ x2, y2 = x1 + w, y1 + h
77
+ if len(contours) > 1:
78
+ for b in contours:
79
+ x_t, y_t, w_t, h_t = cv2.boundingRect(b)
80
+ # 将多个bbox合并成一个
81
+ x1 = min(x1, x_t)
82
+ y1 = min(y1, y_t)
83
+ x2 = max(x2, x_t + w_t)
84
+ y2 = max(y2, y_t + h_t)
85
+ h = y2 - y1
86
+ w = x2 - x1
87
+ return [x1, y1, x2, y2]
88
+
89
+ def fast_process(
90
+ annotations,
91
+ image,
92
+ device,
93
+ scale,
94
+ better_quality=False,
95
+ mask_random_color=True,
96
+ points=None,
97
+ bbox=None,
98
+ point_label=None,
99
+ use_retina=True,
100
+ withContours=True,
101
+ ):
102
+ if isinstance(annotations[0], dict):
103
+ annotations = [annotation['segmentation'] for annotation in annotations]
104
+
105
+ original_h = image.height
106
+ original_w = image.width
107
+ if better_quality:
108
+ if isinstance(annotations[0], torch.Tensor):
109
+ annotations = np.array(annotations.cpu())
110
+ for i, mask in enumerate(annotations):
111
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
112
+ annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
113
+ if device == 'cpu':
114
+ annotations = np.array(annotations)
115
+ inner_mask = fast_show_mask(
116
+ annotations,
117
+ plt.gca(),
118
+ random_color=mask_random_color,
119
+ bbox=bbox,
120
+ points=points,
121
+ pointlabel=point_label,
122
+ retinamask=use_retina,
123
+ target_height=original_h,
124
+ target_width=original_w,
125
+ )
126
+ else:
127
+ if isinstance(annotations[0], np.ndarray):
128
+ annotations = torch.from_numpy(annotations)
129
+ inner_mask = fast_show_mask_gpu(
130
+ annotations,
131
+ plt.gca(),
132
+ random_color=mask_random_color,
133
+ bbox=bbox,
134
+ points=points,
135
+ pointlabel=point_label,
136
+ retinamask=use_retina,
137
+ target_height=original_h,
138
+ target_width=original_w,
139
+ )
140
+ if isinstance(annotations, torch.Tensor):
141
+ annotations = annotations.cpu().numpy()
142
+
143
+ if withContours:
144
+ contour_all = []
145
+ temp = np.zeros((original_h, original_w, 1))
146
+ for i, mask in enumerate(annotations):
147
+ if type(mask) == dict:
148
+ mask = mask['segmentation']
149
+ annotation = mask.astype(np.uint8)
150
+ if use_retina == False:
151
+ annotation = cv2.resize(
152
+ annotation,
153
+ (original_w, original_h),
154
+ interpolation=cv2.INTER_NEAREST,
155
+ )
156
+ contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
157
+ for contour in contours:
158
+ contour_all.append(contour)
159
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
160
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
161
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
162
+ i
163
+ image = image.convert('RGBA')
164
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
165
+ image.paste(overlay_inner, (0, 0), overlay_inner)
166
+
167
+ if withContours:
168
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
169
+ image.paste(overlay_contour, (0, 0), overlay_contour)
170
+
171
+ return image
172
+
173
+
174
+ # CPU post process
175
+ def fast_show_mask(
176
+ annotation,
177
+ ax,
178
+ random_color=False,
179
+ bbox=None,
180
+ points=None,
181
+ pointlabel=None,
182
+ retinamask=True,
183
+ target_height=960,
184
+ target_width=960,
185
+ ):
186
+ mask_sum = annotation.shape[0]
187
+ height = annotation.shape[1]
188
+ weight = annotation.shape[2]
189
+ # 将annotation 按照面积 排序
190
+ areas = np.sum(annotation, axis=(1, 2))
191
+ sorted_indices = np.argsort(areas)[::1]
192
+ annotation = annotation[sorted_indices]
193
+
194
+ index = (annotation != 0).argmax(axis=0)
195
+ if random_color == True:
196
+ color = np.random.random((mask_sum, 1, 1, 3))
197
+ else:
198
+ color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
199
+ transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
200
+ visual = np.concatenate([color, transparency], axis=-1)
201
+ mask_image = np.expand_dims(annotation, -1) * visual
202
+
203
+ mask = np.zeros((height, weight, 4))
204
+
205
+ h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
206
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
207
+
208
+ mask[h_indices, w_indices, :] = mask_image[indices]
209
+ if bbox is not None:
210
+ x1, y1, x2, y2 = bbox
211
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
212
+ # draw point
213
+ if points is not None:
214
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
215
+ [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
216
+ s=20,
217
+ c='y')
218
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
219
+ [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
220
+ s=20,
221
+ c='m')
222
+
223
+ if retinamask == False:
224
+ mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
225
+
226
+ return mask
227
+
228
+
229
+ def fast_show_mask_gpu(
230
+ annotation,
231
+ ax,
232
+ random_color=False,
233
+ bbox=None,
234
+ points=None,
235
+ pointlabel=None,
236
+ retinamask=True,
237
+ target_height=960,
238
+ target_width=960,
239
+ ):
240
+ device = annotation.device
241
+ mask_sum = annotation.shape[0]
242
+ height = annotation.shape[1]
243
+ weight = annotation.shape[2]
244
+ areas = torch.sum(annotation, dim=(1, 2))
245
+ sorted_indices = torch.argsort(areas, descending=False)
246
+ annotation = annotation[sorted_indices]
247
+ # 找每个位置第一个非零值下标
248
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
249
+ if random_color == True:
250
+ color = torch.rand((mask_sum, 1, 1, 3)).to(device)
251
+ else:
252
+ color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
253
+ [30 / 255, 144 / 255, 255 / 255]
254
+ ).to(device)
255
+ transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
256
+ visual = torch.cat([color, transparency], dim=-1)
257
+ mask_image = torch.unsqueeze(annotation, -1) * visual
258
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
259
+ mask = torch.zeros((height, weight, 4)).to(device)
260
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
261
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
262
+ # 使用向量化索引更新show的值
263
+ mask[h_indices, w_indices, :] = mask_image[indices]
264
+ mask_cpu = mask.cpu().numpy()
265
+ if bbox is not None:
266
+ x1, y1, x2, y2 = bbox
267
+ ax.add_patch(
268
+ plt.Rectangle(
269
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
270
+ )
271
+ )
272
+ # draw point
273
+ if points is not None:
274
+ plt.scatter(
275
+ [point[0] for i, point in enumerate(points) if pointlabel[i] == 1],
276
+ [point[1] for i, point in enumerate(points) if pointlabel[i] == 1],
277
+ s=20,
278
+ c="y",
279
+ )
280
+ plt.scatter(
281
+ [point[0] for i, point in enumerate(points) if pointlabel[i] == 0],
282
+ [point[1] for i, point in enumerate(points) if pointlabel[i] == 0],
283
+ s=20,
284
+ c="m",
285
+ )
286
+ if retinamask == False:
287
+ mask_cpu = cv2.resize(
288
+ mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
289
+ )
290
+ return mask_cpu
291
+
292
+
293
+ # clip
294
+ @torch.no_grad()
295
+ def retriev(
296
+ model, preprocess, elements: [Image.Image], search_text: str, device
297
+ ) -> int:
298
+ preprocessed_images = [preprocess(image).to(device) for image in elements]
299
+ tokenized_text = clip.tokenize([search_text]).to(device)
300
+ stacked_images = torch.stack(preprocessed_images)
301
+ image_features = model.encode_image(stacked_images)
302
+ text_features = model.encode_text(tokenized_text)
303
+ image_features /= image_features.norm(dim=-1, keepdim=True)
304
+ text_features /= text_features.norm(dim=-1, keepdim=True)
305
+ probs = 100.0 * image_features @ text_features.T
306
+ return probs[:, 0].softmax(dim=0)
307
+
308
+
309
+ def crop_image(annotations, image_path):
310
+ image = Image.open(image_path)
311
+ ori_w, ori_h = image.size
312
+ mask_h, mask_w = annotations[0]["segmentation"].shape
313
+ if ori_w != mask_w or ori_h != mask_h:
314
+ image = image.resize((mask_w, mask_h))
315
+ cropped_boxes = []
316
+ cropped_images = []
317
+ not_crop = []
318
+ filter_id = []
319
+ # annotations, _ = filter_masks(annotations)
320
+ # filter_id = list(_)
321
+ for _, mask in enumerate(annotations):
322
+ if np.sum(mask["segmentation"]) <= 100:
323
+ filter_id.append(_)
324
+ continue
325
+ bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
326
+ cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
327
+ # cropped_boxes.append(segment_image(image,mask["segmentation"]))
328
+ cropped_images.append(bbox) # 保存裁剪的图片的bbox
329
+
330
+ return cropped_boxes, cropped_images, not_crop, filter_id, annotations
331
+
332
+
333
+ def box_prompt(masks, bbox, target_height, target_width):
334
+ h = masks.shape[1]
335
+ w = masks.shape[2]
336
+ if h != target_height or w != target_width:
337
+ bbox = [
338
+ int(bbox[0] * w / target_width),
339
+ int(bbox[1] * h / target_height),
340
+ int(bbox[2] * w / target_width),
341
+ int(bbox[3] * h / target_height),
342
+ ]
343
+ bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
344
+ bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
345
+ bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
346
+ bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
347
+
348
+ # IoUs = torch.zeros(len(masks), dtype=torch.float32)
349
+ bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
350
+
351
+ masks_area = torch.sum(masks[:, bbox[1] : bbox[3], bbox[0] : bbox[2]], dim=(1, 2))
352
+ orig_masks_area = torch.sum(masks, dim=(1, 2))
353
+
354
+ union = bbox_area + orig_masks_area - masks_area
355
+ IoUs = masks_area / union
356
+ max_iou_index = torch.argmax(IoUs)
357
+
358
+ return masks[max_iou_index].cpu().numpy(), max_iou_index
359
+
360
+
361
+ def point_prompt(masks, points, pointlabel, target_height, target_width): # numpy 处理
362
+ h = masks[0]["segmentation"].shape[0]
363
+ w = masks[0]["segmentation"].shape[1]
364
+ if h != target_height or w != target_width:
365
+ points = [
366
+ [int(point[0] * w / target_width), int(point[1] * h / target_height)]
367
+ for point in points
368
+ ]
369
+ onemask = np.zeros((h, w))
370
+ for i, annotation in enumerate(masks):
371
+ if type(annotation) == dict:
372
+ mask = annotation["segmentation"]
373
+ else:
374
+ mask = annotation
375
+ for i, point in enumerate(points):
376
+ if mask[point[1], point[0]] == 1 and pointlabel[i] == 1:
377
+ onemask += mask
378
+ if mask[point[1], point[0]] == 1 and pointlabel[i] == 0:
379
+ onemask -= mask
380
+ onemask = onemask >= 1
381
+ return onemask, 0
382
+
383
+
384
+ def text_prompt(annotations, args):
385
+ cropped_boxes, cropped_images, not_crop, filter_id, annotaions = crop_image(
386
+ annotations, args.img_path
387
+ )
388
+ clip_model, preprocess = clip.load("ViT-B/32", device=args.device)
389
+ scores = retriev(
390
+ clip_model, preprocess, cropped_boxes, args.text_prompt, device=args.device
391
+ )
392
+ max_idx = scores.argsort()
393
+ max_idx = max_idx[-1]
394
+ max_idx += sum(np.array(filter_id) <= int(max_idx))
395
+ return annotaions[max_idx]["segmentation"], max_idx