File size: 3,341 Bytes
cc403c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#!/usr/bin/env python
# coding: utf-8

import datetime
import argparse
from rknn.api import RKNN
from sys import exit

# 模型配置
MODELS = {
    'tone_clone': 'tone_clone_model.onnx',
    'tone_color_extract': 'tone_color_extract_model.onnx',
}

TARGET_AUDIO_LENS = [1024]

SOURCE_AUDIO_LENS = [1024]

AUDIO_DIM = 513

QUANTIZE=False
detailed_performance_log = True

def convert_model(model_type):
    """转换指定类型的模型到RKNN格式"""
    if model_type not in MODELS:
        print(f"错误: 不支持的模型类型 {model_type}")
        return False
        
    onnx_model = MODELS[model_type]
    rknn_model = onnx_model.replace(".onnx",".rknn")

    if model_type == 'tone_clone':
        shapes = [
            [
                [1, 513, target_audio_len], # audio
                [1], # audio_length
                [1, 256, 1], # src_tone
                [1, 256, 1], # dest_tone
                [1], # tau
            ] for target_audio_len in TARGET_AUDIO_LENS
        ]
    elif model_type == 'tone_color_extract':
        shapes = [
            [
                [1, source_audio_len, 513], # audio
            ] for source_audio_len in SOURCE_AUDIO_LENS
        ]
        # shapes = None
    
    timedate_iso = datetime.datetime.now().isoformat()
    
    rknn = RKNN(verbose=True)
    rknn.config(
        quantized_dtype='w8a8',
        quantized_algorithm='normal',
        quantized_method='channel',
        quantized_hybrid_level=0,
        target_platform='rk3588',
        quant_img_RGB2BGR = False,
        float_dtype='float16',
        optimization_level=3,
        custom_string=f"converted by: qq: 232004040, email: 2302004040@qq.com at {timedate_iso}",
        remove_weight=False,
        compress_weight=False,
        inputs_yuv_fmt=None,
        single_core_mode=False,
        dynamic_input=shapes,
        model_pruning=False,
        op_target=None,
        quantize_weight=False,
        remove_reshape=False,
        sparse_infer=False,
        enable_flash_attention=False,
        #  disable_rules=['convert_gemm_by_exmatmul']
    )

    print(f"开始转换 {model_type} 模型...")
    ret = rknn.load_onnx(model=onnx_model)
    if ret != 0:
        print("加载ONNX模型失败")
        return False
        
    ret = rknn.build(do_quantization=False, rknn_batch_size=None)
    if ret != 0:
        print("构建RKNN模型失败")
        return False
        
    ret = rknn.export_rknn(rknn_model)
    if ret != 0:
        print("导出RKNN模型失败")
        return False
        
    print(f"成功转换模型: {rknn_model}")
    return True

def main():
    parser = argparse.ArgumentParser(description='转换ONNX模型到RKNN格式')
    parser.add_argument('model_type', nargs='?', default='all',
                      choices=['all', 'tone_clone', 'tone_color_extract'],
                      help='要转换的模型类型 (默认: all)')
    
    args = parser.parse_args()
    
    if args.model_type == 'all':
        # 转换所有模型
        for model_type in MODELS.keys():
            if not convert_model(model_type):
                print(f"转换 {model_type} 失败")
    else:
        # 转换指定模型
        if not convert_model(args.model_type):
            print(f"转换 {args.model_type} 失败")

if __name__ == '__main__':
    main()