Spaces:
Running
Running
import argparse | |
import os | |
import sys | |
import warnings | |
from io import BytesIO | |
from pathlib import Path | |
import onnx | |
import torch | |
from mmdet.apis import init_detector | |
from mmengine.config import ConfigDict | |
from mmengine.logging import print_log | |
from mmengine.utils.path import mkdir_or_exist | |
# Add MMYOLO ROOT to sys.path | |
sys.path.append(str(Path(__file__).resolve().parents[3])) | |
from projects.easydeploy.model import DeployModel, MMYOLOBackend # noqa E402 | |
warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) | |
warnings.filterwarnings(action='ignore', category=torch.jit.ScriptWarning) | |
warnings.filterwarnings(action='ignore', category=UserWarning) | |
warnings.filterwarnings(action='ignore', category=FutureWarning) | |
warnings.filterwarnings(action='ignore', category=ResourceWarning) | |
def parse_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('config', help='Config file') | |
parser.add_argument('checkpoint', help='Checkpoint file') | |
parser.add_argument( | |
'--model-only', action='store_true', help='Export model only') | |
parser.add_argument( | |
'--work-dir', default='./work_dir', help='Path to save export model') | |
parser.add_argument( | |
'--img-size', | |
nargs='+', | |
type=int, | |
default=[640, 640], | |
help='Image size of height and width') | |
parser.add_argument('--batch-size', type=int, default=1, help='Batch size') | |
parser.add_argument( | |
'--device', default='cuda:0', help='Device used for inference') | |
parser.add_argument( | |
'--simplify', | |
action='store_true', | |
help='Simplify onnx model by onnx-sim') | |
parser.add_argument( | |
'--opset', type=int, default=11, help='ONNX opset version') | |
parser.add_argument( | |
'--backend', | |
type=str, | |
default='onnxruntime', | |
help='Backend for export onnx') | |
parser.add_argument( | |
'--pre-topk', | |
type=int, | |
default=1000, | |
help='Postprocess pre topk bboxes feed into NMS') | |
parser.add_argument( | |
'--keep-topk', | |
type=int, | |
default=100, | |
help='Postprocess keep topk bboxes out of NMS') | |
parser.add_argument( | |
'--iou-threshold', | |
type=float, | |
default=0.65, | |
help='IoU threshold for NMS') | |
parser.add_argument( | |
'--score-threshold', | |
type=float, | |
default=0.25, | |
help='Score threshold for NMS') | |
args = parser.parse_args() | |
args.img_size *= 2 if len(args.img_size) == 1 else 1 | |
return args | |
def build_model_from_cfg(config_path, checkpoint_path, device): | |
model = init_detector(config_path, checkpoint_path, device=device) | |
model.eval() | |
return model | |
def main(): | |
args = parse_args() | |
mkdir_or_exist(args.work_dir) | |
backend = MMYOLOBackend(args.backend.lower()) | |
if backend in (MMYOLOBackend.ONNXRUNTIME, MMYOLOBackend.OPENVINO, | |
MMYOLOBackend.TENSORRT8, MMYOLOBackend.TENSORRT7): | |
if not args.model_only: | |
print_log('Export ONNX with bbox decoder and NMS ...') | |
else: | |
args.model_only = True | |
print_log(f'Can not export postprocess for {args.backend.lower()}.\n' | |
f'Set "args.model_only=True" default.') | |
if args.model_only: | |
postprocess_cfg = None | |
output_names = None | |
else: | |
postprocess_cfg = ConfigDict( | |
pre_top_k=args.pre_topk, | |
keep_top_k=args.keep_topk, | |
iou_threshold=args.iou_threshold, | |
score_threshold=args.score_threshold) | |
output_names = ['num_dets', 'boxes', 'scores', 'labels'] | |
baseModel = build_model_from_cfg(args.config, args.checkpoint, args.device) | |
deploy_model = DeployModel( | |
baseModel=baseModel, backend=backend, postprocess_cfg=postprocess_cfg) | |
deploy_model.eval() | |
fake_input = torch.randn(args.batch_size, 3, | |
*args.img_size).to(args.device) | |
# dry run | |
deploy_model(fake_input) | |
save_onnx_path = os.path.join( | |
args.work_dir, | |
os.path.basename(args.checkpoint).replace('pth', 'onnx')) | |
# export onnx | |
with BytesIO() as f: | |
torch.onnx.export( | |
deploy_model, | |
fake_input, | |
f, | |
input_names=['images'], | |
output_names=output_names, | |
opset_version=args.opset) | |
f.seek(0) | |
onnx_model = onnx.load(f) | |
onnx.checker.check_model(onnx_model) | |
# Fix tensorrt onnx output shape, just for view | |
if not args.model_only and backend in (MMYOLOBackend.TENSORRT8, | |
MMYOLOBackend.TENSORRT7): | |
shapes = [ | |
args.batch_size, 1, args.batch_size, args.keep_topk, 4, | |
args.batch_size, args.keep_topk, args.batch_size, | |
args.keep_topk | |
] | |
for i in onnx_model.graph.output: | |
for j in i.type.tensor_type.shape.dim: | |
j.dim_param = str(shapes.pop(0)) | |
if args.simplify: | |
try: | |
import onnxsim | |
onnx_model, check = onnxsim.simplify(onnx_model) | |
assert check, 'assert check failed' | |
except Exception as e: | |
print_log(f'Simplify failure: {e}') | |
onnx.save(onnx_model, save_onnx_path) | |
print_log(f'ONNX export success, save into {save_onnx_path}') | |
if __name__ == '__main__': | |
main() | |