OpenVoice-RKNN2 / convert_rknn.py
happyme531's picture
Upload 10 files
cc403c3 verified
#!/usr/bin/env python
# coding: utf-8
import datetime
import argparse
from rknn.api import RKNN
from sys import exit
# 模型配置
MODELS = {
'tone_clone': 'tone_clone_model.onnx',
'tone_color_extract': 'tone_color_extract_model.onnx',
}
TARGET_AUDIO_LENS = [1024]
SOURCE_AUDIO_LENS = [1024]
AUDIO_DIM = 513
QUANTIZE=False
detailed_performance_log = True
def convert_model(model_type):
"""转换指定类型的模型到RKNN格式"""
if model_type not in MODELS:
print(f"错误: 不支持的模型类型 {model_type}")
return False
onnx_model = MODELS[model_type]
rknn_model = onnx_model.replace(".onnx",".rknn")
if model_type == 'tone_clone':
shapes = [
[
[1, 513, target_audio_len], # audio
[1], # audio_length
[1, 256, 1], # src_tone
[1, 256, 1], # dest_tone
[1], # tau
] for target_audio_len in TARGET_AUDIO_LENS
]
elif model_type == 'tone_color_extract':
shapes = [
[
[1, source_audio_len, 513], # audio
] for source_audio_len in SOURCE_AUDIO_LENS
]
# shapes = None
timedate_iso = datetime.datetime.now().isoformat()
rknn = RKNN(verbose=True)
rknn.config(
quantized_dtype='w8a8',
quantized_algorithm='normal',
quantized_method='channel',
quantized_hybrid_level=0,
target_platform='rk3588',
quant_img_RGB2BGR = False,
float_dtype='float16',
optimization_level=3,
custom_string=f"converted by: qq: 232004040, email: 2302004040@qq.com at {timedate_iso}",
remove_weight=False,
compress_weight=False,
inputs_yuv_fmt=None,
single_core_mode=False,
dynamic_input=shapes,
model_pruning=False,
op_target=None,
quantize_weight=False,
remove_reshape=False,
sparse_infer=False,
enable_flash_attention=False,
# disable_rules=['convert_gemm_by_exmatmul']
)
print(f"开始转换 {model_type} 模型...")
ret = rknn.load_onnx(model=onnx_model)
if ret != 0:
print("加载ONNX模型失败")
return False
ret = rknn.build(do_quantization=False, rknn_batch_size=None)
if ret != 0:
print("构建RKNN模型失败")
return False
ret = rknn.export_rknn(rknn_model)
if ret != 0:
print("导出RKNN模型失败")
return False
print(f"成功转换模型: {rknn_model}")
return True
def main():
parser = argparse.ArgumentParser(description='转换ONNX模型到RKNN格式')
parser.add_argument('model_type', nargs='?', default='all',
choices=['all', 'tone_clone', 'tone_color_extract'],
help='要转换的模型类型 (默认: all)')
args = parser.parse_args()
if args.model_type == 'all':
# 转换所有模型
for model_type in MODELS.keys():
if not convert_model(model_type):
print(f"转换 {model_type} 失败")
else:
# 转换指定模型
if not convert_model(args.model_type):
print(f"转换 {args.model_type} 失败")
if __name__ == '__main__':
main()