ONNX
Zhang-Yang-Sustech
Add multi-points input, foreground/background points input and box input to EfficientSAM model (#291)
bf92df0
import argparse
import numpy as np
import cv2 as cv
from efficientSAM import EfficientSAM
# Check OpenCV version
opencv_python_version = lambda str_version: tuple(map(int, (str_version.split("."))))
assert opencv_python_version(cv.__version__) >= opencv_python_version("4.10.0"), \
"Please install latest opencv-python for benchmark: python3 -m pip install --upgrade opencv-python"
# Valid combinations of backends and targets
backend_target_pairs = [
[cv.dnn.DNN_BACKEND_OPENCV, cv.dnn.DNN_TARGET_CPU],
[cv.dnn.DNN_BACKEND_CUDA, cv.dnn.DNN_TARGET_CUDA],
[cv.dnn.DNN_BACKEND_CUDA, cv.dnn.DNN_TARGET_CUDA_FP16],
[cv.dnn.DNN_BACKEND_TIMVX, cv.dnn.DNN_TARGET_NPU],
[cv.dnn.DNN_BACKEND_CANN, cv.dnn.DNN_TARGET_NPU]
]
parser = argparse.ArgumentParser(description='EfficientSAM Demo')
parser.add_argument('--input', '-i', type=str,
help='Set input path to a certain image.')
parser.add_argument('--model', '-m', type=str, default='image_segmentation_efficientsam_ti_2025april.onnx',
help='Set model path, defaults to image_segmentation_efficientsam_ti_2025april.onnx.')
parser.add_argument('--backend_target', '-bt', type=int, default=0,
help='''Choose one of the backend-target pair to run this demo:
{:d}: (default) OpenCV implementation + CPU,
{:d}: CUDA + GPU (CUDA),
{:d}: CUDA + GPU (CUDA FP16),
{:d}: TIM-VX + NPU,
{:d}: CANN + NPU
'''.format(*[x for x in range(len(backend_target_pairs))]))
parser.add_argument('--save', '-s', action='store_true',
help='Specify to save a file with results. Invalid in case of camera input.')
args = parser.parse_args()
# Global configuration
WINDOW_SIZE = (800, 600) # Fixed window size (width, height)
MAX_POINTS = 6 # Maximum allowed points
points = [] # Store clicked coordinates (original image scale)
labels = [] # Point labels (-1: useless, 0: background, 1: foreground, 2: top-left, 3: bottom right)
backend_point = []
rectangle = False
current_img = None
def visualize(image, result):
"""
Visualize the inference result on the input image.
Args:
image (np.ndarray): The input image.
result (np.ndarray): The inference result.
Returns:
vis_result (np.ndarray): The visualized result.
"""
# get image and mask
vis_result = np.copy(image)
mask = np.copy(result)
# change mask to binary image
t, binary = cv.threshold(mask, 127, 255, cv.THRESH_BINARY)
assert set(np.unique(binary)) <= {0, 255}, "The mask must be a binary image."
# enhance red channel to make the segmentation more obviously
enhancement_factor = 1.8
red_channel = vis_result[:, :, 2]
# update the channel
red_channel = np.where(binary == 255, np.minimum(red_channel * enhancement_factor, 255), red_channel)
vis_result[:, :, 2] = red_channel
# draw borders
contours, hierarchy = cv.findContours(binary, cv.RETR_LIST, cv.CHAIN_APPROX_TC89_L1)
cv.drawContours(vis_result, contours, contourIdx = -1, color = (255,255,255), thickness=2)
return vis_result
def select(event, x, y, flags, param):
"""Handle mouse events with coordinate conversion"""
global points, labels, backend_point, rectangle, current_img
orig_img = param['original_img']
image_window = param['image_window']
if event == cv.EVENT_LBUTTONDOWN:
param['mouse_down_time'] = cv.getTickCount()
backend_point = [x, y]
elif event == cv.EVENT_MOUSEMOVE:
if rectangle == True:
rectangle_change_img = current_img.copy()
cv.rectangle(rectangle_change_img, (backend_point[0], backend_point[1]), (x, y), (255,0,0) , 2)
cv.imshow(image_window, rectangle_change_img)
elif len(backend_point) != 0 and len(points) < MAX_POINTS:
rectangle = True
elif event == cv.EVENT_LBUTTONUP:
if len(points) >= MAX_POINTS:
print(f"Maximum points reached {MAX_POINTS}.")
return
if rectangle == False:
duration = (cv.getTickCount() - param['mouse_down_time'])/cv.getTickFrequency()
label = -1 if duration > 0.5 else 1 # Long press = background
points.append([backend_point[0], backend_point[1]])
labels.append(label)
print(f"Added {['background','foreground','background'][label]} point {backend_point}.")
else:
if len(points) + 1 >= MAX_POINTS:
rectangle = False
backend_point.clear()
cv.imshow(image_window, current_img)
print(f"Points reached {MAX_POINTS}, could not add box.")
return
point_leftup = []
point_rightdown = []
if x > backend_point[0] or y > backend_point[1]:
point_leftup.extend(backend_point)
point_rightdown.extend([x,y])
else:
point_leftup.extend([x,y])
point_rightdown.extend(backend_point)
points.append(point_leftup)
points.append(point_rightdown)
print(f"Added box from {point_leftup} to {point_rightdown}.")
labels.append(2)
labels.append(3)
rectangle = False
backend_point.clear()
marked_img = orig_img.copy()
top_left = None
for (px, py), lbl in zip(points, labels):
if lbl == -1:
cv.circle(marked_img, (px, py), 5, (0, 0, 255), -1)
elif lbl == 1:
cv.circle(marked_img, (px, py), 5, (0, 255, 0), -1)
elif lbl == 2:
top_left = (px, py)
elif lbl == 3:
bottom_right = (px, py)
cv.rectangle(marked_img, top_left, bottom_right, (255,0,0) , 2)
cv.imshow(image_window, marked_img)
current_img = marked_img.copy()
if __name__ == '__main__':
backend_id = backend_target_pairs[args.backend_target][0]
target_id = backend_target_pairs[args.backend_target][1]
# Load the EfficientSAM model
model = EfficientSAM(modelPath=args.model)
if args.input is not None:
# Read image
image = cv.imread(args.input)
if image is None:
print('Could not open or find the image:', args.input)
exit(0)
# create window
image_window = "Origin image"
cv.namedWindow(image_window, cv.WINDOW_NORMAL)
# change window size
rate = 1
rate1 = 1
rate2 = 1
if(image.shape[1]>WINDOW_SIZE[0]):
rate1 = WINDOW_SIZE[0]/image.shape[1]
if(image.shape[0]>WINDOW_SIZE[1]):
rate2 = WINDOW_SIZE[1]/image.shape[0]
rate = min(rate1, rate2)
# width, height
WINDOW_SIZE = (int(image.shape[1] * rate), int(image.shape[0] * rate))
cv.resizeWindow(image_window, WINDOW_SIZE[0], WINDOW_SIZE[1])
# put the window on the left of the screen
cv.moveWindow(image_window, 50, 100)
# set listener to record user's click point
param = {
'original_img': image,
'mouse_down_time': 0,
'image_window' : image_window
}
cv.setMouseCallback(image_window, select, param)
# tips in the terminal
print("Click — Select foreground point\n"
"Long press — Select background point\n"
"Drag — Create selection box\n"
"Enter — Infer\n"
"Backspace — Clear the prompts\n"
"Q - Quit")
# show image
cv.imshow(image_window, image)
current_img = image.copy()
# create window to show visualized result
vis_image = image.copy()
segmentation_window = "Segment result"
cv.namedWindow(segmentation_window, cv.WINDOW_NORMAL)
cv.resizeWindow(segmentation_window, WINDOW_SIZE[0], WINDOW_SIZE[1])
cv.moveWindow(segmentation_window, WINDOW_SIZE[0]+51, 100)
cv.imshow(segmentation_window, vis_image)
# waiting for click
while True:
# Check window status
# if click × to close the image window then ending
if (cv.getWindowProperty(image_window, cv.WND_PROP_VISIBLE) < 1 or
cv.getWindowProperty(segmentation_window, cv.WND_PROP_VISIBLE) < 1):
break
# Handle keyboard input
key = cv.waitKey(1)
# receive enter
if key == 13:
vis_image = image.copy()
cv.putText(vis_image, "infering...",
(50, vis_image.shape[0]//2),
cv.FONT_HERSHEY_SIMPLEX, 10, (255,255,255), 5)
cv.imshow(segmentation_window, vis_image)
result = model.infer(image=image, points=points, labels=labels)
if len(result) == 0:
print("clear and select points again!")
else:
vis_result = visualize(image, result)
cv.imshow(segmentation_window, vis_result)
elif key == 8 or key == 127: # ASCII for Backspace or Delete
points.clear()
labels.clear()
backend_point = []
rectangle = False
current_img = image
print("Points are cleared.")
cv.imshow(image_window, image)
elif key == ord('q') or key == ord('Q'):
break
cv.destroyAllWindows()
# Save results if save is true
if args.save:
cv.imwrite('./example_outputs/vis_result.jpg', vis_result)
cv.imwrite("./example_outputs/mask.jpg", result)
print('vis_result.jpg and mask.jpg are saved to ./example_outputs/')
else:
print('Set input path to a certain image.')
pass