happyme531 commited on
Commit
ea6ccf2
1 Parent(s): bff24fd

Upload 7 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sam_vit_b_01ec64.pth.encoder.patched.onnx.rknn filter=lfs diff=lfs merge=lfs -text
convert_encoder.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import urllib
3
+ import traceback
4
+ import time
5
+ import sys
6
+ import numpy as np
7
+ import cv2
8
+ from rknn.api import RKNN
9
+
10
+ ONNX_MODEL="sam_vit_b_01ec64.pth.encoder.patched.onnx"
11
+ RKNN_MODEL="sam_vit_b_01ec64.pth.encoder.onnx.rknn"
12
+ rknn = RKNN(verbose=True)
13
+
14
+ # pre-process config
15
+ print('--> config model')
16
+ rknn.config(target_platform='rk3588', single_core_mode=True)
17
+ print('done')
18
+
19
+ # Load model
20
+ print("--> Loading model")
21
+ ret = rknn.load_onnx(
22
+ model=ONNX_MODEL, inputs=["input_image"], input_size_list=[[1024, 1024, 3]]
23
+ )
24
+ if ret != 0:
25
+ print("Load model failed!")
26
+ exit(ret)
27
+ print("done")
28
+
29
+ # Build model
30
+ print('--> Building model')
31
+ ret = rknn.build(do_quantization=False)
32
+ if ret != 0:
33
+ print('Build model failed!')
34
+ exit(ret)
35
+ print('done')
36
+
37
+ # Export rknn model
38
+ print('--> Export rknn model')
39
+ ret = rknn.export_rknn(RKNN_MODEL)
40
+ if ret != 0:
41
+ print('Export rknn model failed!')
42
+ exit(ret)
43
+ print('done')
input.jpg ADDED
output.jpg ADDED
patch_graph.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnx_graphsurgeon as gs
2
+ import onnx
3
+ import numpy as np
4
+
5
+ # Load the ONNX model
6
+ graph = gs.import_onnx(onnx.load("check3_fuse_ops.onnx"))
7
+
8
+ count=0
9
+ # Iterate through all nodes in the graph
10
+ for node in graph.nodes:
11
+ # Check if the node is a Reshape operator
12
+ if node.op == 'Reshape':
13
+ # Get the shape input of the Reshape node
14
+ shape_input = node.inputs[1]
15
+
16
+ # Check if the shape input is a constant (which it should be for static reshapes)
17
+ if isinstance(shape_input, gs.Constant):
18
+ current_shape = shape_input.values
19
+
20
+ # Check if it's a 5D reshape with the target shape [12,64,64,...,...]
21
+ if len(current_shape) == 5 and current_shape[0] == 12 and current_shape[1] == 64 and current_shape[2] == 64:
22
+ # Modify the shape to [12,4096,...,...]
23
+ new_shape = np.array([12, 4096, current_shape[3], current_shape[4]], dtype=np.int64)
24
+ print(f"Patched {current_shape} -> {new_shape}")
25
+
26
+ # Update the shape input with the new shape
27
+ shape_input.values = new_shape
28
+ count = count + 1
29
+ # print(f"Patched {node}")
30
+
31
+
32
+ # Check if it's a 5D reshape with the target shape [300,14,14,...,...]
33
+ if len(current_shape) == 5 and current_shape[0] == 300 and current_shape[1] == 14 and current_shape[2] == 14:
34
+ # Modify the shape to [300,196,...,...]
35
+ new_shape = np.array([300, 196, current_shape[3], current_shape[4]], dtype=np.int64)
36
+ print(f"Patched {current_shape} -> {new_shape}")
37
+
38
+ # Update the shape input with the new shape
39
+ shape_input.values = new_shape
40
+ count = count + 1
41
+ # print(f"Patched {node}")
42
+
43
+ graph.cleanup().toposort()
44
+ print(f"Patched {count} nodes.")
45
+
46
+ model = gs.export_onnx(graph)
47
+
48
+ # Delete old shape information from the model
49
+ for value_info in model.graph.value_info:
50
+ value_info.type.tensor_type.ClearField('shape')
51
+
52
+ # Save the modified model
53
+ onnx.save(model, "sam_vit_b_01ec64.pth.encoder.patched.onnx")
54
+
55
+ print("Saved as 'sam_vit_b_01ec64.pth.encoder.patched.onnx'")
run_sam_rknn.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ logging.basicConfig(level=logging.DEBUG)
4
+
5
+ from copy import deepcopy
6
+
7
+ import cv2
8
+ import numpy as np
9
+ from rknnlite.api.rknn_lite import RKNNLite
10
+ import onnxruntime
11
+ import time
12
+
13
+ class SegmentAnythingONNXRKNN:
14
+ """Segmentation model using SegmentAnything"""
15
+
16
+ def __init__(self, encoder_model_path, decoder_model_path) -> None:
17
+ self.target_size = 1024
18
+ self.input_size = (1024, 1024)
19
+
20
+ self.encoder_session = RKNNLite()
21
+ self.encoder_session.load_rknn(encoder_model_path)
22
+ self.encoder_session.init_runtime()
23
+
24
+ self.decoder_session = onnxruntime.InferenceSession(
25
+ decoder_model_path, providers=["CPUExecutionProvider"]
26
+ )
27
+
28
+ def get_input_points(self, prompt):
29
+ """Get input points"""
30
+ points = []
31
+ labels = []
32
+ for mark in prompt:
33
+ if mark["type"] == "point":
34
+ points.append(mark["data"])
35
+ labels.append(mark["label"])
36
+ elif mark["type"] == "rectangle":
37
+ points.append([mark["data"][0], mark["data"][1]]) # top left
38
+ points.append(
39
+ [mark["data"][2], mark["data"][3]]
40
+ ) # bottom right
41
+ labels.append(2)
42
+ labels.append(3)
43
+ points, labels = np.array(points), np.array(labels)
44
+ return points, labels
45
+
46
+ def run_encoder(self, encoder_inputs):
47
+ """Run encoder"""
48
+ # output = self.encoder_session.run(None, encoder_inputs)
49
+ start_time = time.time()
50
+ output = self.encoder_session.inference(inputs=[encoder_inputs])
51
+ print(f"Encoder Inference Time (ms): {(time.time() - start_time) * 1000}")
52
+ image_embedding = output[0]
53
+ return image_embedding
54
+
55
+ @staticmethod
56
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
57
+ """
58
+ Compute the output size given input size and target long side length.
59
+ """
60
+ scale = long_side_length * 1.0 / max(oldh, oldw)
61
+ newh, neww = oldh * scale, oldw * scale
62
+ neww = int(neww + 0.5)
63
+ newh = int(newh + 0.5)
64
+ return (newh, neww)
65
+
66
+ def apply_coords(self, coords: np.ndarray, original_size, target_length):
67
+ """
68
+ Expects a numpy array of length 2 in the final dimension. Requires the
69
+ original image size in (H, W) format.
70
+ """
71
+ old_h, old_w = original_size
72
+ new_h, new_w = self.get_preprocess_shape(
73
+ original_size[0], original_size[1], target_length
74
+ )
75
+ coords = deepcopy(coords).astype(float)
76
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
77
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
78
+ return coords
79
+
80
+ def run_decoder(
81
+ self, image_embedding, original_size, transform_matrix, prompt
82
+ ):
83
+ """Run decoder"""
84
+ input_points, input_labels = self.get_input_points(prompt)
85
+
86
+ # Add a batch index, concatenate a padding point, and transform.
87
+ onnx_coord = np.concatenate(
88
+ [input_points, np.array([[0.0, 0.0]])], axis=0
89
+ )[None, :, :]
90
+ onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
91
+ None, :
92
+ ].astype(np.float32)
93
+ onnx_coord = self.apply_coords(
94
+ onnx_coord, self.input_size, self.target_size
95
+ ).astype(np.float32)
96
+
97
+ # Apply the transformation matrix to the coordinates.
98
+ onnx_coord = np.concatenate(
99
+ [
100
+ onnx_coord,
101
+ np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32),
102
+ ],
103
+ axis=2,
104
+ )
105
+ onnx_coord = np.matmul(onnx_coord, transform_matrix.T)
106
+ onnx_coord = onnx_coord[:, :, :2].astype(np.float32)
107
+
108
+ # Create an empty mask input and an indicator for no mask.
109
+ onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
110
+ onnx_has_mask_input = np.zeros(1, dtype=np.float32)
111
+
112
+ decoder_inputs = {
113
+ "image_embeddings": image_embedding,
114
+ "point_coords": onnx_coord,
115
+ "point_labels": onnx_label,
116
+ "mask_input": onnx_mask_input,
117
+ "has_mask_input": onnx_has_mask_input,
118
+ "orig_im_size": np.array(self.input_size, dtype=np.float32),
119
+ }
120
+ start_time = time.time()
121
+ masks, _, _ = self.decoder_session.run(None, decoder_inputs)
122
+ # masks, _, _ = self.decoder_session.run(inputs=[
123
+ # image_embedding, onnx_coord, onnx_label, onnx_mask_input, onnx_has_mask_input, np.array(self.input_size, dtype=np.float32)
124
+ # ])
125
+ print(f"Decoder Inference Time (ms): {(time.time() - start_time) * 1000}")
126
+ # Transform the masks back to the original image size.
127
+ inv_transform_matrix = np.linalg.inv(transform_matrix)
128
+ transformed_masks = self.transform_masks(
129
+ masks, original_size, inv_transform_matrix
130
+ )
131
+
132
+ return transformed_masks
133
+
134
+ def transform_masks(self, masks, original_size, transform_matrix):
135
+ """Transform masks
136
+ Transform the masks back to the original image size.
137
+ """
138
+ output_masks = []
139
+ for batch in range(masks.shape[0]):
140
+ batch_masks = []
141
+ for mask_id in range(masks.shape[1]):
142
+ mask = masks[batch, mask_id]
143
+ mask = cv2.warpAffine(
144
+ mask,
145
+ transform_matrix[:2],
146
+ (original_size[1], original_size[0]),
147
+ flags=cv2.INTER_LINEAR,
148
+ )
149
+ batch_masks.append(mask)
150
+ output_masks.append(batch_masks)
151
+ return np.array(output_masks)
152
+
153
+ def encode(self, cv_image):
154
+ """
155
+ Calculate embedding and metadata for a single image.
156
+ """
157
+ original_size = cv_image.shape[:2]
158
+
159
+ # Calculate a transformation matrix to convert to self.input_size
160
+ scale_x = self.input_size[1] / cv_image.shape[1]
161
+ scale_y = self.input_size[0] / cv_image.shape[0]
162
+ scale = min(scale_x, scale_y)
163
+ transform_matrix = np.array(
164
+ [
165
+ [scale, 0, 0],
166
+ [0, scale, 0],
167
+ [0, 0, 1],
168
+ ]
169
+ )
170
+ cv_image = cv2.warpAffine(
171
+ cv_image,
172
+ transform_matrix[:2],
173
+ (self.input_size[1], self.input_size[0]),
174
+ flags=cv2.INTER_LINEAR,
175
+ )
176
+
177
+ encoder_inputs = cv_image.astype(np.float32)
178
+ print(encoder_inputs.shape)
179
+ image_embedding = self.run_encoder(encoder_inputs)
180
+ return {
181
+ "image_embedding": image_embedding,
182
+ "original_size": original_size,
183
+ "transform_matrix": transform_matrix,
184
+ }
185
+
186
+ def predict_masks(self, embedding, prompt):
187
+ """
188
+ Predict masks for a single image.
189
+ """
190
+ masks = self.run_decoder(
191
+ embedding["image_embedding"],
192
+ embedding["original_size"],
193
+ embedding["transform_matrix"],
194
+ prompt,
195
+ )
196
+
197
+ return masks
198
+
199
+ if __name__ == "__main__":
200
+ encoder_model_path = "sam_vit_b_01ec64.pth.encoder.patched.onnx.rknn"
201
+ decoder_model_path = "sam_vit_b_01ec64.pth.decoder.onnx"
202
+ segmenter = SegmentAnythingONNXRKNN(encoder_model_path, decoder_model_path)
203
+
204
+ image = cv2.imread("input.jpg")
205
+ embedding = segmenter.encode(image)
206
+ prompt = [
207
+ {"type": "point", "data": [540, 512], "label": 0},
208
+ ]
209
+ masks = segmenter.predict_masks(embedding, prompt)
210
+ print(masks.shape)
211
+
212
+ # Save the masks as a single image.
213
+ mask = np.zeros((masks.shape[2], masks.shape[3], 3), dtype=np.uint8)
214
+ for m in masks[0, :, :, :]:
215
+ mask[m > 0.0] = [255, 0, 0]
216
+
217
+ # Binding image and mask
218
+ visualized = cv2.addWeighted(image, 0.5, mask, 0.5, 0)
219
+
220
+ # Draw the prompt points and rectangles.
221
+ for p in prompt:
222
+ if p["type"] == "point":
223
+ color = (
224
+ (0, 255, 0) if p["label"] == 1 else (0, 0, 255)
225
+ ) # green for positive, red for negative
226
+ cv2.circle(visualized, (p["data"][0], p["data"][1]), 10, color, -1)
227
+ elif p["type"] == "rectangle":
228
+ cv2.rectangle(
229
+ visualized,
230
+ (p["data"][0], p["data"][1]),
231
+ (p["data"][2], p["data"][3]),
232
+ (0, 255, 0),
233
+ 2,
234
+ )
235
+
236
+ cv2.imwrite("output.jpg", visualized)
237
+
sam_vit_b_01ec64.pth.decoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba3769ea1c7e4b9a7d3c01715ffdbc4aa9d351d793d8be95575e71c9f552424b
3
+ size 16496903
sam_vit_b_01ec64.pth.encoder.patched.onnx.rknn ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf85560466bdf24a2598694a08b71a584db6f692b12e7047012a8daac90d3706
3
+ size 238909266