imyhxy commited on
Commit
a4207a2
1 Parent(s): 5ca5dd4

Fix TensorRT potential unordered binding addresses (#5826)

Browse files

* feat: change file suffix in pythonic way

* fix: enforce binding addresses order

* fix: enforce binding addresses order

Files changed (2) hide show
  1. export.py +2 -1
  2. models/common.py +3 -3
export.py CHANGED
@@ -276,7 +276,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
276
  assert onnx.exists(), f'failed to export ONNX file: {onnx}'
277
 
278
  LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
279
- f = str(file).replace('.pt', '.engine') # TensorRT engine file
280
  logger = trt.Logger(trt.Logger.INFO)
281
  if verbose:
282
  logger.min_severity = trt.Logger.Severity.VERBOSE
@@ -310,6 +310,7 @@ def export_engine(model, im, file, train, half, simplify, workspace=4, verbose=F
310
  except Exception as e:
311
  LOGGER.info(f'\n{prefix} export failure: {e}')
312
 
 
313
  @torch.no_grad()
314
  def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
315
  weights=ROOT / 'yolov5s.pt', # weights path
 
276
  assert onnx.exists(), f'failed to export ONNX file: {onnx}'
277
 
278
  LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
279
+ f = file.with_suffix('.engine') # TensorRT engine file
280
  logger = trt.Logger(trt.Logger.INFO)
281
  if verbose:
282
  logger.min_severity = trt.Logger.Severity.VERBOSE
 
310
  except Exception as e:
311
  LOGGER.info(f'\n{prefix} export failure: {e}')
312
 
313
+
314
  @torch.no_grad()
315
  def run(data=ROOT / 'data/coco128.yaml', # 'dataset.yaml path'
316
  weights=ROOT / 'yolov5s.pt', # weights path
models/common.py CHANGED
@@ -7,7 +7,7 @@ import json
7
  import math
8
  import platform
9
  import warnings
10
- from collections import namedtuple
11
  from copy import copy
12
  from pathlib import Path
13
 
@@ -326,14 +326,14 @@ class DetectMultiBackend(nn.Module):
326
  logger = trt.Logger(trt.Logger.INFO)
327
  with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
328
  model = runtime.deserialize_cuda_engine(f.read())
329
- bindings = dict()
330
  for index in range(model.num_bindings):
331
  name = model.get_binding_name(index)
332
  dtype = trt.nptype(model.get_binding_dtype(index))
333
  shape = tuple(model.get_binding_shape(index))
334
  data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
335
  bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
336
- binding_addrs = {n: d.ptr for n, d in bindings.items()}
337
  context = model.create_execution_context()
338
  batch_size = bindings['images'].shape[0]
339
  else: # TensorFlow model (TFLite, pb, saved_model)
 
7
  import math
8
  import platform
9
  import warnings
10
+ from collections import OrderedDict, namedtuple
11
  from copy import copy
12
  from pathlib import Path
13
 
 
326
  logger = trt.Logger(trt.Logger.INFO)
327
  with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
328
  model = runtime.deserialize_cuda_engine(f.read())
329
+ bindings = OrderedDict()
330
  for index in range(model.num_bindings):
331
  name = model.get_binding_name(index)
332
  dtype = trt.nptype(model.get_binding_dtype(index))
333
  shape = tuple(model.get_binding_shape(index))
334
  data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
335
  bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
336
+ binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
337
  context = model.create_execution_context()
338
  batch_size = bindings['images'].shape[0]
339
  else: # TensorFlow model (TFLite, pb, saved_model)