glenn-jocher pre-commit-ci[bot] commited on
Commit
3883261
1 Parent(s): 79bca2b

New `DetectMultiBackend()` class (#5549)

Browse files

* New `DetectMultiBackend()` class

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* pb to pt fix

* Cleanup

* explicit apply_classifier path

* Cleanup2

* Cleanup3

* Cleanup4

* Cleanup5

* Cleanup6

* val.py MultiBackend inference

* warmup fix

* to device fix

* pt fix

* device fix

* Val cleanup

* COCO128 URL to assets

* half fix

* detect fix

* detect fix 2

* remove half from DetectMultiBackend

* training half handling

* training half handling 2

* training half handling 3

* Cleanup

* Fix CI error

* Add torchscript _extra_files

* Add TorchScript

* Add CoreML

* CoreML cleanup

* New `DetectMultiBackend()` class

* pb to pt fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Cleanup

* explicit apply_classifier path

* Cleanup2

* Cleanup3

* Cleanup4

* Cleanup5

* Cleanup6

* val.py MultiBackend inference

* warmup fix

* to device fix

* pt fix

* device fix

* Val cleanup

* COCO128 URL to assets

* half fix

* detect fix

* detect fix 2

* remove half from DetectMultiBackend

* training half handling

* training half handling 2

* training half handling 3

* Cleanup

* Fix CI error

* Add torchscript _extra_files

* Add TorchScript

* Add CoreML

* CoreML cleanup

* revert default to pt

* Add Usage examples

* Cleanup val

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Files changed (7) hide show
  1. data/coco128.yaml +1 -1
  2. detect.py +27 -106
  3. export.py +4 -1
  4. models/common.py +127 -1
  5. utils/general.py +2 -1
  6. utils/torch_utils.py +0 -20
  7. val.py +39 -35
data/coco128.yaml CHANGED
@@ -27,4 +27,4 @@ names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 't
27
 
28
 
29
  # Download script/URL (optional)
30
- download: https://github.com/ultralytics/yolov5/releases/download/v1.0/coco128.zip
 
27
 
28
 
29
  # Download script/URL (optional)
30
+ download: https://ultralytics.com/assets/coco128.zip
detect.py CHANGED
@@ -14,12 +14,10 @@ Usage:
14
 
15
  import argparse
16
  import os
17
- import platform
18
  import sys
19
  from pathlib import Path
20
 
21
  import cv2
22
- import numpy as np
23
  import torch
24
  import torch.backends.cudnn as cudnn
25
 
@@ -29,13 +27,12 @@ if str(ROOT) not in sys.path:
29
  sys.path.append(str(ROOT)) # add ROOT to PATH
30
  ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
31
 
32
- from models.experimental import attempt_load
33
  from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
34
- from utils.general import (LOGGER, apply_classifier, check_file, check_img_size, check_imshow, check_requirements,
35
- check_suffix, colorstr, increment_path, non_max_suppression, print_args, scale_coords,
36
- strip_optimizer, xyxy2xywh)
37
  from utils.plots import Annotator, colors, save_one_box
38
- from utils.torch_utils import load_classifier, select_device, time_sync
39
 
40
 
41
  @torch.no_grad()
@@ -77,120 +74,45 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
77
  save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
78
  (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
79
 
80
- # Initialize
81
  device = select_device(device)
82
- half &= device.type != 'cpu' # half precision only supported on CUDA
 
 
83
 
84
- # Load model
85
- w = str(weights[0] if isinstance(weights, list) else weights)
86
- classify, suffix, suffixes = False, Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '']
87
- check_suffix(w, suffixes) # check weights have acceptable suffix
88
- pt, onnx, tflite, pb, saved_model = (suffix == x for x in suffixes) # backend booleans
89
- stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
90
  if pt:
91
- model = torch.jit.load(w) if 'torchscript' in w else attempt_load(weights, map_location=device)
92
- stride = int(model.stride.max()) # model stride
93
- names = model.module.names if hasattr(model, 'module') else model.names # get class names
94
- if half:
95
- model.half() # to FP16
96
- if classify: # second-stage classifier
97
- modelc = load_classifier(name='resnet50', n=2) # initialize
98
- modelc.load_state_dict(torch.load('resnet50.pt', map_location=device)['model']).to(device).eval()
99
- elif onnx:
100
- if dnn:
101
- check_requirements(('opencv-python>=4.5.4',))
102
- net = cv2.dnn.readNetFromONNX(w)
103
- else:
104
- check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
105
- import onnxruntime
106
- session = onnxruntime.InferenceSession(w, None)
107
- else: # TensorFlow models
108
- import tensorflow as tf
109
- if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
110
- def wrap_frozen_graph(gd, inputs, outputs):
111
- x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped import
112
- return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
113
- tf.nest.map_structure(x.graph.as_graph_element, outputs))
114
-
115
- graph_def = tf.Graph().as_graph_def()
116
- graph_def.ParseFromString(open(w, 'rb').read())
117
- frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
118
- elif saved_model:
119
- model = tf.keras.models.load_model(w)
120
- elif tflite:
121
- if "edgetpu" in w: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
122
- import tflite_runtime.interpreter as tflri
123
- delegate = {'Linux': 'libedgetpu.so.1', # install libedgetpu https://coral.ai/software/#edgetpu-runtime
124
- 'Darwin': 'libedgetpu.1.dylib',
125
- 'Windows': 'edgetpu.dll'}[platform.system()]
126
- interpreter = tflri.Interpreter(model_path=w, experimental_delegates=[tflri.load_delegate(delegate)])
127
- else:
128
- interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
129
- interpreter.allocate_tensors() # allocate
130
- input_details = interpreter.get_input_details() # inputs
131
- output_details = interpreter.get_output_details() # outputs
132
- int8 = input_details[0]['dtype'] == np.uint8 # is TFLite quantized uint8 model
133
- imgsz = check_img_size(imgsz, s=stride) # check image size
134
 
135
  # Dataloader
136
  if webcam:
137
  view_img = check_imshow()
138
  cudnn.benchmark = True # set True to speed up constant image size inference
139
- dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt)
140
  bs = len(dataset) # batch_size
141
  else:
142
- dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt)
143
  bs = 1 # batch_size
144
  vid_path, vid_writer = [None] * bs, [None] * bs
145
 
146
  # Run inference
147
  if pt and device.type != 'cpu':
148
- model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.parameters()))) # run once
149
  dt, seen = [0.0, 0.0, 0.0], 0
150
- for path, img, im0s, vid_cap, s in dataset:
151
  t1 = time_sync()
152
- if onnx:
153
- img = img.astype('float32')
154
- else:
155
- img = torch.from_numpy(img).to(device)
156
- img = img.half() if half else img.float() # uint8 to fp16/32
157
- img /= 255 # 0 - 255 to 0.0 - 1.0
158
- if len(img.shape) == 3:
159
- img = img[None] # expand for batch dim
160
  t2 = time_sync()
161
  dt[0] += t2 - t1
162
 
163
  # Inference
164
- if pt:
165
- visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
166
- pred = model(img, augment=augment, visualize=visualize)[0]
167
- elif onnx:
168
- if dnn:
169
- net.setInput(img)
170
- pred = torch.tensor(net.forward())
171
- else:
172
- pred = torch.tensor(session.run([session.get_outputs()[0].name], {session.get_inputs()[0].name: img}))
173
- else: # tensorflow model (tflite, pb, saved_model)
174
- imn = img.permute(0, 2, 3, 1).cpu().numpy() # image in numpy
175
- if pb:
176
- pred = frozen_func(x=tf.constant(imn)).numpy()
177
- elif saved_model:
178
- pred = model(imn, training=False).numpy()
179
- elif tflite:
180
- if int8:
181
- scale, zero_point = input_details[0]['quantization']
182
- imn = (imn / scale + zero_point).astype(np.uint8) # de-scale
183
- interpreter.set_tensor(input_details[0]['index'], imn)
184
- interpreter.invoke()
185
- pred = interpreter.get_tensor(output_details[0]['index'])
186
- if int8:
187
- scale, zero_point = output_details[0]['quantization']
188
- pred = (pred.astype(np.float32) - zero_point) * scale # re-scale
189
- pred[..., 0] *= imgsz[1] # x
190
- pred[..., 1] *= imgsz[0] # y
191
- pred[..., 2] *= imgsz[1] # w
192
- pred[..., 3] *= imgsz[0] # h
193
- pred = torch.tensor(pred)
194
  t3 = time_sync()
195
  dt[1] += t3 - t2
196
 
@@ -199,8 +121,7 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
199
  dt[2] += time_sync() - t3
200
 
201
  # Second-stage classifier (optional)
202
- if classify:
203
- pred = apply_classifier(pred, modelc, img, im0s)
204
 
205
  # Process predictions
206
  for i, det in enumerate(pred): # per image
@@ -212,15 +133,15 @@ def run(weights=ROOT / 'yolov5s.pt', # model.pt path(s)
212
  p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
213
 
214
  p = Path(p) # to Path
215
- save_path = str(save_dir / p.name) # img.jpg
216
- txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # img.txt
217
- s += '%gx%g ' % img.shape[2:] # print string
218
  gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
219
  imc = im0.copy() if save_crop else im0 # for save_crop
220
  annotator = Annotator(im0, line_width=line_thickness, example=str(names))
221
  if len(det):
222
  # Rescale boxes from img_size to im0 size
223
- det[:, :4] = scale_coords(img.shape[2:], det[:, :4], im0.shape).round()
224
 
225
  # Print results
226
  for c in det[:, -1].unique():
 
14
 
15
  import argparse
16
  import os
 
17
  import sys
18
  from pathlib import Path
19
 
20
  import cv2
 
21
  import torch
22
  import torch.backends.cudnn as cudnn
23
 
 
27
  sys.path.append(str(ROOT)) # add ROOT to PATH
28
  ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
29
 
30
+ from models.common import DetectMultiBackend
31
  from utils.datasets import IMG_FORMATS, VID_FORMATS, LoadImages, LoadStreams
32
+ from utils.general import (LOGGER, check_file, check_img_size, check_imshow, check_requirements, colorstr,
33
+ increment_path, non_max_suppression, print_args, scale_coords, strip_optimizer, xyxy2xywh)
 
34
  from utils.plots import Annotator, colors, save_one_box
35
+ from utils.torch_utils import select_device, time_sync
36
 
37
 
38
  @torch.no_grad()
 
74
  save_dir = increment_path(Path(project) / name, exist_ok=exist_ok) # increment run
75
  (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
76
 
77
+ # Load model
78
  device = select_device(device)
79
+ model = DetectMultiBackend(weights, device=device, dnn=dnn)
80
+ stride, names, pt, jit, onnx = model.stride, model.names, model.pt, model.jit, model.onnx
81
+ imgsz = check_img_size(imgsz, s=stride) # check image size
82
 
83
+ # Half
84
+ half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
 
 
 
 
85
  if pt:
86
+ model.model.half() if half else model.model.float()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # Dataloader
89
  if webcam:
90
  view_img = check_imshow()
91
  cudnn.benchmark = True # set True to speed up constant image size inference
92
+ dataset = LoadStreams(source, img_size=imgsz, stride=stride, auto=pt and not jit)
93
  bs = len(dataset) # batch_size
94
  else:
95
+ dataset = LoadImages(source, img_size=imgsz, stride=stride, auto=pt and not jit)
96
  bs = 1 # batch_size
97
  vid_path, vid_writer = [None] * bs, [None] * bs
98
 
99
  # Run inference
100
  if pt and device.type != 'cpu':
101
+ model(torch.zeros(1, 3, *imgsz).to(device).type_as(next(model.model.parameters()))) # warmup
102
  dt, seen = [0.0, 0.0, 0.0], 0
103
+ for path, im, im0s, vid_cap, s in dataset:
104
  t1 = time_sync()
105
+ im = torch.from_numpy(im).to(device)
106
+ im = im.half() if half else im.float() # uint8 to fp16/32
107
+ im /= 255 # 0 - 255 to 0.0 - 1.0
108
+ if len(im.shape) == 3:
109
+ im = im[None] # expand for batch dim
 
 
 
110
  t2 = time_sync()
111
  dt[0] += t2 - t1
112
 
113
  # Inference
114
+ visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
115
+ pred = model(im, augment=augment, visualize=visualize)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  t3 = time_sync()
117
  dt[1] += t3 - t2
118
 
 
121
  dt[2] += time_sync() - t3
122
 
123
  # Second-stage classifier (optional)
124
+ # pred = utils.general.apply_classifier(pred, classifier_model, im, im0s)
 
125
 
126
  # Process predictions
127
  for i, det in enumerate(pred): # per image
 
133
  p, im0, frame = path, im0s.copy(), getattr(dataset, 'frame', 0)
134
 
135
  p = Path(p) # to Path
136
+ save_path = str(save_dir / p.name) # im.jpg
137
+ txt_path = str(save_dir / 'labels' / p.stem) + ('' if dataset.mode == 'image' else f'_{frame}') # im.txt
138
+ s += '%gx%g ' % im.shape[2:] # print string
139
  gn = torch.tensor(im0.shape)[[1, 0, 1, 0]] # normalization gain whwh
140
  imc = im0.copy() if save_crop else im0 # for save_crop
141
  annotator = Annotator(im0, line_width=line_thickness, example=str(names))
142
  if len(det):
143
  # Rescale boxes from img_size to im0 size
144
+ det[:, :4] = scale_coords(im.shape[2:], det[:, :4], im0.shape).round()
145
 
146
  # Print results
147
  for c in det[:, -1].unique():
export.py CHANGED
@@ -21,6 +21,7 @@ TensorFlow.js:
21
  """
22
 
23
  import argparse
 
24
  import os
25
  import subprocess
26
  import sys
@@ -54,7 +55,9 @@ def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:'
54
  f = file.with_suffix('.torchscript.pt')
55
 
56
  ts = torch.jit.trace(model, im, strict=False)
57
- (optimize_for_mobile(ts) if optimize else ts).save(f)
 
 
58
 
59
  LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
60
  except Exception as e:
 
21
  """
22
 
23
  import argparse
24
+ import json
25
  import os
26
  import subprocess
27
  import sys
 
55
  f = file.with_suffix('.torchscript.pt')
56
 
57
  ts = torch.jit.trace(model, im, strict=False)
58
+ d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
59
+ extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
60
+ (optimize_for_mobile(ts) if optimize else ts).save(f, _extra_files=extra_files)
61
 
62
  LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
63
  except Exception as e:
models/common.py CHANGED
@@ -3,11 +3,14 @@
3
  Common modules
4
  """
5
 
 
6
  import math
 
7
  import warnings
8
  from copy import copy
9
  from pathlib import Path
10
 
 
11
  import numpy as np
12
  import pandas as pd
13
  import requests
@@ -17,7 +20,8 @@ from PIL import Image
17
  from torch.cuda import amp
18
 
19
  from utils.datasets import exif_transpose, letterbox
20
- from utils.general import LOGGER, colorstr, increment_path, make_divisible, non_max_suppression, scale_coords, xyxy2xywh
 
21
  from utils.plots import Annotator, colors, save_one_box
22
  from utils.torch_utils import time_sync
23
 
@@ -269,6 +273,128 @@ class Concat(nn.Module):
269
  return torch.cat(x, self.d)
270
 
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  class AutoShape(nn.Module):
273
  # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
274
  conf = 0.25 # NMS confidence threshold
 
3
  Common modules
4
  """
5
 
6
+ import json
7
  import math
8
+ import platform
9
  import warnings
10
  from copy import copy
11
  from pathlib import Path
12
 
13
+ import cv2
14
  import numpy as np
15
  import pandas as pd
16
  import requests
 
20
  from torch.cuda import amp
21
 
22
  from utils.datasets import exif_transpose, letterbox
23
+ from utils.general import (LOGGER, check_requirements, check_suffix, colorstr, increment_path, make_divisible,
24
+ non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
25
  from utils.plots import Annotator, colors, save_one_box
26
  from utils.torch_utils import time_sync
27
 
 
273
  return torch.cat(x, self.d)
274
 
275
 
276
+ class DetectMultiBackend(nn.Module):
277
+ # YOLOv5 MultiBackend class for python inference on various backends
278
+ def __init__(self, weights='yolov5s.pt', device=None, dnn=True):
279
+ # Usage:
280
+ # PyTorch: weights = *.pt
281
+ # TorchScript: *.torchscript.pt
282
+ # CoreML: *.mlmodel
283
+ # TensorFlow: *_saved_model
284
+ # TensorFlow: *.pb
285
+ # TensorFlow Lite: *.tflite
286
+ # ONNX Runtime: *.onnx
287
+ # OpenCV DNN: *.onnx with dnn=True
288
+ super().__init__()
289
+ w = str(weights[0] if isinstance(weights, list) else weights)
290
+ suffix, suffixes = Path(w).suffix.lower(), ['.pt', '.onnx', '.tflite', '.pb', '', '.mlmodel']
291
+ check_suffix(w, suffixes) # check weights have acceptable suffix
292
+ pt, onnx, tflite, pb, saved_model, coreml = (suffix == x for x in suffixes) # backend booleans
293
+ jit = pt and 'torchscript' in w.lower()
294
+ stride, names = 64, [f'class{i}' for i in range(1000)] # assign defaults
295
+
296
+ if jit: # TorchScript
297
+ LOGGER.info(f'Loading {w} for TorchScript inference...')
298
+ extra_files = {'config.txt': ''} # model metadata
299
+ model = torch.jit.load(w, _extra_files=extra_files)
300
+ if extra_files['config.txt']:
301
+ d = json.loads(extra_files['config.txt']) # extra_files dict
302
+ stride, names = int(d['stride']), d['names']
303
+ elif pt: # PyTorch
304
+ from models.experimental import attempt_load # scoped to avoid circular import
305
+ model = torch.jit.load(w) if 'torchscript' in w else attempt_load(weights, map_location=device)
306
+ stride = int(model.stride.max()) # model stride
307
+ names = model.module.names if hasattr(model, 'module') else model.names # get class names
308
+ elif coreml: # CoreML *.mlmodel
309
+ import coremltools as ct
310
+ model = ct.models.MLModel(w)
311
+ elif dnn: # ONNX OpenCV DNN
312
+ LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
313
+ check_requirements(('opencv-python>=4.5.4',))
314
+ net = cv2.dnn.readNetFromONNX(w)
315
+ elif onnx: # ONNX Runtime
316
+ LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
317
+ check_requirements(('onnx', 'onnxruntime-gpu' if torch.has_cuda else 'onnxruntime'))
318
+ import onnxruntime
319
+ session = onnxruntime.InferenceSession(w, None)
320
+ else: # TensorFlow model (TFLite, pb, saved_model)
321
+ import tensorflow as tf
322
+ if pb: # https://www.tensorflow.org/guide/migrate#a_graphpb_or_graphpbtxt
323
+ def wrap_frozen_graph(gd, inputs, outputs):
324
+ x = tf.compat.v1.wrap_function(lambda: tf.compat.v1.import_graph_def(gd, name=""), []) # wrapped
325
+ return x.prune(tf.nest.map_structure(x.graph.as_graph_element, inputs),
326
+ tf.nest.map_structure(x.graph.as_graph_element, outputs))
327
+
328
+ LOGGER.info(f'Loading {w} for TensorFlow *.pb inference...')
329
+ graph_def = tf.Graph().as_graph_def()
330
+ graph_def.ParseFromString(open(w, 'rb').read())
331
+ frozen_func = wrap_frozen_graph(gd=graph_def, inputs="x:0", outputs="Identity:0")
332
+ elif saved_model:
333
+ LOGGER.info(f'Loading {w} for TensorFlow saved_model inference...')
334
+ model = tf.keras.models.load_model(w)
335
+ elif tflite: # https://www.tensorflow.org/lite/guide/python#install_tensorflow_lite_for_python
336
+ if 'edgetpu' in w.lower():
337
+ LOGGER.info(f'Loading {w} for TensorFlow Edge TPU inference...')
338
+ import tflite_runtime.interpreter as tfli
339
+ delegate = {'Linux': 'libedgetpu.so.1', # install https://coral.ai/software/#edgetpu-runtime
340
+ 'Darwin': 'libedgetpu.1.dylib',
341
+ 'Windows': 'edgetpu.dll'}[platform.system()]
342
+ interpreter = tfli.Interpreter(model_path=w, experimental_delegates=[tfli.load_delegate(delegate)])
343
+ else:
344
+ LOGGER.info(f'Loading {w} for TensorFlow Lite inference...')
345
+ interpreter = tf.lite.Interpreter(model_path=w) # load TFLite model
346
+ interpreter.allocate_tensors() # allocate
347
+ input_details = interpreter.get_input_details() # inputs
348
+ output_details = interpreter.get_output_details() # outputs
349
+ self.__dict__.update(locals()) # assign all variables to self
350
+
351
+ def forward(self, im, augment=False, visualize=False, val=False):
352
+ # YOLOv5 MultiBackend inference
353
+ b, ch, h, w = im.shape # batch, channel, height, width
354
+ if self.pt: # PyTorch
355
+ y = self.model(im) if self.jit else self.model(im, augment=augment, visualize=visualize)
356
+ return y if val else y[0]
357
+ elif self.coreml: # CoreML *.mlmodel
358
+ im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
359
+ im = Image.fromarray((im[0] * 255).astype('uint8'))
360
+ # im = im.resize((192, 320), Image.ANTIALIAS)
361
+ y = self.model.predict({'image': im}) # coordinates are xywh normalized
362
+ box = xywh2xyxy(y['coordinates'] * [[w, h, w, h]]) # xyxy pixels
363
+ conf, cls = y['confidence'].max(1), y['confidence'].argmax(1).astype(np.float)
364
+ y = np.concatenate((box, conf.reshape(-1, 1), cls.reshape(-1, 1)), 1)
365
+ elif self.onnx: # ONNX
366
+ im = im.cpu().numpy() # torch to numpy
367
+ if self.dnn: # ONNX OpenCV DNN
368
+ self.net.setInput(im)
369
+ y = self.net.forward()
370
+ else: # ONNX Runtime
371
+ y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
372
+ else: # TensorFlow model (TFLite, pb, saved_model)
373
+ im = im.permute(0, 2, 3, 1).cpu().numpy() # torch BCHW to numpy BHWC shape(1,320,192,3)
374
+ if self.pb:
375
+ y = self.frozen_func(x=self.tf.constant(im)).numpy()
376
+ elif self.saved_model:
377
+ y = self.model(im, training=False).numpy()
378
+ elif self.tflite:
379
+ input, output = self.input_details[0], self.output_details[0]
380
+ int8 = input['dtype'] == np.uint8 # is TFLite quantized uint8 model
381
+ if int8:
382
+ scale, zero_point = input['quantization']
383
+ im = (im / scale + zero_point).astype(np.uint8) # de-scale
384
+ self.interpreter.set_tensor(input['index'], im)
385
+ self.interpreter.invoke()
386
+ y = self.interpreter.get_tensor(output['index'])
387
+ if int8:
388
+ scale, zero_point = output['quantization']
389
+ y = (y.astype(np.float32) - zero_point) * scale # re-scale
390
+ y[..., 0] *= w # x
391
+ y[..., 1] *= h # y
392
+ y[..., 2] *= w # w
393
+ y[..., 3] *= h # h
394
+ y = torch.tensor(y)
395
+ return (y, []) if val else y
396
+
397
+
398
  class AutoShape(nn.Module):
399
  # YOLOv5 input-robust model wrapper for passing cv2/np/PIL/torch inputs. Includes preprocessing, inference and NMS
400
  conf = 0.25 # NMS confidence threshold
utils/general.py CHANGED
@@ -785,7 +785,8 @@ def print_mutation(results, hyp, save_dir, bucket):
785
 
786
 
787
  def apply_classifier(x, model, img, im0):
788
- # Apply a second stage classifier to yolo outputs
 
789
  im0 = [im0] if isinstance(im0, np.ndarray) else im0
790
  for i, d in enumerate(x): # per image
791
  if d is not None and len(d):
 
785
 
786
 
787
  def apply_classifier(x, model, img, im0):
788
+ # Apply a second stage classifier to YOLO outputs
789
+ # Example model = torchvision.models.__dict__['efficientnet_b0'](pretrained=True).to(device).eval()
790
  im0 = [im0] if isinstance(im0, np.ndarray) else im0
791
  for i, d in enumerate(x): # per image
792
  if d is not None and len(d):
utils/torch_utils.py CHANGED
@@ -17,7 +17,6 @@ import torch
17
  import torch.distributed as dist
18
  import torch.nn as nn
19
  import torch.nn.functional as F
20
- import torchvision
21
 
22
  from utils.general import LOGGER
23
 
@@ -235,25 +234,6 @@ def model_info(model, verbose=False, img_size=640):
235
  LOGGER.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
236
 
237
 
238
- def load_classifier(name='resnet101', n=2):
239
- # Loads a pretrained model reshaped to n-class output
240
- model = torchvision.models.__dict__[name](pretrained=True)
241
-
242
- # ResNet model properties
243
- # input_size = [3, 224, 224]
244
- # input_space = 'RGB'
245
- # input_range = [0, 1]
246
- # mean = [0.485, 0.456, 0.406]
247
- # std = [0.229, 0.224, 0.225]
248
-
249
- # Reshape output to n classes
250
- filters = model.fc.weight.shape[1]
251
- model.fc.bias = nn.Parameter(torch.zeros(n), requires_grad=True)
252
- model.fc.weight = nn.Parameter(torch.zeros(n, filters), requires_grad=True)
253
- model.fc.out_features = n
254
- return model
255
-
256
-
257
  def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
258
  # scales img(bs,3,y,x) by ratio constrained to gs-multiple
259
  if ratio == 1.0:
 
17
  import torch.distributed as dist
18
  import torch.nn as nn
19
  import torch.nn.functional as F
 
20
 
21
  from utils.general import LOGGER
22
 
 
234
  LOGGER.info(f"Model Summary: {len(list(model.modules()))} layers, {n_p} parameters, {n_g} gradients{fs}")
235
 
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  def scale_img(img, ratio=1.0, same_shape=False, gs=32): # img(16,3,256,416)
238
  # scales img(bs,3,y,x) by ratio constrained to gs-multiple
239
  if ratio == 1.0:
val.py CHANGED
@@ -23,10 +23,10 @@ if str(ROOT) not in sys.path:
23
  sys.path.append(str(ROOT)) # add ROOT to PATH
24
  ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
25
 
26
- from models.experimental import attempt_load
27
  from utils.callbacks import Callbacks
28
  from utils.datasets import create_dataloader
29
- from utils.general import (LOGGER, box_iou, check_dataset, check_img_size, check_requirements, check_suffix, check_yaml,
30
  coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
31
  scale_coords, xywh2xyxy, xyxy2xywh)
32
  from utils.metrics import ConfusionMatrix, ap_per_class
@@ -100,6 +100,7 @@ def run(data,
100
  name='exp', # save to project/name
101
  exist_ok=False, # existing project/name ok, do not increment
102
  half=True, # use FP16 half-precision inference
 
103
  model=None,
104
  dataloader=None,
105
  save_dir=Path(''),
@@ -110,8 +111,10 @@ def run(data,
110
  # Initialize/load model and set device
111
  training = model is not None
112
  if training: # called by train.py
113
- device = next(model.parameters()).device # get model device
114
 
 
 
115
  else: # called directly
116
  device = select_device(device, batch_size=batch_size)
117
 
@@ -120,22 +123,21 @@ def run(data,
120
  (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
121
 
122
  # Load model
123
- check_suffix(weights, '.pt')
124
- model = attempt_load(weights, map_location=device) # load FP32 model
125
- gs = max(int(model.stride.max()), 32) # grid size (max stride)
126
- imgsz = check_img_size(imgsz, s=gs) # check image size
127
-
128
- # Multi-GPU disabled, incompatible with .half() https://github.com/ultralytics/yolov5/issues/99
129
- # if device.type != 'cpu' and torch.cuda.device_count() > 1:
130
- # model = nn.DataParallel(model)
 
 
 
131
 
132
  # Data
133
  data = check_dataset(data) # check
134
 
135
- # Half
136
- half &= device.type != 'cpu' # half precision only supported on CUDA
137
- model.half() if half else model.float()
138
-
139
  # Configure
140
  model.eval()
141
  is_coco = isinstance(data.get('val'), str) and data['val'].endswith('coco/val2017.txt') # COCO dataset
@@ -145,11 +147,11 @@ def run(data,
145
 
146
  # Dataloader
147
  if not training:
148
- if device.type != 'cpu':
149
- model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.parameters()))) # run once
150
  pad = 0.0 if task == 'speed' else 0.5
151
  task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
152
- dataloader = create_dataloader(data[task], imgsz, batch_size, gs, single_cls, pad=pad, rect=True,
153
  prefix=colorstr(f'{task}: '))[0]
154
 
155
  seen = 0
@@ -160,32 +162,33 @@ def run(data,
160
  dt, p, r, f1, mp, mr, map50, map = [0.0, 0.0, 0.0], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
161
  loss = torch.zeros(3, device=device)
162
  jdict, stats, ap, ap_class = [], [], [], []
163
- for batch_i, (img, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
164
  t1 = time_sync()
165
- img = img.to(device, non_blocking=True)
166
- img = img.half() if half else img.float() # uint8 to fp16/32
167
- img /= 255 # 0 - 255 to 0.0 - 1.0
168
- targets = targets.to(device)
169
- nb, _, height, width = img.shape # batch size, channels, height, width
 
170
  t2 = time_sync()
171
  dt[0] += t2 - t1
172
 
173
- # Run model
174
- out, train_out = model(img, augment=augment) # inference and training outputs
175
  dt[1] += time_sync() - t2
176
 
177
- # Compute loss
178
  if compute_loss:
179
  loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls
180
 
181
- # Run NMS
182
  targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
183
  lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
184
  t3 = time_sync()
185
  out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
186
  dt[2] += time_sync() - t3
187
 
188
- # Statistics per image
189
  for si, pred in enumerate(out):
190
  labels = targets[targets[:, 0] == si, 1:]
191
  nl = len(labels)
@@ -202,12 +205,12 @@ def run(data,
202
  if single_cls:
203
  pred[:, 5] = 0
204
  predn = pred.clone()
205
- scale_coords(img[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
206
 
207
  # Evaluate
208
  if nl:
209
  tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
210
- scale_coords(img[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
211
  labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
212
  correct = process_batch(predn, labelsn, iouv)
213
  if plots:
@@ -221,16 +224,16 @@ def run(data,
221
  save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
222
  if save_json:
223
  save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
224
- callbacks.run('on_val_image_end', pred, predn, path, names, img[si])
225
 
226
  # Plot images
227
  if plots and batch_i < 3:
228
  f = save_dir / f'val_batch{batch_i}_labels.jpg' # labels
229
- Thread(target=plot_images, args=(img, targets, paths, f, names), daemon=True).start()
230
  f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions
231
- Thread(target=plot_images, args=(img, output_to_target(out), paths, f, names), daemon=True).start()
232
 
233
- # Compute statistics
234
  stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
235
  if len(stats) and stats[0].any():
236
  p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
@@ -318,6 +321,7 @@ def parse_opt():
318
  parser.add_argument('--name', default='exp', help='save to project/name')
319
  parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
320
  parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
 
321
  opt = parser.parse_args()
322
  opt.data = check_yaml(opt.data) # check YAML
323
  opt.save_json |= opt.data.endswith('coco.yaml')
 
23
  sys.path.append(str(ROOT)) # add ROOT to PATH
24
  ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
25
 
26
+ from models.common import DetectMultiBackend
27
  from utils.callbacks import Callbacks
28
  from utils.datasets import create_dataloader
29
+ from utils.general import (LOGGER, box_iou, check_dataset, check_img_size, check_requirements, check_yaml,
30
  coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
31
  scale_coords, xywh2xyxy, xyxy2xywh)
32
  from utils.metrics import ConfusionMatrix, ap_per_class
 
100
  name='exp', # save to project/name
101
  exist_ok=False, # existing project/name ok, do not increment
102
  half=True, # use FP16 half-precision inference
103
+ dnn=False, # use OpenCV DNN for ONNX inference
104
  model=None,
105
  dataloader=None,
106
  save_dir=Path(''),
 
111
  # Initialize/load model and set device
112
  training = model is not None
113
  if training: # called by train.py
114
+ device, pt = next(model.parameters()).device, True # get model device, PyTorch model
115
 
116
+ half &= device.type != 'cpu' # half precision only supported on CUDA
117
+ model.half() if half else model.float()
118
  else: # called directly
119
  device = select_device(device, batch_size=batch_size)
120
 
 
123
  (save_dir / 'labels' if save_txt else save_dir).mkdir(parents=True, exist_ok=True) # make dir
124
 
125
  # Load model
126
+ model = DetectMultiBackend(weights, device=device, dnn=dnn)
127
+ stride, pt = model.stride, model.pt
128
+ imgsz = check_img_size(imgsz, s=stride) # check image size
129
+ half &= pt and device.type != 'cpu' # half precision only supported by PyTorch on CUDA
130
+ if pt:
131
+ model.model.half() if half else model.model.float()
132
+ else:
133
+ half = False
134
+ batch_size = 1 # export.py models default to batch-size 1
135
+ device = torch.device('cpu')
136
+ LOGGER.info(f'Forcing --batch-size 1 square inference shape(1,3,{imgsz},{imgsz}) for non-PyTorch backends')
137
 
138
  # Data
139
  data = check_dataset(data) # check
140
 
 
 
 
 
141
  # Configure
142
  model.eval()
143
  is_coco = isinstance(data.get('val'), str) and data['val'].endswith('coco/val2017.txt') # COCO dataset
 
147
 
148
  # Dataloader
149
  if not training:
150
+ if pt and device.type != 'cpu':
151
+ model(torch.zeros(1, 3, imgsz, imgsz).to(device).type_as(next(model.model.parameters()))) # warmup
152
  pad = 0.0 if task == 'speed' else 0.5
153
  task = task if task in ('train', 'val', 'test') else 'val' # path to train/val/test images
154
+ dataloader = create_dataloader(data[task], imgsz, batch_size, stride, single_cls, pad=pad, rect=pt,
155
  prefix=colorstr(f'{task}: '))[0]
156
 
157
  seen = 0
 
162
  dt, p, r, f1, mp, mr, map50, map = [0.0, 0.0, 0.0], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
163
  loss = torch.zeros(3, device=device)
164
  jdict, stats, ap, ap_class = [], [], [], []
165
+ for batch_i, (im, targets, paths, shapes) in enumerate(tqdm(dataloader, desc=s)):
166
  t1 = time_sync()
167
+ if pt:
168
+ im = im.to(device, non_blocking=True)
169
+ targets = targets.to(device)
170
+ im = im.half() if half else im.float() # uint8 to fp16/32
171
+ im /= 255 # 0 - 255 to 0.0 - 1.0
172
+ nb, _, height, width = im.shape # batch size, channels, height, width
173
  t2 = time_sync()
174
  dt[0] += t2 - t1
175
 
176
+ # Inference
177
+ out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs
178
  dt[1] += time_sync() - t2
179
 
180
+ # Loss
181
  if compute_loss:
182
  loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls
183
 
184
+ # NMS
185
  targets[:, 2:] *= torch.Tensor([width, height, width, height]).to(device) # to pixels
186
  lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
187
  t3 = time_sync()
188
  out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
189
  dt[2] += time_sync() - t3
190
 
191
+ # Metrics
192
  for si, pred in enumerate(out):
193
  labels = targets[targets[:, 0] == si, 1:]
194
  nl = len(labels)
 
205
  if single_cls:
206
  pred[:, 5] = 0
207
  predn = pred.clone()
208
+ scale_coords(im[si].shape[1:], predn[:, :4], shape, shapes[si][1]) # native-space pred
209
 
210
  # Evaluate
211
  if nl:
212
  tbox = xywh2xyxy(labels[:, 1:5]) # target boxes
213
+ scale_coords(im[si].shape[1:], tbox, shape, shapes[si][1]) # native-space labels
214
  labelsn = torch.cat((labels[:, 0:1], tbox), 1) # native-space labels
215
  correct = process_batch(predn, labelsn, iouv)
216
  if plots:
 
224
  save_one_txt(predn, save_conf, shape, file=save_dir / 'labels' / (path.stem + '.txt'))
225
  if save_json:
226
  save_one_json(predn, jdict, path, class_map) # append to COCO-JSON dictionary
227
+ callbacks.run('on_val_image_end', pred, predn, path, names, im[si])
228
 
229
  # Plot images
230
  if plots and batch_i < 3:
231
  f = save_dir / f'val_batch{batch_i}_labels.jpg' # labels
232
+ Thread(target=plot_images, args=(im, targets, paths, f, names), daemon=True).start()
233
  f = save_dir / f'val_batch{batch_i}_pred.jpg' # predictions
234
+ Thread(target=plot_images, args=(im, output_to_target(out), paths, f, names), daemon=True).start()
235
 
236
+ # Compute metrics
237
  stats = [np.concatenate(x, 0) for x in zip(*stats)] # to numpy
238
  if len(stats) and stats[0].any():
239
  p, r, ap, f1, ap_class = ap_per_class(*stats, plot=plots, save_dir=save_dir, names=names)
 
321
  parser.add_argument('--name', default='exp', help='save to project/name')
322
  parser.add_argument('--exist-ok', action='store_true', help='existing project/name ok, do not increment')
323
  parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
324
+ parser.add_argument('--dnn', action='store_true', help='use OpenCV DNN for ONNX inference')
325
  opt = parser.parse_args()
326
  opt.data = check_yaml(opt.data) # check YAML
327
  opt.save_json |= opt.data.endswith('coco.yaml')