glenn-jocher commited on
Commit
3bef77f
1 Parent(s): 442a7ab

Addition refactor `export.py` (#4089)

Browse files

* Addition refactor `export.py`

* Update export.py

Files changed (1) hide show
  1. export.py +10 -9
export.py CHANGED
@@ -45,7 +45,7 @@ def export_onnx(model, img, file, opset_version, train, dynamic, simplify):
45
  check_requirements(('onnx', 'onnx-simplifier'))
46
  import onnx
47
 
48
- print(f'{prefix} starting export with onnx {onnx.__version__}...')
49
  f = file.with_suffix('.onnx')
50
  torch.onnx.export(model, img, f, verbose=False, opset_version=opset_version,
51
  training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
@@ -80,16 +80,17 @@ def export_onnx(model, img, file, opset_version, train, dynamic, simplify):
80
  print(f'{prefix} export failure: {e}')
81
 
82
 
83
- def export_coreml(ts_model, img, file, train):
84
  # CoreML model export
85
  prefix = colorstr('CoreML:')
86
  try:
87
  import coremltools as ct
88
 
89
- print(f'{prefix} starting export with coremltools {ct.__version__}...')
90
  f = file.with_suffix('.mlmodel')
91
- assert train, 'CoreML exports should be placed in model.train() mode with `python export.py --train`'
92
- model = ct.convert(ts_model, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
 
93
  model.save(f)
94
  print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
95
  except Exception as e:
@@ -145,12 +146,12 @@ def run(weights='./yolov5s.pt', # weights path
145
  print(f"\n{colorstr('PyTorch:')} starting from {weights} ({file_size(weights):.1f} MB)")
146
 
147
  # Exports
 
 
148
  if 'onnx' in include:
149
  export_onnx(model, img, file, opset_version, train, dynamic, simplify)
150
- if 'torchscript' in include or 'coreml' in include:
151
- ts = export_torchscript(model, img, file, optimize)
152
- if 'coreml' in include:
153
- export_coreml(ts, img, file, train)
154
 
155
  # Finish
156
  print(f'\nExport complete ({time.time() - t:.2f}s). Visualize with https://github.com/lutzroeder/netron.')
 
45
  check_requirements(('onnx', 'onnx-simplifier'))
46
  import onnx
47
 
48
+ print(f'\n{prefix} starting export with onnx {onnx.__version__}...')
49
  f = file.with_suffix('.onnx')
50
  torch.onnx.export(model, img, f, verbose=False, opset_version=opset_version,
51
  training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
 
80
  print(f'{prefix} export failure: {e}')
81
 
82
 
83
+ def export_coreml(model, img, file):
84
  # CoreML model export
85
  prefix = colorstr('CoreML:')
86
  try:
87
  import coremltools as ct
88
 
89
+ print(f'\n{prefix} starting export with coremltools {ct.__version__}...')
90
  f = file.with_suffix('.mlmodel')
91
+ model.train() # CoreML exports should be placed in model.train() mode
92
+ ts = torch.jit.trace(model, img, strict=False) # TorchScript model
93
+ model = ct.convert(ts, inputs=[ct.ImageType('image', shape=img.shape, scale=1 / 255.0, bias=[0, 0, 0])])
94
  model.save(f)
95
  print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
96
  except Exception as e:
 
146
  print(f"\n{colorstr('PyTorch:')} starting from {weights} ({file_size(weights):.1f} MB)")
147
 
148
  # Exports
149
+ if 'torchscript' in include:
150
+ export_torchscript(model, img, file, optimize)
151
  if 'onnx' in include:
152
  export_onnx(model, img, file, opset_version, train, dynamic, simplify)
153
+ if 'coreml' in include:
154
+ export_coreml(model, img, file)
 
 
155
 
156
  # Finish
157
  print(f'\nExport complete ({time.time() - t:.2f}s). Visualize with https://github.com/lutzroeder/netron.')