happyme531
commited on
Commit
•
ea6ccf2
1
Parent(s):
bff24fd
Upload 7 files
Browse files- .gitattributes +1 -0
- convert_encoder.py +43 -0
- input.jpg +0 -0
- output.jpg +0 -0
- patch_graph.py +55 -0
- run_sam_rknn.py +237 -0
- sam_vit_b_01ec64.pth.decoder.onnx +3 -0
- sam_vit_b_01ec64.pth.encoder.patched.onnx.rknn +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
sam_vit_b_01ec64.pth.encoder.patched.onnx.rknn filter=lfs diff=lfs merge=lfs -text
|
convert_encoder.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import urllib
|
3 |
+
import traceback
|
4 |
+
import time
|
5 |
+
import sys
|
6 |
+
import numpy as np
|
7 |
+
import cv2
|
8 |
+
from rknn.api import RKNN
|
9 |
+
|
10 |
+
ONNX_MODEL="sam_vit_b_01ec64.pth.encoder.patched.onnx"
|
11 |
+
RKNN_MODEL="sam_vit_b_01ec64.pth.encoder.onnx.rknn"
|
12 |
+
rknn = RKNN(verbose=True)
|
13 |
+
|
14 |
+
# pre-process config
|
15 |
+
print('--> config model')
|
16 |
+
rknn.config(target_platform='rk3588', single_core_mode=True)
|
17 |
+
print('done')
|
18 |
+
|
19 |
+
# Load model
|
20 |
+
print("--> Loading model")
|
21 |
+
ret = rknn.load_onnx(
|
22 |
+
model=ONNX_MODEL, inputs=["input_image"], input_size_list=[[1024, 1024, 3]]
|
23 |
+
)
|
24 |
+
if ret != 0:
|
25 |
+
print("Load model failed!")
|
26 |
+
exit(ret)
|
27 |
+
print("done")
|
28 |
+
|
29 |
+
# Build model
|
30 |
+
print('--> Building model')
|
31 |
+
ret = rknn.build(do_quantization=False)
|
32 |
+
if ret != 0:
|
33 |
+
print('Build model failed!')
|
34 |
+
exit(ret)
|
35 |
+
print('done')
|
36 |
+
|
37 |
+
# Export rknn model
|
38 |
+
print('--> Export rknn model')
|
39 |
+
ret = rknn.export_rknn(RKNN_MODEL)
|
40 |
+
if ret != 0:
|
41 |
+
print('Export rknn model failed!')
|
42 |
+
exit(ret)
|
43 |
+
print('done')
|
input.jpg
ADDED
output.jpg
ADDED
patch_graph.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import onnx_graphsurgeon as gs
|
2 |
+
import onnx
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
# Load the ONNX model
|
6 |
+
graph = gs.import_onnx(onnx.load("check3_fuse_ops.onnx"))
|
7 |
+
|
8 |
+
count=0
|
9 |
+
# Iterate through all nodes in the graph
|
10 |
+
for node in graph.nodes:
|
11 |
+
# Check if the node is a Reshape operator
|
12 |
+
if node.op == 'Reshape':
|
13 |
+
# Get the shape input of the Reshape node
|
14 |
+
shape_input = node.inputs[1]
|
15 |
+
|
16 |
+
# Check if the shape input is a constant (which it should be for static reshapes)
|
17 |
+
if isinstance(shape_input, gs.Constant):
|
18 |
+
current_shape = shape_input.values
|
19 |
+
|
20 |
+
# Check if it's a 5D reshape with the target shape [12,64,64,...,...]
|
21 |
+
if len(current_shape) == 5 and current_shape[0] == 12 and current_shape[1] == 64 and current_shape[2] == 64:
|
22 |
+
# Modify the shape to [12,4096,...,...]
|
23 |
+
new_shape = np.array([12, 4096, current_shape[3], current_shape[4]], dtype=np.int64)
|
24 |
+
print(f"Patched {current_shape} -> {new_shape}")
|
25 |
+
|
26 |
+
# Update the shape input with the new shape
|
27 |
+
shape_input.values = new_shape
|
28 |
+
count = count + 1
|
29 |
+
# print(f"Patched {node}")
|
30 |
+
|
31 |
+
|
32 |
+
# Check if it's a 5D reshape with the target shape [300,14,14,...,...]
|
33 |
+
if len(current_shape) == 5 and current_shape[0] == 300 and current_shape[1] == 14 and current_shape[2] == 14:
|
34 |
+
# Modify the shape to [300,196,...,...]
|
35 |
+
new_shape = np.array([300, 196, current_shape[3], current_shape[4]], dtype=np.int64)
|
36 |
+
print(f"Patched {current_shape} -> {new_shape}")
|
37 |
+
|
38 |
+
# Update the shape input with the new shape
|
39 |
+
shape_input.values = new_shape
|
40 |
+
count = count + 1
|
41 |
+
# print(f"Patched {node}")
|
42 |
+
|
43 |
+
graph.cleanup().toposort()
|
44 |
+
print(f"Patched {count} nodes.")
|
45 |
+
|
46 |
+
model = gs.export_onnx(graph)
|
47 |
+
|
48 |
+
# Delete old shape information from the model
|
49 |
+
for value_info in model.graph.value_info:
|
50 |
+
value_info.type.tensor_type.ClearField('shape')
|
51 |
+
|
52 |
+
# Save the modified model
|
53 |
+
onnx.save(model, "sam_vit_b_01ec64.pth.encoder.patched.onnx")
|
54 |
+
|
55 |
+
print("Saved as 'sam_vit_b_01ec64.pth.encoder.patched.onnx'")
|
run_sam_rknn.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
|
3 |
+
logging.basicConfig(level=logging.DEBUG)
|
4 |
+
|
5 |
+
from copy import deepcopy
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
from rknnlite.api.rknn_lite import RKNNLite
|
10 |
+
import onnxruntime
|
11 |
+
import time
|
12 |
+
|
13 |
+
class SegmentAnythingONNXRKNN:
|
14 |
+
"""Segmentation model using SegmentAnything"""
|
15 |
+
|
16 |
+
def __init__(self, encoder_model_path, decoder_model_path) -> None:
|
17 |
+
self.target_size = 1024
|
18 |
+
self.input_size = (1024, 1024)
|
19 |
+
|
20 |
+
self.encoder_session = RKNNLite()
|
21 |
+
self.encoder_session.load_rknn(encoder_model_path)
|
22 |
+
self.encoder_session.init_runtime()
|
23 |
+
|
24 |
+
self.decoder_session = onnxruntime.InferenceSession(
|
25 |
+
decoder_model_path, providers=["CPUExecutionProvider"]
|
26 |
+
)
|
27 |
+
|
28 |
+
def get_input_points(self, prompt):
|
29 |
+
"""Get input points"""
|
30 |
+
points = []
|
31 |
+
labels = []
|
32 |
+
for mark in prompt:
|
33 |
+
if mark["type"] == "point":
|
34 |
+
points.append(mark["data"])
|
35 |
+
labels.append(mark["label"])
|
36 |
+
elif mark["type"] == "rectangle":
|
37 |
+
points.append([mark["data"][0], mark["data"][1]]) # top left
|
38 |
+
points.append(
|
39 |
+
[mark["data"][2], mark["data"][3]]
|
40 |
+
) # bottom right
|
41 |
+
labels.append(2)
|
42 |
+
labels.append(3)
|
43 |
+
points, labels = np.array(points), np.array(labels)
|
44 |
+
return points, labels
|
45 |
+
|
46 |
+
def run_encoder(self, encoder_inputs):
|
47 |
+
"""Run encoder"""
|
48 |
+
# output = self.encoder_session.run(None, encoder_inputs)
|
49 |
+
start_time = time.time()
|
50 |
+
output = self.encoder_session.inference(inputs=[encoder_inputs])
|
51 |
+
print(f"Encoder Inference Time (ms): {(time.time() - start_time) * 1000}")
|
52 |
+
image_embedding = output[0]
|
53 |
+
return image_embedding
|
54 |
+
|
55 |
+
@staticmethod
|
56 |
+
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
|
57 |
+
"""
|
58 |
+
Compute the output size given input size and target long side length.
|
59 |
+
"""
|
60 |
+
scale = long_side_length * 1.0 / max(oldh, oldw)
|
61 |
+
newh, neww = oldh * scale, oldw * scale
|
62 |
+
neww = int(neww + 0.5)
|
63 |
+
newh = int(newh + 0.5)
|
64 |
+
return (newh, neww)
|
65 |
+
|
66 |
+
def apply_coords(self, coords: np.ndarray, original_size, target_length):
|
67 |
+
"""
|
68 |
+
Expects a numpy array of length 2 in the final dimension. Requires the
|
69 |
+
original image size in (H, W) format.
|
70 |
+
"""
|
71 |
+
old_h, old_w = original_size
|
72 |
+
new_h, new_w = self.get_preprocess_shape(
|
73 |
+
original_size[0], original_size[1], target_length
|
74 |
+
)
|
75 |
+
coords = deepcopy(coords).astype(float)
|
76 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
77 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
78 |
+
return coords
|
79 |
+
|
80 |
+
def run_decoder(
|
81 |
+
self, image_embedding, original_size, transform_matrix, prompt
|
82 |
+
):
|
83 |
+
"""Run decoder"""
|
84 |
+
input_points, input_labels = self.get_input_points(prompt)
|
85 |
+
|
86 |
+
# Add a batch index, concatenate a padding point, and transform.
|
87 |
+
onnx_coord = np.concatenate(
|
88 |
+
[input_points, np.array([[0.0, 0.0]])], axis=0
|
89 |
+
)[None, :, :]
|
90 |
+
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
|
91 |
+
None, :
|
92 |
+
].astype(np.float32)
|
93 |
+
onnx_coord = self.apply_coords(
|
94 |
+
onnx_coord, self.input_size, self.target_size
|
95 |
+
).astype(np.float32)
|
96 |
+
|
97 |
+
# Apply the transformation matrix to the coordinates.
|
98 |
+
onnx_coord = np.concatenate(
|
99 |
+
[
|
100 |
+
onnx_coord,
|
101 |
+
np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32),
|
102 |
+
],
|
103 |
+
axis=2,
|
104 |
+
)
|
105 |
+
onnx_coord = np.matmul(onnx_coord, transform_matrix.T)
|
106 |
+
onnx_coord = onnx_coord[:, :, :2].astype(np.float32)
|
107 |
+
|
108 |
+
# Create an empty mask input and an indicator for no mask.
|
109 |
+
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
|
110 |
+
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
|
111 |
+
|
112 |
+
decoder_inputs = {
|
113 |
+
"image_embeddings": image_embedding,
|
114 |
+
"point_coords": onnx_coord,
|
115 |
+
"point_labels": onnx_label,
|
116 |
+
"mask_input": onnx_mask_input,
|
117 |
+
"has_mask_input": onnx_has_mask_input,
|
118 |
+
"orig_im_size": np.array(self.input_size, dtype=np.float32),
|
119 |
+
}
|
120 |
+
start_time = time.time()
|
121 |
+
masks, _, _ = self.decoder_session.run(None, decoder_inputs)
|
122 |
+
# masks, _, _ = self.decoder_session.run(inputs=[
|
123 |
+
# image_embedding, onnx_coord, onnx_label, onnx_mask_input, onnx_has_mask_input, np.array(self.input_size, dtype=np.float32)
|
124 |
+
# ])
|
125 |
+
print(f"Decoder Inference Time (ms): {(time.time() - start_time) * 1000}")
|
126 |
+
# Transform the masks back to the original image size.
|
127 |
+
inv_transform_matrix = np.linalg.inv(transform_matrix)
|
128 |
+
transformed_masks = self.transform_masks(
|
129 |
+
masks, original_size, inv_transform_matrix
|
130 |
+
)
|
131 |
+
|
132 |
+
return transformed_masks
|
133 |
+
|
134 |
+
def transform_masks(self, masks, original_size, transform_matrix):
|
135 |
+
"""Transform masks
|
136 |
+
Transform the masks back to the original image size.
|
137 |
+
"""
|
138 |
+
output_masks = []
|
139 |
+
for batch in range(masks.shape[0]):
|
140 |
+
batch_masks = []
|
141 |
+
for mask_id in range(masks.shape[1]):
|
142 |
+
mask = masks[batch, mask_id]
|
143 |
+
mask = cv2.warpAffine(
|
144 |
+
mask,
|
145 |
+
transform_matrix[:2],
|
146 |
+
(original_size[1], original_size[0]),
|
147 |
+
flags=cv2.INTER_LINEAR,
|
148 |
+
)
|
149 |
+
batch_masks.append(mask)
|
150 |
+
output_masks.append(batch_masks)
|
151 |
+
return np.array(output_masks)
|
152 |
+
|
153 |
+
def encode(self, cv_image):
|
154 |
+
"""
|
155 |
+
Calculate embedding and metadata for a single image.
|
156 |
+
"""
|
157 |
+
original_size = cv_image.shape[:2]
|
158 |
+
|
159 |
+
# Calculate a transformation matrix to convert to self.input_size
|
160 |
+
scale_x = self.input_size[1] / cv_image.shape[1]
|
161 |
+
scale_y = self.input_size[0] / cv_image.shape[0]
|
162 |
+
scale = min(scale_x, scale_y)
|
163 |
+
transform_matrix = np.array(
|
164 |
+
[
|
165 |
+
[scale, 0, 0],
|
166 |
+
[0, scale, 0],
|
167 |
+
[0, 0, 1],
|
168 |
+
]
|
169 |
+
)
|
170 |
+
cv_image = cv2.warpAffine(
|
171 |
+
cv_image,
|
172 |
+
transform_matrix[:2],
|
173 |
+
(self.input_size[1], self.input_size[0]),
|
174 |
+
flags=cv2.INTER_LINEAR,
|
175 |
+
)
|
176 |
+
|
177 |
+
encoder_inputs = cv_image.astype(np.float32)
|
178 |
+
print(encoder_inputs.shape)
|
179 |
+
image_embedding = self.run_encoder(encoder_inputs)
|
180 |
+
return {
|
181 |
+
"image_embedding": image_embedding,
|
182 |
+
"original_size": original_size,
|
183 |
+
"transform_matrix": transform_matrix,
|
184 |
+
}
|
185 |
+
|
186 |
+
def predict_masks(self, embedding, prompt):
|
187 |
+
"""
|
188 |
+
Predict masks for a single image.
|
189 |
+
"""
|
190 |
+
masks = self.run_decoder(
|
191 |
+
embedding["image_embedding"],
|
192 |
+
embedding["original_size"],
|
193 |
+
embedding["transform_matrix"],
|
194 |
+
prompt,
|
195 |
+
)
|
196 |
+
|
197 |
+
return masks
|
198 |
+
|
199 |
+
if __name__ == "__main__":
|
200 |
+
encoder_model_path = "sam_vit_b_01ec64.pth.encoder.patched.onnx.rknn"
|
201 |
+
decoder_model_path = "sam_vit_b_01ec64.pth.decoder.onnx"
|
202 |
+
segmenter = SegmentAnythingONNXRKNN(encoder_model_path, decoder_model_path)
|
203 |
+
|
204 |
+
image = cv2.imread("input.jpg")
|
205 |
+
embedding = segmenter.encode(image)
|
206 |
+
prompt = [
|
207 |
+
{"type": "point", "data": [540, 512], "label": 0},
|
208 |
+
]
|
209 |
+
masks = segmenter.predict_masks(embedding, prompt)
|
210 |
+
print(masks.shape)
|
211 |
+
|
212 |
+
# Save the masks as a single image.
|
213 |
+
mask = np.zeros((masks.shape[2], masks.shape[3], 3), dtype=np.uint8)
|
214 |
+
for m in masks[0, :, :, :]:
|
215 |
+
mask[m > 0.0] = [255, 0, 0]
|
216 |
+
|
217 |
+
# Binding image and mask
|
218 |
+
visualized = cv2.addWeighted(image, 0.5, mask, 0.5, 0)
|
219 |
+
|
220 |
+
# Draw the prompt points and rectangles.
|
221 |
+
for p in prompt:
|
222 |
+
if p["type"] == "point":
|
223 |
+
color = (
|
224 |
+
(0, 255, 0) if p["label"] == 1 else (0, 0, 255)
|
225 |
+
) # green for positive, red for negative
|
226 |
+
cv2.circle(visualized, (p["data"][0], p["data"][1]), 10, color, -1)
|
227 |
+
elif p["type"] == "rectangle":
|
228 |
+
cv2.rectangle(
|
229 |
+
visualized,
|
230 |
+
(p["data"][0], p["data"][1]),
|
231 |
+
(p["data"][2], p["data"][3]),
|
232 |
+
(0, 255, 0),
|
233 |
+
2,
|
234 |
+
)
|
235 |
+
|
236 |
+
cv2.imwrite("output.jpg", visualized)
|
237 |
+
|
sam_vit_b_01ec64.pth.decoder.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ba3769ea1c7e4b9a7d3c01715ffdbc4aa9d351d793d8be95575e71c9f552424b
|
3 |
+
size 16496903
|
sam_vit_b_01ec64.pth.encoder.patched.onnx.rknn
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bf85560466bdf24a2598694a08b71a584db6f692b12e7047012a8daac90d3706
|
3 |
+
size 238909266
|