glenn-jocher pre-commit-ci[bot] commited on
Commit
407a905
1 Parent(s): c1249a4

Check TensorRT>=8.0.0 version (#6021)

Browse files

* Check TensorRT>=8.0.0 version

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Files changed (2) hide show
  1. models/common.py +3 -2
  2. utils/general.py +7 -5
models/common.py CHANGED
@@ -21,8 +21,8 @@ from PIL import Image
21
  from torch.cuda import amp
22
 
23
  from utils.datasets import exif_transpose, letterbox
24
- from utils.general import (LOGGER, check_requirements, check_suffix, colorstr, increment_path, make_divisible,
25
- non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
26
  from utils.plots import Annotator, colors, save_one_box
27
  from utils.torch_utils import copy_attr, time_sync
28
 
@@ -328,6 +328,7 @@ class DetectMultiBackend(nn.Module):
328
  elif engine: # TensorRT
329
  LOGGER.info(f'Loading {w} for TensorRT inference...')
330
  import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
 
331
  Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
332
  logger = trt.Logger(trt.Logger.INFO)
333
  with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
 
21
  from torch.cuda import amp
22
 
23
  from utils.datasets import exif_transpose, letterbox
24
+ from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
25
+ make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
26
  from utils.plots import Annotator, colors, save_one_box
27
  from utils.torch_utils import copy_attr, time_sync
28
 
 
328
  elif engine: # TensorRT
329
  LOGGER.info(f'Loading {w} for TensorRT inference...')
330
  import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
331
+ check_version(trt.__version__, '8.0.0', verbose=True) # version requirement
332
  Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
333
  logger = trt.Logger(trt.Logger.INFO)
334
  with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
utils/general.py CHANGED
@@ -248,14 +248,16 @@ def check_python(minimum='3.6.2'):
248
  check_version(platform.python_version(), minimum, name='Python ', hard=True)
249
 
250
 
251
- def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False):
252
  # Check version vs. required version
253
  current, minimum = (pkg.parse_version(x) for x in (current, minimum))
254
  result = (current == minimum) if pinned else (current >= minimum) # bool
255
- if hard: # assert min requirements met
256
- assert result, f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed'
257
- else:
258
- return result
 
 
259
 
260
 
261
  @try_except
 
248
  check_version(platform.python_version(), minimum, name='Python ', hard=True)
249
 
250
 
251
+ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
252
  # Check version vs. required version
253
  current, minimum = (pkg.parse_version(x) for x in (current, minimum))
254
  result = (current == minimum) if pinned else (current >= minimum) # bool
255
+ s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
256
+ if hard:
257
+ assert result, s # assert min requirements met
258
+ if verbose and not result:
259
+ LOGGER.warning(s)
260
+ return result
261
 
262
 
263
  @try_except