# Copyright (c) Tencent Inc. All rights reserved. import argparse import logging import os import os.path as osp from functools import partial import mmengine import torch.multiprocessing as mp from torch.multiprocessing import Process, set_start_method from mmdeploy.apis import (create_calib_input_data, extract_model, get_predefined_partition_cfg, torch2onnx, torch2torchscript, visualize_model) from mmdeploy.apis.core import PIPELINE_MANAGER from mmdeploy.apis.utils import to_backend from mmdeploy.backend.sdk.export_info import export2SDK from mmdeploy.utils import (IR, Backend, get_backend, get_calib_filename, get_ir_config, get_partition_config, get_root_logger, load_config, target_wrapper) def parse_args(): parser = argparse.ArgumentParser(description='Export model to backends.') parser.add_argument('deploy_cfg', help='deploy config path') parser.add_argument('model_cfg', help='model config path') parser.add_argument('checkpoint', help='model checkpoint path') parser.add_argument('img', help='image used to convert model model') parser.add_argument( '--test-img', default=None, type=str, nargs='+', help='image used to test model') parser.add_argument( '--work-dir', default=os.getcwd(), help='the dir to save logs and models') parser.add_argument( '--calib-dataset-cfg', help='dataset config path used to calibrate in int8 mode. If not \ specified, it will use "val" dataset in model config instead.', default=None) parser.add_argument( '--device', help='device used for conversion', default='cpu') parser.add_argument( '--log-level', help='set log level', default='INFO', choices=list(logging._nameToLevel.keys())) parser.add_argument( '--show', action='store_true', help='Show detection outputs') parser.add_argument( '--dump-info', action='store_true', help='Output information for SDK') parser.add_argument( '--quant-image-dir', default=None, help='Image directory for quantize model.') parser.add_argument( '--quant', action='store_true', help='Quantize model to low bit.') parser.add_argument( '--uri', default='192.168.1.1:60000', help='Remote ipv4:port or ipv6:port for inference on edge device.') args = parser.parse_args() return args def create_process(name, target, args, kwargs, ret_value=None): logger = get_root_logger() logger.info(f'{name} start.') log_level = logger.level wrap_func = partial(target_wrapper, target, log_level, ret_value) process = Process(target=wrap_func, args=args, kwargs=kwargs) process.start() process.join() if ret_value is not None: if ret_value.value != 0: logger.error(f'{name} failed.') exit(1) else: logger.info(f'{name} success.') def torch2ir(ir_type: IR): """Return the conversion function from torch to the intermediate representation. Args: ir_type (IR): The type of the intermediate representation. """ if ir_type == IR.ONNX: return torch2onnx elif ir_type == IR.TORCHSCRIPT: return torch2torchscript else: raise KeyError(f'Unexpected IR type {ir_type}') def main(): args = parse_args() set_start_method('spawn', force=True) logger = get_root_logger() log_level = logging.getLevelName(args.log_level) logger.setLevel(log_level) pipeline_funcs = [ torch2onnx, torch2torchscript, extract_model, create_calib_input_data ] PIPELINE_MANAGER.enable_multiprocess(True, pipeline_funcs) PIPELINE_MANAGER.set_log_level(log_level, pipeline_funcs) deploy_cfg_path = args.deploy_cfg model_cfg_path = args.model_cfg checkpoint_path = args.checkpoint quant = args.quant quant_image_dir = args.quant_image_dir # load deploy_cfg deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path) # create work_dir if not mmengine.mkdir_or_exist(osp.abspath(args.work_dir)) if args.dump_info: export2SDK( deploy_cfg, model_cfg, args.work_dir, pth=checkpoint_path, device=args.device) ret_value = mp.Value('d', 0, lock=False) # convert to IR ir_config = get_ir_config(deploy_cfg) ir_save_file = ir_config['save_file'] ir_type = IR.get(ir_config['type']) torch2ir(ir_type)( args.img, args.work_dir, ir_save_file, deploy_cfg_path, model_cfg_path, checkpoint_path, device=args.device) # convert backend ir_files = [osp.join(args.work_dir, ir_save_file)] # partition model partition_cfgs = get_partition_config(deploy_cfg) if partition_cfgs is not None: if 'partition_cfg' in partition_cfgs: partition_cfgs = partition_cfgs.get('partition_cfg', None) else: assert 'type' in partition_cfgs partition_cfgs = get_predefined_partition_cfg( deploy_cfg, partition_cfgs['type']) origin_ir_file = ir_files[0] ir_files = [] for partition_cfg in partition_cfgs: save_file = partition_cfg['save_file'] save_path = osp.join(args.work_dir, save_file) start = partition_cfg['start'] end = partition_cfg['end'] dynamic_axes = partition_cfg.get('dynamic_axes', None) extract_model( origin_ir_file, start, end, dynamic_axes=dynamic_axes, save_file=save_path) ir_files.append(save_path) # calib data calib_filename = get_calib_filename(deploy_cfg) if calib_filename is not None: calib_path = osp.join(args.work_dir, calib_filename) create_calib_input_data( calib_path, deploy_cfg_path, model_cfg_path, checkpoint_path, dataset_cfg=args.calib_dataset_cfg, dataset_type='val', device=args.device) backend_files = ir_files # convert backend backend = get_backend(deploy_cfg) # preprocess deploy_cfg if backend == Backend.RKNN: # TODO: Add this to task_processor in the future import tempfile from mmdeploy.utils import (get_common_config, get_normalization, get_quantization_config, get_rknn_quantization) quantization_cfg = get_quantization_config(deploy_cfg) common_params = get_common_config(deploy_cfg) if get_rknn_quantization(deploy_cfg) is True: transform = get_normalization(model_cfg) common_params.update( dict( mean_values=[transform['mean']], std_values=[transform['std']])) dataset_file = tempfile.NamedTemporaryFile(suffix='.txt').name with open(dataset_file, 'w') as f: f.writelines([osp.abspath(args.img)]) if quantization_cfg.get('dataset', None) is None: quantization_cfg['dataset'] = dataset_file if backend == Backend.ASCEND: # TODO: Add this to backend manager in the future if args.dump_info: from mmdeploy.backend.ascend import update_sdk_pipeline update_sdk_pipeline(args.work_dir) if backend == Backend.VACC: # TODO: Add this to task_processor in the future from onnx2vacc_quant_dataset import get_quant from mmdeploy.utils import get_model_inputs deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path) model_inputs = get_model_inputs(deploy_cfg) for onnx_path, model_input in zip(ir_files, model_inputs): quant_mode = model_input.get('qconfig', {}).get('dtype', 'fp16') assert quant_mode in ['int8', 'fp16'], quant_mode + ' not support now' shape_dict = model_input.get('shape', {}) if quant_mode == 'int8': create_process( 'vacc quant dataset', target=get_quant, args=(deploy_cfg, model_cfg, shape_dict, checkpoint_path, args.work_dir, args.device), kwargs=dict(), ret_value=ret_value) # convert to backend PIPELINE_MANAGER.set_log_level(log_level, [to_backend]) if backend == Backend.TENSORRT: PIPELINE_MANAGER.enable_multiprocess(True, [to_backend]) backend_files = to_backend( backend, ir_files, work_dir=args.work_dir, deploy_cfg=deploy_cfg, log_level=log_level, device=args.device, uri=args.uri) # ncnn quantization if backend == Backend.NCNN and quant: from onnx2ncnn_quant_table import get_table from mmdeploy.apis.ncnn import get_quant_model_file, ncnn2int8 model_param_paths = backend_files[::2] model_bin_paths = backend_files[1::2] backend_files = [] for onnx_path, model_param_path, model_bin_path in zip( ir_files, model_param_paths, model_bin_paths): deploy_cfg, model_cfg = load_config(deploy_cfg_path, model_cfg_path) quant_onnx, quant_table, quant_param, quant_bin = get_quant_model_file( # noqa: E501 onnx_path, args.work_dir) create_process( 'ncnn quant table', target=get_table, args=(onnx_path, deploy_cfg, model_cfg, quant_onnx, quant_table, quant_image_dir, args.device), kwargs=dict(), ret_value=ret_value) create_process( 'ncnn_int8', target=ncnn2int8, args=(model_param_path, model_bin_path, quant_table, quant_param, quant_bin), kwargs=dict(), ret_value=ret_value) backend_files += [quant_param, quant_bin] if args.test_img is None: args.test_img = args.img extra = dict( backend=backend, output_file=osp.join(args.work_dir, f'output_{backend.value}.jpg'), show_result=args.show) if backend == Backend.SNPE: extra['uri'] = args.uri # get backend inference result, try render create_process( f'visualize {backend.value} model', target=visualize_model, args=(model_cfg_path, deploy_cfg_path, backend_files, args.test_img, args.device), kwargs=extra, ret_value=ret_value) # get pytorch model inference result, try visualize if possible create_process( 'visualize pytorch model', target=visualize_model, args=(model_cfg_path, deploy_cfg_path, [checkpoint_path], args.test_img, args.device), kwargs=dict( backend=Backend.PYTORCH, output_file=osp.join(args.work_dir, 'output_pytorch.jpg'), show_result=args.show), ret_value=ret_value) logger.info('All process success.') if __name__ == '__main__': main()