File size: 3,314 Bytes
50704de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
#!/usr/bin/env python
# coding: utf-8

import datetime
import argparse
from rknn.api import RKNN
from sys import exit
import os
import onnxslim

num_pointss = [1]
num_labelss = [1]

def convert_to_rknn(onnx_model, model_part, dataset="/home/zt/rk3588-nn/rknn_model_zoo/datasets/COCO/coco_subset_20.txt", quantize=False):
    """转换单个ONNX模型到RKNN格式"""
    rknn_model = onnx_model.replace(".onnx",".rknn")
    timedate_iso = datetime.datetime.now().isoformat()
    
    print(f"\n开始转换 {onnx_model}{rknn_model}")

    input_shapes = None

    if model_part == "encoder":
        input_shapes = None
    elif model_part == "decoder":
        input_shapes = [
            [
                [1, 256, 64, 64],  # image_embedding
                [1, 32, 256, 256],  # high_res_feats_0
                [1, 64, 128, 128],  # high_res_feats_1
                [num_labels, num_points, 2],  # point_coords
                [num_labels, num_points],  # point_labels
                [num_labels, 1, 256, 256],  # mask_input
                [num_labels],  # has_mask_input
            ]
            for num_labels in num_labelss
            for num_points in num_pointss
        ]
    
    rknn = RKNN(verbose=True)
    rknn.config(
        dynamic_input=input_shapes,
        std_values=[[255,255,255]] if model_part == "encoder" else None,
        quantized_dtype='w8a8',
        quantized_algorithm='normal',
        quantized_method='channel',
        quantized_hybrid_level=0,
        target_platform='rk3588',
        quant_img_RGB2BGR = False,
        float_dtype='float16',
        optimization_level=3,
        custom_string=f"converted at {timedate_iso}",
        remove_weight=False,
        compress_weight=False,
        inputs_yuv_fmt=None,
        single_core_mode=False,
        model_pruning=False,
        op_target=None,
        quantize_weight=False,
        remove_reshape=False,
        sparse_infer=False,
        enable_flash_attention=False,
    )

    ret = rknn.load_onnx(model=onnx_model)
    ret = rknn.build(do_quantization=quantize, dataset=dataset, rknn_batch_size=None)
    ret = rknn.export_rknn(rknn_model)
    print(f"完成转换 {rknn_model}\n")

def main():
    parser = argparse.ArgumentParser(description='转换SAM模型从ONNX到RKNN格式')
    parser.add_argument('model_name', type=str, help='模型名称,例如: sam2.1_hiera_tiny')
    args = parser.parse_args()
    
    # 构建encoder和decoder的文件名
    encoder_onnx = f"{args.model_name}_encoder.onnx"
    decoder_onnx = f"{args.model_name}_decoder.onnx"
    
    # 检查文件是否存在
    for model in [encoder_onnx, decoder_onnx]:
        if not os.path.exists(model):
            print(f"错误: 找不到文件 {model}")
            exit(1)
    
    # 转换encoder和decoder
    #encoder需要先跑一个onnxslim
    print("开始转换encoder...")
    onnxslim.slim(encoder_onnx, output_model="encoder_slim.onnx", skip_fusion_patterns=["EliminationSlice"])
    convert_to_rknn("encoder_slim.onnx", model_part="encoder")
    os.rename("encoder_slim.rknn", encoder_onnx.replace(".onnx", ".rknn"))
    os.remove("encoder_slim.onnx")

    # convert_to_rknn(decoder_onnx, model_part="decoder") # 坏的
    
    print("所有模型转换完成!")

if __name__ == "__main__":
    main()