MMOCR / tools /deployment /pytorch2onnx.py
tomofi's picture
Add application file
2366e36
raw
history blame
13 kB
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from argparse import ArgumentParser
from functools import partial
import cv2
import numpy as np
import torch
from mmcv.onnx import register_extra_symbolics
from mmcv.parallel import collate
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
from torch import nn
from mmocr.apis import init_detector
from mmocr.core.deployment import ONNXRuntimeDetector, ONNXRuntimeRecognizer
from mmocr.datasets.pipelines.crop import crop_img # noqa: F401
from mmocr.utils import is_2dlist
def _convert_batchnorm(module):
module_output = module
if isinstance(module, torch.nn.SyncBatchNorm):
module_output = torch.nn.BatchNorm2d(module.num_features, module.eps,
module.momentum, module.affine,
module.track_running_stats)
if module.affine:
module_output.weight.data = module.weight.data.clone().detach()
module_output.bias.data = module.bias.data.clone().detach()
# keep requires_grad unchanged
module_output.weight.requires_grad = module.weight.requires_grad
module_output.bias.requires_grad = module.bias.requires_grad
module_output.running_mean = module.running_mean
module_output.running_var = module.running_var
module_output.num_batches_tracked = module.num_batches_tracked
for name, child in module.named_children():
module_output.add_module(name, _convert_batchnorm(child))
del module
return module_output
def _prepare_data(cfg, imgs):
"""Inference image(s) with the detector.
Args:
model (nn.Module): The loaded detector.
imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
Either image files or loaded images.
Returns:
result (dict): Predicted results.
"""
if isinstance(imgs, (list, tuple)):
if not isinstance(imgs[0], (np.ndarray, str)):
raise AssertionError('imgs must be strings or numpy arrays')
elif isinstance(imgs, (np.ndarray, str)):
imgs = [imgs]
else:
raise AssertionError('imgs must be strings or numpy arrays')
is_ndarray = isinstance(imgs[0], np.ndarray)
if is_ndarray:
cfg = cfg.copy()
# set loading pipeline type
cfg.data.test.pipeline[0].type = 'LoadImageFromNdarray'
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
test_pipeline = Compose(cfg.data.test.pipeline)
data = []
for img in imgs:
# prepare data
if is_ndarray:
# directly add img
datum = dict(img=img)
else:
# add information into dict
datum = dict(img_info=dict(filename=img), img_prefix=None)
# build the data pipeline
datum = test_pipeline(datum)
# get tensor from list to stack for batch mode (text detection)
data.append(datum)
if isinstance(data[0]['img'], list) and len(data) > 1:
raise Exception('aug test does not support '
f'inference with batch size '
f'{len(data)}')
data = collate(data, samples_per_gpu=len(imgs))
# process img_metas
if isinstance(data['img_metas'], list):
data['img_metas'] = [
img_metas.data[0] for img_metas in data['img_metas']
]
else:
data['img_metas'] = data['img_metas'].data
if isinstance(data['img'], list):
data['img'] = [img.data for img in data['img']]
if isinstance(data['img'][0], list):
data['img'] = [img[0] for img in data['img']]
else:
data['img'] = data['img'].data
return data
def pytorch2onnx(model: nn.Module,
model_type: str,
img_path: str,
verbose: bool = False,
show: bool = False,
opset_version: int = 11,
output_file: str = 'tmp.onnx',
verify: bool = False,
dynamic_export: bool = False,
device_id: int = 0):
"""Export PyTorch model to ONNX model and verify the outputs are same
between PyTorch and ONNX.
Args:
model (nn.Module): PyTorch model we want to export.
model_type (str): Model type, detection or recognition model.
img_path (str): We need to use this input to execute the model.
opset_version (int): The onnx op version. Default: 11.
verbose (bool): Whether print the computation graph. Default: False.
show (bool): Whether visialize final results. Default: False.
output_file (string): The path to where we store the output ONNX model.
Default: `tmp.onnx`.
verify (bool): Whether compare the outputs between PyTorch and ONNX.
Default: False.
dynamic_export (bool): Whether apply dynamic export.
Default: False.
device_id (id): Device id to place model and data.
Default: 0
"""
device = torch.device(type='cuda', index=device_id)
model.to(device).eval()
_convert_batchnorm(model)
# prepare inputs
mm_inputs = _prepare_data(cfg=model.cfg, imgs=img_path)
imgs = mm_inputs.pop('img')
img_metas = mm_inputs.pop('img_metas')
if isinstance(imgs, list):
imgs = imgs[0]
img_list = [img[None, :].to(device) for img in imgs]
origin_forward = model.forward
if (model_type == 'det'):
model.forward = partial(
model.simple_test, img_metas=img_metas, rescale=True)
else:
model.forward = partial(
model.forward,
img_metas=img_metas,
return_loss=False,
rescale=True)
# pytorch has some bug in pytorch1.3, we have to fix it
# by replacing these existing op
register_extra_symbolics(opset_version)
dynamic_axes = None
if dynamic_export and model_type == 'det':
dynamic_axes = {
'input': {
0: 'batch',
2: 'height',
3: 'width'
},
'output': {
0: 'batch',
2: 'height',
3: 'width'
}
}
elif dynamic_export and model_type == 'recog':
dynamic_axes = {
'input': {
0: 'batch',
3: 'width'
},
'output': {
0: 'batch',
1: 'seq_len',
2: 'num_classes'
}
}
with torch.no_grad():
torch.onnx.export(
model, (img_list[0], ),
output_file,
input_names=['input'],
output_names=['output'],
export_params=True,
keep_initializers_as_inputs=False,
verbose=verbose,
opset_version=opset_version,
dynamic_axes=dynamic_axes)
print(f'Successfully exported ONNX model: {output_file}')
if verify:
# check by onnx
import onnx
onnx_model = onnx.load(output_file)
onnx.checker.check_model(onnx_model)
scale_factor = (0.5, 0.5) if model_type == 'det' else (1, 0.5)
if dynamic_export:
# scale image for dynamic shape test
img_list = [
nn.functional.interpolate(_, scale_factor=scale_factor)
for _ in img_list
]
if model_type == 'det':
img_metas[0][0][
'scale_factor'] = img_metas[0][0]['scale_factor'] * (
scale_factor * 2)
# check the numerical value
# get pytorch output
with torch.no_grad():
model.forward = origin_forward
pytorch_out = model.simple_test(
img_list[0], img_metas[0], rescale=True)
# get onnx output
if model_type == 'det':
onnx_model = ONNXRuntimeDetector(output_file, model.cfg, device_id)
else:
onnx_model = ONNXRuntimeRecognizer(output_file, model.cfg,
device_id)
onnx_out = onnx_model.simple_test(
img_list[0], img_metas[0], rescale=True)
# compare results
same_diff = 'same'
if model_type == 'recog':
for onnx_result, pytorch_result in zip(onnx_out, pytorch_out):
if onnx_result['text'] != pytorch_result[
'text'] or not np.allclose(
np.array(onnx_result['score']),
np.array(pytorch_result['score']),
rtol=1e-4,
atol=1e-4):
same_diff = 'different'
break
else:
for onnx_result, pytorch_result in zip(
onnx_out[0]['boundary_result'],
pytorch_out[0]['boundary_result']):
if not np.allclose(
np.array(onnx_result),
np.array(pytorch_result),
rtol=1e-4,
atol=1e-4):
same_diff = 'different'
break
print('The outputs are {} between PyTorch and ONNX'.format(same_diff))
if show:
onnx_img = onnx_model.show_result(
img_path, onnx_out[0], out_file='onnx.jpg', show=False)
pytorch_img = model.show_result(
img_path, pytorch_out[0], out_file='pytorch.jpg', show=False)
if onnx_img is None:
onnx_img = cv2.imread(img_path)
if pytorch_img is None:
pytorch_img = cv2.imread(img_path)
cv2.imshow('PyTorch', pytorch_img)
cv2.imshow('ONNXRuntime', onnx_img)
cv2.waitKey()
return
def main():
parser = ArgumentParser(
description='Convert MMOCR models from pytorch to ONNX')
parser.add_argument('model_config', type=str, help='Config file.')
parser.add_argument(
'model_ckpt', type=str, help='Checkpint file (local or url).')
parser.add_argument(
'model_type',
type=str,
help='Detection or recognition model to deploy.',
choices=['recog', 'det'])
parser.add_argument('image_path', type=str, help='Input Image file.')
parser.add_argument(
'--output-file',
type=str,
help='Output file name of the onnx model.',
default='tmp.onnx')
parser.add_argument(
'--device-id', default=0, help='Device used for inference.')
parser.add_argument(
'--opset-version',
type=int,
help='ONNX opset version, default to 11.',
default=11)
parser.add_argument(
'--verify',
action='store_true',
help='Whether verify the outputs of onnx and pytorch are same.',
default=False)
parser.add_argument(
'--verbose',
action='store_true',
help='Whether print the computation graph.',
default=False)
parser.add_argument(
'--show',
action='store_true',
help='Whether visualize final output.',
default=False)
parser.add_argument(
'--dynamic-export',
action='store_true',
help='Whether dynamically export onnx model.',
default=False)
args = parser.parse_args()
# Following strings of text style are from colorama package
bright_style, reset_style = '\x1b[1m', '\x1b[0m'
red_text, blue_text = '\x1b[31m', '\x1b[34m'
white_background = '\x1b[107m'
msg = white_background + bright_style + red_text
msg += 'DeprecationWarning: This tool will be deprecated in future. '
msg += blue_text + 'Welcome to use the unified model deployment toolbox '
msg += 'MMDeploy: https://github.com/open-mmlab/mmdeploy'
msg += reset_style
warnings.warn(msg)
device = torch.device(type='cuda', index=args.device_id)
# build model
model = init_detector(args.model_config, args.model_ckpt, device=device)
if hasattr(model, 'module'):
model = model.module
if model.cfg.data.test.get('pipeline', None) is None:
if is_2dlist(model.cfg.data.test.datasets):
model.cfg.data.test.pipeline = \
model.cfg.data.test.datasets[0][0].pipeline
else:
model.cfg.data.test.pipeline = \
model.cfg.data.test['datasets'][0].pipeline
if is_2dlist(model.cfg.data.test.pipeline):
model.cfg.data.test.pipeline = model.cfg.data.test.pipeline[0]
pytorch2onnx(
model,
model_type=args.model_type,
output_file=args.output_file,
img_path=args.image_path,
opset_version=args.opset_version,
verify=args.verify,
verbose=args.verbose,
show=args.show,
device_id=args.device_id,
dynamic_export=args.dynamic_export)
if __name__ == '__main__':
main()