Spaces:
Build error
Build error
# Copyright (c) OpenMMLab. All rights reserved. | |
import argparse | |
import os | |
import os.path as osp | |
import warnings | |
import numpy as np | |
import onnx | |
import torch | |
from mmcv import Config | |
from mmcv.tensorrt import is_tensorrt_plugin_loaded, onnx2trt, save_trt_engine | |
from mmdet.core.export import preprocess_example_input | |
from mmdet.core.export.model_wrappers import (ONNXRuntimeDetector, | |
TensorRTDetector) | |
from mmdet.datasets import DATASETS | |
def get_GiB(x: int): | |
"""return x GiB.""" | |
return x * (1 << 30) | |
def onnx2tensorrt(onnx_file, | |
trt_file, | |
input_config, | |
verify=False, | |
show=False, | |
workspace_size=1, | |
verbose=False): | |
import tensorrt as trt | |
onnx_model = onnx.load(onnx_file) | |
max_shape = input_config['max_shape'] | |
min_shape = input_config['min_shape'] | |
opt_shape = input_config['opt_shape'] | |
fp16_mode = False | |
# create trt engine and wrapper | |
opt_shape_dict = {'input': [min_shape, opt_shape, max_shape]} | |
max_workspace_size = get_GiB(workspace_size) | |
trt_engine = onnx2trt( | |
onnx_model, | |
opt_shape_dict, | |
log_level=trt.Logger.VERBOSE if verbose else trt.Logger.ERROR, | |
fp16_mode=fp16_mode, | |
max_workspace_size=max_workspace_size) | |
save_dir, _ = osp.split(trt_file) | |
if save_dir: | |
os.makedirs(save_dir, exist_ok=True) | |
save_trt_engine(trt_engine, trt_file) | |
print(f'Successfully created TensorRT engine: {trt_file}') | |
if verify: | |
# prepare input | |
one_img, one_meta = preprocess_example_input(input_config) | |
img_list, img_meta_list = [one_img], [[one_meta]] | |
img_list = [_.cuda().contiguous() for _ in img_list] | |
# wrap ONNX and TensorRT model | |
onnx_model = ONNXRuntimeDetector(onnx_file, CLASSES, device_id=0) | |
trt_model = TensorRTDetector(trt_file, CLASSES, device_id=0) | |
# inference with wrapped model | |
with torch.no_grad(): | |
onnx_results = onnx_model( | |
img_list, img_metas=img_meta_list, return_loss=False)[0] | |
trt_results = trt_model( | |
img_list, img_metas=img_meta_list, return_loss=False)[0] | |
if show: | |
out_file_ort, out_file_trt = None, None | |
else: | |
out_file_ort, out_file_trt = 'show-ort.png', 'show-trt.png' | |
show_img = one_meta['show_img'] | |
score_thr = 0.3 | |
onnx_model.show_result( | |
show_img, | |
onnx_results, | |
score_thr=score_thr, | |
show=True, | |
win_name='ONNXRuntime', | |
out_file=out_file_ort) | |
trt_model.show_result( | |
show_img, | |
trt_results, | |
score_thr=score_thr, | |
show=True, | |
win_name='TensorRT', | |
out_file=out_file_trt) | |
with_mask = trt_model.with_masks | |
# compare a part of result | |
if with_mask: | |
compare_pairs = list(zip(onnx_results, trt_results)) | |
else: | |
compare_pairs = [(onnx_results, trt_results)] | |
err_msg = 'The numerical values are different between Pytorch' + \ | |
' and ONNX, but it does not necessarily mean the' + \ | |
' exported ONNX model is problematic.' | |
# check the numerical value | |
for onnx_res, pytorch_res in compare_pairs: | |
for o_res, p_res in zip(onnx_res, pytorch_res): | |
np.testing.assert_allclose( | |
o_res, p_res, rtol=1e-03, atol=1e-05, err_msg=err_msg) | |
print('The numerical values are the same between Pytorch and ONNX') | |
def parse_normalize_cfg(test_pipeline): | |
transforms = None | |
for pipeline in test_pipeline: | |
if 'transforms' in pipeline: | |
transforms = pipeline['transforms'] | |
break | |
assert transforms is not None, 'Failed to find `transforms`' | |
norm_config_li = [_ for _ in transforms if _['type'] == 'Normalize'] | |
assert len(norm_config_li) == 1, '`norm_config` should only have one' | |
norm_config = norm_config_li[0] | |
return norm_config | |
def parse_args(): | |
parser = argparse.ArgumentParser( | |
description='Convert MMDetection models from ONNX to TensorRT') | |
parser.add_argument('config', help='test config file path') | |
parser.add_argument('model', help='Filename of input ONNX model') | |
parser.add_argument( | |
'--trt-file', | |
type=str, | |
default='tmp.trt', | |
help='Filename of output TensorRT engine') | |
parser.add_argument( | |
'--input-img', type=str, default='', help='Image for test') | |
parser.add_argument( | |
'--show', action='store_true', help='Whether to show output results') | |
parser.add_argument( | |
'--dataset', | |
type=str, | |
default='coco', | |
help='Dataset name. This argument is deprecated and will be \ | |
removed in future releases.') | |
parser.add_argument( | |
'--verify', | |
action='store_true', | |
help='Verify the outputs of ONNXRuntime and TensorRT') | |
parser.add_argument( | |
'--verbose', | |
action='store_true', | |
help='Whether to verbose logging messages while creating \ | |
TensorRT engine. Defaults to False.') | |
parser.add_argument( | |
'--to-rgb', | |
action='store_false', | |
help='Feed model with RGB or BGR image. Default is RGB. This \ | |
argument is deprecated and will be removed in future releases.') | |
parser.add_argument( | |
'--shape', | |
type=int, | |
nargs='+', | |
default=[400, 600], | |
help='Input size of the model') | |
parser.add_argument( | |
'--mean', | |
type=float, | |
nargs='+', | |
default=[123.675, 116.28, 103.53], | |
help='Mean value used for preprocess input data. This argument \ | |
is deprecated and will be removed in future releases.') | |
parser.add_argument( | |
'--std', | |
type=float, | |
nargs='+', | |
default=[58.395, 57.12, 57.375], | |
help='Variance value used for preprocess input data. \ | |
This argument is deprecated and will be removed in future releases.') | |
parser.add_argument( | |
'--min-shape', | |
type=int, | |
nargs='+', | |
default=None, | |
help='Minimum input size of the model in TensorRT') | |
parser.add_argument( | |
'--max-shape', | |
type=int, | |
nargs='+', | |
default=None, | |
help='Maximum input size of the model in TensorRT') | |
parser.add_argument( | |
'--workspace-size', | |
type=int, | |
default=1, | |
help='Max workspace size in GiB') | |
args = parser.parse_args() | |
return args | |
if __name__ == '__main__': | |
assert is_tensorrt_plugin_loaded(), 'TensorRT plugin should be compiled.' | |
args = parse_args() | |
warnings.warn( | |
'Arguments like `--to-rgb`, `--mean`, `--std`, `--dataset` would be \ | |
parsed directly from config file and are deprecated and will be \ | |
removed in future releases.') | |
if not args.input_img: | |
args.input_img = osp.join(osp.dirname(__file__), '../../demo/demo.jpg') | |
cfg = Config.fromfile(args.config) | |
def parse_shape(shape): | |
if len(shape) == 1: | |
shape = (1, 3, shape[0], shape[0]) | |
elif len(args.shape) == 2: | |
shape = (1, 3) + tuple(shape) | |
else: | |
raise ValueError('invalid input shape') | |
return shape | |
if args.shape: | |
input_shape = parse_shape(args.shape) | |
else: | |
img_scale = cfg.test_pipeline[1]['img_scale'] | |
input_shape = (1, 3, img_scale[1], img_scale[0]) | |
if not args.max_shape: | |
max_shape = input_shape | |
else: | |
max_shape = parse_shape(args.max_shape) | |
if not args.min_shape: | |
min_shape = input_shape | |
else: | |
min_shape = parse_shape(args.min_shape) | |
dataset = DATASETS.get(cfg.data.test['type']) | |
assert (dataset is not None) | |
CLASSES = dataset.CLASSES | |
normalize_cfg = parse_normalize_cfg(cfg.test_pipeline) | |
input_config = { | |
'min_shape': min_shape, | |
'opt_shape': input_shape, | |
'max_shape': max_shape, | |
'input_shape': input_shape, | |
'input_path': args.input_img, | |
'normalize_cfg': normalize_cfg | |
} | |
# Create TensorRT engine | |
onnx2tensorrt( | |
args.model, | |
args.trt_file, | |
input_config, | |
verify=args.verify, | |
show=args.show, | |
workspace_size=args.workspace_size, | |
verbose=args.verbose) | |
# Following strings of text style are from colorama package | |
bright_style, reset_style = '\x1b[1m', '\x1b[0m' | |
red_text, blue_text = '\x1b[31m', '\x1b[34m' | |
white_background = '\x1b[107m' | |
msg = white_background + bright_style + red_text | |
msg += 'DeprecationWarning: This tool will be deprecated in future. ' | |
msg += blue_text + 'Welcome to use the unified model deployment toolbox ' | |
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy' | |
msg += reset_style | |
warnings.warn(msg) | |