glenn-jocher commited on
Commit
7b31a53
1 Parent(s): a2f4a17

Add `tensorrt>=7.0.0` checks (#6193)

Browse files

* Add `tensorrt>=7.0.0` checks

* Update export.py

* Update common.py

* Update export.py

Files changed (2) hide show
  1. export.py +6 -6
  2. models/common.py +1 -1
export.py CHANGED
@@ -61,8 +61,8 @@ from models.experimental import attempt_load
61
  from models.yolo import Detect
62
  from utils.activations import SiLU
63
  from utils.datasets import LoadImages
64
- from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, colorstr, file_size, print_args,
65
- url2file)
66
  from utils.torch_utils import select_device
67
 
68
 
@@ -174,14 +174,14 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
174
  check_requirements(('tensorrt',))
175
  import tensorrt as trt
176
 
177
- opset = (12, 13)[trt.__version__[0] == '8'] # test on TensorRT 7.x and 8.x
178
- if opset == 12: # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
179
  grid = model.model[-1].anchor_grid
180
  model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
181
- export_onnx(model, im, file, opset, train, False, simplify)
182
  model.model[-1].anchor_grid = grid
183
  else: # TensorRT >= 8
184
- export_onnx(model, im, file, opset, train, False, simplify)
 
185
  onnx = file.with_suffix('.onnx')
186
  assert onnx.exists(), f'failed to export ONNX file: {onnx}'
187
 
 
61
  from models.yolo import Detect
62
  from utils.activations import SiLU
63
  from utils.datasets import LoadImages
64
+ from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, colorstr,
65
+ file_size, print_args, url2file)
66
  from utils.torch_utils import select_device
67
 
68
 
 
174
  check_requirements(('tensorrt',))
175
  import tensorrt as trt
176
 
177
+ if trt.__version__[0] == 7: # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
 
178
  grid = model.model[-1].anchor_grid
179
  model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
180
+ export_onnx(model, im, file, 12, train, False, simplify) # opset 12
181
  model.model[-1].anchor_grid = grid
182
  else: # TensorRT >= 8
183
+ check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=8.0.0
184
+ export_onnx(model, im, file, 13, train, False, simplify) # opset 13
185
  onnx = file.with_suffix('.onnx')
186
  assert onnx.exists(), f'failed to export ONNX file: {onnx}'
187
 
models/common.py CHANGED
@@ -337,7 +337,7 @@ class DetectMultiBackend(nn.Module):
337
  elif engine: # TensorRT
338
  LOGGER.info(f'Loading {w} for TensorRT inference...')
339
  import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
340
- check_version(trt.__version__, '8.0.0', verbose=True) # version requirement
341
  Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
342
  logger = trt.Logger(trt.Logger.INFO)
343
  with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
 
337
  elif engine: # TensorRT
338
  LOGGER.info(f'Loading {w} for TensorRT inference...')
339
  import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
340
+ check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
341
  Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
342
  logger = trt.Logger(trt.Logger.INFO)
343
  with open(w, 'rb') as f, trt.Runtime(logger) as runtime: