robgonsalves's picture
add code example
3b5fa9b verified
metadata
license: mit

Segment Anything 8-Bit ONNX

How to run:

import onnxruntime as ort
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt

# Path to the image file
image_path = "example.png"

# Load the image and preprocess it
image = Image.open(image_path).convert("RGB")
orig_width, orig_height = image.size
input_tensor = np.array(image)
mean = np.array([123.675, 116.28, 103.53])
std = np.array([58.395, 57.12, 57.375])
input_tensor = (input_tensor - mean) / std
input_tensor = input_tensor.transpose(2, 0, 1)[None, :, :, :].astype(np.float32)

# Pad input tensor to 1024x1024
pad_height = 1024 - input_tensor.shape[2]
pad_width = 1024 - input_tensor.shape[3]
input_tensor = np.pad(input_tensor, ((0, 0), (0, 0), (0, pad_height), (0, pad_width)))

# Load the encoder model and run inference
encoder = ort.InferenceSession("sam_encoder.onnx")
embeddings = encoder.run(None, {"images": input_tensor})[0]

# Choose a point (e.g., x=150, y=100) in the original image
point = [150, 100]

# Convert point coordinates to match the padded image
point = np.array([[point]])
coords = point.astype(float)
coords[..., 0] = coords[..., 0] * (1024 / orig_width)
coords[..., 1] = coords[..., 1] * (1024 / orig_height)
onnx_coord = coords.astype("float32")

# Prepare inputs for the decoder
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)
onnx_label = np.array([1, -1]).astype(np.float32)[None, :]

# Load the decoder model and run inference
decoder = ort.InferenceSession("sam_decoder.onnx")
masks_output, _, _ = decoder.run(None, {
    "image_embeddings": embeddings,
    "point_coords": onnx_coord,
    "point_labels": onnx_label,
    "mask_input": onnx_mask_input,
    "has_mask_input": onnx_has_mask_input,
    "orig_im_size": np.array([orig_height, orig_width], dtype=np.float32)
})

# Process the output mask
mask = masks_output[0][0]
mask = (mask > 0).astype('uint8') * 255