AAAAAAyq commited on
Commit
bd6726a
1 Parent(s): 7c3570f

try to fix CUDA bug

Browse files
Files changed (3) hide show
  1. app.py +176 -238
  2. app_copy.py +196 -0
  3. requirements.txt +1 -1
app.py CHANGED
@@ -1,239 +1,177 @@
1
- from ultralytics import YOLO
2
- import numpy as np
3
- import matplotlib.pyplot as plt
4
- import gradio as gr
5
- import cv2
6
- import torch
7
- # import queue
8
- # import threading
9
- # from PIL import Image
10
-
11
-
12
- model = YOLO('checkpoints/FastSAM.pt') # load a custom model
13
-
14
-
15
- def fast_process(annotations, image, high_quality, device):
16
- if isinstance(annotations[0],dict):
17
- annotations = [annotation['segmentation'] for annotation in annotations]
18
-
19
- original_h = image.height
20
- original_w = image.width
21
- fig = plt.figure(figsize=(10, 10))
22
- plt.imshow(image)
23
- if high_quality == True:
24
- if isinstance(annotations[0],torch.Tensor):
25
- annotations = np.array(annotations.cpu())
26
- for i, mask in enumerate(annotations):
27
- mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
28
- annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
29
- if device == 'cpu':
30
- annotations = np.array(annotations)
31
- fast_show_mask(annotations,
32
- plt.gca(),
33
- bbox=None,
34
- points=None,
35
- pointlabel=None,
36
- retinamask=True,
37
- target_height=original_h,
38
- target_width=original_w)
39
- else:
40
- if isinstance(annotations[0],np.ndarray):
41
- annotations = torch.from_numpy(annotations)
42
- fast_show_mask_gpu(annotations,
43
- plt.gca(),
44
- bbox=None,
45
- points=None,
46
- pointlabel=None)
47
- if isinstance(annotations, torch.Tensor):
48
- annotations = annotations.cpu().numpy()
49
- if high_quality == True:
50
- contour_all = []
51
- temp = np.zeros((original_h, original_w,1))
52
- for i, mask in enumerate(annotations):
53
- if type(mask) == dict:
54
- mask = mask['segmentation']
55
- annotation = mask.astype(np.uint8)
56
- contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
57
- for contour in contours:
58
- contour_all.append(contour)
59
- cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
60
- color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
61
- contour_mask = temp / 225 * color.reshape(1, 1, -1)
62
- plt.imshow(contour_mask)
63
-
64
- plt.axis('off')
65
- plt.tight_layout()
66
- return fig
67
-
68
-
69
- # CPU post process
70
- def fast_show_mask(annotation, ax, bbox=None,
71
- points=None, pointlabel=None,
72
- retinamask=True, target_height=960,
73
- target_width=960):
74
- msak_sum = annotation.shape[0]
75
- height = annotation.shape[1]
76
- weight = annotation.shape[2]
77
- # 将annotation 按照面积 排序
78
- areas = np.sum(annotation, axis=(1, 2))
79
- sorted_indices = np.argsort(areas)[::1]
80
- annotation = annotation[sorted_indices]
81
-
82
- index = (annotation != 0).argmax(axis=0)
83
- color = np.random.random((msak_sum,1,1,3))
84
- transparency = np.ones((msak_sum,1,1,1)) * 0.6
85
- visual = np.concatenate([color,transparency],axis=-1)
86
- mask_image = np.expand_dims(annotation,-1) * visual
87
-
88
- show = np.zeros((height,weight,4))
89
-
90
- h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
91
- indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
92
- # 使用向量化索引更新show的值
93
- show[h_indices, w_indices, :] = mask_image[indices]
94
- if bbox is not None:
95
- x1, y1, x2, y2 = bbox
96
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
97
- # draw point
98
- if points is not None:
99
- plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
100
- plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
101
-
102
- if retinamask==False:
103
- show = cv2.resize(show,(target_width,target_height),interpolation=cv2.INTER_NEAREST)
104
- ax.imshow(show)
105
-
106
-
107
- def fast_show_mask_gpu(annotation, ax,
108
- bbox=None, points=None,
109
- pointlabel=None):
110
- msak_sum = annotation.shape[0]
111
- height = annotation.shape[1]
112
- weight = annotation.shape[2]
113
- areas = torch.sum(annotation, dim=(1, 2))
114
- sorted_indices = torch.argsort(areas, descending=False)
115
- annotation = annotation[sorted_indices]
116
- # 找每个位置第一个非零值下标
117
- index = (annotation != 0).to(torch.long).argmax(dim=0)
118
- color = torch.rand((msak_sum,1,1,3)).to(annotation.device)
119
- transparency = torch.ones((msak_sum,1,1,1)).to(annotation.device) * 0.6
120
- visual = torch.cat([color,transparency],dim=-1)
121
- mask_image = torch.unsqueeze(annotation,-1) * visual
122
- # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
123
- show = torch.zeros((height,weight,4)).to(annotation.device)
124
- h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
125
- indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
126
- # 使用向量化索引更新show的值
127
- show[h_indices, w_indices, :] = mask_image[indices]
128
- show_cpu = show.cpu().numpy()
129
- if bbox is not None:
130
- x1, y1, x2, y2 = bbox
131
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
132
- # draw point
133
- if points is not None:
134
- plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
135
- plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
136
- ax.imshow(show_cpu)
137
-
138
-
139
- # # 预测队列
140
- # prediction_queue = queue.Queue(maxsize=5)
141
-
142
- # # 线程锁
143
- # lock = threading.Lock()
144
-
145
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
146
-
147
- def predict(input, input_size=512, high_visual_quality=False):
148
- input_size = int(input_size) # 确保 imgsz 是整数
149
- # # 获取线程锁
150
- # with lock:
151
- # print('5')
152
- # # 将任务添加到队列
153
- # prediction_queue.put((input, input_size, high_visual_quality))
154
-
155
- # # 等待结果
156
- # print('6')
157
- # fig = prediction_queue.get()[0]
158
- # print(fig)
159
- # return fig
160
- results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
161
- fig = fast_process(annotations=results[0].masks.data,
162
- image=input, high_quality=high_visual_quality, device=device)
163
- return fig
164
-
165
- # def worker():
166
- # while True:
167
- # # 从队列获取任务
168
- # print('1')
169
- # input, input_size, high_visual_quality = prediction_queue.get()
170
-
171
- # # 执行模型预测
172
- # print('2')
173
- # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
174
- # print('3')
175
- # fig = fast_process(annotations=results[0].masks.data,
176
- # image=input, high_quality=high_visual_quality, device=device)
177
- # print('4')
178
- # # 将结果放回队列
179
- # prediction_queue.put(fig)
180
-
181
- # # 在一个新的线程中启动工作函数
182
- # threading.Thread(target=worker).start()
183
-
184
- # # 将耗时的函数包装在另一个函数中,用于控制队列和线程同步
185
- # def process_request():
186
- # while True:
187
- # if not request_queue.empty():
188
- # # 如果请求队列不为空,则处理该请求
189
- # try:
190
- # lock.put(1) # 加锁,防止同时处理多个请求
191
- # input, input_size, high_visual_quality = request_queue.get()
192
- # fig = predict(input, input_size, high_visual_quality)
193
- # request_queue.task_done() # 请求处理结束,移除请求
194
- # lock.get() # 解锁
195
- # yield fig # 返回预测结果
196
- # except:
197
- # lock.get() # 出错时也需要解锁
198
- # else:
199
- # # 如果请求队列为空,则等待新的请求到达
200
- # time.sleep(1)
201
-
202
-
203
- # input_size=1024
204
- # high_quality_visual=True
205
- # inp = 'assets/sa_192.jpg'
206
- # input = Image.open(inp)
207
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
208
- # input_size = int(input_size) # 确保 imgsz 是整数
209
- # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
210
- # pil_image = fast_process(annotations=results[0].masks.data,
211
- # image=input, high_quality=high_quality_visual, device=device)
212
- app_interface = gr.Interface(fn=predict,
213
- inputs=[gr.components.Image(type='pil'),
214
- gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='input_size'),
215
- gr.components.Checkbox(value=False, label='high_visual_quality')],
216
- outputs=['plot'],
217
- examples=[["assets/sa_8776.jpg", 1024, True]],
218
- # # ["assets/sa_1309.jpg", 1024]],
219
- # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
220
- # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
221
- # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
222
- # ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
223
- cache_examples=True,
224
- title="Fast Segment Anything (Everything mode)"
225
- )
226
-
227
- # # 定义一个请求处理函数��将请求添加到队列中
228
- # def handle_request(value):
229
- # try:
230
- # request_queue.put_nowait(value) # 添加请求到队列
231
- # except:
232
- # return "当前队列已满,请稍后再试!"
233
- # return None
234
-
235
- # # 添加请求处理函数到应用程序界面
236
- # app_interface.call_function()
237
-
238
- app_interface.queue(concurrency_count=1, max_size=20)
239
  app_interface.launch()
 
1
+ from ultralytics import YOLO
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import gradio as gr
5
+ import cv2
6
+ import torch
7
+ # import queue
8
+ # import threading
9
+ # from PIL import Image
10
+
11
+
12
+ model = YOLO('checkpoints/FastSAM.pt') # load a custom model
13
+
14
+
15
+ def fast_process(annotations, image, high_quality, device):
16
+ if isinstance(annotations[0],dict):
17
+ annotations = [annotation['segmentation'] for annotation in annotations]
18
+
19
+ original_h = image.height
20
+ original_w = image.width
21
+ fig = plt.figure(figsize=(10, 10))
22
+ plt.imshow(image)
23
+ if high_quality == True:
24
+ if isinstance(annotations[0],torch.Tensor):
25
+ annotations = np.array(annotations.cpu())
26
+ for i, mask in enumerate(annotations):
27
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
28
+ annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
29
+ if device == 'cpu':
30
+ annotations = np.array(annotations)
31
+ fast_show_mask(annotations,
32
+ plt.gca(),
33
+ bbox=None,
34
+ points=None,
35
+ pointlabel=None,
36
+ retinamask=True,
37
+ target_height=original_h,
38
+ target_width=original_w)
39
+ else:
40
+ if isinstance(annotations[0],np.ndarray):
41
+ annotations = torch.from_numpy(annotations)
42
+ fast_show_mask_gpu(annotations,
43
+ plt.gca(),
44
+ bbox=None,
45
+ points=None,
46
+ pointlabel=None)
47
+ if isinstance(annotations, torch.Tensor):
48
+ annotations = annotations.cpu().numpy()
49
+ if high_quality == True:
50
+ contour_all = []
51
+ temp = np.zeros((original_h, original_w,1))
52
+ for i, mask in enumerate(annotations):
53
+ if type(mask) == dict:
54
+ mask = mask['segmentation']
55
+ annotation = mask.astype(np.uint8)
56
+ contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
57
+ for contour in contours:
58
+ contour_all.append(contour)
59
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
60
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
61
+ contour_mask = temp / 225 * color.reshape(1, 1, -1)
62
+ plt.imshow(contour_mask)
63
+
64
+ plt.axis('off')
65
+ plt.tight_layout()
66
+ return fig
67
+
68
+
69
+ # CPU post process
70
+ def fast_show_mask(annotation, ax, bbox=None,
71
+ points=None, pointlabel=None,
72
+ retinamask=True, target_height=960,
73
+ target_width=960):
74
+ msak_sum = annotation.shape[0]
75
+ height = annotation.shape[1]
76
+ weight = annotation.shape[2]
77
+ # 将annotation 按照面积 排序
78
+ areas = np.sum(annotation, axis=(1, 2))
79
+ sorted_indices = np.argsort(areas)[::1]
80
+ annotation = annotation[sorted_indices]
81
+
82
+ index = (annotation != 0).argmax(axis=0)
83
+ color = np.random.random((msak_sum,1,1,3))
84
+ transparency = np.ones((msak_sum,1,1,1)) * 0.6
85
+ visual = np.concatenate([color,transparency],axis=-1)
86
+ mask_image = np.expand_dims(annotation,-1) * visual
87
+
88
+ show = np.zeros((height,weight,4))
89
+
90
+ h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
91
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
92
+ # 使用向量化索引更新show的值
93
+ show[h_indices, w_indices, :] = mask_image[indices]
94
+ if bbox is not None:
95
+ x1, y1, x2, y2 = bbox
96
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
97
+ # draw point
98
+ if points is not None:
99
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
100
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
101
+
102
+ if retinamask==False:
103
+ show = cv2.resize(show,(target_width,target_height),interpolation=cv2.INTER_NEAREST)
104
+ ax.imshow(show)
105
+
106
+
107
+ def fast_show_mask_gpu(annotation, ax,
108
+ bbox=None, points=None,
109
+ pointlabel=None):
110
+ msak_sum = annotation.shape[0]
111
+ height = annotation.shape[1]
112
+ weight = annotation.shape[2]
113
+ areas = torch.sum(annotation, dim=(1, 2))
114
+ sorted_indices = torch.argsort(areas, descending=False)
115
+ annotation = annotation[sorted_indices]
116
+ # 找每个位置第一个非零值下标
117
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
118
+ color = torch.rand((msak_sum,1,1,3)).to(annotation.device)
119
+ transparency = torch.ones((msak_sum,1,1,1)).to(annotation.device) * 0.6
120
+ visual = torch.cat([color,transparency],dim=-1)
121
+ mask_image = torch.unsqueeze(annotation,-1) * visual
122
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
123
+ show = torch.zeros((height,weight,4)).to(annotation.device)
124
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
125
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
126
+ # 使用向量化索引更新show的值
127
+ show[h_indices, w_indices, :] = mask_image[indices]
128
+ show_cpu = show.cpu().numpy()
129
+ if bbox is not None:
130
+ x1, y1, x2, y2 = bbox
131
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
132
+ # draw point
133
+ if points is not None:
134
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
135
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
136
+ ax.imshow(show_cpu)
137
+
138
+
139
+
140
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
141
+
142
+ def predict(input, input_size=512, high_visual_quality=False):
143
+ input_size = int(input_size) # 确保 imgsz 是整数
144
+ results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
145
+ fig = fast_process(annotations=results[0].masks.data,
146
+ image=input, high_quality=high_visual_quality, device=device)
147
+ return fig
148
+
149
+
150
+
151
+ # input_size=1024
152
+ # high_quality_visual=True
153
+ # inp = 'assets/sa_192.jpg'
154
+ # input = Image.open(inp)
155
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
156
+ # input_size = int(input_size) # 确保 imgsz 是整数
157
+ # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
158
+ # pil_image = fast_process(annotations=results[0].masks.data,
159
+ # image=input, high_quality=high_quality_visual, device=device)
160
+ app_interface = gr.Interface(fn=predict,
161
+ inputs=[gr.components.Image(type='pil'),
162
+ gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='input_size'),
163
+ gr.components.Checkbox(value=False, label='high_visual_quality')],
164
+ outputs=['plot'],
165
+ examples=[["assets/sa_8776.jpg", 1024, True]],
166
+ # # ["assets/sa_1309.jpg", 1024]],
167
+ # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
168
+ # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
169
+ # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
170
+ # ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
171
+ cache_examples=True,
172
+ title="Fast Segment Anything (Everything mode)"
173
+ )
174
+
175
+
176
+ app_interface.queue(concurrency_count=1, max_size=20)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  app_interface.launch()
app_copy.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ultralytics import YOLO
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ import gradio as gr
5
+ import cv2
6
+ import torch
7
+ # import queue
8
+ # import threading
9
+ from PIL import Image
10
+
11
+
12
+ model = YOLO('checkpoints/FastSAM.pt') # load a custom model
13
+
14
+
15
+ def fast_process(annotations, image, high_quality, device):
16
+ if isinstance(annotations[0],dict):
17
+ annotations = [annotation['segmentation'] for annotation in annotations]
18
+
19
+ original_h = image.height
20
+ original_w = image.width
21
+ # fig = plt.figure(figsize=(10, 10))
22
+ # plt.imshow(image)
23
+ if high_quality == True:
24
+ if isinstance(annotations[0],torch.Tensor):
25
+ annotations = np.array(annotations.cpu())
26
+ for i, mask in enumerate(annotations):
27
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
28
+ annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
29
+ if device == 'cpu':
30
+ annotations = np.array(annotations)
31
+ inner_mask = fast_show_mask(annotations,
32
+ plt.gca(),
33
+ bbox=None,
34
+ points=None,
35
+ pointlabel=None,
36
+ retinamask=True,
37
+ target_height=original_h,
38
+ target_width=original_w)
39
+ else:
40
+ if isinstance(annotations[0],np.ndarray):
41
+ annotations = torch.from_numpy(annotations)
42
+ inner_mask = fast_show_mask_gpu(annotations,
43
+ plt.gca(),
44
+ bbox=None,
45
+ points=None,
46
+ pointlabel=None)
47
+ if isinstance(annotations, torch.Tensor):
48
+ annotations = annotations.cpu().numpy()
49
+ if high_quality == True:
50
+ contour_all = []
51
+ temp = np.zeros((original_h, original_w,1))
52
+ for i, mask in enumerate(annotations):
53
+ if type(mask) == dict:
54
+ mask = mask['segmentation']
55
+ annotation = mask.astype(np.uint8)
56
+ contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
57
+ for contour in contours:
58
+ contour_all.append(contour)
59
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
60
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
61
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
62
+ # plt.imshow(contour_mask)
63
+ image = image.convert('RGBA')
64
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
65
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
66
+ # image = image.convert('RGBA')
67
+ # image = Image.alpha_composite(image, overlay_inner)
68
+ # image = Image.alpha_composite(image, overlay_contour)
69
+ image.paste(overlay_inner, (0, 0), overlay_inner)
70
+ image.paste(overlay_contour, (0, 0), overlay_contour)
71
+
72
+ return image
73
+ # plt.axis('off')
74
+ # plt.tight_layout()
75
+ # return fig
76
+
77
+
78
+ # CPU post process
79
+ def fast_show_mask(annotation, ax, bbox=None,
80
+ points=None, pointlabel=None,
81
+ retinamask=True, target_height=960,
82
+ target_width=960):
83
+ msak_sum = annotation.shape[0]
84
+ height = annotation.shape[1]
85
+ weight = annotation.shape[2]
86
+ # 将annotation 按照面积 排序
87
+ areas = np.sum(annotation, axis=(1, 2))
88
+ sorted_indices = np.argsort(areas)[::1]
89
+ annotation = annotation[sorted_indices]
90
+
91
+ index = (annotation != 0).argmax(axis=0)
92
+ color = np.random.random((msak_sum,1,1,3))
93
+ transparency = np.ones((msak_sum,1,1,1)) * 0.6
94
+ visual = np.concatenate([color,transparency],axis=-1)
95
+ mask_image = np.expand_dims(annotation,-1) * visual
96
+
97
+ mask = np.zeros((height,weight,4))
98
+
99
+ h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
100
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
101
+ # 使用向量化索引更新show的值
102
+ mask[h_indices, w_indices, :] = mask_image[indices]
103
+ if bbox is not None:
104
+ x1, y1, x2, y2 = bbox
105
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
106
+ # draw point
107
+ if points is not None:
108
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
109
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
110
+
111
+ if retinamask==False:
112
+ mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
113
+ # ax.imshow(mask)
114
+
115
+ return mask
116
+
117
+
118
+ def fast_show_mask_gpu(annotation, ax,
119
+ bbox=None, points=None,
120
+ pointlabel=None):
121
+ msak_sum = annotation.shape[0]
122
+ height = annotation.shape[1]
123
+ weight = annotation.shape[2]
124
+ areas = torch.sum(annotation, dim=(1, 2))
125
+ sorted_indices = torch.argsort(areas, descending=False)
126
+ annotation = annotation[sorted_indices]
127
+ # 找每个位置第一个非零值下标
128
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
129
+ color = torch.rand((msak_sum,1,1,3)).to(annotation.device)
130
+ transparency = torch.ones((msak_sum,1,1,1)).to(annotation.device) * 0.6
131
+ visual = torch.cat([color,transparency],dim=-1)
132
+ mask_image = torch.unsqueeze(annotation,-1) * visual
133
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
134
+ mask = torch.zeros((height,weight,4)).to(annotation.device)
135
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
136
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
137
+ # 使用向量化索引更新show的值
138
+ mask[h_indices, w_indices, :] = mask_image[indices]
139
+ mask_cpu = mask.cpu().numpy()
140
+ if bbox is not None:
141
+ x1, y1, x2, y2 = bbox
142
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
143
+ # draw point
144
+ if points is not None:
145
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
146
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
147
+ # ax.imshow(mask_cpu)
148
+ return mask_cpu
149
+
150
+
151
+ # # 预测队列
152
+ # prediction_queue = queue.Queue(maxsize=5)
153
+
154
+ # # 线程锁
155
+ # lock = threading.Lock()
156
+
157
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
158
+
159
+ def predict(input, input_size=512, high_visual_quality=False):
160
+ input_size = int(input_size) # 确保 imgsz 是整数
161
+ results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
162
+ fig = fast_process(annotations=results[0].masks.data,
163
+ image=input, high_quality=high_visual_quality, device=device)
164
+ return fig
165
+
166
+
167
+
168
+ # input_size=1024
169
+ # high_quality_visual=True
170
+ # inp = 'assets/sa_192.jpg'
171
+ # input = Image.open(inp)
172
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
173
+ # input_size = int(input_size) # 确保 imgsz 是整数
174
+ # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
175
+ # pil_image = fast_process(annotations=results[0].masks.data,
176
+ # image=input, high_quality=high_quality_visual, device=device)
177
+
178
+ app_interface = gr.Interface(fn=predict,
179
+ inputs=[gr.components.Image(type='pil'),
180
+ gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='input_size'),
181
+ gr.components.Checkbox(value=False, label='high_visual_quality')],
182
+ # outputs=['plot'],
183
+ outputs=gr.components.Image(type='pil'),
184
+ examples=[["assets/sa_8776.jpg", 1024, True]],
185
+ # # ["assets/sa_1309.jpg", 1024]],
186
+ # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
187
+ # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
188
+ # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
189
+ # ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
190
+ cache_examples=True,
191
+ title="Fast Segment Anything (Everything mode)"
192
+ )
193
+
194
+
195
+ app_interface.queue(concurrency_count=1, max_size=20)
196
+ app_interface.launch()
requirements.txt CHANGED
@@ -14,5 +14,5 @@ opencv-python
14
  # seaborn>=0.11.0
15
 
16
  # Ultralytics-----------------------------------
17
- ultralytics==8.0.112
18
 
 
14
  # seaborn>=0.11.0
15
 
16
  # Ultralytics-----------------------------------
17
+ ultralytics==8.0.121
18