File size: 5,810 Bytes
50704de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os
import time
os.chdir(os.path.dirname(os.path.abspath(__file__)))

import numpy as np
import onnxruntime
from rknnlite.api import RKNNLite
from PIL import Image
import matplotlib.pyplot as plt
import cv2


def load_image(path):
    """加载并预处理图片"""
    image = Image.open(path).convert("RGB")
    print(f"Original image size: {image.size}")
    
    # 计算resize后的尺寸,保持长宽比
    target_size = (1024, 1024)
    w, h = image.size
    scale = min(target_size[0] / w, target_size[1] / h)
    new_w = int(w * scale)
    new_h = int(h * scale)
    print(f"Scale factor: {scale}")
    print(f"Resized dimensions: {new_w}x{new_h}")
    
    # resize图片
    resized_image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
    
    # 创建1024x1024的黑色背景
    processed_image = Image.new("RGB", target_size, (0, 0, 0))
    # 将resized图片粘贴到中心位置
    paste_x = (target_size[0] - new_w) // 2
    paste_y = (target_size[1] - new_h) // 2
    print(f"Paste position: ({paste_x}, {paste_y})")
    processed_image.paste(resized_image, (paste_x, paste_y))
    
    # 保存处理后的图片用于检查
    processed_image.save("debug_processed_image.png")
    
    # 转换为numpy数组并归一化到[0,1] # 归一化整合到模型了
    img_np = np.array(processed_image).astype(np.float32) # / 255.0
    # 调整维度顺序从HWC到CHW
    img_np = img_np.transpose(2, 0, 1)
    # 添加batch维度
    img_np = np.expand_dims(img_np, axis=0)
    
    print(f"Final input tensor shape: {img_np.shape}")
    
    return image, img_np, (scale, paste_x, paste_y)

def prepare_point_input(point_coords, point_labels, image_size=(1024, 1024)):
    """准备点击输入数据"""
    point_coords = np.array(point_coords, dtype=np.float32)
    point_labels = np.array(point_labels, dtype=np.float32)
    
    # 添加batch维度
    point_coords = np.expand_dims(point_coords, axis=0)
    point_labels = np.expand_dims(point_labels, axis=0)
    
    # 准备mask输入
    mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
    has_mask_input = np.zeros(1, dtype=np.float32)
    orig_im_size = np.array(image_size, dtype=np.int32)
    
    return point_coords, point_labels, mask_input, has_mask_input, orig_im_size

def main():
    # 1. 加载原始图片
    path = "dog.jpg"
    orig_image, input_image, (scale, offset_x, offset_y) = load_image(path)
    decoder_path = "sam2.1_hiera_small_decoder.onnx"
    encoder_path = "sam2.1_hiera_small_encoder.rknn"

    # 2. 准备输入点
    # input_point_orig = [[750, 400]]
    input_point_orig = [[189, 394]]
    input_point = [[
        int(x * scale + offset_x), 
        int(y * scale + offset_y)
    ] for x, y in input_point_orig]
    input_label = [1]
    
    # 3. 运行RKNN encoder
    print("Running RKNN encoder...")
    rknn_lite = RKNNLite(verbose=False)
    
    ret = rknn_lite.load_rknn(encoder_path)
    if ret != 0:
        print('Load RKNN model failed')
        exit(ret)
    
    ret = rknn_lite.init_runtime()
    if ret != 0:
        print('Init runtime environment failed')
        exit(ret)
    start_time = time.time()
    encoder_outputs = rknn_lite.inference(inputs=[input_image], data_format="nchw")
    end_time = time.time()
    print(f"RKNN encoder time: {end_time - start_time} seconds")
    high_res_feats_0, high_res_feats_1, image_embed = encoder_outputs
    rknn_lite.release()
    
    # 4. 运行ONNX decoder
    print("Running ONNX decoder...")
    decoder_session = onnxruntime.InferenceSession(decoder_path)
    
    point_coords, point_labels, mask_input, has_mask_input, orig_im_size = prepare_point_input(
        input_point, input_label, orig_image.size[::-1]
    )
    
    decoder_inputs = {
        'image_embed': image_embed,
        'high_res_feats_0': high_res_feats_0,
        'high_res_feats_1': high_res_feats_1,
        'point_coords': point_coords,
        'point_labels': point_labels,
        'mask_input': mask_input,
        'has_mask_input': has_mask_input,
    }
    start_time = time.time()
    low_res_masks, iou_predictions = decoder_session.run(None, decoder_inputs)
    end_time = time.time()
    print(f"ONNX decoder time: {end_time - start_time} seconds")
    print(low_res_masks.shape)
    # 5. 后处理
    w, h = orig_image.size
    masks_rknn = []
    
    # 处理所有3个mask
    for i in range(low_res_masks.shape[1]):
        # 将mask缩放到1024x1024
        masks_1024 = cv2.resize(
            low_res_masks[0,i],
            (1024, 1024),
            interpolation=cv2.INTER_LINEAR
        )
        
        # 去除padding
        new_h = int(h * scale)
        new_w = int(w * scale)
        start_h = (1024 - new_h) // 2
        start_w = (1024 - new_w) // 2
        masks_no_pad = masks_1024[start_h:start_h+new_h, start_w:start_w+new_w]
        
        # 缩放到原始图片尺寸
        mask = cv2.resize(
            masks_no_pad,
            (w, h),
            interpolation=cv2.INTER_LINEAR
        )
        
        # 二值化
        mask = mask > 0.0
        masks_rknn.append(mask)
    
    # 6. 可视化结果
    plt.figure(figsize=(15, 5))
    
    # 获取IoU分数排序的索引
    sorted_indices = np.argsort(iou_predictions[0])[::-1]  # 降序排序
    
    for idx, mask_idx in enumerate(sorted_indices):
        plt.subplot(1, 3, idx + 1)
        plt.imshow(orig_image)
        plt.imshow(masks_rknn[mask_idx], alpha=0.5)
        plt.plot(input_point_orig[0][0], input_point_orig[0][1], 'rx')
        plt.title(f'Mask {mask_idx+1}\nIoU: {iou_predictions[0][mask_idx]:.3f}')
        plt.axis('off')
    
    plt.tight_layout()
    # plt.show()
    plt.savefig("result.png")
    
    print(f"\nIoU predictions: {iou_predictions}")

if __name__ == "__main__":
    main()