imyhxy glenn-jocher commited on
Commit
a2f4a17
1 Parent(s): fb83929

TensorRT 7 `anchor_grid` compatibility fix (#6185)

Browse files

* fix: TensorRT 7 incompatiable

* Add comment

* Add if: else and comment

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Files changed (1) hide show
  1. export.py +9 -3
export.py CHANGED
@@ -175,7 +175,13 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
175
  import tensorrt as trt
176
 
177
  opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x
178
- export_onnx(model, im, file, opset, train, False, simplify)
 
 
 
 
 
 
179
  onnx = file.with_suffix('.onnx')
180
  assert onnx.exists(), f'failed to export ONNX file: {onnx}'
181
 
@@ -418,12 +424,12 @@ def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
418
  # Exports
419
  if 'torchscript' in include:
420
  export_torchscript(model, im, file, optimize)
 
 
421
  if ('onnx' in include) or ('openvino' in include): # OpenVINO requires ONNX
422
  export_onnx(model, im, file, opset, train, dynamic, simplify)
423
  if 'openvino' in include:
424
  export_openvino(model, im, file)
425
- if 'engine' in include:
426
- export_engine(model, im, file, train, half, simplify, workspace, verbose)
427
  if 'coreml' in include:
428
  export_coreml(model, im, file)
429
 
 
175
  import tensorrt as trt
176
 
177
  opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x
178
+ if opset == 12: # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
179
+ grid = model.model[-1].anchor_grid
180
+ model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
181
+ export_onnx(model, im, file, opset, train, False, simplify)
182
+ model.model[-1].anchor_grid = grid
183
+ else: # TensorRT >= 8
184
+ export_onnx(model, im, file, opset, train, False, simplify)
185
  onnx = file.with_suffix('.onnx')
186
  assert onnx.exists(), f'failed to export ONNX file: {onnx}'
187
 
 
424
  # Exports
425
  if 'torchscript' in include:
426
  export_torchscript(model, im, file, optimize)
427
+ if 'engine' in include: # TensorRT required before ONNX
428
+ export_engine(model, im, file, train, half, simplify, workspace, verbose)
429
  if ('onnx' in include) or ('openvino' in include): # OpenVINO requires ONNX
430
  export_onnx(model, im, file, opset, train, dynamic, simplify)
431
  if 'openvino' in include:
432
  export_openvino(model, im, file)
 
 
433
  if 'coreml' in include:
434
  export_coreml(model, im, file)
435