Commit
•
407a905
1
Parent(s):
c1249a4
Check TensorRT>=8.0.0 version (#6021)
Browse files* Check TensorRT>=8.0.0 version
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- models/common.py +3 -2
- utils/general.py +7 -5
models/common.py
CHANGED
@@ -21,8 +21,8 @@ from PIL import Image
|
|
21 |
from torch.cuda import amp
|
22 |
|
23 |
from utils.datasets import exif_transpose, letterbox
|
24 |
-
from utils.general import (LOGGER, check_requirements, check_suffix, colorstr, increment_path,
|
25 |
-
non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
|
26 |
from utils.plots import Annotator, colors, save_one_box
|
27 |
from utils.torch_utils import copy_attr, time_sync
|
28 |
|
@@ -328,6 +328,7 @@ class DetectMultiBackend(nn.Module):
|
|
328 |
elif engine: # TensorRT
|
329 |
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
330 |
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
|
|
331 |
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
332 |
logger = trt.Logger(trt.Logger.INFO)
|
333 |
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
|
|
21 |
from torch.cuda import amp
|
22 |
|
23 |
from utils.datasets import exif_transpose, letterbox
|
24 |
+
from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
|
25 |
+
make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
|
26 |
from utils.plots import Annotator, colors, save_one_box
|
27 |
from utils.torch_utils import copy_attr, time_sync
|
28 |
|
|
|
328 |
elif engine: # TensorRT
|
329 |
LOGGER.info(f'Loading {w} for TensorRT inference...')
|
330 |
import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
|
331 |
+
check_version(trt.__version__, '8.0.0', verbose=True) # version requirement
|
332 |
Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
|
333 |
logger = trt.Logger(trt.Logger.INFO)
|
334 |
with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
|
utils/general.py
CHANGED
@@ -248,14 +248,16 @@ def check_python(minimum='3.6.2'):
|
|
248 |
check_version(platform.python_version(), minimum, name='Python ', hard=True)
|
249 |
|
250 |
|
251 |
-
def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False):
|
252 |
# Check version vs. required version
|
253 |
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
254 |
result = (current == minimum) if pinned else (current >= minimum) # bool
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
|
|
|
|
259 |
|
260 |
|
261 |
@try_except
|
|
|
248 |
check_version(platform.python_version(), minimum, name='Python ', hard=True)
|
249 |
|
250 |
|
251 |
+
def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=False, hard=False, verbose=False):
|
252 |
# Check version vs. required version
|
253 |
current, minimum = (pkg.parse_version(x) for x in (current, minimum))
|
254 |
result = (current == minimum) if pinned else (current >= minimum) # bool
|
255 |
+
s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
|
256 |
+
if hard:
|
257 |
+
assert result, s # assert min requirements met
|
258 |
+
if verbose and not result:
|
259 |
+
LOGGER.warning(s)
|
260 |
+
return result
|
261 |
|
262 |
|
263 |
@try_except
|