|
|
|
|
|
|
|
import datetime |
|
import argparse |
|
from rknn.api import RKNN |
|
from sys import exit |
|
import os |
|
import onnxslim |
|
|
|
num_pointss = [1] |
|
num_labelss = [1] |
|
|
|
def convert_to_rknn(onnx_model, model_part, dataset="/home/zt/rk3588-nn/rknn_model_zoo/datasets/COCO/coco_subset_20.txt", quantize=False): |
|
"""转换单个ONNX模型到RKNN格式""" |
|
rknn_model = onnx_model.replace(".onnx",".rknn") |
|
timedate_iso = datetime.datetime.now().isoformat() |
|
|
|
print(f"\n开始转换 {onnx_model} 到 {rknn_model}") |
|
|
|
input_shapes = None |
|
|
|
if model_part == "encoder": |
|
input_shapes = None |
|
elif model_part == "decoder": |
|
input_shapes = [ |
|
[ |
|
[1, 256, 64, 64], |
|
[1, 32, 256, 256], |
|
[1, 64, 128, 128], |
|
[num_labels, num_points, 2], |
|
[num_labels, num_points], |
|
[num_labels, 1, 256, 256], |
|
[num_labels], |
|
] |
|
for num_labels in num_labelss |
|
for num_points in num_pointss |
|
] |
|
|
|
rknn = RKNN(verbose=True) |
|
rknn.config( |
|
dynamic_input=input_shapes, |
|
std_values=[[255,255,255]] if model_part == "encoder" else None, |
|
quantized_dtype='w8a8', |
|
quantized_algorithm='normal', |
|
quantized_method='channel', |
|
quantized_hybrid_level=0, |
|
target_platform='rk3588', |
|
quant_img_RGB2BGR = False, |
|
float_dtype='float16', |
|
optimization_level=3, |
|
custom_string=f"converted at {timedate_iso}", |
|
remove_weight=False, |
|
compress_weight=False, |
|
inputs_yuv_fmt=None, |
|
single_core_mode=False, |
|
model_pruning=False, |
|
op_target=None, |
|
quantize_weight=False, |
|
remove_reshape=False, |
|
sparse_infer=False, |
|
enable_flash_attention=False, |
|
) |
|
|
|
ret = rknn.load_onnx(model=onnx_model) |
|
ret = rknn.build(do_quantization=quantize, dataset=dataset, rknn_batch_size=None) |
|
ret = rknn.export_rknn(rknn_model) |
|
print(f"完成转换 {rknn_model}\n") |
|
|
|
def main(): |
|
parser = argparse.ArgumentParser(description='转换SAM模型从ONNX到RKNN格式') |
|
parser.add_argument('model_name', type=str, help='模型名称,例如: sam2.1_hiera_tiny') |
|
args = parser.parse_args() |
|
|
|
|
|
encoder_onnx = f"{args.model_name}_encoder.onnx" |
|
decoder_onnx = f"{args.model_name}_decoder.onnx" |
|
|
|
|
|
for model in [encoder_onnx, decoder_onnx]: |
|
if not os.path.exists(model): |
|
print(f"错误: 找不到文件 {model}") |
|
exit(1) |
|
|
|
|
|
|
|
print("开始转换encoder...") |
|
onnxslim.slim(encoder_onnx, output_model="encoder_slim.onnx", skip_fusion_patterns=["EliminationSlice"]) |
|
convert_to_rknn("encoder_slim.onnx", model_part="encoder") |
|
os.rename("encoder_slim.rknn", encoder_onnx.replace(".onnx", ".rknn")) |
|
os.remove("encoder_slim.onnx") |
|
|
|
|
|
|
|
print("所有模型转换完成!") |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|