AAAAAAyq commited on
Commit
dc120b2
1 Parent(s): 77f39b8

Update app.py

Browse files
Files changed (2) hide show
  1. app.py +116 -119
  2. requirements.txt +1 -1
app.py CHANGED
@@ -4,29 +4,70 @@ import matplotlib.pyplot as plt
4
  import gradio as gr
5
  import cv2
6
  import torch
 
7
 
8
  model = YOLO('checkpoints/FastSAM.pt') # load a custom model
9
 
10
 
11
- def fast_process(annotations, image):
 
 
 
 
 
12
  fig = plt.figure(figsize=(10, 10))
13
  plt.imshow(image)
14
- #original_h = image.shape[0]
15
- #original_w = image.shape[1]
16
- #for i, mask in enumerate(annotations):
17
- # mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
18
- # annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
19
- fast_show_mask(annotations,
20
- plt.gca())
21
- #target_height=original_h,
22
- #target_width=original_w)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  plt.axis('off')
24
  plt.tight_layout()
25
  return fig
26
 
27
 
28
  # CPU post process
29
- def fast_show_mask(annotation, ax):
 
 
 
30
  msak_sum = annotation.shape[0]
31
  height = annotation.shape[1]
32
  weight = annotation.shape[2]
@@ -36,136 +77,92 @@ def fast_show_mask(annotation, ax):
36
  annotation = annotation[sorted_indices]
37
 
38
  index = (annotation != 0).argmax(axis=0)
39
- color = np.random.random((msak_sum, 1, 1, 3))
40
- transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
41
- visual = np.concatenate([color, transparency], axis=-1)
42
- mask_image = np.expand_dims(annotation, -1) * visual
43
 
44
- show = np.zeros((height, weight, 4))
45
 
46
  h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
47
  indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
48
  # 使用向量化索引更新show的值
49
  show[h_indices, w_indices, :] = mask_image[indices]
50
-
51
-
52
- #if retinamask == False:
53
- # show = cv2.resize(show, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
 
 
 
 
 
 
54
  ax.imshow(show)
55
 
56
 
57
-
58
- # post_process(results[0].masks, Image.open("../data/cake.png"))
59
-
60
- def predict(input, input_size=512):
61
- input_size = int(input_size) # 确保 imgsz 是整数
62
- results = model(input, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
63
- pil_image = fast_process(annotations=results[0].masks.data.numpy(), image=input)
64
-
65
- return pil_image
66
-
67
-
68
- # inp = 'assets/sa_192.jpg'
69
- # results = model(inp, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=1024)
70
- # results = format_results(results[0], 100)
71
- # post_process(annotations=results, image_path=inp)
72
-
73
- demo = gr.Interface(fn=predict,
74
- inputs=[gr.inputs.Image(type='pil'), gr.inputs.Dropdown(choices=[512, 800, 1024], default=512)],
75
- outputs=['plot'],
76
- examples=[["assets/sa_8776.jpg", 1024]],
77
- # ["assets/sa_1309.jpg", 1024]],
78
- # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
79
- # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
80
- # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
81
- # ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
82
- )
83
-
84
- demo.launch()
85
- """
86
-
87
- from ultralytics import YOLO
88
- import numpy as np
89
- import matplotlib.pyplot as plt
90
- import gradio as gr
91
- import torch
92
-
93
- model = YOLO('checkpoints/FastSAM.pt') # load a custom model
94
-
95
- def format_results(result,filter = 0):
96
- annotations = []
97
- n = len(result.masks.data)
98
- for i in range(n):
99
- annotation = {}
100
- mask = result.masks.data[i] == 1.0
101
-
102
- if torch.sum(mask) < filter:
103
- continue
104
- annotation['id'] = i
105
- annotation['segmentation'] = mask.cpu().numpy()
106
- annotation['bbox'] = result.boxes.data[i]
107
- annotation['score'] = result.boxes.conf[i]
108
- annotation['area'] = annotation['segmentation'].sum()
109
- annotations.append(annotation)
110
- return annotations
111
-
112
- def show_mask(annotation, ax, random_color=True, bbox=None, points=None):
113
- if random_color : # random mask color
114
- color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
115
- else:
116
- color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
117
- if type(annotation) == dict:
118
- annotation = annotation['segmentation']
119
- mask = annotation
120
- h, w = mask.shape[-2:]
121
- mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
122
- # draw box
123
  if bbox is not None:
124
  x1, y1, x2, y2 = bbox
125
  ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
126
  # draw point
127
  if points is not None:
128
- ax.scatter([point[0] for point in points], [point[1] for point in points], s=10, c='g')
129
- ax.imshow(mask_image)
130
- return mask_image
131
-
132
- def post_process(annotations, image, mask_random_color=True, bbox=None, points=None):
133
- fig = plt.figure(figsize=(10, 10))
134
- plt.imshow(image)
135
- for i, mask in enumerate(annotations):
136
- show_mask(mask, plt.gca(),random_color=mask_random_color,bbox=bbox,points=points)
137
- plt.axis('off')
138
-
139
- plt.tight_layout()
140
- return fig
141
-
142
 
143
  # post_process(results[0].masks, Image.open("../data/cake.png"))
144
 
145
- def predict(input, input_size):
 
146
  input_size = int(input_size) # 确保 imgsz 是整数
147
- results = model(input, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
148
- results = format_results(results[0], 100)
149
- results.sort(key=lambda x: x['area'], reverse=True)
150
- pil_image = post_process(annotations=results, image=input)
151
  return pil_image
152
 
 
 
153
  # inp = 'assets/sa_192.jpg'
154
- # results = model(inp, device='cpu', retina_masks=True, iou=0.7, conf=0.25, imgsz=1024)
155
- # results = format_results(results[0], 100)
156
- # post_process(annotations=results, image_path=inp)
157
-
 
 
158
  demo = gr.Interface(fn=predict,
159
- inputs=[gr.inputs.Image(type='pil'), gr.inputs.Dropdown(choices=[512, 800, 1024], default=1024)],
 
 
160
  outputs=['plot'],
161
- examples=[["assets/sa_8776.jpg", 1024]],
162
- # ["assets/sa_1309.jpg", 1024]],
163
- # examples=[["assets/sa_192.jpg"], ["assets/sa_414.jpg"],
164
- # ["assets/sa_561.jpg"], ["assets/sa_862.jpg"],
165
- # ["assets/sa_1309.jpg"], ["assets/sa_8776.jpg"],
166
- # ["assets/sa_10039.jpg"], ["assets/sa_11025.jpg"],],
 
167
  )
168
 
169
- demo.launch()
170
-
171
- """
 
4
  import gradio as gr
5
  import cv2
6
  import torch
7
+ from PIL import Image
8
 
9
  model = YOLO('checkpoints/FastSAM.pt') # load a custom model
10
 
11
 
12
+ def fast_process(annotations, image, high_quality, device):
13
+ if isinstance(annotations[0],dict):
14
+ annotations = [annotation['segmentation'] for annotation in annotations]
15
+
16
+ original_h = image.height
17
+ original_w = image.width
18
  fig = plt.figure(figsize=(10, 10))
19
  plt.imshow(image)
20
+ if high_quality == True:
21
+ if isinstance(annotations[0],torch.Tensor):
22
+ annotations = np.array(annotations.cpu())
23
+ for i, mask in enumerate(annotations):
24
+ mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
25
+ annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
26
+ if device == 'cpu':
27
+ annotations = np.array(annotations)
28
+ fast_show_mask(annotations,
29
+ plt.gca(),
30
+ bbox=None,
31
+ points=None,
32
+ pointlabel=None,
33
+ retinamask=True,
34
+ target_height=original_h,
35
+ target_width=original_w)
36
+ else:
37
+ if isinstance(annotations[0],np.ndarray):
38
+ annotations = torch.from_numpy(annotations)
39
+ fast_show_mask_gpu(annotations,
40
+ plt.gca(),
41
+ bbox=None,
42
+ points=None,
43
+ pointlabel=None)
44
+ if isinstance(annotations, torch.Tensor):
45
+ annotations = annotations.cpu().numpy()
46
+ if high_quality == True:
47
+ contour_all = []
48
+ temp = np.zeros((original_h, original_w,1))
49
+ for i, mask in enumerate(annotations):
50
+ if type(mask) == dict:
51
+ mask = mask['segmentation']
52
+ annotation = mask.astype(np.uint8)
53
+ contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
54
+ for contour in contours:
55
+ contour_all.append(contour)
56
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
57
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
58
+ contour_mask = temp / 225 * color.reshape(1, 1, -1)
59
+ plt.imshow(contour_mask)
60
+
61
  plt.axis('off')
62
  plt.tight_layout()
63
  return fig
64
 
65
 
66
  # CPU post process
67
+ def fast_show_mask(annotation, ax, bbox=None,
68
+ points=None, pointlabel=None,
69
+ retinamask=True, target_height=960,
70
+ target_width=960):
71
  msak_sum = annotation.shape[0]
72
  height = annotation.shape[1]
73
  weight = annotation.shape[2]
 
77
  annotation = annotation[sorted_indices]
78
 
79
  index = (annotation != 0).argmax(axis=0)
80
+ color = np.random.random((msak_sum,1,1,3))
81
+ transparency = np.ones((msak_sum,1,1,1)) * 0.6
82
+ visual = np.concatenate([color,transparency],axis=-1)
83
+ mask_image = np.expand_dims(annotation,-1) * visual
84
 
85
+ show = np.zeros((height,weight,4))
86
 
87
  h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
88
  indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
89
  # 使用向量化索引更新show的值
90
  show[h_indices, w_indices, :] = mask_image[indices]
91
+ if bbox is not None:
92
+ x1, y1, x2, y2 = bbox
93
+ ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
94
+ # draw point
95
+ if points is not None:
96
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
97
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
98
+
99
+ if retinamask==False:
100
+ show = cv2.resize(show,(target_width,target_height),interpolation=cv2.INTER_NEAREST)
101
  ax.imshow(show)
102
 
103
 
104
+ def fast_show_mask_gpu(annotation, ax,
105
+ bbox=None, points=None,
106
+ pointlabel=None):
107
+ msak_sum = annotation.shape[0]
108
+ height = annotation.shape[1]
109
+ weight = annotation.shape[2]
110
+ areas = torch.sum(annotation, dim=(1, 2))
111
+ sorted_indices = torch.argsort(areas, descending=False)
112
+ annotation = annotation[sorted_indices]
113
+ # 找每个位置第一个非零值下标
114
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
115
+ color = torch.rand((msak_sum,1,1,3)).to(annotation.device)
116
+ transparency = torch.ones((msak_sum,1,1,1)).to(annotation.device) * 0.6
117
+ visual = torch.cat([color,transparency],dim=-1)
118
+ mask_image = torch.unsqueeze(annotation,-1) * visual
119
+ # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
120
+ show = torch.zeros((height,weight,4)).to(annotation.device)
121
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
122
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
123
+ # 使用向量化索引更新show的值
124
+ show[h_indices, w_indices, :] = mask_image[indices]
125
+ show_cpu = show.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  if bbox is not None:
127
  x1, y1, x2, y2 = bbox
128
  ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
129
  # draw point
130
  if points is not None:
131
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==1], [point[1] for i, point in enumerate(points) if pointlabel[i]==1], s=20, c='y')
132
+ plt.scatter([point[0] for i, point in enumerate(points) if pointlabel[i]==0], [point[1] for i, point in enumerate(points) if pointlabel[i]==0], s=20, c='m')
133
+ ax.imshow(show_cpu)
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  # post_process(results[0].masks, Image.open("../data/cake.png"))
136
 
137
+ def predict(input, input_size=512, high_quality_visual=True):
138
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
139
  input_size = int(input_size) # 确保 imgsz 是整数
140
+ results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
141
+ pil_image = fast_process(annotations=results[0].masks.data,
142
+ image=input, high_quality=high_quality_visual, device=device)
 
143
  return pil_image
144
 
145
+ # input_size=1024
146
+ # high_quality_visual=True
147
  # inp = 'assets/sa_192.jpg'
148
+ # input = Image.open(inp)
149
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
150
+ # input_size = int(input_size) # 确保 imgsz 是整数
151
+ # results = model(input, device=device, retina_masks=True, iou=0.7, conf=0.25, imgsz=input_size)
152
+ # pil_image = fast_process(annotations=results[0].masks.data,
153
+ # image=input, high_quality=high_quality_visual, device=device)
154
  demo = gr.Interface(fn=predict,
155
+ inputs=[gr.components.Image(type='pil'),
156
+ gr.components.Dropdown(choices=[512, 800, 1024], default=512),
157
+ gr.components.Checkbox(default=True)],
158
  outputs=['plot'],
159
+ # examples=[["assets/sa_8776.jpg", 1024, True]],
160
+ # ["assets/sa_1309.jpg", 1024]],
161
+ examples=[["assets/sa_192.jpg", 1024, True], ["assets/sa_414.jpg", 1024, True],
162
+ ["assets/sa_561.jpg", 1024, True], ["assets/sa_862.jpg", 1024, True],
163
+ ["assets/sa_1309.jpg", 1024, True], ["assets/sa_8776.jpg", 1024, True],
164
+ ["assets/sa_10039.jpg", 1024, True], ["assets/sa_11025.jpg", 1024, True],],
165
+ cache_examples=False,
166
  )
167
 
168
+ demo.launch()
 
 
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  # Base-----------------------------------
2
  matplotlib==3.2.2
3
  numpy
4
- # opencv-python>=4.6.0
5
  # Pillow>=7.1.2
6
  # PyYAML>=5.3.1
7
  # requests>=2.23.0
 
1
  # Base-----------------------------------
2
  matplotlib==3.2.2
3
  numpy
4
+ opencv-python
5
  # Pillow>=7.1.2
6
  # PyYAML>=5.3.1
7
  # requests>=2.23.0