happyme531's picture
Upload 12 files
50704de verified
raw
history blame
5.81 kB
import os
import time
os.chdir(os.path.dirname(os.path.abspath(__file__)))
import numpy as np
import onnxruntime
from rknnlite.api import RKNNLite
from PIL import Image
import matplotlib.pyplot as plt
import cv2
def load_image(path):
"""加载并预处理图片"""
image = Image.open(path).convert("RGB")
print(f"Original image size: {image.size}")
# 计算resize后的尺寸,保持长宽比
target_size = (1024, 1024)
w, h = image.size
scale = min(target_size[0] / w, target_size[1] / h)
new_w = int(w * scale)
new_h = int(h * scale)
print(f"Scale factor: {scale}")
print(f"Resized dimensions: {new_w}x{new_h}")
# resize图片
resized_image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
# 创建1024x1024的黑色背景
processed_image = Image.new("RGB", target_size, (0, 0, 0))
# 将resized图片粘贴到中心位置
paste_x = (target_size[0] - new_w) // 2
paste_y = (target_size[1] - new_h) // 2
print(f"Paste position: ({paste_x}, {paste_y})")
processed_image.paste(resized_image, (paste_x, paste_y))
# 保存处理后的图片用于检查
processed_image.save("debug_processed_image.png")
# 转换为numpy数组并归一化到[0,1] # 归一化整合到模型了
img_np = np.array(processed_image).astype(np.float32) # / 255.0
# 调整维度顺序从HWC到CHW
img_np = img_np.transpose(2, 0, 1)
# 添加batch维度
img_np = np.expand_dims(img_np, axis=0)
print(f"Final input tensor shape: {img_np.shape}")
return image, img_np, (scale, paste_x, paste_y)
def prepare_point_input(point_coords, point_labels, image_size=(1024, 1024)):
"""准备点击输入数据"""
point_coords = np.array(point_coords, dtype=np.float32)
point_labels = np.array(point_labels, dtype=np.float32)
# 添加batch维度
point_coords = np.expand_dims(point_coords, axis=0)
point_labels = np.expand_dims(point_labels, axis=0)
# 准备mask输入
mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
has_mask_input = np.zeros(1, dtype=np.float32)
orig_im_size = np.array(image_size, dtype=np.int32)
return point_coords, point_labels, mask_input, has_mask_input, orig_im_size
def main():
# 1. 加载原始图片
path = "dog.jpg"
orig_image, input_image, (scale, offset_x, offset_y) = load_image(path)
decoder_path = "sam2.1_hiera_small_decoder.onnx"
encoder_path = "sam2.1_hiera_small_encoder.rknn"
# 2. 准备输入点
# input_point_orig = [[750, 400]]
input_point_orig = [[189, 394]]
input_point = [[
int(x * scale + offset_x),
int(y * scale + offset_y)
] for x, y in input_point_orig]
input_label = [1]
# 3. 运行RKNN encoder
print("Running RKNN encoder...")
rknn_lite = RKNNLite(verbose=False)
ret = rknn_lite.load_rknn(encoder_path)
if ret != 0:
print('Load RKNN model failed')
exit(ret)
ret = rknn_lite.init_runtime()
if ret != 0:
print('Init runtime environment failed')
exit(ret)
start_time = time.time()
encoder_outputs = rknn_lite.inference(inputs=[input_image], data_format="nchw")
end_time = time.time()
print(f"RKNN encoder time: {end_time - start_time} seconds")
high_res_feats_0, high_res_feats_1, image_embed = encoder_outputs
rknn_lite.release()
# 4. 运行ONNX decoder
print("Running ONNX decoder...")
decoder_session = onnxruntime.InferenceSession(decoder_path)
point_coords, point_labels, mask_input, has_mask_input, orig_im_size = prepare_point_input(
input_point, input_label, orig_image.size[::-1]
)
decoder_inputs = {
'image_embed': image_embed,
'high_res_feats_0': high_res_feats_0,
'high_res_feats_1': high_res_feats_1,
'point_coords': point_coords,
'point_labels': point_labels,
'mask_input': mask_input,
'has_mask_input': has_mask_input,
}
start_time = time.time()
low_res_masks, iou_predictions = decoder_session.run(None, decoder_inputs)
end_time = time.time()
print(f"ONNX decoder time: {end_time - start_time} seconds")
print(low_res_masks.shape)
# 5. 后处理
w, h = orig_image.size
masks_rknn = []
# 处理所有3个mask
for i in range(low_res_masks.shape[1]):
# 将mask缩放到1024x1024
masks_1024 = cv2.resize(
low_res_masks[0,i],
(1024, 1024),
interpolation=cv2.INTER_LINEAR
)
# 去除padding
new_h = int(h * scale)
new_w = int(w * scale)
start_h = (1024 - new_h) // 2
start_w = (1024 - new_w) // 2
masks_no_pad = masks_1024[start_h:start_h+new_h, start_w:start_w+new_w]
# 缩放到原始图片尺寸
mask = cv2.resize(
masks_no_pad,
(w, h),
interpolation=cv2.INTER_LINEAR
)
# 二值化
mask = mask > 0.0
masks_rknn.append(mask)
# 6. 可视化结果
plt.figure(figsize=(15, 5))
# 获取IoU分数排序的索引
sorted_indices = np.argsort(iou_predictions[0])[::-1] # 降序排序
for idx, mask_idx in enumerate(sorted_indices):
plt.subplot(1, 3, idx + 1)
plt.imshow(orig_image)
plt.imshow(masks_rknn[mask_idx], alpha=0.5)
plt.plot(input_point_orig[0][0], input_point_orig[0][1], 'rx')
plt.title(f'Mask {mask_idx+1}\nIoU: {iou_predictions[0][mask_idx]:.3f}')
plt.axis('off')
plt.tight_layout()
# plt.show()
plt.savefig("result.png")
print(f"\nIoU predictions: {iou_predictions}")
if __name__ == "__main__":
main()