stevengrove
initial commit
186701e
import argparse
import os
import sys
import warnings
from io import BytesIO
from pathlib import Path
import onnx
import torch
from mmdet.apis import init_detector
from mmengine.config import ConfigDict
from mmengine.logging import print_log
from mmengine.utils.path import mkdir_or_exist
# Add MMYOLO ROOT to sys.path
sys.path.append(str(Path(__file__).resolve().parents[3]))
from projects.easydeploy.model import DeployModel, MMYOLOBackend # noqa E402
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning)
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning)
warnings.filterwarnings(action='ignore', category=UserWarning)
warnings.filterwarnings(action='ignore', category=FutureWarning)
warnings.filterwarnings(action='ignore', category=ResourceWarning)
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('config', help='Config file')
parser.add_argument('checkpoint', help='Checkpoint file')
parser.add_argument(
'--model-only', action='store_true', help='Export model only')
parser.add_argument(
'--work-dir', default='./work_dir', help='Path to save export model')
parser.add_argument(
'--img-size',
nargs='+',
type=int,
default=[640, 640],
help='Image size of height and width')
parser.add_argument('--batch-size', type=int, default=1, help='Batch size')
parser.add_argument(
'--device', default='cuda:0', help='Device used for inference')
parser.add_argument(
'--simplify',
action='store_true',
help='Simplify onnx model by onnx-sim')
parser.add_argument(
'--opset', type=int, default=11, help='ONNX opset version')
parser.add_argument(
'--backend',
type=str,
default='onnxruntime',
help='Backend for export onnx')
parser.add_argument(
'--pre-topk',
type=int,
default=1000,
help='Postprocess pre topk bboxes feed into NMS')
parser.add_argument(
'--keep-topk',
type=int,
default=100,
help='Postprocess keep topk bboxes out of NMS')
parser.add_argument(
'--iou-threshold',
type=float,
default=0.65,
help='IoU threshold for NMS')
parser.add_argument(
'--score-threshold',
type=float,
default=0.25,
help='Score threshold for NMS')
args = parser.parse_args()
args.img_size *= 2 if len(args.img_size) == 1 else 1
return args
def build_model_from_cfg(config_path, checkpoint_path, device):
model = init_detector(config_path, checkpoint_path, device=device)
model.eval()
return model
def main():
args = parse_args()
mkdir_or_exist(args.work_dir)
backend = MMYOLOBackend(args.backend.lower())
if backend in (MMYOLOBackend.ONNXRUNTIME, MMYOLOBackend.OPENVINO,
MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT7):
if not args.model_only:
print_log('Export ONNX with bbox decoder and NMS ...')
else:
args.model_only = True
print_log(f'Can not export postprocess for {args.backend.lower()}.\n'
f'Set "args.model_only=True" default.')
if args.model_only:
postprocess_cfg = None
output_names = None
else:
postprocess_cfg = ConfigDict(
pre_top_k=args.pre_topk,
keep_top_k=args.keep_topk,
iou_threshold=args.iou_threshold,
score_threshold=args.score_threshold)
output_names = ['num_dets', 'boxes', 'scores', 'labels']
baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device)
deploy_model = DeployModel(
baseModel=baseModel, backend=backend, postprocess_cfg=postprocess_cfg)
deploy_model.eval()
fake_input = torch.randn(args.batch_size, 3,
*args.img_size).to(args.device)
# dry run
deploy_model(fake_input)
save_onnx_path = os.path.join(
args.work_dir,
os.path.basename(args.checkpoint).replace('pth', 'onnx'))
# export onnx
with BytesIO() as f:
torch.onnx.export(
deploy_model,
fake_input,
f,
input_names=['images'],
output_names=output_names,
opset_version=args.opset)
f.seek(0)
onnx_model = onnx.load(f)
onnx.checker.check_model(onnx_model)
# Fix tensorrt onnx output shape, just for view
if not args.model_only and backend in (MMYOLOBackend.TENSORRT8,
MMYOLOBackend.TENSORRT7):
shapes = [
args.batch_size, 1, args.batch_size, args.keep_topk, 4,
args.batch_size, args.keep_topk, args.batch_size,
args.keep_topk
]
for i in onnx_model.graph.output:
for j in i.type.tensor_type.shape.dim:
j.dim_param = str(shapes.pop(0))
if args.simplify:
try:
import onnxsim
onnx_model, check = onnxsim.simplify(onnx_model)
assert check, 'assert check failed'
except Exception as e:
print_log(f'Simplify failure: {e}')
onnx.save(onnx_model, save_onnx_path)
print_log(f'ONNX export success, save into {save_onnx_path}')
if __name__ == '__main__':
main()