glenn-jocher commited on
Commit
1f3e482
1 Parent(s): e5d7122

ONNX Simplifier (#2815)

Browse files

* ONNX Simplifier

Add ONNX Simplifier to ONNX export pipeline in export.py. Will auto-install onnx-simplifier if onnx is installed but onnx-simplifier is not.

* Update general.py

Files changed (2) hide show
  1. models/export.py +30 -15
  2. utils/general.py +1 -1
models/export.py CHANGED
@@ -1,7 +1,7 @@
1
  """Exports a YOLOv5 *.pt model to ONNX and TorchScript formats
2
 
3
  Usage:
4
- $ export PYTHONPATH="$PWD" && python models/export.py --weights ./weights/yolov5s.pt --img 640 --batch 1
5
  """
6
 
7
  import argparse
@@ -16,7 +16,7 @@ import torch.nn as nn
16
  import models
17
  from models.experimental import attempt_load
18
  from utils.activations import Hardswish, SiLU
19
- from utils.general import set_logging, check_img_size
20
  from utils.torch_utils import select_device
21
 
22
  if __name__ == '__main__':
@@ -59,20 +59,22 @@ if __name__ == '__main__':
59
  y = model(img) # dry run
60
 
61
  # TorchScript export
 
62
  try:
63
- print('\nStarting TorchScript export with torch %s...' % torch.__version__)
64
  f = opt.weights.replace('.pt', '.torchscript.pt') # filename
65
  ts = torch.jit.trace(model, img, strict=False)
66
  ts.save(f)
67
- print('TorchScript export success, saved as %s' % f)
68
  except Exception as e:
69
- print('TorchScript export failure: %s' % e)
70
 
71
  # ONNX export
 
72
  try:
73
  import onnx
74
 
75
- print('\nStarting ONNX export with onnx %s...' % onnx.__version__)
76
  f = opt.weights.replace('.pt', '.onnx') # filename
77
  torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
78
  output_names=['classes', 'boxes'] if y is None else ['output'],
@@ -80,25 +82,38 @@ if __name__ == '__main__':
80
  'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)
81
 
82
  # Checks
83
- onnx_model = onnx.load(f) # load onnx model
84
- onnx.checker.check_model(onnx_model) # check onnx model
85
- # print(onnx.helper.printable_graph(onnx_model.graph)) # print a human readable model
86
- print('ONNX export success, saved as %s' % f)
 
 
 
 
 
 
 
 
 
 
 
 
87
  except Exception as e:
88
- print('ONNX export failure: %s' % e)
89
 
90
  # CoreML export
 
91
  try:
92
  import coremltools as ct
93
 
94
- print('\nStarting CoreML export with coremltools %s...' % ct.__version__)
95
  # convert model from torchscript and apply pixel scaling as per detect.py
96
  model = ct.convert(ts, inputs=[ct.ImageType(name='image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
97
  f = opt.weights.replace('.pt', '.mlmodel') # filename
98
  model.save(f)
99
- print('CoreML export success, saved as %s' % f)
100
  except Exception as e:
101
- print('CoreML export failure: %s' % e)
102
 
103
  # Finish
104
- print('\nExport complete (%.2fs). Visualize with https://github.com/lutzroeder/netron.' % (time.time() - t))
 
1
  """Exports a YOLOv5 *.pt model to ONNX and TorchScript formats
2
 
3
  Usage:
4
+ $ export PYTHONPATH="$PWD" && python models/export.py --weights yolov5s.pt --img 640 --batch 1
5
  """
6
 
7
  import argparse
 
16
  import models
17
  from models.experimental import attempt_load
18
  from utils.activations import Hardswish, SiLU
19
+ from utils.general import colorstr, check_img_size, check_requirements, set_logging
20
  from utils.torch_utils import select_device
21
 
22
  if __name__ == '__main__':
 
59
  y = model(img) # dry run
60
 
61
  # TorchScript export
62
+ prefix = colorstr('TorchScript:')
63
  try:
64
+ print(f'\n{prefix} starting export with torch {torch.__version__}...')
65
  f = opt.weights.replace('.pt', '.torchscript.pt') # filename
66
  ts = torch.jit.trace(model, img, strict=False)
67
  ts.save(f)
68
+ print(f'{prefix} export success, saved as {f}')
69
  except Exception as e:
70
+ print(f'{prefix} export failure: {e}')
71
 
72
  # ONNX export
73
+ prefix = colorstr('ONNX:')
74
  try:
75
  import onnx
76
 
77
+ print(f'{prefix} starting export with onnx {onnx.__version__}...')
78
  f = opt.weights.replace('.pt', '.onnx') # filename
79
  torch.onnx.export(model, img, f, verbose=False, opset_version=12, input_names=['images'],
80
  output_names=['classes', 'boxes'] if y is None else ['output'],
 
82
  'output': {0: 'batch', 2: 'y', 3: 'x'}} if opt.dynamic else None)
83
 
84
  # Checks
85
+ model_onnx = onnx.load(f) # load onnx model
86
+ onnx.checker.check_model(model_onnx) # check onnx model
87
+ # print(onnx.helper.printable_graph(model_onnx.graph)) # print
88
+
89
+ # Simplify
90
+ try:
91
+ check_requirements(['onnx-simplifier'])
92
+ import onnxsim
93
+
94
+ print(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
95
+ model_onnx, check = onnxsim.simplify(model_onnx)
96
+ assert check, 'assert check failed'
97
+ onnx.save(model_onnx, f)
98
+ except Exception as e:
99
+ print(f'{prefix} simplifier failure: {e}')
100
+ print(f'{prefix} export success, saved as {f}')
101
  except Exception as e:
102
+ print(f'{prefix} export failure: {e}')
103
 
104
  # CoreML export
105
+ prefix = colorstr('CoreML:')
106
  try:
107
  import coremltools as ct
108
 
109
+ print(f'{prefix} starting export with coremltools {onnx.__version__}...')
110
  # convert model from torchscript and apply pixel scaling as per detect.py
111
  model = ct.convert(ts, inputs=[ct.ImageType(name='image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
112
  f = opt.weights.replace('.pt', '.mlmodel') # filename
113
  model.save(f)
114
+ print(f'{prefix} export success, saved as {f}')
115
  except Exception as e:
116
+ print(f'{prefix} export failure: {e}')
117
 
118
  # Finish
119
+ print(f'\nExport complete ({time.time() - t:.2f}s). Visualize with https://github.com/lutzroeder/netron.')
utils/general.py CHANGED
@@ -111,7 +111,7 @@ def check_requirements(requirements='requirements.txt', exclude=()):
111
  except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
112
  n += 1
113
  print(f"{prefix} {e.req} not found and is required by YOLOv5, attempting auto-update...")
114
- print(subprocess.check_output(f"pip install '{e.req}'", shell=True).decode())
115
 
116
  if n: # if packages updated
117
  source = file.resolve() if 'file' in locals() else requirements
 
111
  except Exception as e: # DistributionNotFound or VersionConflict if requirements not met
112
  n += 1
113
  print(f"{prefix} {e.req} not found and is required by YOLOv5, attempting auto-update...")
114
+ print(subprocess.check_output(f"pip install {e.req}", shell=True).decode())
115
 
116
  if n: # if packages updated
117
  source = file.resolve() if 'file' in locals() else requirements