File size: 6,481 Bytes
50704de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import os
os.chdir(os.path.dirname(os.path.abspath(__file__)))

import numpy as np
import torch
import onnxruntime
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


def load_image(url):
    """加载并预处理图片"""
    response = requests.get(url)
    image = Image.open(BytesIO(response.content)).convert("RGB")
    print(f"Original image size: {image.size}")
    
    # 计算resize后的尺寸,保持长宽比
    target_size = (1024, 1024)
    w, h = image.size
    scale = min(target_size[0] / w, target_size[1] / h)
    new_w = int(w * scale)
    new_h = int(h * scale)
    print(f"Scale factor: {scale}")
    print(f"Resized dimensions: {new_w}x{new_h}")
    
    # resize图片
    resized_image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
    
    # 创建1024x1024的黑色背景
    processed_image = Image.new("RGB", target_size, (0, 0, 0))
    # 将resized图片粘贴到中心位置
    paste_x = (target_size[0] - new_w) // 2
    paste_y = (target_size[1] - new_h) // 2
    print(f"Paste position: ({paste_x}, {paste_y})")
    processed_image.paste(resized_image, (paste_x, paste_y))
    
    # 保存处理后的图片用于检查
    processed_image.save("debug_processed_image.png")
    
    # 转换为numpy数组并归一化到[0,1]
    img_np = np.array(processed_image).astype(np.float32) / 255.0
    # 调整维度顺序从HWC到CHW
    img_np = img_np.transpose(2, 0, 1)
    # 添加batch维度
    img_np = np.expand_dims(img_np, axis=0)
    
    print(f"Final input tensor shape: {img_np.shape}")
    
    return image, img_np, (scale, paste_x, paste_y)

def prepare_point_input(point_coords, point_labels, image_size=(1024, 1024)):
    """准备点击输入数据"""
    point_coords = np.array(point_coords, dtype=np.float32)
    point_labels = np.array(point_labels, dtype=np.float32)
    
    # 添加batch维度
    point_coords = np.expand_dims(point_coords, axis=0)
    point_labels = np.expand_dims(point_labels, axis=0)
    
    # 准备mask输入
    mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
    has_mask_input = np.zeros(1, dtype=np.float32)
    orig_im_size = np.array(image_size, dtype=np.int32)
    
    return point_coords, point_labels, mask_input, has_mask_input, orig_im_size

def main():
    # 1. 加载原始图片
    url = "https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/dog.jpg"
    orig_image, input_image, (scale, offset_x, offset_y) = load_image(url)
    
    # 2. 准备输入点 - 需要根据scale和offset调整点击坐标
    input_point_orig = [[750, 400]]
    input_point = [[
        int(x * scale + offset_x), 
        int(y * scale + offset_y)
    ] for x, y in input_point_orig]
    print(f"Original point: {input_point_orig}")
    print(f"Transformed point: {input_point}")
    input_label = [1]
    
    # 3. 运行PyTorch模型
    print("Running PyTorch model...")
    checkpoint = "sam2.1_hiera_large.pt"
    model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
    predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
    
    with torch.inference_mode():
        predictor.set_image(orig_image)
        masks_pt, iou_scores_pt, low_res_masks_pt = predictor.predict(
            point_coords=np.array(input_point),
            point_labels=np.array(input_label),
            multimask_output=True
        )
    
    # 4. 运行ONNX模型
    print("Running ONNX model...")
    encoder_path = "sam2.1_hiera_tiny_encoder.s.onnx"
    decoder_path = "sam2.1_hiera_tiny_decoder.onnx"
    
    # 创建ONNX Runtime会话
    encoder_session = onnxruntime.InferenceSession(encoder_path)
    decoder_session = onnxruntime.InferenceSession(decoder_path)
    
    # 运行encoder
    encoder_inputs = {'image': input_image}
    high_res_feats_0, high_res_feats_1, image_embed = encoder_session.run(None, encoder_inputs)
    
    # 准备decoder输入
    point_coords, point_labels, mask_input, has_mask_input, orig_im_size = prepare_point_input(
        input_point, input_label, orig_image.size[::-1]
    )
    
    # 运行decoder
    decoder_inputs = {
        'image_embed': image_embed,
        'high_res_feats_0': high_res_feats_0,
        'high_res_feats_1': high_res_feats_1,
        'point_coords': point_coords,
        'point_labels': point_labels,
        # 'orig_im_size': orig_im_size,
        'mask_input': mask_input,
        'has_mask_input': has_mask_input,
    }
    
    low_res_masks, iou_predictions = decoder_session.run(None, decoder_inputs)
    
    # 后处理: 将low_res_masks缩放到原始图片尺寸
    w, h = orig_image.size
    
    # 1. 首先将mask缩放到1024x1024
    masks_1024 = torch.nn.functional.interpolate(
        torch.from_numpy(low_res_masks),
        size=(1024, 1024),
        mode="bilinear",
        align_corners=False
    )
    
    # 2. 去除padding
    new_h = int(h * scale)
    new_w = int(w * scale)
    start_h = (1024 - new_h) // 2
    start_w = (1024 - new_w) // 2
    masks_no_pad = masks_1024[..., start_h:start_h+new_h, start_w:start_w+new_w]
    
    # 3. 缩放到原始图片尺寸
    masks_onnx = torch.nn.functional.interpolate(
        masks_no_pad,
        size=(h, w),
        mode="bilinear",
        align_corners=False
    )
    
    # 4. 二值化
    masks_onnx = masks_onnx > 0.0
    masks_onnx = masks_onnx.numpy()
    
    # 在运行ONNX模型后,打印输出的shape
    print(f"\nOutput shapes:")
    print(f"PyTorch masks shape: {masks_pt.shape}")
    print(f"ONNX masks shape: {masks_onnx.shape}")
    
    # 修改可视化部分,暂时注释掉差异图
    plt.figure(figsize=(10, 5))
    
    # PyTorch结果
    plt.subplot(121)
    plt.imshow(orig_image)
    plt.imshow(masks_pt[0], alpha=0.5)
    plt.plot(input_point_orig[0][0], input_point_orig[0][1], 'rx')
    plt.title('PyTorch Output')
    plt.axis('off')
    
    # ONNX结果
    plt.subplot(122)
    plt.imshow(orig_image)
    plt.imshow(masks_onnx[0,0], alpha=0.5)
    plt.plot(input_point_orig[0][0], input_point_orig[0][1], 'rx')
    plt.title('ONNX Output')
    plt.axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # 6. 打印一些统计信息
    print("\nStatistics:")
    print(f"PyTorch IoU scores: {iou_scores_pt}")
    print(f"ONNX IoU predictions: {iou_predictions}")

if __name__ == "__main__":
    main()