Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from argparse import ArgumentParser | |
from functools import partial | |
import cv2 | |
import numpy as np | |
import torch | |
from mmcv.onnx import register_extra_symbolics | |
from mmcv.parallel import collate | |
from mmdet.datasets import replace_ImageToTensor | |
from mmdet.datasets.pipelines import Compose | |
from torch import nn | |
from mmocr.apis import init_detector | |
from mmocr.core.deployment import ONNXRuntimeDetector, ONNXRuntimeRecognizer | |
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401 | |
from mmocr.utils import is_2dlist | |
def _convert_batchnorm(module): | |
module_output = module | |
if isinstance(module, torch.nn.SyncBatchNorm): | |
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps, | |
module.momentum, module.affine, | |
module.track_running_stats) | |
if module.affine: | |
module_output.weight.data = module.weight.data.clone().detach() | |
module_output.bias.data = module.bias.data.clone().detach() | |
# keep requires_grad unchanged | |
module_output.weight.requires_grad = module.weight.requires_grad | |
module_output.bias.requires_grad = module.bias.requires_grad | |
module_output.running_mean = module.running_mean | |
module_output.running_var = module.running_var | |
module_output.num_batches_tracked = module.num_batches_tracked | |
for name, child in module.named_children(): | |
module_output.add_module(name, _convert_batchnorm(child)) | |
del module | |
return module_output | |
def _prepare_data(cfg, imgs): | |
"""Inference image(s) with the detector. | |
Args: | |
model (nn.Module): The loaded detector. | |
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]): | |
Either image files or loaded images. | |
Returns: | |
result (dict): Predicted results. | |
""" | |
if isinstance(imgs, (list, tuple)): | |
if not isinstance(imgs[0], (np.ndarray, str)): | |
raise AssertionError('imgs must be strings or numpy arrays') | |
elif isinstance(imgs, (np.ndarray, str)): | |
imgs = [imgs] | |
else: | |
raise AssertionError('imgs must be strings or numpy arrays') | |
is_ndarray = isinstance(imgs[0], np.ndarray) | |
if is_ndarray: | |
cfg = cfg.copy() | |
# set loading pipeline type | |
cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray' | |
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) | |
test_pipeline = Compose(cfg.data.test.pipeline) | |
data = [] | |
for img in imgs: | |
# prepare data | |
if is_ndarray: | |
# directly add img | |
datum = dict(img=img) | |
else: | |
# add information into dict | |
datum = dict(img_info=dict(filename=img), img_prefix=None) | |
# build the data pipeline | |
datum = test_pipeline(datum) | |
# get tensor from list to stack for batch mode (text detection) | |
data.append(datum) | |
if isinstance(data[0]['img'], list) and len(data) > 1: | |
raise Exception('aug test does not support ' | |
f'inference with batch size ' | |
f'{len(data)}') | |
data = collate(data, samples_per_gpu=len(imgs)) | |
# process img_metas | |
if isinstance(data['img_metas'], list): | |
data['img_metas'] = [ | |
img_metas.data[0] for img_metas in data['img_metas'] | |
] | |
else: | |
data['img_metas'] = data['img_metas'].data | |
if isinstance(data['img'], list): | |
data['img'] = [img.data for img in data['img']] | |
if isinstance(data['img'][0], list): | |
data['img'] = [img[0] for img in data['img']] | |
else: | |
data['img'] = data['img'].data | |
return data | |
def pytorch2onnx(model: nn.Module, | |
model_type: str, | |
img_path: str, | |
verbose: bool = False, | |
show: bool = False, | |
opset_version: int = 11, | |
output_file: str = 'tmp.onnx', | |
verify: bool = False, | |
dynamic_export: bool = False, | |
device_id: int = 0): | |
"""Export PyTorch model to ONNX model and verify the outputs are same | |
between PyTorch and ONNX. | |
Args: | |
model (nn.Module): PyTorch model we want to export. | |
model_type (str): Model type, detection or recognition model. | |
img_path (str): We need to use this input to execute the model. | |
opset_version (int): The onnx op version. Default: 11. | |
verbose (bool): Whether print the computation graph. Default: False. | |
show (bool): Whether visialize final results. Default: False. | |
output_file (string): The path to where we store the output ONNX model. | |
Default: `tmp.onnx`. | |
verify (bool): Whether compare the outputs between PyTorch and ONNX. | |
Default: False. | |
dynamic_export (bool): Whether apply dynamic export. | |
Default: False. | |
device_id (id): Device id to place model and data. | |
Default: 0 | |
""" | |
device = torch.device(type='cuda', index=device_id) | |
model.to(device).eval() | |
_convert_batchnorm(model) | |
# prepare inputs | |
mm_inputs = _prepare_data(cfg=model.cfg, imgs=img_path) | |
imgs = mm_inputs.pop('img') | |
img_metas = mm_inputs.pop('img_metas') | |
if isinstance(imgs, list): | |
imgs = imgs[0] | |
img_list = [img[None, :].to(device) for img in imgs] | |
origin_forward = model.forward | |
if (model_type == 'det'): | |
model.forward = partial( | |
model.simple_test, img_metas=img_metas, rescale=True) | |
else: | |
model.forward = partial( | |
model.forward, | |
img_metas=img_metas, | |
return_loss=False, | |
rescale=True) | |
# pytorch has some bug in pytorch1.3, we have to fix it | |
# by replacing these existing op | |
register_extra_symbolics(opset_version) | |
dynamic_axes = None | |
if dynamic_export and model_type == 'det': | |
dynamic_axes = { | |
'input': { | |
0: 'batch', | |
2: 'height', | |
3: 'width' | |
}, | |
'output': { | |
0: 'batch', | |
2: 'height', | |
3: 'width' | |
} | |
} | |
elif dynamic_export and model_type == 'recog': | |
dynamic_axes = { | |
'input': { | |
0: 'batch', | |
3: 'width' | |
}, | |
'output': { | |
0: 'batch', | |
1: 'seq_len', | |
2: 'num_classes' | |
} | |
} | |
with torch.no_grad(): | |
torch.onnx.export( | |
model, (img_list[0], ), | |
output_file, | |
input_names=['input'], | |
output_names=['output'], | |
export_params=True, | |
keep_initializers_as_inputs=False, | |
verbose=verbose, | |
opset_version=opset_version, | |
dynamic_axes=dynamic_axes) | |
print(f'Successfully exported ONNX model: {output_file}') | |
if verify: | |
# check by onnx | |
import onnx | |
onnx_model = onnx.load(output_file) | |
onnx.checker.check_model(onnx_model) | |
scale_factor = (0.5, 0.5) if model_type == 'det' else (1, 0.5) | |
if dynamic_export: | |
# scale image for dynamic shape test | |
img_list = [ | |
nn.functional.interpolate(_, scale_factor=scale_factor) | |
for _ in img_list | |
] | |
if model_type == 'det': | |
img_metas[0][0][ | |
'scale_factor'] = img_metas[0][0]['scale_factor'] * ( | |
scale_factor * 2) | |
# check the numerical value | |
# get pytorch output | |
with torch.no_grad(): | |
model.forward = origin_forward | |
pytorch_out = model.simple_test( | |
img_list[0], img_metas[0], rescale=True) | |
# get onnx output | |
if model_type == 'det': | |
onnx_model = ONNXRuntimeDetector(output_file, model.cfg, device_id) | |
else: | |
onnx_model = ONNXRuntimeRecognizer(output_file, model.cfg, | |
device_id) | |
onnx_out = onnx_model.simple_test( | |
img_list[0], img_metas[0], rescale=True) | |
# compare results | |
same_diff = 'same' | |
if model_type == 'recog': | |
for onnx_result, pytorch_result in zip(onnx_out, pytorch_out): | |
if onnx_result['text'] != pytorch_result[ | |
'text'] or not np.allclose( | |
np.array(onnx_result['score']), | |
np.array(pytorch_result['score']), | |
rtol=1e-4, | |
atol=1e-4): | |
same_diff = 'different' | |
break | |
else: | |
for onnx_result, pytorch_result in zip( | |
onnx_out[0]['boundary_result'], | |
pytorch_out[0]['boundary_result']): | |
if not np.allclose( | |
np.array(onnx_result), | |
np.array(pytorch_result), | |
rtol=1e-4, | |
atol=1e-4): | |
same_diff = 'different' | |
break | |
print('The outputs are {} between PyTorch and ONNX'.format(same_diff)) | |
if show: | |
onnx_img = onnx_model.show_result( | |
img_path, onnx_out[0], out_file='onnx.jpg', show=False) | |
pytorch_img = model.show_result( | |
img_path, pytorch_out[0], out_file='pytorch.jpg', show=False) | |
if onnx_img is None: | |
onnx_img = cv2.imread(img_path) | |
if pytorch_img is None: | |
pytorch_img = cv2.imread(img_path) | |
cv2.imshow('PyTorch', pytorch_img) | |
cv2.imshow('ONNXRuntime', onnx_img) | |
cv2.waitKey() | |
return | |
def main(): | |
parser = ArgumentParser( | |
description='Convert MMOCR models from pytorch to ONNX') | |
parser.add_argument('model_config', type=str, help='Config file.') | |
parser.add_argument( | |
'model_ckpt', type=str, help='Checkpint file (local or url).') | |
parser.add_argument( | |
'model_type', | |
type=str, | |
help='Detection or recognition model to deploy.', | |
choices=['recog', 'det']) | |
parser.add_argument('image_path', type=str, help='Input Image file.') | |
parser.add_argument( | |
'--output-file', | |
type=str, | |
help='Output file name of the onnx model.', | |
default='tmp.onnx') | |
parser.add_argument( | |
'--device-id', default=0, help='Device used for inference.') | |
parser.add_argument( | |
'--opset-version', | |
type=int, | |
help='ONNX opset version, default to 11.', | |
default=11) | |
parser.add_argument( | |
'--verify', | |
action='store_true', | |
help='Whether verify the outputs of onnx and pytorch are same.', | |
default=False) | |
parser.add_argument( | |
'--verbose', | |
action='store_true', | |
help='Whether print the computation graph.', | |
default=False) | |
parser.add_argument( | |
'--show', | |
action='store_true', | |
help='Whether visualize final output.', | |
default=False) | |
parser.add_argument( | |
'--dynamic-export', | |
action='store_true', | |
help='Whether dynamically export onnx model.', | |
default=False) | |
args = parser.parse_args() | |
# Following strings of text style are from colorama package | |
bright_style, reset_style = '\x1b[1m', '\x1b[0m' | |
red_text, blue_text = '\x1b[31m', '\x1b[34m' | |
white_background = '\x1b[107m' | |
msg = white_background + bright_style + red_text | |
msg += 'DeprecationWarning: This tool will be deprecated in future. ' | |
msg += blue_text + 'Welcome to use the unified model deployment toolbox ' | |
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy' | |
msg += reset_style | |
warnings.warn(msg) | |
device = torch.device(type='cuda', index=args.device_id) | |
# build model | |
model = init_detector(args.model_config, args.model_ckpt, device=device) | |
if hasattr(model, 'module'): | |
model = model.module | |
if model.cfg.data.test.get('pipeline', None) is None: | |
if is_2dlist(model.cfg.data.test.datasets): | |
model.cfg.data.test.pipeline = \ | |
model.cfg.data.test.datasets[0][0].pipeline | |
else: | |
model.cfg.data.test.pipeline = \ | |
model.cfg.data.test['datasets'][0].pipeline | |
if is_2dlist(model.cfg.data.test.pipeline): | |
model.cfg.data.test.pipeline = model.cfg.data.test.pipeline[0] | |
pytorch2onnx( | |
model, | |
model_type=args.model_type, | |
output_file=args.output_file, | |
img_path=args.image_path, | |
opset_version=args.opset_version, | |
verify=args.verify, | |
verbose=args.verbose, | |
show=args.show, | |
device_id=args.device_id, | |
dynamic_export=args.dynamic_export) | |
if __name__ == '__main__': | |
main() | |