AAAAAAyq commited on
Commit
ca3609f
โ€ข
1 Parent(s): d1be458

Update the interface layout

Browse files
Files changed (1) hide show
  1. app.py +122 -42
app.py CHANGED
@@ -4,22 +4,45 @@ import matplotlib.pyplot as plt
4
  import gradio as gr
5
  import cv2
6
  import torch
7
- # import queue
8
- # import threading
9
  from PIL import Image
10
 
 
 
11
 
12
- model = YOLO('checkpoints/FastSAM.pt') # load a custom model
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- def fast_process(annotations, image, high_quality, device):
 
 
 
 
 
 
 
 
 
16
  if isinstance(annotations[0],dict):
17
  annotations = [annotation['segmentation'] for annotation in annotations]
18
 
19
  original_h = image.height
20
  original_w = image.width
21
- # fig = plt.figure(figsize=(10, 10))
22
- # plt.imshow(image)
23
  if high_quality == True:
24
  if isinstance(annotations[0],torch.Tensor):
25
  annotations = np.array(annotations.cpu())
@@ -57,10 +80,9 @@ def fast_process(annotations, image, high_quality, device):
57
  contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
58
  for contour in contours:
59
  contour_all.append(contour)
60
- cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
61
  color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
62
  contour_mask = temp / 255 * color.reshape(1, 1, -1)
63
- # plt.imshow(contour_mask)
64
  image = image.convert('RGBA')
65
 
66
  overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
@@ -71,10 +93,6 @@ def fast_process(annotations, image, high_quality, device):
71
  image.paste(overlay_contour, (0, 0), overlay_contour)
72
 
73
  return image
74
- # plt.axis('off')
75
- # plt.tight_layout()
76
- # return fig
77
-
78
 
79
  # CPU post process
80
  def fast_show_mask(annotation, ax, bbox=None,
@@ -111,7 +129,6 @@ def fast_show_mask(annotation, ax, bbox=None,
111
 
112
  if retinamask==False:
113
  mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
114
- # ax.imshow(mask)
115
 
116
  return mask
117
 
@@ -145,19 +162,12 @@ def fast_show_mask_gpu(annotation, ax,
145
  if points is not None:
146
  plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
147
  plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
148
- # ax.imshow(mask_cpu)
149
  return mask_cpu
150
 
151
 
152
- # # ้ข„ๆต‹้˜Ÿๅˆ—
153
- # prediction_queue = queue.Queue(maxsize=5)
154
-
155
- # # ็บฟ็จ‹้”
156
- # lock = threading.Lock()
157
-
158
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
159
 
160
- def predict(input, input_size=1024, high_visual_quality=True):
161
  input_size = int(input_size) # ็กฎไฟ imgsz ๆ˜ฏๆ•ดๆ•ฐ
162
 
163
  # Thanks for the suggestion by hysts in HuggingFace.
@@ -167,9 +177,10 @@ def predict(input, input_size=1024, high_visual_quality=True):
167
  new_h = int(h * scale)
168
  input = input.resize((new_w, new_h))
169
 
170
- results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
171
  fig = fast_process(annotations=results[0].masks.data,
172
- image=input, high_quality=high_visual_quality, device=device)
 
173
  return fig
174
 
175
  # input_size=1024
@@ -182,22 +193,91 @@ def predict(input, input_size=1024, high_visual_quality=True):
182
  # pil_image = fast_process(annotations=results[0].masks.data,
183
  # image=input, high_quality=high_quality_visual, device=device)
184
 
185
- app_interface = gr.Interface(fn=predict,
186
- inputs=[gr.Image(type='pil'),
187
- gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='input_size'),
188
- gr.components.Checkbox(value=True, label='high_visual_quality')],
189
- # outputs=['plot'],
190
- outputs=gr.Image(type='pil'),
191
- # examples=[["assets/sa_8776.jpg"]],
192
- # # ["assets/sa_1309.jpg", 1024]],
193
- examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
194
- ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
195
- ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
196
- ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
197
- cache_examples=True,
198
- title="Fast Segment Anything (Everything mode)"
199
- )
200
-
201
-
202
- app_interface.queue(concurrency_count=1, max_size=20)
203
- app_interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import gradio as gr
5
  import cv2
6
  import torch
 
 
7
  from PIL import Image
8
 
9
+ # Load the pre-trained model
10
+ model = YOLO('checkpoints/FastSAM.pt')
11
 
12
+ # Description
13
+ title = "<center><strong><font size='8'>๐Ÿƒ Fast Segment Anything ๐Ÿค—</font></strong></center>"
14
 
15
+ description = """This is a demo on Github project ๐Ÿƒ [Fast Segment Anything Model](https://github.com/CASIA-IVA-Lab/FastSAM).
16
+
17
+ ๐ŸŽฏ Upload an Image, segment it with Fast Segment Anything (Everything mode). The other modes will come soon.
18
+
19
+ โŒ›๏ธ It takes about 4~ seconds to generate segment results. The concurrency_count of queue is 1, please wait for a moment when it is crowded.
20
+
21
+ ๐Ÿš€ To get faster results, you can use a smaller input size and leave high_visual_quality unchecked.
22
+
23
+ ๐Ÿ“ฃ You can also obtain the segmentation results of any Image through this Colab: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oX14f6IneGGw612WgVlAiy91UHwFAvr9?usp=sharing)
24
+
25
+ ๐Ÿ˜š A huge thanks goes out to the @HuggingFace Team for supporting us with GPU grant.
26
+
27
+ ๐Ÿ  Check out our [Model Card ๐Ÿƒ](https://huggingface.co/An-619/FastSAM)
28
+
29
+ """
30
 
31
+ examples = [["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
32
+ ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
33
+ ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
34
+ ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"]]
35
+
36
+ default_example = examples[5]
37
+
38
+ css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
39
+
40
+ def fast_process(annotations, image, high_quality, device, scale):
41
  if isinstance(annotations[0],dict):
42
  annotations = [annotation['segmentation'] for annotation in annotations]
43
 
44
  original_h = image.height
45
  original_w = image.width
 
 
46
  if high_quality == True:
47
  if isinstance(annotations[0],torch.Tensor):
48
  annotations = np.array(annotations.cpu())
 
80
  contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
81
  for contour in contours:
82
  contour_all.append(contour)
83
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
84
  color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
85
  contour_mask = temp / 255 * color.reshape(1, 1, -1)
 
86
  image = image.convert('RGBA')
87
 
88
  overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
 
93
  image.paste(overlay_contour, (0, 0), overlay_contour)
94
 
95
  return image
 
 
 
 
96
 
97
  # CPU post process
98
  def fast_show_mask(annotation, ax, bbox=None,
 
129
 
130
  if retinamask==False:
131
  mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
 
132
 
133
  return mask
134
 
 
162
  if points is not None:
163
  plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
164
  plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
 
165
  return mask_cpu
166
 
167
 
 
 
 
 
 
 
168
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
169
 
170
+ def segment_image(input, input_size=1024, high_visual_quality=True, iou_threshold=0.7, conf_threshold=0.25):
171
  input_size = int(input_size) # ็กฎไฟ imgsz ๆ˜ฏๆ•ดๆ•ฐ
172
 
173
  # Thanks for the suggestion by hysts in HuggingFace.
 
177
  new_h = int(h * scale)
178
  input = input.resize((new_w, new_h))
179
 
180
+ results = model(input, device=device, retina_masks=True, iou=iou_threshold, conf=conf_threshold, imgsz=input_size)
181
  fig = fast_process(annotations=results[0].masks.data,
182
+ image=input, high_quality=high_visual_quality,
183
+ device=device, scale=(1024 // input_size))
184
  return fig
185
 
186
  # input_size=1024
 
193
  # pil_image = fast_process(annotations=results[0].masks.data,
194
  # image=input, high_quality=high_quality_visual, device=device)
195
 
196
+ cond_img = gr.Image(label="Input", value=default_example[0], type='pil')
197
+
198
+ segm_img = gr.Image(label="Segmented Image", interactive=False, type='pil')
199
+
200
+ input_size_slider = gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='input_size')
201
+
202
+ with gr.Blocks(css=css, title='Fast Segment Anything') as demo:
203
+ with gr.Row():
204
+ # Title
205
+ gr.Markdown(title)
206
+ # # # Description
207
+ # # gr.Markdown(description)
208
+
209
+ # Images
210
+ with gr.Row(variant="panel"):
211
+ with gr.Column(scale=1):
212
+ cond_img.render()
213
+
214
+ with gr.Column(scale=1):
215
+ segm_img.render()
216
+
217
+ # Submit & Clear
218
+ with gr.Row():
219
+ with gr.Column():
220
+ input_size_slider.render()
221
+
222
+ with gr.Row():
223
+ vis_check = gr.Checkbox(value=True, label='high_visual_quality')
224
+
225
+ with gr.Column():
226
+ segment_btn = gr.Button("Segment Anything", variant='primary')
227
+
228
+ # with gr.Column():
229
+ # clear_btn = gr.Button("Clear", variant="primary")
230
+
231
+ gr.Markdown("Try some of the examples below โฌ‡๏ธ")
232
+ gr.Examples(examples=examples,
233
+ inputs=[cond_img],
234
+ outputs=segm_img,
235
+ fn=segment_image,
236
+ cache_examples=True,
237
+ examples_per_page=4)
238
+ # gr.Markdown("Try some of the examples below โฌ‡๏ธ")
239
+ # gr.Examples(examples=examples,
240
+ # inputs=[cond_img, input_size_slider, vis_check, iou_threshold, conf_threshold],
241
+ # outputs=output,
242
+ # fn=segment_image,
243
+ # examples_per_page=4)
244
+
245
+ with gr.Column():
246
+ with gr.Accordion("Advanced options", open=False):
247
+ iou_threshold = gr.Slider(0.1, 0.9, 0.7, step=0.1, label='iou_threshold')
248
+ conf_threshold = gr.Slider(0.1, 0.9, 0.25, step=0.05, label='conf_threshold')
249
+
250
+ # Description
251
+ gr.Markdown(description)
252
+
253
+ segment_btn.click(segment_image,
254
+ inputs=[cond_img, input_size_slider, vis_check, iou_threshold, conf_threshold],
255
+ outputs=segm_img)
256
+
257
+ # def clear():
258
+ # return None, None
259
+
260
+ # clear_btn.click(fn=clear, inputs=None, outputs=None)
261
+
262
+ demo.queue()
263
+ demo.launch()
264
+
265
+ # app_interface = gr.Interface(fn=predict,
266
+ # inputs=[gr.Image(type='pil'),
267
+ # gr.components.Slider(minimum=512, maximum=1024, value=1024, step=64, label='input_size'),
268
+ # gr.components.Checkbox(value=True, label='high_visual_quality')],
269
+ # # outputs=['plot'],
270
+ # outputs=gr.Image(type='pil'),
271
+ # # examples=[["assets/sa_8776.jpg"]],
272
+ # # # ["assets/sa_1309.jpg", 1024]],
273
+ # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
274
+ # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
275
+ # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
276
+ # ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
277
+ # cache_examples=True,
278
+ # title="Fast Segment Anything (Everything mode)"
279
+ # )
280
+
281
+
282
+ # app_interface.queue(concurrency_count=1, max_size=20)
283
+ # app_interface.launch()