| |
| |
|
|
| import datetime |
| import argparse |
| from rknn.api import RKNN |
| from sys import exit |
|
|
| |
| MODELS = { |
| 'tone_clone': 'tone_clone_model.onnx', |
| 'tone_color_extract': 'tone_color_extract_model.onnx', |
| } |
|
|
| TARGET_AUDIO_LENS = [1024] |
|
|
| SOURCE_AUDIO_LENS = [1024] |
|
|
| AUDIO_DIM = 513 |
|
|
| QUANTIZE=False |
| detailed_performance_log = True |
|
|
| def convert_model(model_type): |
| """转换指定类型的模型到RKNN格式""" |
| if model_type not in MODELS: |
| print(f"错误: 不支持的模型类型 {model_type}") |
| return False |
| |
| onnx_model = MODELS[model_type] |
| rknn_model = onnx_model.replace(".onnx",".rknn") |
|
|
| if model_type == 'tone_clone': |
| shapes = [ |
| [ |
| [1, 513, target_audio_len], |
| [1], |
| [1, 256, 1], |
| [1, 256, 1], |
| [1], |
| ] for target_audio_len in TARGET_AUDIO_LENS |
| ] |
| elif model_type == 'tone_color_extract': |
| shapes = [ |
| [ |
| [1, source_audio_len, 513], |
| ] for source_audio_len in SOURCE_AUDIO_LENS |
| ] |
| |
| |
| timedate_iso = datetime.datetime.now().isoformat() |
| |
| rknn = RKNN(verbose=True) |
| rknn.config( |
| quantized_dtype='w8a8', |
| quantized_algorithm='normal', |
| quantized_method='channel', |
| quantized_hybrid_level=0, |
| target_platform='rk3588', |
| quant_img_RGB2BGR = False, |
| float_dtype='float16', |
| optimization_level=3, |
| custom_string=f"converted by: qq: 232004040, email: 2302004040@qq.com at {timedate_iso}", |
| remove_weight=False, |
| compress_weight=False, |
| inputs_yuv_fmt=None, |
| single_core_mode=False, |
| dynamic_input=shapes, |
| model_pruning=False, |
| op_target=None, |
| quantize_weight=False, |
| remove_reshape=False, |
| sparse_infer=False, |
| enable_flash_attention=False, |
| |
| ) |
|
|
| print(f"开始转换 {model_type} 模型...") |
| ret = rknn.load_onnx(model=onnx_model) |
| if ret != 0: |
| print("加载ONNX模型失败") |
| return False |
| |
| ret = rknn.build(do_quantization=False, rknn_batch_size=None) |
| if ret != 0: |
| print("构建RKNN模型失败") |
| return False |
| |
| ret = rknn.export_rknn(rknn_model) |
| if ret != 0: |
| print("导出RKNN模型失败") |
| return False |
| |
| print(f"成功转换模型: {rknn_model}") |
| return True |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='转换ONNX模型到RKNN格式') |
| parser.add_argument('model_type', nargs='?', default='all', |
| choices=['all', 'tone_clone', 'tone_color_extract'], |
| help='要转换的模型类型 (默认: all)') |
| |
| args = parser.parse_args() |
| |
| if args.model_type == 'all': |
| |
| for model_type in MODELS.keys(): |
| if not convert_model(model_type): |
| print(f"转换 {model_type} 失败") |
| else: |
| |
| if not convert_model(args.model_type): |
| print(f"转换 {args.model_type} 失败") |
|
|
| if __name__ == '__main__': |
| main() |
|
|
|
|
|
|