|
import logging |
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
|
|
from copy import deepcopy |
|
|
|
import cv2 |
|
import numpy as np |
|
from rknnlite.api.rknn_lite import RKNNLite |
|
import onnxruntime |
|
import time |
|
|
|
class SegmentAnythingONNXRKNN: |
|
"""Segmentation model using SegmentAnything""" |
|
|
|
def __init__(self, encoder_model_path, decoder_model_path) -> None: |
|
self.target_size = 1024 |
|
self.input_size = (1024, 1024) |
|
|
|
self.encoder_session = RKNNLite() |
|
self.encoder_session.load_rknn(encoder_model_path) |
|
self.encoder_session.init_runtime() |
|
|
|
self.decoder_session = onnxruntime.InferenceSession( |
|
decoder_model_path, providers=["CPUExecutionProvider"] |
|
) |
|
|
|
def get_input_points(self, prompt): |
|
"""Get input points""" |
|
points = [] |
|
labels = [] |
|
for mark in prompt: |
|
if mark["type"] == "point": |
|
points.append(mark["data"]) |
|
labels.append(mark["label"]) |
|
elif mark["type"] == "rectangle": |
|
points.append([mark["data"][0], mark["data"][1]]) |
|
points.append( |
|
[mark["data"][2], mark["data"][3]] |
|
) |
|
labels.append(2) |
|
labels.append(3) |
|
points, labels = np.array(points), np.array(labels) |
|
return points, labels |
|
|
|
def run_encoder(self, encoder_inputs): |
|
"""Run encoder""" |
|
|
|
start_time = time.time() |
|
output = self.encoder_session.inference(inputs=[encoder_inputs]) |
|
print(f"Encoder Inference Time (ms): {(time.time() - start_time) * 1000}") |
|
image_embedding = output[0] |
|
return image_embedding |
|
|
|
@staticmethod |
|
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int): |
|
""" |
|
Compute the output size given input size and target long side length. |
|
""" |
|
scale = long_side_length * 1.0 / max(oldh, oldw) |
|
newh, neww = oldh * scale, oldw * scale |
|
neww = int(neww + 0.5) |
|
newh = int(newh + 0.5) |
|
return (newh, neww) |
|
|
|
def apply_coords(self, coords: np.ndarray, original_size, target_length): |
|
""" |
|
Expects a numpy array of length 2 in the final dimension. Requires the |
|
original image size in (H, W) format. |
|
""" |
|
old_h, old_w = original_size |
|
new_h, new_w = self.get_preprocess_shape( |
|
original_size[0], original_size[1], target_length |
|
) |
|
coords = deepcopy(coords).astype(float) |
|
coords[..., 0] = coords[..., 0] * (new_w / old_w) |
|
coords[..., 1] = coords[..., 1] * (new_h / old_h) |
|
return coords |
|
|
|
def run_decoder( |
|
self, image_embedding, original_size, transform_matrix, prompt |
|
): |
|
"""Run decoder""" |
|
input_points, input_labels = self.get_input_points(prompt) |
|
|
|
|
|
onnx_coord = np.concatenate( |
|
[input_points, np.array([[0.0, 0.0]])], axis=0 |
|
)[None, :, :] |
|
onnx_label = np.concatenate([input_labels, np.array([-1])], axis=0)[ |
|
None, : |
|
].astype(np.float32) |
|
onnx_coord = self.apply_coords( |
|
onnx_coord, self.input_size, self.target_size |
|
).astype(np.float32) |
|
|
|
|
|
onnx_coord = np.concatenate( |
|
[ |
|
onnx_coord, |
|
np.ones((1, onnx_coord.shape[1], 1), dtype=np.float32), |
|
], |
|
axis=2, |
|
) |
|
onnx_coord = np.matmul(onnx_coord, transform_matrix.T) |
|
onnx_coord = onnx_coord[:, :, :2].astype(np.float32) |
|
|
|
|
|
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32) |
|
onnx_has_mask_input = np.zeros(1, dtype=np.float32) |
|
|
|
decoder_inputs = { |
|
"image_embeddings": image_embedding, |
|
"point_coords": onnx_coord, |
|
"point_labels": onnx_label, |
|
"mask_input": onnx_mask_input, |
|
"has_mask_input": onnx_has_mask_input, |
|
"orig_im_size": np.array(self.input_size, dtype=np.float32), |
|
} |
|
start_time = time.time() |
|
masks, _, _ = self.decoder_session.run(None, decoder_inputs) |
|
|
|
|
|
|
|
print(f"Decoder Inference Time (ms): {(time.time() - start_time) * 1000}") |
|
|
|
inv_transform_matrix = np.linalg.inv(transform_matrix) |
|
transformed_masks = self.transform_masks( |
|
masks, original_size, inv_transform_matrix |
|
) |
|
|
|
return transformed_masks |
|
|
|
def transform_masks(self, masks, original_size, transform_matrix): |
|
"""Transform masks |
|
Transform the masks back to the original image size. |
|
""" |
|
output_masks = [] |
|
for batch in range(masks.shape[0]): |
|
batch_masks = [] |
|
for mask_id in range(masks.shape[1]): |
|
mask = masks[batch, mask_id] |
|
mask = cv2.warpAffine( |
|
mask, |
|
transform_matrix[:2], |
|
(original_size[1], original_size[0]), |
|
flags=cv2.INTER_LINEAR, |
|
) |
|
batch_masks.append(mask) |
|
output_masks.append(batch_masks) |
|
return np.array(output_masks) |
|
|
|
def encode(self, cv_image): |
|
""" |
|
Calculate embedding and metadata for a single image. |
|
""" |
|
original_size = cv_image.shape[:2] |
|
|
|
|
|
scale_x = self.input_size[1] / cv_image.shape[1] |
|
scale_y = self.input_size[0] / cv_image.shape[0] |
|
scale = min(scale_x, scale_y) |
|
transform_matrix = np.array( |
|
[ |
|
[scale, 0, 0], |
|
[0, scale, 0], |
|
[0, 0, 1], |
|
] |
|
) |
|
cv_image = cv2.warpAffine( |
|
cv_image, |
|
transform_matrix[:2], |
|
(self.input_size[1], self.input_size[0]), |
|
flags=cv2.INTER_LINEAR, |
|
) |
|
|
|
encoder_inputs = cv_image.astype(np.float32) |
|
print(encoder_inputs.shape) |
|
image_embedding = self.run_encoder(encoder_inputs) |
|
return { |
|
"image_embedding": image_embedding, |
|
"original_size": original_size, |
|
"transform_matrix": transform_matrix, |
|
} |
|
|
|
def predict_masks(self, embedding, prompt): |
|
""" |
|
Predict masks for a single image. |
|
""" |
|
masks = self.run_decoder( |
|
embedding["image_embedding"], |
|
embedding["original_size"], |
|
embedding["transform_matrix"], |
|
prompt, |
|
) |
|
|
|
return masks |
|
|
|
if __name__ == "__main__": |
|
encoder_model_path = "sam_vit_b_01ec64.pth.encoder.patched.onnx.rknn" |
|
decoder_model_path = "sam_vit_b_01ec64.pth.decoder.onnx" |
|
segmenter = SegmentAnythingONNXRKNN(encoder_model_path, decoder_model_path) |
|
|
|
image = cv2.imread("input.jpg") |
|
embedding = segmenter.encode(image) |
|
prompt = [ |
|
{"type": "point", "data": [540, 512], "label": 1}, |
|
] |
|
masks = segmenter.predict_masks(embedding, prompt) |
|
print(masks.shape) |
|
|
|
|
|
mask = np.zeros((masks.shape[2], masks.shape[3], 3), dtype=np.uint8) |
|
for m in masks[0, :, :, :]: |
|
mask[m > 0.0] = [255, 0, 0] |
|
|
|
|
|
visualized = cv2.addWeighted(image, 0.5, mask, 0.5, 0) |
|
|
|
|
|
for p in prompt: |
|
if p["type"] == "point": |
|
color = ( |
|
(0, 255, 0) if p["label"] == 1 else (0, 0, 255) |
|
) |
|
cv2.circle(visualized, (p["data"][0], p["data"][1]), 10, color, -1) |
|
elif p["type"] == "rectangle": |
|
cv2.rectangle( |
|
visualized, |
|
(p["data"][0], p["data"][1]), |
|
(p["data"][2], p["data"][3]), |
|
(0, 255, 0), |
|
2, |
|
) |
|
|
|
cv2.imwrite("output.jpg", visualized) |
|
|
|
|