File size: 5,810 Bytes
50704de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import os
import time
os.chdir(os.path.dirname(os.path.abspath(__file__)))
import numpy as np
import onnxruntime
from rknnlite.api import RKNNLite
from PIL import Image
import matplotlib.pyplot as plt
import cv2
def load_image(path):
"""加载并预处理图片"""
image = Image.open(path).convert("RGB")
print(f"Original image size: {image.size}")
# 计算resize后的尺寸,保持长宽比
target_size = (1024, 1024)
w, h = image.size
scale = min(target_size[0] / w, target_size[1] / h)
new_w = int(w * scale)
new_h = int(h * scale)
print(f"Scale factor: {scale}")
print(f"Resized dimensions: {new_w}x{new_h}")
# resize图片
resized_image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
# 创建1024x1024的黑色背景
processed_image = Image.new("RGB", target_size, (0, 0, 0))
# 将resized图片粘贴到中心位置
paste_x = (target_size[0] - new_w) // 2
paste_y = (target_size[1] - new_h) // 2
print(f"Paste position: ({paste_x}, {paste_y})")
processed_image.paste(resized_image, (paste_x, paste_y))
# 保存处理后的图片用于检查
processed_image.save("debug_processed_image.png")
# 转换为numpy数组并归一化到[0,1] # 归一化整合到模型了
img_np = np.array(processed_image).astype(np.float32) # / 255.0
# 调整维度顺序从HWC到CHW
img_np = img_np.transpose(2, 0, 1)
# 添加batch维度
img_np = np.expand_dims(img_np, axis=0)
print(f"Final input tensor shape: {img_np.shape}")
return image, img_np, (scale, paste_x, paste_y)
def prepare_point_input(point_coords, point_labels, image_size=(1024, 1024)):
"""准备点击输入数据"""
point_coords = np.array(point_coords, dtype=np.float32)
point_labels = np.array(point_labels, dtype=np.float32)
# 添加batch维度
point_coords = np.expand_dims(point_coords, axis=0)
point_labels = np.expand_dims(point_labels, axis=0)
# 准备mask输入
mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
has_mask_input = np.zeros(1, dtype=np.float32)
orig_im_size = np.array(image_size, dtype=np.int32)
return point_coords, point_labels, mask_input, has_mask_input, orig_im_size
def main():
# 1. 加载原始图片
path = "dog.jpg"
orig_image, input_image, (scale, offset_x, offset_y) = load_image(path)
decoder_path = "sam2.1_hiera_small_decoder.onnx"
encoder_path = "sam2.1_hiera_small_encoder.rknn"
# 2. 准备输入点
# input_point_orig = [[750, 400]]
input_point_orig = [[189, 394]]
input_point = [[
int(x * scale + offset_x),
int(y * scale + offset_y)
] for x, y in input_point_orig]
input_label = [1]
# 3. 运行RKNN encoder
print("Running RKNN encoder...")
rknn_lite = RKNNLite(verbose=False)
ret = rknn_lite.load_rknn(encoder_path)
if ret != 0:
print('Load RKNN model failed')
exit(ret)
ret = rknn_lite.init_runtime()
if ret != 0:
print('Init runtime environment failed')
exit(ret)
start_time = time.time()
encoder_outputs = rknn_lite.inference(inputs=[input_image], data_format="nchw")
end_time = time.time()
print(f"RKNN encoder time: {end_time - start_time} seconds")
high_res_feats_0, high_res_feats_1, image_embed = encoder_outputs
rknn_lite.release()
# 4. 运行ONNX decoder
print("Running ONNX decoder...")
decoder_session = onnxruntime.InferenceSession(decoder_path)
point_coords, point_labels, mask_input, has_mask_input, orig_im_size = prepare_point_input(
input_point, input_label, orig_image.size[::-1]
)
decoder_inputs = {
'image_embed': image_embed,
'high_res_feats_0': high_res_feats_0,
'high_res_feats_1': high_res_feats_1,
'point_coords': point_coords,
'point_labels': point_labels,
'mask_input': mask_input,
'has_mask_input': has_mask_input,
}
start_time = time.time()
low_res_masks, iou_predictions = decoder_session.run(None, decoder_inputs)
end_time = time.time()
print(f"ONNX decoder time: {end_time - start_time} seconds")
print(low_res_masks.shape)
# 5. 后处理
w, h = orig_image.size
masks_rknn = []
# 处理所有3个mask
for i in range(low_res_masks.shape[1]):
# 将mask缩放到1024x1024
masks_1024 = cv2.resize(
low_res_masks[0,i],
(1024, 1024),
interpolation=cv2.INTER_LINEAR
)
# 去除padding
new_h = int(h * scale)
new_w = int(w * scale)
start_h = (1024 - new_h) // 2
start_w = (1024 - new_w) // 2
masks_no_pad = masks_1024[start_h:start_h+new_h, start_w:start_w+new_w]
# 缩放到原始图片尺寸
mask = cv2.resize(
masks_no_pad,
(w, h),
interpolation=cv2.INTER_LINEAR
)
# 二值化
mask = mask > 0.0
masks_rknn.append(mask)
# 6. 可视化结果
plt.figure(figsize=(15, 5))
# 获取IoU分数排序的索引
sorted_indices = np.argsort(iou_predictions[0])[::-1] # 降序排序
for idx, mask_idx in enumerate(sorted_indices):
plt.subplot(1, 3, idx + 1)
plt.imshow(orig_image)
plt.imshow(masks_rknn[mask_idx], alpha=0.5)
plt.plot(input_point_orig[0][0], input_point_orig[0][1], 'rx')
plt.title(f'Mask {mask_idx+1}\nIoU: {iou_predictions[0][mask_idx]:.3f}')
plt.axis('off')
plt.tight_layout()
# plt.show()
plt.savefig("result.png")
print(f"\nIoU predictions: {iou_predictions}")
if __name__ == "__main__":
main() |