segment-anything-rknn2 / run_sam_rknn.py
happyme531's picture
Fix label
35aa782 verified
import logging
logging.basicConfig(level=logging.DEBUG)
from copy import deepcopy
import cv2
import numpy as np
from rknnlite.api.rknn_lite import RKNNLite
import onnxruntime
import time
class SegmentAnythingONNXRKNN:
"""Segmentation model using SegmentAnything"""
def __init__(self, encoder_model_path, decoder_model_path) -> None:
self.target_size = 1024
self.input_size = (1024, 1024)
self.encoder_session = RKNNLite()
self.encoder_session.load_rknn(encoder_model_path)
self.encoder_session.init_runtime()
self.decoder_session = onnxruntime.InferenceSession(
decoder_model_path, providers=["CPUExecutionProvider"]
)
def get_input_points(self, prompt):
"""Get input points"""
points = []
labels = []
for mark in prompt:
if mark["type"] == "point":
points.append(mark["data"])
labels.append(mark["label"])
elif mark["type"] == "rectangle":
points.append([mark["data"][0], mark["data"][1]]) # top left
points.append(
[mark["data"][2], mark["data"][3]]
) # bottom right
labels.append(2)
labels.append(3)
points, labels = np.array(points), np.array(labels)
return points, labels
def run_encoder(self, encoder_inputs):
"""Run encoder"""
# output = self.encoder_session.run(None, encoder_inputs)
start_time = time.time()
output = self.encoder_session.inference(inputs=[encoder_inputs])
print(f"Encoder Inference Time (ms): {(time.time() - start_time) * 1000}")
image_embedding = output[0]
return image_embedding
@staticmethod
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int):
"""
Compute the output size given input size and target long side length.
"""
scale = long_side_length * 1.0 / max(oldh, oldw)
newh, neww = oldh * scale, oldw * scale
neww = int(neww + 0.5)
newh = int(newh + 0.5)
return (newh, neww)
def apply_coords(self, coords: np.ndarray, original_size, target_length):
"""
Expects a numpy array of length 2 in the final dimension. Requires the
original image size in (H, W) format.
"""
old_h, old_w = original_size
new_h, new_w = self.get_preprocess_shape(
original_size[0], original_size[1], target_length
)
coords = deepcopy(coords).astype(float)
coords[..., 0] = coords[..., 0] * (new_w / old_w)
coords[..., 1] = coords[..., 1] * (new_h / old_h)
return coords
def run_decoder(
self, image_embedding, original_size, transform_matrix, prompt
):
"""Run decoder"""
input_points, input_labels = self.get_input_points(prompt)
# Add a batch index, concatenate a padding point, and transform.
onnx_coord = np.concatenate(
[input_points, np.array([[0.0, 0.0]])], axis=0
)[None, :, :]
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[
None, :
].astype(np.float32)
onnx_coord = self.apply_coords(
onnx_coord, self.input_size, self.target_size
).astype(np.float32)
# Apply the transformation matrix to the coordinates.
onnx_coord = np.concatenate(
[
onnx_coord,
np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32),
],
axis=2,
)
onnx_coord = np.matmul(onnx_coord, transform_matrix.T)
onnx_coord = onnx_coord[:, :, :2].astype(np.float32)
# Create an empty mask input and an indicator for no mask.
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
decoder_inputs = {
"image_embeddings": image_embedding,
"point_coords": onnx_coord,
"point_labels": onnx_label,
"mask_input": onnx_mask_input,
"has_mask_input": onnx_has_mask_input,
"orig_im_size": np.array(self.input_size, dtype=np.float32),
}
start_time = time.time()
masks, _, _ = self.decoder_session.run(None, decoder_inputs)
# masks, _, _ = self.decoder_session.run(inputs=[
# image_embedding, onnx_coord, onnx_label, onnx_mask_input, onnx_has_mask_input, np.array(self.input_size, dtype=np.float32)
# ])
print(f"Decoder Inference Time (ms): {(time.time() - start_time) * 1000}")
# Transform the masks back to the original image size.
inv_transform_matrix = np.linalg.inv(transform_matrix)
transformed_masks = self.transform_masks(
masks, original_size, inv_transform_matrix
)
return transformed_masks
def transform_masks(self, masks, original_size, transform_matrix):
"""Transform masks
Transform the masks back to the original image size.
"""
output_masks = []
for batch in range(masks.shape[0]):
batch_masks = []
for mask_id in range(masks.shape[1]):
mask = masks[batch, mask_id]
mask = cv2.warpAffine(
mask,
transform_matrix[:2],
(original_size[1], original_size[0]),
flags=cv2.INTER_LINEAR,
)
batch_masks.append(mask)
output_masks.append(batch_masks)
return np.array(output_masks)
def encode(self, cv_image):
"""
Calculate embedding and metadata for a single image.
"""
original_size = cv_image.shape[:2]
# Calculate a transformation matrix to convert to self.input_size
scale_x = self.input_size[1] / cv_image.shape[1]
scale_y = self.input_size[0] / cv_image.shape[0]
scale = min(scale_x, scale_y)
transform_matrix = np.array(
[
[scale, 0, 0],
[0, scale, 0],
[0, 0, 1],
]
)
cv_image = cv2.warpAffine(
cv_image,
transform_matrix[:2],
(self.input_size[1], self.input_size[0]),
flags=cv2.INTER_LINEAR,
)
encoder_inputs = cv_image.astype(np.float32)
print(encoder_inputs.shape)
image_embedding = self.run_encoder(encoder_inputs)
return {
"image_embedding": image_embedding,
"original_size": original_size,
"transform_matrix": transform_matrix,
}
def predict_masks(self, embedding, prompt):
"""
Predict masks for a single image.
"""
masks = self.run_decoder(
embedding["image_embedding"],
embedding["original_size"],
embedding["transform_matrix"],
prompt,
)
return masks
if __name__ == "__main__":
encoder_model_path = "sam_vit_b_01ec64.pth.encoder.patched.onnx.rknn"
decoder_model_path = "sam_vit_b_01ec64.pth.decoder.onnx"
segmenter = SegmentAnythingONNXRKNN(encoder_model_path, decoder_model_path)
image = cv2.imread("input.jpg")
embedding = segmenter.encode(image)
prompt = [
{"type": "point", "data": [540, 512], "label": 1},
]
masks = segmenter.predict_masks(embedding, prompt)
print(masks.shape)
# Save the masks as a single image.
mask = np.zeros((masks.shape[2], masks.shape[3], 3), dtype=np.uint8)
for m in masks[0, :, :, :]:
mask[m > 0.0] = [255, 0, 0]
# Binding image and mask
visualized = cv2.addWeighted(image, 0.5, mask, 0.5, 0)
# Draw the prompt points and rectangles.
for p in prompt:
if p["type"] == "point":
color = (
(0, 255, 0) if p["label"] == 1 else (0, 0, 255)
) # green for positive, red for negative
cv2.circle(visualized, (p["data"][0], p["data"][1]), 10, color, -1)
elif p["type"] == "rectangle":
cv2.rectangle(
visualized,
(p["data"][0], p["data"][1]),
(p["data"][2], p["data"][3]),
(0, 255, 0),
2,
)
cv2.imwrite("output.jpg", visualized)