Zengyf-CVer commited on
Commit
aa24cb4
1 Parent(s): 112bf3b

app update

Browse files
data/coco128.yaml CHANGED
@@ -14,16 +14,87 @@ val: images/train2017 # val images (relative to 'path') 128 images
14
  test: # test images (optional)
15
 
16
  # Classes
17
- nc: 80 # number of classes
18
- names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light',
19
- 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
20
- 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
21
- 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard',
22
- 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
23
- 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
24
- 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
25
- 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', 'scissors', 'teddy bear',
26
- 'hair drier', 'toothbrush'] # class names
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
  # Download script/URL (optional)
 
14
  test: # test images (optional)
15
 
16
  # Classes
17
+ names:
18
+ 0: person
19
+ 1: bicycle
20
+ 2: car
21
+ 3: motorcycle
22
+ 4: airplane
23
+ 5: bus
24
+ 6: train
25
+ 7: truck
26
+ 8: boat
27
+ 9: traffic light
28
+ 10: fire hydrant
29
+ 11: stop sign
30
+ 12: parking meter
31
+ 13: bench
32
+ 14: bird
33
+ 15: cat
34
+ 16: dog
35
+ 17: horse
36
+ 18: sheep
37
+ 19: cow
38
+ 20: elephant
39
+ 21: bear
40
+ 22: zebra
41
+ 23: giraffe
42
+ 24: backpack
43
+ 25: umbrella
44
+ 26: handbag
45
+ 27: tie
46
+ 28: suitcase
47
+ 29: frisbee
48
+ 30: skis
49
+ 31: snowboard
50
+ 32: sports ball
51
+ 33: kite
52
+ 34: baseball bat
53
+ 35: baseball glove
54
+ 36: skateboard
55
+ 37: surfboard
56
+ 38: tennis racket
57
+ 39: bottle
58
+ 40: wine glass
59
+ 41: cup
60
+ 42: fork
61
+ 43: knife
62
+ 44: spoon
63
+ 45: bowl
64
+ 46: banana
65
+ 47: apple
66
+ 48: sandwich
67
+ 49: orange
68
+ 50: broccoli
69
+ 51: carrot
70
+ 52: hot dog
71
+ 53: pizza
72
+ 54: donut
73
+ 55: cake
74
+ 56: chair
75
+ 57: couch
76
+ 58: potted plant
77
+ 59: bed
78
+ 60: dining table
79
+ 61: toilet
80
+ 62: tv
81
+ 63: laptop
82
+ 64: mouse
83
+ 65: remote
84
+ 66: keyboard
85
+ 67: cell phone
86
+ 68: microwave
87
+ 69: oven
88
+ 70: toaster
89
+ 71: sink
90
+ 72: refrigerator
91
+ 73: book
92
+ 74: clock
93
+ 75: vase
94
+ 76: scissors
95
+ 77: teddy bear
96
+ 78: hair drier
97
+ 79: toothbrush
98
 
99
 
100
  # Download script/URL (optional)
export.py CHANGED
@@ -21,19 +21,19 @@ Requirements:
21
  $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
22
 
23
  Usage:
24
- $ python path/to/export.py --weights yolov5s.pt --include torchscript onnx openvino engine coreml tflite ...
25
 
26
  Inference:
27
- $ python path/to/detect.py --weights yolov5s.pt # PyTorch
28
- yolov5s.torchscript # TorchScript
29
- yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
30
- yolov5s.xml # OpenVINO
31
- yolov5s.engine # TensorRT
32
- yolov5s.mlmodel # CoreML (macOS-only)
33
- yolov5s_saved_model # TensorFlow SavedModel
34
- yolov5s.pb # TensorFlow GraphDef
35
- yolov5s.tflite # TensorFlow Lite
36
- yolov5s_edgetpu.tflite # TensorFlow Edge TPU
37
 
38
  TensorFlow.js:
39
  $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
@@ -65,10 +65,10 @@ if platform.system() != 'Windows':
65
  ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
66
 
67
  from models.experimental import attempt_load
68
- from models.yolo import Detect
69
  from utils.dataloaders import LoadImages
70
- from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_version, check_yaml,
71
- colorstr, file_size, print_args, url2file)
72
  from utils.torch_utils import select_device, smart_inference_mode
73
 
74
 
@@ -89,200 +89,199 @@ def export_formats():
89
  return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
90
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
93
  # YOLOv5 TorchScript model export
94
- try:
95
- LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
96
- f = file.with_suffix('.torchscript')
97
-
98
- ts = torch.jit.trace(model, im, strict=False)
99
- d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
100
- extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
101
- if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
102
- optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
103
- else:
104
- ts.save(str(f), _extra_files=extra_files)
105
 
106
- LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
107
- return f
108
- except Exception as e:
109
- LOGGER.info(f'{prefix} export failure: {e}')
 
 
 
 
110
 
111
 
 
112
  def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
113
  # YOLOv5 ONNX export
114
- try:
115
- check_requirements(('onnx',))
116
- import onnx
117
-
118
- LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
119
- f = file.with_suffix('.onnx')
120
-
121
- torch.onnx.export(
122
- model.cpu() if dynamic else model, # --dynamic only compatible with cpu
123
- im.cpu() if dynamic else im,
124
- f,
125
- verbose=False,
126
- opset_version=opset,
127
- training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
128
- do_constant_folding=not train,
129
- input_names=['images'],
130
- output_names=['output'],
131
- dynamic_axes={
132
- 'images': {
133
- 0: 'batch',
134
- 2: 'height',
135
- 3: 'width'}, # shape(1,3,640,640)
136
- 'output': {
137
- 0: 'batch',
138
- 1: 'anchors'} # shape(1,25200,85)
139
- } if dynamic else None)
140
-
141
- # Checks
142
- model_onnx = onnx.load(f) # load onnx model
143
- onnx.checker.check_model(model_onnx) # check onnx model
144
-
145
- # Metadata
146
- d = {'stride': int(max(model.stride)), 'names': model.names}
147
- for k, v in d.items():
148
- meta = model_onnx.metadata_props.add()
149
- meta.key, meta.value = k, str(v)
150
- onnx.save(model_onnx, f)
151
-
152
- # Simplify
153
- if simplify:
154
- try:
155
- cuda = torch.cuda.is_available()
156
- check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
157
- import onnxsim
158
-
159
- LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
160
- model_onnx, check = onnxsim.simplify(model_onnx)
161
- assert check, 'assert check failed'
162
- onnx.save(model_onnx, f)
163
- except Exception as e:
164
- LOGGER.info(f'{prefix} simplifier failure: {e}')
165
- LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
166
- return f
167
- except Exception as e:
168
- LOGGER.info(f'{prefix} export failure: {e}')
169
 
170
 
 
171
  def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
172
  # YOLOv5 OpenVINO export
173
- try:
174
- check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
175
- import openvino.inference_engine as ie
176
-
177
- LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
178
- f = str(file).replace('.pt', f'_openvino_model{os.sep}')
179
 
180
- cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
181
- subprocess.check_output(cmd.split()) # export
182
- with open(Path(f) / file.with_suffix('.yaml').name, 'w') as g:
183
- yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml
184
 
185
- LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
186
- return f
187
- except Exception as e:
188
- LOGGER.info(f'\n{prefix} export failure: {e}')
 
189
 
190
 
 
191
  def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
192
  # YOLOv5 CoreML export
193
- try:
194
- check_requirements(('coremltools',))
195
- import coremltools as ct
196
-
197
- LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
198
- f = file.with_suffix('.mlmodel')
199
-
200
- ts = torch.jit.trace(model, im, strict=False) # TorchScript model
201
- ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
202
- bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)
203
- if bits < 32:
204
- if platform.system() == 'Darwin': # quantization only supported on macOS
205
- with warnings.catch_warnings():
206
- warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
207
- ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
208
- else:
209
- print(f'{prefix} quantization only supported on macOS, skipping...')
210
- ct_model.save(f)
211
-
212
- LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
213
- return ct_model, f
214
- except Exception as e:
215
- LOGGER.info(f'\n{prefix} export failure: {e}')
216
- return None, None
217
-
218
-
219
- def export_engine(model, im, file, train, half, dynamic, simplify, workspace=4, verbose=False):
220
  # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
221
- prefix = colorstr('TensorRT:')
222
  try:
223
- assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
224
- try:
225
- import tensorrt as trt
226
- except Exception:
227
- if platform.system() == 'Linux':
228
- check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',))
229
- import tensorrt as trt
230
-
231
- if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
232
- grid = model.model[-1].anchor_grid
233
- model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
234
- export_onnx(model, im, file, 12, train, dynamic, simplify) # opset 12
235
- model.model[-1].anchor_grid = grid
236
- else: # TensorRT >= 8
237
- check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
238
- export_onnx(model, im, file, 13, train, dynamic, simplify) # opset 13
239
- onnx = file.with_suffix('.onnx')
240
-
241
- LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
242
- assert onnx.exists(), f'failed to export ONNX file: {onnx}'
243
- f = file.with_suffix('.engine') # TensorRT engine file
244
- logger = trt.Logger(trt.Logger.INFO)
245
- if verbose:
246
- logger.min_severity = trt.Logger.Severity.VERBOSE
247
-
248
- builder = trt.Builder(logger)
249
- config = builder.create_builder_config()
250
- config.max_workspace_size = workspace * 1 << 30
251
- # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
252
-
253
- flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
254
- network = builder.create_network(flag)
255
- parser = trt.OnnxParser(network, logger)
256
- if not parser.parse_from_file(str(onnx)):
257
- raise RuntimeError(f'failed to load ONNX file: {onnx}')
258
-
259
- inputs = [network.get_input(i) for i in range(network.num_inputs)]
260
- outputs = [network.get_output(i) for i in range(network.num_outputs)]
261
- LOGGER.info(f'{prefix} Network Description:')
 
 
 
 
 
 
 
262
  for inp in inputs:
263
- LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
264
- for out in outputs:
265
- LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
266
-
267
- if dynamic:
268
- if im.shape[0] <= 1:
269
- LOGGER.warning(f"{prefix}WARNING: --dynamic model requires maximum --batch-size argument")
270
- profile = builder.create_optimization_profile()
271
- for inp in inputs:
272
- profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
273
- config.add_optimization_profile(profile)
274
-
275
- LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}')
276
- if builder.platform_has_fast_fp16 and half:
277
- config.set_flag(trt.BuilderFlag.FP16)
278
- with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
279
- t.write(engine.serialize())
280
- LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
281
- return f
282
- except Exception as e:
283
- LOGGER.info(f'\n{prefix} export failure: {e}')
284
 
285
 
 
286
  def export_saved_model(model,
287
  im,
288
  file,
@@ -296,163 +295,142 @@ def export_saved_model(model,
296
  keras=False,
297
  prefix=colorstr('TensorFlow SavedModel:')):
298
  # YOLOv5 TensorFlow SavedModel export
299
- try:
300
- import tensorflow as tf
301
- from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
302
-
303
- from models.tf import TFDetect, TFModel
304
-
305
- LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
306
- f = str(file).replace('.pt', '_saved_model')
307
- batch_size, ch, *imgsz = list(im.shape) # BCHW
308
-
309
- tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
310
- im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
311
- _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
312
- inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)
313
- outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
314
- keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
315
- keras_model.trainable = False
316
- keras_model.summary()
317
- if keras:
318
- keras_model.save(f, save_format='tf')
319
- else:
320
- spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
321
- m = tf.function(lambda x: keras_model(x)) # full model
322
- m = m.get_concrete_function(spec)
323
- frozen_func = convert_variables_to_constants_v2(m)
324
- tfm = tf.Module()
325
- tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x)[0], [spec])
326
- tfm.__call__(im)
327
- tf.saved_model.save(tfm,
328
- f,
329
- options=tf.saved_model.SaveOptions(experimental_custom_gradients=False)
330
- if check_version(tf.__version__, '2.6') else tf.saved_model.SaveOptions())
331
- LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
332
- return keras_model, f
333
- except Exception as e:
334
- LOGGER.info(f'\n{prefix} export failure: {e}')
335
- return None, None
336
 
337
 
 
338
  def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
339
  # YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
340
- try:
341
- import tensorflow as tf
342
- from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
343
 
344
- LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
345
- f = file.with_suffix('.pb')
346
 
347
- m = tf.function(lambda x: keras_model(x)) # full model
348
- m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
349
- frozen_func = convert_variables_to_constants_v2(m)
350
- frozen_func.graph.as_graph_def()
351
- tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
352
-
353
- LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
354
- return f
355
- except Exception as e:
356
- LOGGER.info(f'\n{prefix} export failure: {e}')
357
 
358
 
 
359
  def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
360
  # YOLOv5 TensorFlow Lite export
361
- try:
362
- import tensorflow as tf
363
-
364
- LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
365
- batch_size, ch, *imgsz = list(im.shape) # BCHW
366
- f = str(file).replace('.pt', '-fp16.tflite')
367
-
368
- converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
369
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
370
- converter.target_spec.supported_types = [tf.float16]
371
- converter.optimizations = [tf.lite.Optimize.DEFAULT]
372
- if int8:
373
- from models.tf import representative_dataset_gen
374
- dataset = LoadImages(check_dataset(check_yaml(data))['train'], img_size=imgsz, auto=False)
375
- converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)
376
- converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
377
- converter.target_spec.supported_types = []
378
- converter.inference_input_type = tf.uint8 # or tf.int8
379
- converter.inference_output_type = tf.uint8 # or tf.int8
380
- converter.experimental_new_quantizer = True
381
- f = str(file).replace('.pt', '-int8.tflite')
382
- if nms or agnostic_nms:
383
- converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
384
-
385
- tflite_model = converter.convert()
386
- open(f, "wb").write(tflite_model)
387
- LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
388
- return f
389
- except Exception as e:
390
- LOGGER.info(f'\n{prefix} export failure: {e}')
391
-
392
-
393
  def export_edgetpu(file, prefix=colorstr('Edge TPU:')):
394
  # YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
395
- try:
396
- cmd = 'edgetpu_compiler --version'
397
- help_url = 'https://coral.ai/docs/edgetpu/compiler/'
398
- assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
399
- if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:
400
- LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
401
- sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
402
- for c in (
403
- 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
404
- 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
405
- 'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
406
- subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
407
- ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
408
-
409
- LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
410
- f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model
411
- f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model
412
-
413
- cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}"
414
- subprocess.run(cmd.split(), check=True)
415
-
416
- LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
417
- return f
418
- except Exception as e:
419
- LOGGER.info(f'\n{prefix} export failure: {e}')
420
-
421
-
422
  def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
423
  # YOLOv5 TensorFlow.js export
424
- try:
425
- check_requirements(('tensorflowjs',))
426
- import re
427
-
428
- import tensorflowjs as tfjs
429
-
430
- LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
431
- f = str(file).replace('.pt', '_web_model') # js dir
432
- f_pb = file.with_suffix('.pb') # *.pb path
433
- f_json = f'{f}/model.json' # *.json path
434
-
435
- cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
436
- f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'
437
- subprocess.run(cmd.split())
438
-
439
- with open(f_json) as j:
440
- json = j.read()
441
- with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
442
- subst = re.sub(
443
- r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
444
- r'"Identity.?.?": {"name": "Identity.?.?"}, '
445
- r'"Identity.?.?": {"name": "Identity.?.?"}, '
446
- r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
447
- r'"Identity_1": {"name": "Identity_1"}, '
448
- r'"Identity_2": {"name": "Identity_2"}, '
449
- r'"Identity_3": {"name": "Identity_3"}}}', json)
450
- j.write(subst)
451
-
452
- LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
453
- return f
454
- except Exception as e:
455
- LOGGER.info(f'\n{prefix} export failure: {e}')
456
 
457
 
458
  @smart_inference_mode()
@@ -495,11 +473,9 @@ def run(
495
  assert device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0'
496
  assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
497
  model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model
498
- nc, names = model.nc, model.names # number of classes, class names
499
 
500
  # Checks
501
  imgsz *= 2 if len(imgsz) == 1 else 1 # expand
502
- assert nc == len(names), f'Model class count {nc} != len(names) {len(names)}'
503
  if optimize:
504
  assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
505
 
@@ -513,36 +489,37 @@ def run(
513
  for k, m in model.named_modules():
514
  if isinstance(m, Detect):
515
  m.inplace = inplace
516
- m.onnx_dynamic = dynamic
517
  m.export = True
518
 
519
  for _ in range(2):
520
  y = model(im) # dry runs
521
  if half and not coreml:
522
  im, model = im.half(), model.half() # to FP16
523
- shape = tuple(y[0].shape) # model output shape
524
  LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
525
 
526
  # Exports
527
  f = [''] * 10 # exported filenames
528
  warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
529
  if jit:
530
- f[0] = export_torchscript(model, im, file, optimize)
531
  if engine: # TensorRT required before ONNX
532
- f[1] = export_engine(model, im, file, train, half, dynamic, simplify, workspace, verbose)
533
  if onnx or xml: # OpenVINO requires ONNX
534
- f[2] = export_onnx(model, im, file, opset, train, dynamic, simplify)
535
  if xml: # OpenVINO
536
- f[3] = export_openvino(model, file, half)
537
  if coreml:
538
- _, f[4] = export_coreml(model, im, file, int8, half)
539
 
540
  # TensorFlow Exports
541
  if any((saved_model, pb, tflite, edgetpu, tfjs)):
542
  if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
543
  check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
544
  assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
545
- model, f[5] = export_saved_model(model.cpu(),
 
546
  im,
547
  file,
548
  dynamic,
@@ -554,19 +531,19 @@ def run(
554
  conf_thres=conf_thres,
555
  keras=keras)
556
  if pb or tfjs: # pb prerequisite to tfjs
557
- f[6] = export_pb(model, file)
558
  if tflite or edgetpu:
559
- f[7] = export_tflite(model, im, file, int8=int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
560
  if edgetpu:
561
- f[8] = export_edgetpu(file)
562
  if tfjs:
563
- f[9] = export_tfjs(file)
564
 
565
  # Finish
566
  f = [str(x) for x in f if x] # filter out '' and None
567
  if any(f):
568
  h = '--half' if half else '' # --half FP16 inference arg
569
- LOGGER.info(f'\nExport complete ({time.time() - t:.2f}s)'
570
  f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
571
  f"\nDetect: python detect.py --weights {f[-1]} {h}"
572
  f"\nValidate: python val.py --weights {f[-1]} {h}"
@@ -601,7 +578,7 @@ def parse_opt():
601
  parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
602
  parser.add_argument('--include',
603
  nargs='+',
604
- default=['torchscript', 'onnx'],
605
  help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
606
  opt = parser.parse_args()
607
  print_args(vars(opt))
 
21
  $ pip install -r requirements.txt coremltools onnx onnx-simplifier onnxruntime-gpu openvino-dev tensorflow # GPU
22
 
23
  Usage:
24
+ $ python export.py --weights yolov5s.pt --include torchscript onnx openvino engine coreml tflite ...
25
 
26
  Inference:
27
+ $ python detect.py --weights yolov5s.pt # PyTorch
28
+ yolov5s.torchscript # TorchScript
29
+ yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
30
+ yolov5s.xml # OpenVINO
31
+ yolov5s.engine # TensorRT
32
+ yolov5s.mlmodel # CoreML (macOS-only)
33
+ yolov5s_saved_model # TensorFlow SavedModel
34
+ yolov5s.pb # TensorFlow GraphDef
35
+ yolov5s.tflite # TensorFlow Lite
36
+ yolov5s_edgetpu.tflite # TensorFlow Edge TPU
37
 
38
  TensorFlow.js:
39
  $ cd .. && git clone https://github.com/zldrobit/tfjs-yolov5-example.git && cd tfjs-yolov5-example
 
65
  ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
66
 
67
  from models.experimental import attempt_load
68
+ from models.yolo import ClassificationModel, Detect
69
  from utils.dataloaders import LoadImages
70
+ from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_version,
71
+ check_yaml, colorstr, file_size, get_default_args, print_args, url2file)
72
  from utils.torch_utils import select_device, smart_inference_mode
73
 
74
 
 
89
  return pd.DataFrame(x, columns=['Format', 'Argument', 'Suffix', 'CPU', 'GPU'])
90
 
91
 
92
+ def try_export(inner_func):
93
+ # YOLOv5 export decorator, i..e @try_export
94
+ inner_args = get_default_args(inner_func)
95
+
96
+ def outer_func(*args, **kwargs):
97
+ prefix = inner_args['prefix']
98
+ try:
99
+ with Profile() as dt:
100
+ f, model = inner_func(*args, **kwargs)
101
+ LOGGER.info(f'{prefix} export success ✅ {dt.t:.1f}s, saved as {f} ({file_size(f):.1f} MB)')
102
+ return f, model
103
+ except Exception as e:
104
+ LOGGER.info(f'{prefix} export failure ❌ {dt.t:.1f}s: {e}')
105
+ return None, None
106
+
107
+ return outer_func
108
+
109
+
110
+ @try_export
111
  def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
112
  # YOLOv5 TorchScript model export
113
+ LOGGER.info(f'\n{prefix} starting export with torch {torch.__version__}...')
114
+ f = file.with_suffix('.torchscript')
 
 
 
 
 
 
 
 
 
115
 
116
+ ts = torch.jit.trace(model, im, strict=False)
117
+ d = {"shape": im.shape, "stride": int(max(model.stride)), "names": model.names}
118
+ extra_files = {'config.txt': json.dumps(d)} # torch._C.ExtraFilesMap()
119
+ if optimize: # https://pytorch.org/tutorials/recipes/mobile_interpreter.html
120
+ optimize_for_mobile(ts)._save_for_lite_interpreter(str(f), _extra_files=extra_files)
121
+ else:
122
+ ts.save(str(f), _extra_files=extra_files)
123
+ return f, None
124
 
125
 
126
+ @try_export
127
  def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):
128
  # YOLOv5 ONNX export
129
+ check_requirements(('onnx',))
130
+ import onnx
131
+
132
+ LOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')
133
+ f = file.with_suffix('.onnx')
134
+
135
+ torch.onnx.export(
136
+ model.cpu() if dynamic else model, # --dynamic only compatible with cpu
137
+ im.cpu() if dynamic else im,
138
+ f,
139
+ verbose=False,
140
+ opset_version=opset,
141
+ training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,
142
+ do_constant_folding=not train,
143
+ input_names=['images'],
144
+ output_names=['output'],
145
+ dynamic_axes={
146
+ 'images': {
147
+ 0: 'batch',
148
+ 2: 'height',
149
+ 3: 'width'}, # shape(1,3,640,640)
150
+ 'output': {
151
+ 0: 'batch',
152
+ 1: 'anchors'} # shape(1,25200,85)
153
+ } if dynamic else None)
154
+
155
+ # Checks
156
+ model_onnx = onnx.load(f) # load onnx model
157
+ onnx.checker.check_model(model_onnx) # check onnx model
158
+
159
+ # Metadata
160
+ d = {'stride': int(max(model.stride)), 'names': model.names}
161
+ for k, v in d.items():
162
+ meta = model_onnx.metadata_props.add()
163
+ meta.key, meta.value = k, str(v)
164
+ onnx.save(model_onnx, f)
165
+
166
+ # Simplify
167
+ if simplify:
168
+ try:
169
+ cuda = torch.cuda.is_available()
170
+ check_requirements(('onnxruntime-gpu' if cuda else 'onnxruntime', 'onnx-simplifier>=0.4.1'))
171
+ import onnxsim
172
+
173
+ LOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')
174
+ model_onnx, check = onnxsim.simplify(model_onnx)
175
+ assert check, 'assert check failed'
176
+ onnx.save(model_onnx, f)
177
+ except Exception as e:
178
+ LOGGER.info(f'{prefix} simplifier failure: {e}')
179
+ return f, model_onnx
 
 
 
 
180
 
181
 
182
+ @try_export
183
  def export_openvino(model, file, half, prefix=colorstr('OpenVINO:')):
184
  # YOLOv5 OpenVINO export
185
+ check_requirements(('openvino-dev',)) # requires openvino-dev: https://pypi.org/project/openvino-dev/
186
+ import openvino.inference_engine as ie
 
 
 
 
187
 
188
+ LOGGER.info(f'\n{prefix} starting export with openvino {ie.__version__}...')
189
+ f = str(file).replace('.pt', f'_openvino_model{os.sep}')
 
 
190
 
191
+ cmd = f"mo --input_model {file.with_suffix('.onnx')} --output_dir {f} --data_type {'FP16' if half else 'FP32'}"
192
+ subprocess.check_output(cmd.split()) # export
193
+ with open(Path(f) / file.with_suffix('.yaml').name, 'w') as g:
194
+ yaml.dump({'stride': int(max(model.stride)), 'names': model.names}, g) # add metadata.yaml
195
+ return f, None
196
 
197
 
198
+ @try_export
199
  def export_coreml(model, im, file, int8, half, prefix=colorstr('CoreML:')):
200
  # YOLOv5 CoreML export
201
+ check_requirements(('coremltools',))
202
+ import coremltools as ct
203
+
204
+ LOGGER.info(f'\n{prefix} starting export with coremltools {ct.__version__}...')
205
+ f = file.with_suffix('.mlmodel')
206
+
207
+ ts = torch.jit.trace(model, im, strict=False) # TorchScript model
208
+ ct_model = ct.convert(ts, inputs=[ct.ImageType('image', shape=im.shape, scale=1 / 255, bias=[0, 0, 0])])
209
+ bits, mode = (8, 'kmeans_lut') if int8 else (16, 'linear') if half else (32, None)
210
+ if bits < 32:
211
+ if platform.system() == 'Darwin': # quantization only supported on macOS
212
+ with warnings.catch_warnings():
213
+ warnings.filterwarnings("ignore", category=DeprecationWarning) # suppress numpy==1.20 float warning
214
+ ct_model = ct.models.neural_network.quantization_utils.quantize_weights(ct_model, bits, mode)
215
+ else:
216
+ print(f'{prefix} quantization only supported on macOS, skipping...')
217
+ ct_model.save(f)
218
+ return f, ct_model
219
+
220
+
221
+ @try_export
222
+ def export_engine(model, im, file, half, dynamic, simplify, workspace=4, verbose=False, prefix=colorstr('TensorRT:')):
 
 
 
 
 
223
  # YOLOv5 TensorRT export https://developer.nvidia.com/tensorrt
224
+ assert im.device.type != 'cpu', 'export running on CPU but must be on GPU, i.e. `python export.py --device 0`'
225
  try:
226
+ import tensorrt as trt
227
+ except Exception:
228
+ if platform.system() == 'Linux':
229
+ check_requirements(('nvidia-tensorrt',), cmds=('-U --index-url https://pypi.ngc.nvidia.com',))
230
+ import tensorrt as trt
231
+
232
+ if trt.__version__[0] == '7': # TensorRT 7 handling https://github.com/ultralytics/yolov5/issues/6012
233
+ grid = model.model[-1].anchor_grid
234
+ model.model[-1].anchor_grid = [a[..., :1, :1, :] for a in grid]
235
+ export_onnx(model, im, file, 12, False, dynamic, simplify) # opset 12
236
+ model.model[-1].anchor_grid = grid
237
+ else: # TensorRT >= 8
238
+ check_version(trt.__version__, '8.0.0', hard=True) # require tensorrt>=8.0.0
239
+ export_onnx(model, im, file, 13, False, dynamic, simplify) # opset 13
240
+ onnx = file.with_suffix('.onnx')
241
+
242
+ LOGGER.info(f'\n{prefix} starting export with TensorRT {trt.__version__}...')
243
+ assert onnx.exists(), f'failed to export ONNX file: {onnx}'
244
+ f = file.with_suffix('.engine') # TensorRT engine file
245
+ logger = trt.Logger(trt.Logger.INFO)
246
+ if verbose:
247
+ logger.min_severity = trt.Logger.Severity.VERBOSE
248
+
249
+ builder = trt.Builder(logger)
250
+ config = builder.create_builder_config()
251
+ config.max_workspace_size = workspace * 1 << 30
252
+ # config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, workspace << 30) # fix TRT 8.4 deprecation notice
253
+
254
+ flag = (1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
255
+ network = builder.create_network(flag)
256
+ parser = trt.OnnxParser(network, logger)
257
+ if not parser.parse_from_file(str(onnx)):
258
+ raise RuntimeError(f'failed to load ONNX file: {onnx}')
259
+
260
+ inputs = [network.get_input(i) for i in range(network.num_inputs)]
261
+ outputs = [network.get_output(i) for i in range(network.num_outputs)]
262
+ LOGGER.info(f'{prefix} Network Description:')
263
+ for inp in inputs:
264
+ LOGGER.info(f'{prefix}\tinput "{inp.name}" with shape {inp.shape} and dtype {inp.dtype}')
265
+ for out in outputs:
266
+ LOGGER.info(f'{prefix}\toutput "{out.name}" with shape {out.shape} and dtype {out.dtype}')
267
+
268
+ if dynamic:
269
+ if im.shape[0] <= 1:
270
+ LOGGER.warning(f"{prefix}WARNING: --dynamic model requires maximum --batch-size argument")
271
+ profile = builder.create_optimization_profile()
272
  for inp in inputs:
273
+ profile.set_shape(inp.name, (1, *im.shape[1:]), (max(1, im.shape[0] // 2), *im.shape[1:]), im.shape)
274
+ config.add_optimization_profile(profile)
275
+
276
+ LOGGER.info(f'{prefix} building FP{16 if builder.platform_has_fast_fp16 and half else 32} engine in {f}')
277
+ if builder.platform_has_fast_fp16 and half:
278
+ config.set_flag(trt.BuilderFlag.FP16)
279
+ with builder.build_engine(network, config) as engine, open(f, 'wb') as t:
280
+ t.write(engine.serialize())
281
+ return f, None
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
 
284
+ @try_export
285
  def export_saved_model(model,
286
  im,
287
  file,
 
295
  keras=False,
296
  prefix=colorstr('TensorFlow SavedModel:')):
297
  # YOLOv5 TensorFlow SavedModel export
298
+ import tensorflow as tf
299
+ from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
300
+
301
+ from models.tf import TFModel
302
+
303
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
304
+ f = str(file).replace('.pt', '_saved_model')
305
+ batch_size, ch, *imgsz = list(im.shape) # BCHW
306
+
307
+ tf_model = TFModel(cfg=model.yaml, model=model, nc=model.nc, imgsz=imgsz)
308
+ im = tf.zeros((batch_size, *imgsz, ch)) # BHWC order for TensorFlow
309
+ _ = tf_model.predict(im, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
310
+ inputs = tf.keras.Input(shape=(*imgsz, ch), batch_size=None if dynamic else batch_size)
311
+ outputs = tf_model.predict(inputs, tf_nms, agnostic_nms, topk_per_class, topk_all, iou_thres, conf_thres)
312
+ keras_model = tf.keras.Model(inputs=inputs, outputs=outputs)
313
+ keras_model.trainable = False
314
+ keras_model.summary()
315
+ if keras:
316
+ keras_model.save(f, save_format='tf')
317
+ else:
318
+ spec = tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype)
319
+ m = tf.function(lambda x: keras_model(x)) # full model
320
+ m = m.get_concrete_function(spec)
321
+ frozen_func = convert_variables_to_constants_v2(m)
322
+ tfm = tf.Module()
323
+ tfm.__call__ = tf.function(lambda x: frozen_func(x)[:4] if tf_nms else frozen_func(x)[0], [spec])
324
+ tfm.__call__(im)
325
+ tf.saved_model.save(tfm,
326
+ f,
327
+ options=tf.saved_model.SaveOptions(experimental_custom_gradients=False) if check_version(
328
+ tf.__version__, '2.6') else tf.saved_model.SaveOptions())
329
+ return f, keras_model
 
 
 
 
 
330
 
331
 
332
+ @try_export
333
  def export_pb(keras_model, file, prefix=colorstr('TensorFlow GraphDef:')):
334
  # YOLOv5 TensorFlow GraphDef *.pb export https://github.com/leimao/Frozen_Graph_TensorFlow
335
+ import tensorflow as tf
336
+ from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
 
337
 
338
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
339
+ f = file.with_suffix('.pb')
340
 
341
+ m = tf.function(lambda x: keras_model(x)) # full model
342
+ m = m.get_concrete_function(tf.TensorSpec(keras_model.inputs[0].shape, keras_model.inputs[0].dtype))
343
+ frozen_func = convert_variables_to_constants_v2(m)
344
+ frozen_func.graph.as_graph_def()
345
+ tf.io.write_graph(graph_or_graph_def=frozen_func.graph, logdir=str(f.parent), name=f.name, as_text=False)
346
+ return f, None
 
 
 
 
347
 
348
 
349
+ @try_export
350
  def export_tflite(keras_model, im, file, int8, data, nms, agnostic_nms, prefix=colorstr('TensorFlow Lite:')):
351
  # YOLOv5 TensorFlow Lite export
352
+ import tensorflow as tf
353
+
354
+ LOGGER.info(f'\n{prefix} starting export with tensorflow {tf.__version__}...')
355
+ batch_size, ch, *imgsz = list(im.shape) # BCHW
356
+ f = str(file).replace('.pt', '-fp16.tflite')
357
+
358
+ converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
359
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS]
360
+ converter.target_spec.supported_types = [tf.float16]
361
+ converter.optimizations = [tf.lite.Optimize.DEFAULT]
362
+ if int8:
363
+ from models.tf import representative_dataset_gen
364
+ dataset = LoadImages(check_dataset(check_yaml(data))['train'], img_size=imgsz, auto=False)
365
+ converter.representative_dataset = lambda: representative_dataset_gen(dataset, ncalib=100)
366
+ converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
367
+ converter.target_spec.supported_types = []
368
+ converter.inference_input_type = tf.uint8 # or tf.int8
369
+ converter.inference_output_type = tf.uint8 # or tf.int8
370
+ converter.experimental_new_quantizer = True
371
+ f = str(file).replace('.pt', '-int8.tflite')
372
+ if nms or agnostic_nms:
373
+ converter.target_spec.supported_ops.append(tf.lite.OpsSet.SELECT_TF_OPS)
374
+
375
+ tflite_model = converter.convert()
376
+ open(f, "wb").write(tflite_model)
377
+ return f, None
378
+
379
+
380
+ @try_export
 
 
 
381
  def export_edgetpu(file, prefix=colorstr('Edge TPU:')):
382
  # YOLOv5 Edge TPU export https://coral.ai/docs/edgetpu/models-intro/
383
+ cmd = 'edgetpu_compiler --version'
384
+ help_url = 'https://coral.ai/docs/edgetpu/compiler/'
385
+ assert platform.system() == 'Linux', f'export only supported on Linux. See {help_url}'
386
+ if subprocess.run(f'{cmd} >/dev/null', shell=True).returncode != 0:
387
+ LOGGER.info(f'\n{prefix} export requires Edge TPU compiler. Attempting install from {help_url}')
388
+ sudo = subprocess.run('sudo --version >/dev/null', shell=True).returncode == 0 # sudo installed on system
389
+ for c in (
390
+ 'curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | sudo apt-key add -',
391
+ 'echo "deb https://packages.cloud.google.com/apt coral-edgetpu-stable main" | sudo tee /etc/apt/sources.list.d/coral-edgetpu.list',
392
+ 'sudo apt-get update', 'sudo apt-get install edgetpu-compiler'):
393
+ subprocess.run(c if sudo else c.replace('sudo ', ''), shell=True, check=True)
394
+ ver = subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1]
395
+
396
+ LOGGER.info(f'\n{prefix} starting export with Edge TPU compiler {ver}...')
397
+ f = str(file).replace('.pt', '-int8_edgetpu.tflite') # Edge TPU model
398
+ f_tfl = str(file).replace('.pt', '-int8.tflite') # TFLite model
399
+
400
+ cmd = f"edgetpu_compiler -s -d -k 10 --out_dir {file.parent} {f_tfl}"
401
+ subprocess.run(cmd.split(), check=True)
402
+ return f, None
403
+
404
+
405
+ @try_export
 
 
 
 
406
  def export_tfjs(file, prefix=colorstr('TensorFlow.js:')):
407
  # YOLOv5 TensorFlow.js export
408
+ check_requirements(('tensorflowjs',))
409
+ import re
410
+
411
+ import tensorflowjs as tfjs
412
+
413
+ LOGGER.info(f'\n{prefix} starting export with tensorflowjs {tfjs.__version__}...')
414
+ f = str(file).replace('.pt', '_web_model') # js dir
415
+ f_pb = file.with_suffix('.pb') # *.pb path
416
+ f_json = f'{f}/model.json' # *.json path
417
+
418
+ cmd = f'tensorflowjs_converter --input_format=tf_frozen_model ' \
419
+ f'--output_node_names=Identity,Identity_1,Identity_2,Identity_3 {f_pb} {f}'
420
+ subprocess.run(cmd.split())
421
+
422
+ json = Path(f_json).read_text()
423
+ with open(f_json, 'w') as j: # sort JSON Identity_* in ascending order
424
+ subst = re.sub(
425
+ r'{"outputs": {"Identity.?.?": {"name": "Identity.?.?"}, '
426
+ r'"Identity.?.?": {"name": "Identity.?.?"}, '
427
+ r'"Identity.?.?": {"name": "Identity.?.?"}, '
428
+ r'"Identity.?.?": {"name": "Identity.?.?"}}}', r'{"outputs": {"Identity": {"name": "Identity"}, '
429
+ r'"Identity_1": {"name": "Identity_1"}, '
430
+ r'"Identity_2": {"name": "Identity_2"}, '
431
+ r'"Identity_3": {"name": "Identity_3"}}}', json)
432
+ j.write(subst)
433
+ return f, None
 
 
 
 
 
 
434
 
435
 
436
  @smart_inference_mode()
 
473
  assert device.type != 'cpu' or coreml, '--half only compatible with GPU export, i.e. use --device 0'
474
  assert not dynamic, '--half not compatible with --dynamic, i.e. use either --half or --dynamic but not both'
475
  model = attempt_load(weights, device=device, inplace=True, fuse=True) # load FP32 model
 
476
 
477
  # Checks
478
  imgsz *= 2 if len(imgsz) == 1 else 1 # expand
 
479
  if optimize:
480
  assert device.type == 'cpu', '--optimize not compatible with cuda devices, i.e. use --device cpu'
481
 
 
489
  for k, m in model.named_modules():
490
  if isinstance(m, Detect):
491
  m.inplace = inplace
492
+ m.dynamic = dynamic
493
  m.export = True
494
 
495
  for _ in range(2):
496
  y = model(im) # dry runs
497
  if half and not coreml:
498
  im, model = im.half(), model.half() # to FP16
499
+ shape = tuple((y[0] if isinstance(y, tuple) else y).shape) # model output shape
500
  LOGGER.info(f"\n{colorstr('PyTorch:')} starting from {file} with output shape {shape} ({file_size(file):.1f} MB)")
501
 
502
  # Exports
503
  f = [''] * 10 # exported filenames
504
  warnings.filterwarnings(action='ignore', category=torch.jit.TracerWarning) # suppress TracerWarning
505
  if jit:
506
+ f[0], _ = export_torchscript(model, im, file, optimize)
507
  if engine: # TensorRT required before ONNX
508
+ f[1], _ = export_engine(model, im, file, half, dynamic, simplify, workspace, verbose)
509
  if onnx or xml: # OpenVINO requires ONNX
510
+ f[2], _ = export_onnx(model, im, file, opset, train, dynamic, simplify)
511
  if xml: # OpenVINO
512
+ f[3], _ = export_openvino(model, file, half)
513
  if coreml:
514
+ f[4], _ = export_coreml(model, im, file, int8, half)
515
 
516
  # TensorFlow Exports
517
  if any((saved_model, pb, tflite, edgetpu, tfjs)):
518
  if int8 or edgetpu: # TFLite --int8 bug https://github.com/ultralytics/yolov5/issues/5707
519
  check_requirements(('flatbuffers==1.12',)) # required before `import tensorflow`
520
  assert not tflite or not tfjs, 'TFLite and TF.js models must be exported separately, please pass only one type.'
521
+ assert not isinstance(model, ClassificationModel), 'ClassificationModel export to TF formats not yet supported.'
522
+ f[5], model = export_saved_model(model.cpu(),
523
  im,
524
  file,
525
  dynamic,
 
531
  conf_thres=conf_thres,
532
  keras=keras)
533
  if pb or tfjs: # pb prerequisite to tfjs
534
+ f[6], _ = export_pb(model, file)
535
  if tflite or edgetpu:
536
+ f[7], _ = export_tflite(model, im, file, int8 or edgetpu, data=data, nms=nms, agnostic_nms=agnostic_nms)
537
  if edgetpu:
538
+ f[8], _ = export_edgetpu(file)
539
  if tfjs:
540
+ f[9], _ = export_tfjs(file)
541
 
542
  # Finish
543
  f = [str(x) for x in f if x] # filter out '' and None
544
  if any(f):
545
  h = '--half' if half else '' # --half FP16 inference arg
546
+ LOGGER.info(f'\nExport complete ({time.time() - t:.1f}s)'
547
  f"\nResults saved to {colorstr('bold', file.parent.resolve())}"
548
  f"\nDetect: python detect.py --weights {f[-1]} {h}"
549
  f"\nValidate: python val.py --weights {f[-1]} {h}"
 
578
  parser.add_argument('--conf-thres', type=float, default=0.25, help='TF.js NMS: confidence threshold')
579
  parser.add_argument('--include',
580
  nargs='+',
581
+ default=['torchscript'],
582
  help='torchscript, onnx, openvino, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs')
583
  opt = parser.parse_args()
584
  print_args(vars(opt))
models/common.py CHANGED
@@ -17,15 +17,15 @@ import pandas as pd
17
  import requests
18
  import torch
19
  import torch.nn as nn
20
- import yaml
21
  from PIL import Image
22
  from torch.cuda import amp
23
 
24
  from utils.dataloaders import exif_transpose, letterbox
25
- from utils.general import (LOGGER, check_requirements, check_suffix, check_version, colorstr, increment_path,
26
- make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh)
 
27
  from utils.plots import Annotator, colors, save_one_box
28
- from utils.torch_utils import copy_attr, smart_inference_mode, time_sync
29
 
30
 
31
  def autopad(k, p=None): # kernel, padding
@@ -322,13 +322,10 @@ class DetectMultiBackend(nn.Module):
322
 
323
  super().__init__()
324
  w = str(weights[0] if isinstance(weights, list) else weights)
325
- pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self.model_type(w) # get backend
326
  w = attempt_download(w) # download if not local
327
- fp16 &= (pt or jit or onnx or engine) and device.type != 'cpu' # FP16
328
- stride, names = 32, [f'class{i}' for i in range(1000)] # assign defaults
329
- if data: # assign class names (optional)
330
- with open(data, errors='ignore') as f:
331
- names = yaml.safe_load(f)['names']
332
 
333
  if pt: # PyTorch
334
  model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
@@ -341,8 +338,10 @@ class DetectMultiBackend(nn.Module):
341
  extra_files = {'config.txt': ''} # model metadata
342
  model = torch.jit.load(w, _extra_files=extra_files)
343
  model.half() if fp16 else model.float()
344
- if extra_files['config.txt']:
345
- d = json.loads(extra_files['config.txt']) # extra_files dict
 
 
346
  stride, names = int(d['stride']), d['names']
347
  elif dnn: # ONNX OpenCV DNN
348
  LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
@@ -350,11 +349,12 @@ class DetectMultiBackend(nn.Module):
350
  net = cv2.dnn.readNetFromONNX(w)
351
  elif onnx: # ONNX Runtime
352
  LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
353
- cuda = torch.cuda.is_available()
354
  check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
355
  import onnxruntime
356
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
357
  session = onnxruntime.InferenceSession(w, providers=providers)
 
358
  meta = session.get_modelmeta().custom_metadata_map # metadata
359
  if 'stride' in meta:
360
  stride, names = int(meta['stride']), eval(meta['names'])
@@ -373,13 +373,13 @@ class DetectMultiBackend(nn.Module):
373
  batch_size = batch_dim.get_length()
374
  executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
375
  output_layer = next(iter(executable_network.outputs))
376
- meta = Path(w).with_suffix('.yaml')
377
- if meta.exists():
378
- stride, names = self._load_metadata(meta) # load metadata
379
  elif engine: # TensorRT
380
  LOGGER.info(f'Loading {w} for TensorRT inference...')
381
  import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
382
  check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
 
 
383
  Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
384
  logger = trt.Logger(trt.Logger.INFO)
385
  with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
@@ -398,8 +398,8 @@ class DetectMultiBackend(nn.Module):
398
  if dtype == np.float16:
399
  fp16 = True
400
  shape = tuple(context.get_binding_shape(index))
401
- data = torch.from_numpy(np.empty(shape, dtype=np.dtype(dtype))).to(device)
402
- bindings[name] = Binding(name, dtype, shape, data, int(data.data_ptr()))
403
  binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
404
  batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
405
  elif coreml: # CoreML
@@ -445,28 +445,35 @@ class DetectMultiBackend(nn.Module):
445
  input_details = interpreter.get_input_details() # inputs
446
  output_details = interpreter.get_output_details() # outputs
447
  elif tfjs:
448
- raise Exception('ERROR: YOLOv5 TF.js inference is not supported')
449
  else:
450
- raise Exception(f'ERROR: {w} is not a supported format')
 
 
 
 
 
 
 
451
  self.__dict__.update(locals()) # assign all variables to self
452
 
453
- def forward(self, im, augment=False, visualize=False, val=False):
454
  # YOLOv5 MultiBackend inference
455
  b, ch, h, w = im.shape # batch, channel, height, width
456
  if self.fp16 and im.dtype != torch.float16:
457
  im = im.half() # to FP16
458
 
459
  if self.pt: # PyTorch
460
- y = self.model(im, augment=augment, visualize=visualize)[0]
461
  elif self.jit: # TorchScript
462
- y = self.model(im)[0]
463
  elif self.dnn: # ONNX OpenCV DNN
464
  im = im.cpu().numpy() # torch to numpy
465
  self.net.setInput(im)
466
  y = self.net.forward()
467
  elif self.onnx: # ONNX Runtime
468
  im = im.cpu().numpy() # torch to numpy
469
- y = self.session.run([self.session.get_outputs()[0].name], {self.session.get_inputs()[0].name: im})[0]
470
  elif self.xml: # OpenVINO
471
  im = im.cpu().numpy() # FP32
472
  y = self.executable_network([im])[self.output_layer]
@@ -513,20 +520,24 @@ class DetectMultiBackend(nn.Module):
513
  y = (y.astype(np.float32) - zero_point) * scale # re-scale
514
  y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
515
 
516
- if isinstance(y, np.ndarray):
517
- y = torch.tensor(y, device=self.device)
518
- return (y, []) if val else y
 
 
 
 
519
 
520
  def warmup(self, imgsz=(1, 3, 640, 640)):
521
  # Warmup model by running inference once
522
  warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb
523
  if any(warmup_types) and self.device.type != 'cpu':
524
- im = torch.zeros(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
525
  for _ in range(2 if self.jit else 1): #
526
  self.forward(im) # warmup
527
 
528
  @staticmethod
529
- def model_type(p='path/to/model.pt'):
530
  # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
531
  from export import export_formats
532
  suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes
@@ -538,11 +549,12 @@ class DetectMultiBackend(nn.Module):
538
  return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs
539
 
540
  @staticmethod
541
- def _load_metadata(f='path/to/meta.yaml'):
542
  # Load metadata from meta.yaml if it exists
543
- with open(f, errors='ignore') as f:
544
- d = yaml.safe_load(f)
545
- return d['stride'], d['names'] # assign stride, names
 
546
 
547
 
548
  class AutoShape(nn.Module):
@@ -579,9 +591,9 @@ class AutoShape(nn.Module):
579
  return self
580
 
581
  @smart_inference_mode()
582
- def forward(self, imgs, size=640, augment=False, profile=False):
583
- # Inference from various sources. For height=640, width=1280, RGB images example inputs are:
584
- # file: imgs = 'data/images/zidane.jpg' # str or PosixPath
585
  # URI: = 'https://ultralytics.com/images/zidane.jpg'
586
  # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
587
  # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
@@ -589,65 +601,67 @@ class AutoShape(nn.Module):
589
  # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
590
  # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
591
 
592
- t = [time_sync()]
593
- p = next(self.model.parameters()) if self.pt else torch.zeros(1, device=self.model.device) # for device, type
594
- autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
595
- if isinstance(imgs, torch.Tensor): # torch
596
- with amp.autocast(autocast):
597
- return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
598
-
599
- # Pre-process
600
- n, imgs = (len(imgs), list(imgs)) if isinstance(imgs, (list, tuple)) else (1, [imgs]) # number, list of images
601
- shape0, shape1, files = [], [], [] # image and inference shapes, filenames
602
- for i, im in enumerate(imgs):
603
- f = f'image{i}' # filename
604
- if isinstance(im, (str, Path)): # filename or uri
605
- im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
606
- im = np.asarray(exif_transpose(im))
607
- elif isinstance(im, Image.Image): # PIL Image
608
- im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
609
- files.append(Path(f).with_suffix('.jpg').name)
610
- if im.shape[0] < 5: # image in CHW
611
- im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
612
- im = im[..., :3] if im.ndim == 3 else np.tile(im[..., None], 3) # enforce 3ch input
613
- s = im.shape[:2] # HWC
614
- shape0.append(s) # image shape
615
- g = (size / max(s)) # gain
616
- shape1.append([y * g for y in s])
617
- imgs[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
618
- shape1 = [make_divisible(x, self.stride) if self.pt else size for x in np.array(shape1).max(0)] # inf shape
619
- x = [letterbox(im, shape1, auto=False)[0] for im in imgs] # pad
620
- x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
621
- x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
622
- t.append(time_sync())
 
 
623
 
624
  with amp.autocast(autocast):
625
  # Inference
626
- y = self.model(x, augment, profile) # forward
627
- t.append(time_sync())
628
 
629
  # Post-process
630
- y = non_max_suppression(y if self.dmb else y[0],
631
- self.conf,
632
- self.iou,
633
- self.classes,
634
- self.agnostic,
635
- self.multi_label,
636
- max_det=self.max_det) # NMS
637
- for i in range(n):
638
- scale_coords(shape1, y[i][:, :4], shape0[i])
 
639
 
640
- t.append(time_sync())
641
- return Detections(imgs, y, files, t, self.names, x.shape)
642
 
643
 
644
  class Detections:
645
  # YOLOv5 detections class for inference results
646
- def __init__(self, imgs, pred, files, times=(0, 0, 0, 0), names=None, shape=None):
647
  super().__init__()
648
  d = pred[0].device # device
649
- gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in imgs] # normalizations
650
- self.imgs = imgs # list of images as numpy arrays
651
  self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
652
  self.names = names # class names
653
  self.files = files # image filenames
@@ -657,12 +671,12 @@ class Detections:
657
  self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
658
  self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
659
  self.n = len(self.pred) # number of images (batch size)
660
- self.t = tuple((times[i + 1] - times[i]) * 1000 / self.n for i in range(3)) # timestamps (ms)
661
  self.s = shape # inference BCHW shape
662
 
663
  def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
664
  crops = []
665
- for i, (im, pred) in enumerate(zip(self.imgs, self.pred)):
666
  s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
667
  if pred.shape[0]:
668
  for c in pred[:, -1].unique():
@@ -697,7 +711,7 @@ class Detections:
697
  if i == self.n - 1:
698
  LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
699
  if render:
700
- self.imgs[i] = np.asarray(im)
701
  if crop:
702
  if save:
703
  LOGGER.info(f'Saved results to {save_dir}\n')
@@ -720,7 +734,7 @@ class Detections:
720
 
721
  def render(self, labels=True):
722
  self.display(render=True, labels=labels) # render results
723
- return self.imgs
724
 
725
  def pandas(self):
726
  # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
@@ -735,9 +749,9 @@ class Detections:
735
  def tolist(self):
736
  # return a list of Detections objects, i.e. 'for result in results.tolist():'
737
  r = range(self.n) # iterable
738
- x = [Detections([self.imgs[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
739
  # for d in x:
740
- # for k in ['imgs', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
741
  # setattr(d, k, getattr(d, k)[0]) # pop out of list
742
  return x
743
 
@@ -753,10 +767,13 @@ class Classify(nn.Module):
753
  # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
754
  def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
755
  super().__init__()
756
- self.aap = nn.AdaptiveAvgPool2d(1) # to x(b,c1,1,1)
757
- self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g) # to x(b,c2,1,1)
758
- self.flat = nn.Flatten()
 
 
759
 
760
  def forward(self, x):
761
- z = torch.cat([self.aap(y) for y in (x if isinstance(x, list) else [x])], 1) # cat if list
762
- return self.flat(self.conv(z)) # flatten to x(b,c2)
 
 
17
  import requests
18
  import torch
19
  import torch.nn as nn
 
20
  from PIL import Image
21
  from torch.cuda import amp
22
 
23
  from utils.dataloaders import exif_transpose, letterbox
24
+ from utils.general import (LOGGER, ROOT, Profile, check_requirements, check_suffix, check_version, colorstr,
25
+ increment_path, make_divisible, non_max_suppression, scale_coords, xywh2xyxy, xyxy2xywh,
26
+ yaml_load)
27
  from utils.plots import Annotator, colors, save_one_box
28
+ from utils.torch_utils import copy_attr, smart_inference_mode
29
 
30
 
31
  def autopad(k, p=None): # kernel, padding
 
322
 
323
  super().__init__()
324
  w = str(weights[0] if isinstance(weights, list) else weights)
325
+ pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs = self._model_type(w) # get backend
326
  w = attempt_download(w) # download if not local
327
+ fp16 &= pt or jit or onnx or engine # FP16
328
+ stride = 32 # default stride
 
 
 
329
 
330
  if pt: # PyTorch
331
  model = attempt_load(weights if isinstance(weights, list) else w, device=device, inplace=True, fuse=fuse)
 
338
  extra_files = {'config.txt': ''} # model metadata
339
  model = torch.jit.load(w, _extra_files=extra_files)
340
  model.half() if fp16 else model.float()
341
+ if extra_files['config.txt']: # load metadata dict
342
+ d = json.loads(extra_files['config.txt'],
343
+ object_hook=lambda d: {int(k) if k.isdigit() else k: v
344
+ for k, v in d.items()})
345
  stride, names = int(d['stride']), d['names']
346
  elif dnn: # ONNX OpenCV DNN
347
  LOGGER.info(f'Loading {w} for ONNX OpenCV DNN inference...')
 
349
  net = cv2.dnn.readNetFromONNX(w)
350
  elif onnx: # ONNX Runtime
351
  LOGGER.info(f'Loading {w} for ONNX Runtime inference...')
352
+ cuda = torch.cuda.is_available() and device.type != 'cpu'
353
  check_requirements(('onnx', 'onnxruntime-gpu' if cuda else 'onnxruntime'))
354
  import onnxruntime
355
  providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if cuda else ['CPUExecutionProvider']
356
  session = onnxruntime.InferenceSession(w, providers=providers)
357
+ output_names = [x.name for x in session.get_outputs()]
358
  meta = session.get_modelmeta().custom_metadata_map # metadata
359
  if 'stride' in meta:
360
  stride, names = int(meta['stride']), eval(meta['names'])
 
373
  batch_size = batch_dim.get_length()
374
  executable_network = ie.compile_model(network, device_name="CPU") # device_name="MYRIAD" for Intel NCS2
375
  output_layer = next(iter(executable_network.outputs))
376
+ stride, names = self._load_metadata(Path(w).with_suffix('.yaml')) # load metadata
 
 
377
  elif engine: # TensorRT
378
  LOGGER.info(f'Loading {w} for TensorRT inference...')
379
  import tensorrt as trt # https://developer.nvidia.com/nvidia-tensorrt-download
380
  check_version(trt.__version__, '7.0.0', hard=True) # require tensorrt>=7.0.0
381
+ if device.type == 'cpu':
382
+ device = torch.device('cuda:0')
383
  Binding = namedtuple('Binding', ('name', 'dtype', 'shape', 'data', 'ptr'))
384
  logger = trt.Logger(trt.Logger.INFO)
385
  with open(w, 'rb') as f, trt.Runtime(logger) as runtime:
 
398
  if dtype == np.float16:
399
  fp16 = True
400
  shape = tuple(context.get_binding_shape(index))
401
+ im = torch.from_numpy(np.empty(shape, dtype=dtype)).to(device)
402
+ bindings[name] = Binding(name, dtype, shape, im, int(im.data_ptr()))
403
  binding_addrs = OrderedDict((n, d.ptr) for n, d in bindings.items())
404
  batch_size = bindings['images'].shape[0] # if dynamic, this is instead max batch size
405
  elif coreml: # CoreML
 
445
  input_details = interpreter.get_input_details() # inputs
446
  output_details = interpreter.get_output_details() # outputs
447
  elif tfjs:
448
+ raise NotImplementedError('ERROR: YOLOv5 TF.js inference is not supported')
449
  else:
450
+ raise NotImplementedError(f'ERROR: {w} is not a supported format')
451
+
452
+ # class names
453
+ if 'names' not in locals():
454
+ names = yaml_load(data)['names'] if data else {i: f'class{i}' for i in range(999)}
455
+ if names[0] == 'n01440764' and len(names) == 1000: # ImageNet
456
+ names = yaml_load(ROOT / 'data/ImageNet.yaml')['names'] # human-readable names
457
+
458
  self.__dict__.update(locals()) # assign all variables to self
459
 
460
+ def forward(self, im, augment=False, visualize=False):
461
  # YOLOv5 MultiBackend inference
462
  b, ch, h, w = im.shape # batch, channel, height, width
463
  if self.fp16 and im.dtype != torch.float16:
464
  im = im.half() # to FP16
465
 
466
  if self.pt: # PyTorch
467
+ y = self.model(im, augment=augment, visualize=visualize) if augment or visualize else self.model(im)
468
  elif self.jit: # TorchScript
469
+ y = self.model(im)
470
  elif self.dnn: # ONNX OpenCV DNN
471
  im = im.cpu().numpy() # torch to numpy
472
  self.net.setInput(im)
473
  y = self.net.forward()
474
  elif self.onnx: # ONNX Runtime
475
  im = im.cpu().numpy() # torch to numpy
476
+ y = self.session.run(self.output_names, {self.session.get_inputs()[0].name: im})
477
  elif self.xml: # OpenVINO
478
  im = im.cpu().numpy() # FP32
479
  y = self.executable_network([im])[self.output_layer]
 
520
  y = (y.astype(np.float32) - zero_point) * scale # re-scale
521
  y[..., :4] *= [w, h, w, h] # xywh normalized to pixels
522
 
523
+ if isinstance(y, (list, tuple)):
524
+ return self.from_numpy(y[0]) if len(y) == 1 else [self.from_numpy(x) for x in y]
525
+ else:
526
+ return self.from_numpy(y)
527
+
528
+ def from_numpy(self, x):
529
+ return torch.from_numpy(x).to(self.device) if isinstance(x, np.ndarray) else x
530
 
531
  def warmup(self, imgsz=(1, 3, 640, 640)):
532
  # Warmup model by running inference once
533
  warmup_types = self.pt, self.jit, self.onnx, self.engine, self.saved_model, self.pb
534
  if any(warmup_types) and self.device.type != 'cpu':
535
+ im = torch.empty(*imgsz, dtype=torch.half if self.fp16 else torch.float, device=self.device) # input
536
  for _ in range(2 if self.jit else 1): #
537
  self.forward(im) # warmup
538
 
539
  @staticmethod
540
+ def _model_type(p='path/to/model.pt'):
541
  # Return model type from model path, i.e. path='path/to/model.onnx' -> type=onnx
542
  from export import export_formats
543
  suffixes = list(export_formats().Suffix) + ['.xml'] # export suffixes
 
549
  return pt, jit, onnx, xml, engine, coreml, saved_model, pb, tflite, edgetpu, tfjs
550
 
551
  @staticmethod
552
+ def _load_metadata(f=Path('path/to/meta.yaml')):
553
  # Load metadata from meta.yaml if it exists
554
+ if f.exists():
555
+ d = yaml_load(f)
556
+ return d['stride'], d['names'] # assign stride, names
557
+ return None, None
558
 
559
 
560
  class AutoShape(nn.Module):
 
591
  return self
592
 
593
  @smart_inference_mode()
594
+ def forward(self, ims, size=640, augment=False, profile=False):
595
+ # Inference from various sources. For size(height=640, width=1280), RGB images example inputs are:
596
+ # file: ims = 'data/images/zidane.jpg' # str or PosixPath
597
  # URI: = 'https://ultralytics.com/images/zidane.jpg'
598
  # OpenCV: = cv2.imread('image.jpg')[:,:,::-1] # HWC BGR to RGB x(640,1280,3)
599
  # PIL: = Image.open('image.jpg') or ImageGrab.grab() # HWC x(640,1280,3)
 
601
  # torch: = torch.zeros(16,3,320,640) # BCHW (scaled to size=640, 0-1 values)
602
  # multiple: = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...] # list of images
603
 
604
+ dt = (Profile(), Profile(), Profile())
605
+ with dt[0]:
606
+ if isinstance(size, int): # expand
607
+ size = (size, size)
608
+ p = next(self.model.parameters()) if self.pt else torch.empty(1, device=self.model.device) # param
609
+ autocast = self.amp and (p.device.type != 'cpu') # Automatic Mixed Precision (AMP) inference
610
+ if isinstance(ims, torch.Tensor): # torch
611
+ with amp.autocast(autocast):
612
+ return self.model(ims.to(p.device).type_as(p), augment, profile) # inference
613
+
614
+ # Pre-process
615
+ n, ims = (len(ims), list(ims)) if isinstance(ims, (list, tuple)) else (1, [ims]) # number, list of images
616
+ shape0, shape1, files = [], [], [] # image and inference shapes, filenames
617
+ for i, im in enumerate(ims):
618
+ f = f'image{i}' # filename
619
+ if isinstance(im, (str, Path)): # filename or uri
620
+ im, f = Image.open(requests.get(im, stream=True).raw if str(im).startswith('http') else im), im
621
+ im = np.asarray(exif_transpose(im))
622
+ elif isinstance(im, Image.Image): # PIL Image
623
+ im, f = np.asarray(exif_transpose(im)), getattr(im, 'filename', f) or f
624
+ files.append(Path(f).with_suffix('.jpg').name)
625
+ if im.shape[0] < 5: # image in CHW
626
+ im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1)
627
+ im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input
628
+ s = im.shape[:2] # HWC
629
+ shape0.append(s) # image shape
630
+ g = max(size) / max(s) # gain
631
+ shape1.append([y * g for y in s])
632
+ ims[i] = im if im.data.contiguous else np.ascontiguousarray(im) # update
633
+ shape1 = [make_divisible(x, self.stride) for x in np.array(shape1).max(0)] if self.pt else size # inf shape
634
+ x = [letterbox(im, shape1, auto=False)[0] for im in ims] # pad
635
+ x = np.ascontiguousarray(np.array(x).transpose((0, 3, 1, 2))) # stack and BHWC to BCHW
636
+ x = torch.from_numpy(x).to(p.device).type_as(p) / 255 # uint8 to fp16/32
637
 
638
  with amp.autocast(autocast):
639
  # Inference
640
+ with dt[1]:
641
+ y = self.model(x, augment, profile) # forward
642
 
643
  # Post-process
644
+ with dt[2]:
645
+ y = non_max_suppression(y if self.dmb else y[0],
646
+ self.conf,
647
+ self.iou,
648
+ self.classes,
649
+ self.agnostic,
650
+ self.multi_label,
651
+ max_det=self.max_det) # NMS
652
+ for i in range(n):
653
+ scale_coords(shape1, y[i][:, :4], shape0[i])
654
 
655
+ return Detections(ims, y, files, dt, self.names, x.shape)
 
656
 
657
 
658
  class Detections:
659
  # YOLOv5 detections class for inference results
660
+ def __init__(self, ims, pred, files, times=(0, 0, 0), names=None, shape=None):
661
  super().__init__()
662
  d = pred[0].device # device
663
+ gn = [torch.tensor([*(im.shape[i] for i in [1, 0, 1, 0]), 1, 1], device=d) for im in ims] # normalizations
664
+ self.ims = ims # list of images as numpy arrays
665
  self.pred = pred # list of tensors pred[0] = (xyxy, conf, cls)
666
  self.names = names # class names
667
  self.files = files # image filenames
 
671
  self.xyxyn = [x / g for x, g in zip(self.xyxy, gn)] # xyxy normalized
672
  self.xywhn = [x / g for x, g in zip(self.xywh, gn)] # xywh normalized
673
  self.n = len(self.pred) # number of images (batch size)
674
+ self.t = tuple(x.t / self.n * 1E3 for x in times) # timestamps (ms)
675
  self.s = shape # inference BCHW shape
676
 
677
  def display(self, pprint=False, show=False, save=False, crop=False, render=False, labels=True, save_dir=Path('')):
678
  crops = []
679
+ for i, (im, pred) in enumerate(zip(self.ims, self.pred)):
680
  s = f'image {i + 1}/{len(self.pred)}: {im.shape[0]}x{im.shape[1]} ' # string
681
  if pred.shape[0]:
682
  for c in pred[:, -1].unique():
 
711
  if i == self.n - 1:
712
  LOGGER.info(f"Saved {self.n} image{'s' * (self.n > 1)} to {colorstr('bold', save_dir)}")
713
  if render:
714
+ self.ims[i] = np.asarray(im)
715
  if crop:
716
  if save:
717
  LOGGER.info(f'Saved results to {save_dir}\n')
 
734
 
735
  def render(self, labels=True):
736
  self.display(render=True, labels=labels) # render results
737
+ return self.ims
738
 
739
  def pandas(self):
740
  # return detections as pandas DataFrames, i.e. print(results.pandas().xyxy[0])
 
749
  def tolist(self):
750
  # return a list of Detections objects, i.e. 'for result in results.tolist():'
751
  r = range(self.n) # iterable
752
+ x = [Detections([self.ims[i]], [self.pred[i]], [self.files[i]], self.times, self.names, self.s) for i in r]
753
  # for d in x:
754
+ # for k in ['ims', 'pred', 'xyxy', 'xyxyn', 'xywh', 'xywhn']:
755
  # setattr(d, k, getattr(d, k)[0]) # pop out of list
756
  return x
757
 
 
767
  # Classification head, i.e. x(b,c1,20,20) to x(b,c2)
768
  def __init__(self, c1, c2, k=1, s=1, p=None, g=1): # ch_in, ch_out, kernel, stride, padding, groups
769
  super().__init__()
770
+ c_ = 1280 # efficientnet_b0 size
771
+ self.conv = Conv(c1, c_, k, s, autopad(k, p), g)
772
+ self.pool = nn.AdaptiveAvgPool2d(1) # to x(b,c_,1,1)
773
+ self.drop = nn.Dropout(p=0.0, inplace=True)
774
+ self.linear = nn.Linear(c_, c2) # to x(b,c2)
775
 
776
  def forward(self, x):
777
+ if isinstance(x, list):
778
+ x = torch.cat(x, 1)
779
+ return self.linear(self.drop(self.pool(self.conv(x)).flatten(1)))
models/experimental.py CHANGED
@@ -8,7 +8,6 @@ import numpy as np
8
  import torch
9
  import torch.nn as nn
10
 
11
- from models.common import Conv
12
  from utils.downloads import attempt_download
13
 
14
 
@@ -79,9 +78,16 @@ def attempt_load(weights, device=None, inplace=True, fuse=True):
79
  for w in weights if isinstance(weights, list) else [weights]:
80
  ckpt = torch.load(attempt_download(w), map_location='cpu') # load
81
  ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
82
- model.append(ckpt.fuse().eval() if fuse else ckpt.eval()) # fused or un-fused model in eval mode
83
 
84
- # Compatibility updates
 
 
 
 
 
 
 
 
85
  for m in model.modules():
86
  t = type(m)
87
  if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
@@ -92,11 +98,14 @@ def attempt_load(weights, device=None, inplace=True, fuse=True):
92
  elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
93
  m.recompute_scale_factor = None # torch 1.11.0 compatibility
94
 
 
95
  if len(model) == 1:
96
- return model[-1] # return model
 
 
97
  print(f'Ensemble created with {weights}\n')
98
  for k in 'names', 'nc', 'yaml':
99
  setattr(model, k, getattr(model[0], k))
100
  model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
101
  assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
102
- return model # return ensemble
 
8
  import torch
9
  import torch.nn as nn
10
 
 
11
  from utils.downloads import attempt_download
12
 
13
 
 
78
  for w in weights if isinstance(weights, list) else [weights]:
79
  ckpt = torch.load(attempt_download(w), map_location='cpu') # load
80
  ckpt = (ckpt.get('ema') or ckpt['model']).to(device).float() # FP32 model
 
81
 
82
+ # Model compatibility updates
83
+ if not hasattr(ckpt, 'stride'):
84
+ ckpt.stride = torch.tensor([32.])
85
+ if hasattr(ckpt, 'names') and isinstance(ckpt.names, (list, tuple)):
86
+ ckpt.names = dict(enumerate(ckpt.names)) # convert to dict
87
+
88
+ model.append(ckpt.fuse().eval() if fuse and hasattr(ckpt, 'fuse') else ckpt.eval()) # model in eval mode
89
+
90
+ # Module compatibility updates
91
  for m in model.modules():
92
  t = type(m)
93
  if t in (nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU, Detect, Model):
 
98
  elif t is nn.Upsample and not hasattr(m, 'recompute_scale_factor'):
99
  m.recompute_scale_factor = None # torch 1.11.0 compatibility
100
 
101
+ # Return model
102
  if len(model) == 1:
103
+ return model[-1]
104
+
105
+ # Return detection ensemble
106
  print(f'Ensemble created with {weights}\n')
107
  for k in 'names', 'nc', 'yaml':
108
  setattr(model, k, getattr(model[0], k))
109
  model.stride = model[torch.argmax(torch.tensor([m.stride.max() for m in model])).int()].stride # max stride
110
  assert all(model[0].nc == m.nc for m in model), f'Models have different class counts: {[m.nc for m in model]}'
111
+ return model
models/tf.py CHANGED
@@ -7,7 +7,7 @@ Usage:
7
  $ python models/tf.py --weights yolov5s.pt
8
 
9
  Export:
10
- $ python path/to/export.py --weights yolov5s.pt --include saved_model pb tflite tfjs
11
  """
12
 
13
  import argparse
 
7
  $ python models/tf.py --weights yolov5s.pt
8
 
9
  Export:
10
+ $ python export.py --weights yolov5s.pt --include saved_model pb tflite tfjs
11
  """
12
 
13
  import argparse
models/yolo.py CHANGED
@@ -3,7 +3,7 @@
3
  YOLO-specific modules
4
 
5
  Usage:
6
- $ python path/to/models/yolo.py --cfg yolov5s.yaml
7
  """
8
 
9
  import argparse
@@ -37,7 +37,7 @@ except ImportError:
37
 
38
  class Detect(nn.Module):
39
  stride = None # strides computed during build
40
- onnx_dynamic = False # ONNX export parameter
41
  export = False # export mode
42
 
43
  def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
@@ -46,8 +46,8 @@ class Detect(nn.Module):
46
  self.no = nc + 5 # number of outputs per anchor
47
  self.nl = len(anchors) # number of detection layers
48
  self.na = len(anchors[0]) // 2 # number of anchors
49
- self.grid = [torch.zeros(1)] * self.nl # init grid
50
- self.anchor_grid = [torch.zeros(1)] * self.nl # init anchor grid
51
  self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
52
  self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
53
  self.inplace = inplace # use inplace ops (e.g. slice assignment)
@@ -60,7 +60,7 @@ class Detect(nn.Module):
60
  x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
61
 
62
  if not self.training: # inference
63
- if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
64
  self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
65
 
66
  y = x[i].sigmoid()
@@ -81,17 +81,70 @@ class Detect(nn.Module):
81
  t = self.anchors[i].dtype
82
  shape = 1, self.na, ny, nx, 2 # grid shape
83
  y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
84
- if torch_1_10: # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
85
- yv, xv = torch.meshgrid(y, x, indexing='ij')
86
- else:
87
- yv, xv = torch.meshgrid(y, x)
88
  grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
89
  anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
90
  return grid, anchor_grid
91
 
92
 
93
- class Model(nn.Module):
94
- # YOLOv5 model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
96
  super().__init__()
97
  if isinstance(cfg, dict):
@@ -119,7 +172,7 @@ class Model(nn.Module):
119
  if isinstance(m, Detect):
120
  s = 256 # 2x min stride
121
  m.inplace = self.inplace
122
- m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.zeros(1, ch, s, s))]) # forward
123
  check_anchor_order(m) # must be in pixel-space (not grid-space)
124
  m.anchors /= m.stride.view(-1, 1, 1)
125
  self.stride = m.stride
@@ -149,19 +202,6 @@ class Model(nn.Module):
149
  y = self._clip_augmented(y) # clip augmented tails
150
  return torch.cat(y, 1), None # augmented inference, train
151
 
152
- def _forward_once(self, x, profile=False, visualize=False):
153
- y, dt = [], [] # outputs
154
- for m in self.model:
155
- if m.f != -1: # if not from previous layer
156
- x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
157
- if profile:
158
- self._profile_one_layer(m, x, dt)
159
- x = m(x) # run
160
- y.append(x if m.i in self.save else None) # save output
161
- if visualize:
162
- feature_visualization(x, m.type, m.i, save_dir=visualize)
163
- return x
164
-
165
  def _descale_pred(self, p, flips, scale, img_size):
166
  # de-scale predictions following augmented inference (inverse operation)
167
  if self.inplace:
@@ -190,19 +230,6 @@ class Model(nn.Module):
190
  y[-1] = y[-1][:, i:] # small
191
  return y
192
 
193
- def _profile_one_layer(self, m, x, dt):
194
- c = isinstance(m, Detect) # is final layer, copy input as inplace fix
195
- o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
196
- t = time_sync()
197
- for _ in range(10):
198
- m(x.copy() if c else x)
199
- dt.append((time_sync() - t) * 100)
200
- if m == self.model[0]:
201
- LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
202
- LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
203
- if c:
204
- LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
205
-
206
  def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
207
  # https://arxiv.org/abs/1708.02002 section 3.3
208
  # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
@@ -213,41 +240,34 @@ class Model(nn.Module):
213
  b[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # cls
214
  mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
215
 
216
- def _print_biases(self):
217
- m = self.model[-1] # Detect() module
218
- for mi in m.m: # from
219
- b = mi.bias.detach().view(m.na, -1).T # conv.bias(255) to (3,85)
220
- LOGGER.info(
221
- ('%6g Conv2d.bias:' + '%10.3g' * 6) % (mi.weight.shape[1], *b[:5].mean(1).tolist(), b[5:].mean()))
222
 
223
- # def _print_weights(self):
224
- # for m in self.model.modules():
225
- # if type(m) is Bottleneck:
226
- # LOGGER.info('%10.3g' % (m.w.detach().sigmoid() * 2)) # shortcut weights
227
 
228
- def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
229
- LOGGER.info('Fusing layers... ')
230
- for m in self.model.modules():
231
- if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
232
- m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
233
- delattr(m, 'bn') # remove batchnorm
234
- m.forward = m.forward_fuse # update forward
235
- self.info()
236
- return self
237
 
238
- def info(self, verbose=False, img_size=640): # print model information
239
- model_info(self, verbose, img_size)
240
-
241
- def _apply(self, fn):
242
- # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
243
- self = super()._apply(fn)
244
- m = self.model[-1] # Detect()
245
- if isinstance(m, Detect):
246
- m.stride = fn(m.stride)
247
- m.grid = list(map(fn, m.grid))
248
- if isinstance(m.anchor_grid, list):
249
- m.anchor_grid = list(map(fn, m.anchor_grid))
250
- return self
 
 
 
 
 
 
 
 
 
 
 
251
 
252
 
253
  def parse_model(d, ch): # model_dict, input_channels(3)
@@ -321,7 +341,7 @@ if __name__ == '__main__':
321
 
322
  # Options
323
  if opt.line_profile: # profile layer by layer
324
- _ = model(im, profile=True)
325
 
326
  elif opt.profile: # profile forward-backward
327
  results = profile(input=im, ops=[model], n=3)
 
3
  YOLO-specific modules
4
 
5
  Usage:
6
+ $ python models/yolo.py --cfg yolov5s.yaml
7
  """
8
 
9
  import argparse
 
37
 
38
  class Detect(nn.Module):
39
  stride = None # strides computed during build
40
+ dynamic = False # force grid reconstruction
41
  export = False # export mode
42
 
43
  def __init__(self, nc=80, anchors=(), ch=(), inplace=True): # detection layer
 
46
  self.no = nc + 5 # number of outputs per anchor
47
  self.nl = len(anchors) # number of detection layers
48
  self.na = len(anchors[0]) // 2 # number of anchors
49
+ self.grid = [torch.empty(1)] * self.nl # init grid
50
+ self.anchor_grid = [torch.empty(1)] * self.nl # init anchor grid
51
  self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) # shape(nl,na,2)
52
  self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) # output conv
53
  self.inplace = inplace # use inplace ops (e.g. slice assignment)
 
60
  x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
61
 
62
  if not self.training: # inference
63
+ if self.dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
64
  self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
65
 
66
  y = x[i].sigmoid()
 
81
  t = self.anchors[i].dtype
82
  shape = 1, self.na, ny, nx, 2 # grid shape
83
  y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
84
+ yv, xv = torch.meshgrid(y, x, indexing='ij') if torch_1_10 else torch.meshgrid(y, x) # torch>=0.7 compatibility
 
 
 
85
  grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 # add grid offset, i.e. y = 2.0 * x - 0.5
86
  anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
87
  return grid, anchor_grid
88
 
89
 
90
+ class BaseModel(nn.Module):
91
+ # YOLOv5 base model
92
+ def forward(self, x, profile=False, visualize=False):
93
+ return self._forward_once(x, profile, visualize) # single-scale inference, train
94
+
95
+ def _forward_once(self, x, profile=False, visualize=False):
96
+ y, dt = [], [] # outputs
97
+ for m in self.model:
98
+ if m.f != -1: # if not from previous layer
99
+ x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
100
+ if profile:
101
+ self._profile_one_layer(m, x, dt)
102
+ x = m(x) # run
103
+ y.append(x if m.i in self.save else None) # save output
104
+ if visualize:
105
+ feature_visualization(x, m.type, m.i, save_dir=visualize)
106
+ return x
107
+
108
+ def _profile_one_layer(self, m, x, dt):
109
+ c = m == self.model[-1] # is final layer, copy input as inplace fix
110
+ o = thop.profile(m, inputs=(x.copy() if c else x,), verbose=False)[0] / 1E9 * 2 if thop else 0 # FLOPs
111
+ t = time_sync()
112
+ for _ in range(10):
113
+ m(x.copy() if c else x)
114
+ dt.append((time_sync() - t) * 100)
115
+ if m == self.model[0]:
116
+ LOGGER.info(f"{'time (ms)':>10s} {'GFLOPs':>10s} {'params':>10s} module")
117
+ LOGGER.info(f'{dt[-1]:10.2f} {o:10.2f} {m.np:10.0f} {m.type}')
118
+ if c:
119
+ LOGGER.info(f"{sum(dt):10.2f} {'-':>10s} {'-':>10s} Total")
120
+
121
+ def fuse(self): # fuse model Conv2d() + BatchNorm2d() layers
122
+ LOGGER.info('Fusing layers... ')
123
+ for m in self.model.modules():
124
+ if isinstance(m, (Conv, DWConv)) and hasattr(m, 'bn'):
125
+ m.conv = fuse_conv_and_bn(m.conv, m.bn) # update conv
126
+ delattr(m, 'bn') # remove batchnorm
127
+ m.forward = m.forward_fuse # update forward
128
+ self.info()
129
+ return self
130
+
131
+ def info(self, verbose=False, img_size=640): # print model information
132
+ model_info(self, verbose, img_size)
133
+
134
+ def _apply(self, fn):
135
+ # Apply to(), cpu(), cuda(), half() to model tensors that are not parameters or registered buffers
136
+ self = super()._apply(fn)
137
+ m = self.model[-1] # Detect()
138
+ if isinstance(m, Detect):
139
+ m.stride = fn(m.stride)
140
+ m.grid = list(map(fn, m.grid))
141
+ if isinstance(m.anchor_grid, list):
142
+ m.anchor_grid = list(map(fn, m.anchor_grid))
143
+ return self
144
+
145
+
146
+ class DetectionModel(BaseModel):
147
+ # YOLOv5 detection model
148
  def __init__(self, cfg='yolov5s.yaml', ch=3, nc=None, anchors=None): # model, input channels, number of classes
149
  super().__init__()
150
  if isinstance(cfg, dict):
 
172
  if isinstance(m, Detect):
173
  s = 256 # 2x min stride
174
  m.inplace = self.inplace
175
+ m.stride = torch.tensor([s / x.shape[-2] for x in self.forward(torch.empty(1, ch, s, s))]) # forward
176
  check_anchor_order(m) # must be in pixel-space (not grid-space)
177
  m.anchors /= m.stride.view(-1, 1, 1)
178
  self.stride = m.stride
 
202
  y = self._clip_augmented(y) # clip augmented tails
203
  return torch.cat(y, 1), None # augmented inference, train
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  def _descale_pred(self, p, flips, scale, img_size):
206
  # de-scale predictions following augmented inference (inverse operation)
207
  if self.inplace:
 
230
  y[-1] = y[-1][:, i:] # small
231
  return y
232
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  def _initialize_biases(self, cf=None): # initialize biases into Detect(), cf is class frequency
234
  # https://arxiv.org/abs/1708.02002 section 3.3
235
  # cf = torch.bincount(torch.tensor(np.concatenate(dataset.labels, 0)[:, 0]).long(), minlength=nc) + 1.
 
240
  b[:, 5:] += math.log(0.6 / (m.nc - 0.999999)) if cf is None else torch.log(cf / cf.sum()) # cls
241
  mi.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
242
 
 
 
 
 
 
 
243
 
244
+ Model = DetectionModel # retain YOLOv5 'Model' class for backwards compatibility
 
 
 
245
 
 
 
 
 
 
 
 
 
 
246
 
247
+ class ClassificationModel(BaseModel):
248
+ # YOLOv5 classification model
249
+ def __init__(self, cfg=None, model=None, nc=1000, cutoff=10): # yaml, model, number of classes, cutoff index
250
+ super().__init__()
251
+ self._from_detection_model(model, nc, cutoff) if model is not None else self._from_yaml(cfg)
252
+
253
+ def _from_detection_model(self, model, nc=1000, cutoff=10):
254
+ # Create a YOLOv5 classification model from a YOLOv5 detection model
255
+ if isinstance(model, DetectMultiBackend):
256
+ model = model.model # unwrap DetectMultiBackend
257
+ model.model = model.model[:cutoff] # backbone
258
+ m = model.model[-1] # last layer
259
+ ch = m.conv.in_channels if hasattr(m, 'conv') else m.cv1.conv.in_channels # ch into module
260
+ c = Classify(ch, nc) # Classify()
261
+ c.i, c.f, c.type = m.i, m.f, 'models.common.Classify' # index, from, type
262
+ model.model[-1] = c # replace
263
+ self.model = model.model
264
+ self.stride = model.stride
265
+ self.save = []
266
+ self.nc = nc
267
+
268
+ def _from_yaml(self, cfg):
269
+ # Create a YOLOv5 classification model from a *.yaml file
270
+ self.model = None
271
 
272
 
273
  def parse_model(d, ch): # model_dict, input_channels(3)
 
341
 
342
  # Options
343
  if opt.line_profile: # profile layer by layer
344
+ model(im, profile=True)
345
 
346
  elif opt.profile: # profile forward-backward
347
  results = profile(input=im, ops=[model], n=3)
requirements.txt CHANGED
@@ -23,9 +23,9 @@ pandas>=1.1.4
23
  seaborn>=0.11.0
24
 
25
  # Export --------------------------------------
26
- coremltools>=4.1 # CoreML export
27
  onnx>=1.9.0 # ONNX export
28
- onnx-simplifier>=0.3.6 # ONNX simplifier
29
  onnxruntime
30
  # nvidia-pyindex # TensorRT export
31
  # nvidia-tensorrt # TensorRT export
 
23
  seaborn>=0.11.0
24
 
25
  # Export --------------------------------------
26
+ coremltools>=5.2 # CoreML export
27
  onnx>=1.9.0 # ONNX export
28
+ onnx-simplifier>=0.4.1 # ONNX simplifier
29
  onnxruntime
30
  # nvidia-pyindex # TensorRT export
31
  # nvidia-tensorrt # TensorRT export
utils/__init__.py CHANGED
@@ -3,6 +3,33 @@
3
  utils/initialization
4
  """
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def notebook_init(verbose=True):
8
  # Check system software and hardware
@@ -11,10 +38,12 @@ def notebook_init(verbose=True):
11
  import os
12
  import shutil
13
 
14
- from utils.general import check_requirements, emojis, is_colab
15
  from utils.torch_utils import select_device # imports
16
 
17
  check_requirements(('psutil', 'IPython'))
 
 
18
  import psutil
19
  from IPython import display # to display images and clear console output
20
 
 
3
  utils/initialization
4
  """
5
 
6
+ import contextlib
7
+ import threading
8
+
9
+
10
+ class TryExcept(contextlib.ContextDecorator):
11
+ # YOLOv5 TryExcept class. Usage: @TryExcept() decorator or 'with TryExcept():' context manager
12
+ def __init__(self, msg=''):
13
+ self.msg = msg
14
+
15
+ def __enter__(self):
16
+ pass
17
+
18
+ def __exit__(self, exc_type, value, traceback):
19
+ if value:
20
+ print(f'{self.msg}{value}')
21
+ return True
22
+
23
+
24
+ def threaded(func):
25
+ # Multi-threads a target function and returns thread. Usage: @threaded decorator
26
+ def wrapper(*args, **kwargs):
27
+ thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
28
+ thread.start()
29
+ return thread
30
+
31
+ return wrapper
32
+
33
 
34
  def notebook_init(verbose=True):
35
  # Check system software and hardware
 
38
  import os
39
  import shutil
40
 
41
+ from utils.general import check_font, check_requirements, emojis, is_colab
42
  from utils.torch_utils import select_device # imports
43
 
44
  check_requirements(('psutil', 'IPython'))
45
+ check_font()
46
+
47
  import psutil
48
  from IPython import display # to display images and clear console output
49
 
utils/augmentations.py CHANGED
@@ -8,15 +8,22 @@ import random
8
 
9
  import cv2
10
  import numpy as np
 
 
 
11
 
12
  from utils.general import LOGGER, check_version, colorstr, resample_segments, segment2box
13
  from utils.metrics import bbox_ioa
14
 
 
 
 
15
 
16
  class Albumentations:
17
  # YOLOv5 Albumentations class (optional, only used if package is installed)
18
  def __init__(self):
19
  self.transform = None
 
20
  try:
21
  import albumentations as A
22
  check_version(A.__version__, '1.0.3', hard=True) # version requirement
@@ -31,11 +38,11 @@ class Albumentations:
31
  A.ImageCompression(quality_lower=75, p=0.0)] # transforms
32
  self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
33
 
34
- LOGGER.info(colorstr('albumentations: ') + ', '.join(f'{x}' for x in self.transform.transforms if x.p))
35
  except ImportError: # package not installed, skip
36
  pass
37
  except Exception as e:
38
- LOGGER.info(colorstr('albumentations: ') + f'{e}')
39
 
40
  def __call__(self, im, labels, p=1.0):
41
  if self.transform and random.random() < p:
@@ -44,6 +51,18 @@ class Albumentations:
44
  return im, labels
45
 
46
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
48
  # HSV color-space augmentation
49
  if hgain or sgain or vgain:
@@ -282,3 +301,96 @@ def box_candidates(box1, box2, wh_thr=2, ar_thr=100, area_thr=0.1, eps=1e-16):
282
  w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
283
  ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
284
  return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  import cv2
10
  import numpy as np
11
+ import torch
12
+ import torchvision.transforms as T
13
+ import torchvision.transforms.functional as TF
14
 
15
  from utils.general import LOGGER, check_version, colorstr, resample_segments, segment2box
16
  from utils.metrics import bbox_ioa
17
 
18
+ IMAGENET_MEAN = 0.485, 0.456, 0.406 # RGB mean
19
+ IMAGENET_STD = 0.229, 0.224, 0.225 # RGB standard deviation
20
+
21
 
22
  class Albumentations:
23
  # YOLOv5 Albumentations class (optional, only used if package is installed)
24
  def __init__(self):
25
  self.transform = None
26
+ prefix = colorstr('albumentations: ')
27
  try:
28
  import albumentations as A
29
  check_version(A.__version__, '1.0.3', hard=True) # version requirement
 
38
  A.ImageCompression(quality_lower=75, p=0.0)] # transforms
39
  self.transform = A.Compose(T, bbox_params=A.BboxParams(format='yolo', label_fields=['class_labels']))
40
 
41
+ LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
42
  except ImportError: # package not installed, skip
43
  pass
44
  except Exception as e:
45
+ LOGGER.info(f'{prefix}{e}')
46
 
47
  def __call__(self, im, labels, p=1.0):
48
  if self.transform and random.random() < p:
 
51
  return im, labels
52
 
53
 
54
+ def normalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD, inplace=False):
55
+ # Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = (x - mean) / std
56
+ return TF.normalize(x, mean, std, inplace=inplace)
57
+
58
+
59
+ def denormalize(x, mean=IMAGENET_MEAN, std=IMAGENET_STD):
60
+ # Denormalize RGB images x per ImageNet stats in BCHW format, i.e. = x * std + mean
61
+ for i in range(3):
62
+ x[:, i] = x[:, i] * std[i] + mean[i]
63
+ return x
64
+
65
+
66
  def augment_hsv(im, hgain=0.5, sgain=0.5, vgain=0.5):
67
  # HSV color-space augmentation
68
  if hgain or sgain or vgain:
 
301
  w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
302
  ar = np.maximum(w2 / (h2 + eps), h2 / (w2 + eps)) # aspect ratio
303
  return (w2 > wh_thr) & (h2 > wh_thr) & (w2 * h2 / (w1 * h1 + eps) > area_thr) & (ar < ar_thr) # candidates
304
+
305
+
306
+ def classify_albumentations(augment=True,
307
+ size=224,
308
+ scale=(0.08, 1.0),
309
+ hflip=0.5,
310
+ vflip=0.0,
311
+ jitter=0.4,
312
+ mean=IMAGENET_MEAN,
313
+ std=IMAGENET_STD,
314
+ auto_aug=False):
315
+ # YOLOv5 classification Albumentations (optional, only used if package is installed)
316
+ prefix = colorstr('albumentations: ')
317
+ try:
318
+ import albumentations as A
319
+ from albumentations.pytorch import ToTensorV2
320
+ check_version(A.__version__, '1.0.3', hard=True) # version requirement
321
+ if augment: # Resize and crop
322
+ T = [A.RandomResizedCrop(height=size, width=size, scale=scale)]
323
+ if auto_aug:
324
+ # TODO: implement AugMix, AutoAug & RandAug in albumentation
325
+ LOGGER.info(f'{prefix}auto augmentations are currently not supported')
326
+ else:
327
+ if hflip > 0:
328
+ T += [A.HorizontalFlip(p=hflip)]
329
+ if vflip > 0:
330
+ T += [A.VerticalFlip(p=vflip)]
331
+ if jitter > 0:
332
+ color_jitter = (float(jitter),) * 3 # repeat value for brightness, contrast, satuaration, 0 hue
333
+ T += [A.ColorJitter(*color_jitter, 0)]
334
+ else: # Use fixed crop for eval set (reproducibility)
335
+ T = [A.SmallestMaxSize(max_size=size), A.CenterCrop(height=size, width=size)]
336
+ T += [A.Normalize(mean=mean, std=std), ToTensorV2()] # Normalize and convert to Tensor
337
+ LOGGER.info(prefix + ', '.join(f'{x}'.replace('always_apply=False, ', '') for x in T if x.p))
338
+ return A.Compose(T)
339
+
340
+ except ImportError: # package not installed, skip
341
+ pass
342
+ except Exception as e:
343
+ LOGGER.info(f'{prefix}{e}')
344
+
345
+
346
+ def classify_transforms(size=224):
347
+ # Transforms to apply if albumentations not installed
348
+ assert isinstance(size, int), f'ERROR: classify_transforms size {size} must be integer, not (list, tuple)'
349
+ # T.Compose([T.ToTensor(), T.Resize(size), T.CenterCrop(size), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
350
+ return T.Compose([CenterCrop(size), ToTensor(), T.Normalize(IMAGENET_MEAN, IMAGENET_STD)])
351
+
352
+
353
+ class LetterBox:
354
+ # YOLOv5 LetterBox class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
355
+ def __init__(self, size=(640, 640), auto=False, stride=32):
356
+ super().__init__()
357
+ self.h, self.w = (size, size) if isinstance(size, int) else size
358
+ self.auto = auto # pass max size integer, automatically solve for short side using stride
359
+ self.stride = stride # used with auto
360
+
361
+ def __call__(self, im): # im = np.array HWC
362
+ imh, imw = im.shape[:2]
363
+ r = min(self.h / imh, self.w / imw) # ratio of new/old
364
+ h, w = round(imh * r), round(imw * r) # resized image
365
+ hs, ws = (math.ceil(x / self.stride) * self.stride for x in (h, w)) if self.auto else self.h, self.w
366
+ top, left = round((hs - h) / 2 - 0.1), round((ws - w) / 2 - 0.1)
367
+ im_out = np.full((self.h, self.w, 3), 114, dtype=im.dtype)
368
+ im_out[top:top + h, left:left + w] = cv2.resize(im, (w, h), interpolation=cv2.INTER_LINEAR)
369
+ return im_out
370
+
371
+
372
+ class CenterCrop:
373
+ # YOLOv5 CenterCrop class for image preprocessing, i.e. T.Compose([CenterCrop(size), ToTensor()])
374
+ def __init__(self, size=640):
375
+ super().__init__()
376
+ self.h, self.w = (size, size) if isinstance(size, int) else size
377
+
378
+ def __call__(self, im): # im = np.array HWC
379
+ imh, imw = im.shape[:2]
380
+ m = min(imh, imw) # min dimension
381
+ top, left = (imh - m) // 2, (imw - m) // 2
382
+ return cv2.resize(im[top:top + m, left:left + m], (self.w, self.h), interpolation=cv2.INTER_LINEAR)
383
+
384
+
385
+ class ToTensor:
386
+ # YOLOv5 ToTensor class for image preprocessing, i.e. T.Compose([LetterBox(size), ToTensor()])
387
+ def __init__(self, half=False):
388
+ super().__init__()
389
+ self.half = half
390
+
391
+ def __call__(self, im): # im = np.array HWC in BGR order
392
+ im = np.ascontiguousarray(im.transpose((2, 0, 1))[::-1]) # HWC to CHW -> BGR to RGB -> contiguous
393
+ im = torch.from_numpy(im) # to torch
394
+ im = im.half() if self.half else im.float() # uint8 to fp16/32
395
+ im /= 255.0 # 0-255 to 0.0-1.0
396
+ return im
utils/autoanchor.py CHANGED
@@ -10,6 +10,7 @@ import torch
10
  import yaml
11
  from tqdm import tqdm
12
 
 
13
  from utils.general import LOGGER, colorstr
14
 
15
  PREFIX = colorstr('AutoAnchor: ')
@@ -25,6 +26,7 @@ def check_anchor_order(m):
25
  m.anchors[:] = m.anchors.flip(0)
26
 
27
 
 
28
  def check_anchors(dataset, model, thr=4.0, imgsz=640):
29
  # Check anchor fit to data, recompute if necessary
30
  m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
@@ -49,10 +51,7 @@ def check_anchors(dataset, model, thr=4.0, imgsz=640):
49
  else:
50
  LOGGER.info(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...')
51
  na = m.anchors.numel() // 2 # number of anchors
52
- try:
53
- anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
54
- except Exception as e:
55
- LOGGER.info(f'{PREFIX}ERROR: {e}')
56
  new_bpr = metric(anchors)[0]
57
  if new_bpr > bpr: # replace anchors
58
  anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
@@ -124,7 +123,7 @@ def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen
124
  i = (wh0 < 3.0).any(1).sum()
125
  if i:
126
  LOGGER.info(f'{PREFIX}WARNING: Extremely small objects found: {i} of {len(wh0)} labels are < 3 pixels in size')
127
- wh = wh0[(wh0 >= 2.0).any(1)] # filter > 2 pixels
128
  # wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
129
 
130
  # Kmeans init
@@ -167,4 +166,4 @@ def kmean_anchors(dataset='./data/coco128.yaml', n=9, img_size=640, thr=4.0, gen
167
  if verbose:
168
  print_results(k, verbose)
169
 
170
- return print_results(k)
 
10
  import yaml
11
  from tqdm import tqdm
12
 
13
+ from utils import TryExcept
14
  from utils.general import LOGGER, colorstr
15
 
16
  PREFIX = colorstr('AutoAnchor: ')
 
26
  m.anchors[:] = m.anchors.flip(0)
27
 
28
 
29
+ @TryExcept(f'{PREFIX}ERROR: ')
30
  def check_anchors(dataset, model, thr=4.0, imgsz=640):
31
  # Check anchor fit to data, recompute if necessary
32
  m = model.module.model[-1] if hasattr(model, 'module') else model.model[-1] # Detect()
 
51
  else:
52
  LOGGER.info(f'{s}Anchors are a poor fit to dataset ⚠️, attempting to improve...')
53
  na = m.anchors.numel() // 2 # number of anchors
54
+ anchors = kmean_anchors(dataset, n=na, img_size=imgsz, thr=thr, gen=1000, verbose=False)
 
 
 
55
  new_bpr = metric(anchors)[0]
56
  if new_bpr > bpr: # replace anchors
57
  anchors = torch.tensor(anchors, device=m.anchors.device).type_as(m.anchors)
 
123
  i = (wh0 < 3.0).any(1).sum()
124
  if i:
125
  LOGGER.info(f'{PREFIX}WARNING: Extremely small objects found: {i} of {len(wh0)} labels are < 3 pixels in size')
126
+ wh = wh0[(wh0 >= 2.0).any(1)].astype(np.float32) # filter > 2 pixels
127
  # wh = wh * (npr.rand(wh.shape[0], 1) * 0.9 + 0.1) # multiply by random scale 0-1
128
 
129
  # Kmeans init
 
166
  if verbose:
167
  print_results(k, verbose)
168
 
169
+ return print_results(k).astype(np.float32)
utils/autobatch.py CHANGED
@@ -18,7 +18,7 @@ def check_train_batch_size(model, imgsz=640, amp=True):
18
  return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
19
 
20
 
21
- def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
22
  # Automatically estimate best batch size to use `fraction` of available CUDA memory
23
  # Usage:
24
  # import torch
@@ -47,7 +47,7 @@ def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
47
  # Profile batch sizes
48
  batch_sizes = [1, 2, 4, 8, 16]
49
  try:
50
- img = [torch.zeros(b, 3, imgsz, imgsz) for b in batch_sizes]
51
  results = profile(img, model, n=3, device=device)
52
  except Exception as e:
53
  LOGGER.warning(f'{prefix}{e}')
@@ -60,6 +60,9 @@ def autobatch(model, imgsz=640, fraction=0.9, batch_size=16):
60
  i = results.index(None) # first fail index
61
  if b >= batch_sizes[i]: # y intercept above failure point
62
  b = batch_sizes[max(i - 1, 0)] # select prior safe point
 
 
 
63
 
64
  fraction = np.polyval(p, b) / t # actual fraction predicted
65
  LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
 
18
  return autobatch(deepcopy(model).train(), imgsz) # compute optimal batch size
19
 
20
 
21
+ def autobatch(model, imgsz=640, fraction=0.8, batch_size=16):
22
  # Automatically estimate best batch size to use `fraction` of available CUDA memory
23
  # Usage:
24
  # import torch
 
47
  # Profile batch sizes
48
  batch_sizes = [1, 2, 4, 8, 16]
49
  try:
50
+ img = [torch.empty(b, 3, imgsz, imgsz) for b in batch_sizes]
51
  results = profile(img, model, n=3, device=device)
52
  except Exception as e:
53
  LOGGER.warning(f'{prefix}{e}')
 
60
  i = results.index(None) # first fail index
61
  if b >= batch_sizes[i]: # y intercept above failure point
62
  b = batch_sizes[max(i - 1, 0)] # select prior safe point
63
+ if b < 1 or b > 1024: # b outside of safe range
64
+ b = batch_size
65
+ LOGGER.warning(f'{prefix}WARNING: ⚠️ CUDA anomaly detected, recommend restart environment and retry command.')
66
 
67
  fraction = np.polyval(p, b) / t # actual fraction predicted
68
  LOGGER.info(f'{prefix}Using batch-size {b} for {d} {t * fraction:.2f}G/{t:.2f}G ({fraction * 100:.0f}%) ✅')
utils/benchmarks.py CHANGED
@@ -92,10 +92,14 @@ def run(
92
  LOGGER.info('\n')
93
  parse_opt()
94
  notebook_init() # print system info
95
- c = ['Format', 'Size (MB)', 'mAP@0.5:0.95', 'Inference time (ms)'] if map else ['Format', 'Export', '', '']
96
  py = pd.DataFrame(y, columns=c)
97
  LOGGER.info(f'\nBenchmarks complete ({time.time() - t:.2f}s)')
98
  LOGGER.info(str(py if map else py.iloc[:, :2]))
 
 
 
 
99
  return py
100
 
101
 
@@ -141,7 +145,7 @@ def parse_opt():
141
  parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
142
  parser.add_argument('--test', action='store_true', help='test exports only')
143
  parser.add_argument('--pt-only', action='store_true', help='test PyTorch only')
144
- parser.add_argument('--hard-fail', action='store_true', help='throw error on benchmark failure')
145
  opt = parser.parse_args()
146
  opt.data = check_yaml(opt.data) # check YAML
147
  print_args(vars(opt))
 
92
  LOGGER.info('\n')
93
  parse_opt()
94
  notebook_init() # print system info
95
+ c = ['Format', 'Size (MB)', 'mAP50-95', 'Inference time (ms)'] if map else ['Format', 'Export', '', '']
96
  py = pd.DataFrame(y, columns=c)
97
  LOGGER.info(f'\nBenchmarks complete ({time.time() - t:.2f}s)')
98
  LOGGER.info(str(py if map else py.iloc[:, :2]))
99
+ if hard_fail and isinstance(hard_fail, str):
100
+ metrics = py['mAP50-95'].array # values to compare to floor
101
+ floor = eval(hard_fail) # minimum metric floor to pass, i.e. = 0.29 mAP for YOLOv5n
102
+ assert all(x > floor for x in metrics if pd.notna(x)), f'HARD FAIL: mAP50-95 < floor {floor}'
103
  return py
104
 
105
 
 
145
  parser.add_argument('--half', action='store_true', help='use FP16 half-precision inference')
146
  parser.add_argument('--test', action='store_true', help='test exports only')
147
  parser.add_argument('--pt-only', action='store_true', help='test PyTorch only')
148
+ parser.add_argument('--hard-fail', nargs='?', const=True, default=False, help='Exception on error or < min metric')
149
  opt = parser.parse_args()
150
  opt.data = check_yaml(opt.data) # check YAML
151
  print_args(vars(opt))
utils/callbacks.py CHANGED
@@ -3,6 +3,8 @@
3
  Callback utils
4
  """
5
 
 
 
6
 
7
  class Callbacks:
8
  """"
@@ -55,17 +57,20 @@ class Callbacks:
55
  """
56
  return self._callbacks[hook] if hook else self._callbacks
57
 
58
- def run(self, hook, *args, **kwargs):
59
  """
60
- Loop through the registered actions and fire all callbacks
61
 
62
  Args:
63
  hook: The name of the hook to check, defaults to all
64
  args: Arguments to receive from YOLOv5
 
65
  kwargs: Keyword Arguments to receive from YOLOv5
66
  """
67
 
68
  assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
69
-
70
  for logger in self._callbacks[hook]:
71
- logger['callback'](*args, **kwargs)
 
 
 
 
3
  Callback utils
4
  """
5
 
6
+ import threading
7
+
8
 
9
  class Callbacks:
10
  """"
 
57
  """
58
  return self._callbacks[hook] if hook else self._callbacks
59
 
60
+ def run(self, hook, *args, thread=False, **kwargs):
61
  """
62
+ Loop through the registered actions and fire all callbacks on main thread
63
 
64
  Args:
65
  hook: The name of the hook to check, defaults to all
66
  args: Arguments to receive from YOLOv5
67
+ thread: (boolean) Run callbacks in daemon thread
68
  kwargs: Keyword Arguments to receive from YOLOv5
69
  """
70
 
71
  assert hook in self._callbacks, f"hook '{hook}' not found in callbacks {self._callbacks}"
 
72
  for logger in self._callbacks[hook]:
73
+ if thread:
74
+ threading.Thread(target=logger['callback'], args=args, kwargs=kwargs, daemon=True).start()
75
+ else:
76
+ logger['callback'](*args, **kwargs)
utils/dataloaders.py CHANGED
@@ -22,22 +22,25 @@ from zipfile import ZipFile
22
  import numpy as np
23
  import torch
24
  import torch.nn.functional as F
 
25
  import yaml
26
  from PIL import ExifTags, Image, ImageOps
27
  from torch.utils.data import DataLoader, Dataset, dataloader, distributed
28
  from tqdm import tqdm
29
 
30
- from utils.augmentations import Albumentations, augment_hsv, copy_paste, letterbox, mixup, random_perspective
 
31
  from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
32
  cv2, is_colab, is_kaggle, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
33
  from utils.torch_utils import torch_distributed_zero_first
34
 
35
  # Parameters
36
- HELP_URL = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
37
- IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp' # include image suffixes
38
  VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
39
  BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
40
  LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
 
41
 
42
  # Get orientation exif tag
43
  for orientation in ExifTags.TAGS.keys():
@@ -81,7 +84,7 @@ def exif_transpose(image):
81
  5: Image.TRANSPOSE,
82
  6: Image.ROTATE_270,
83
  7: Image.TRANSVERSE,
84
- 8: Image.ROTATE_90,}.get(orientation)
85
  if method is not None:
86
  image = image.transpose(method)
87
  del exif[0x0112]
@@ -142,7 +145,7 @@ def create_dataloader(path,
142
  shuffle=shuffle and sampler is None,
143
  num_workers=nw,
144
  sampler=sampler,
145
- pin_memory=True,
146
  collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
147
  worker_init_fn=seed_worker,
148
  generator=generator), dataset
@@ -184,7 +187,7 @@ class _RepeatSampler:
184
 
185
  class LoadImages:
186
  # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
187
- def __init__(self, path, img_size=640, stride=32, auto=True):
188
  files = []
189
  for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
190
  p = str(Path(p).resolve())
@@ -208,8 +211,10 @@ class LoadImages:
208
  self.video_flag = [False] * ni + [True] * nv
209
  self.mode = 'image'
210
  self.auto = auto
 
 
211
  if any(videos):
212
- self.new_video(videos[0]) # new video
213
  else:
214
  self.cap = None
215
  assert self.nf > 0, f'No images or videos found in {p}. ' \
@@ -227,103 +232,71 @@ class LoadImages:
227
  if self.video_flag[self.count]:
228
  # Read video
229
  self.mode = 'video'
230
- ret_val, img0 = self.cap.read()
 
231
  while not ret_val:
232
  self.count += 1
233
  self.cap.release()
234
  if self.count == self.nf: # last video
235
  raise StopIteration
236
  path = self.files[self.count]
237
- self.new_video(path)
238
- ret_val, img0 = self.cap.read()
239
 
240
  self.frame += 1
 
241
  s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
242
 
243
  else:
244
  # Read image
245
  self.count += 1
246
- img0 = cv2.imread(path) # BGR
247
- assert img0 is not None, f'Image Not Found {path}'
248
  s = f'image {self.count}/{self.nf} {path}: '
249
 
250
- # Padded resize
251
- img = letterbox(img0, self.img_size, stride=self.stride, auto=self.auto)[0]
252
-
253
- # Convert
254
- img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
255
- img = np.ascontiguousarray(img)
256
 
257
- return path, img, img0, self.cap, s
258
 
259
- def new_video(self, path):
 
260
  self.frame = 0
261
  self.cap = cv2.VideoCapture(path)
262
- self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
  def __len__(self):
265
  return self.nf # number of files
266
 
267
 
268
- class LoadWebcam: # for inference
269
- # YOLOv5 local webcam dataloader, i.e. `python detect.py --source 0`
270
- def __init__(self, pipe='0', img_size=640, stride=32):
271
- self.img_size = img_size
272
- self.stride = stride
273
- self.pipe = eval(pipe) if pipe.isnumeric() else pipe
274
- self.cap = cv2.VideoCapture(self.pipe) # video capture object
275
- self.cap.set(cv2.CAP_PROP_BUFFERSIZE, 3) # set buffer size
276
-
277
- def __iter__(self):
278
- self.count = -1
279
- return self
280
-
281
- def __next__(self):
282
- self.count += 1
283
- if cv2.waitKey(1) == ord('q'): # q to quit
284
- self.cap.release()
285
- cv2.destroyAllWindows()
286
- raise StopIteration
287
-
288
- # Read frame
289
- ret_val, img0 = self.cap.read()
290
- img0 = cv2.flip(img0, 1) # flip left-right
291
-
292
- # Print
293
- assert ret_val, f'Camera Error {self.pipe}'
294
- img_path = 'webcam.jpg'
295
- s = f'webcam {self.count}: '
296
-
297
- # Padded resize
298
- img = letterbox(img0, self.img_size, stride=self.stride)[0]
299
-
300
- # Convert
301
- img = img.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
302
- img = np.ascontiguousarray(img)
303
-
304
- return img_path, img, img0, None, s
305
-
306
- def __len__(self):
307
- return 0
308
-
309
-
310
  class LoadStreams:
311
  # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
312
- def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True):
 
313
  self.mode = 'stream'
314
  self.img_size = img_size
315
  self.stride = stride
316
-
317
- if os.path.isfile(sources):
318
- with open(sources) as f:
319
- sources = [x.strip() for x in f.read().strip().splitlines() if len(x.strip())]
320
- else:
321
- sources = [sources]
322
-
323
  n = len(sources)
324
- self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
325
  self.sources = [clean_str(x) for x in sources] # clean source names for later
326
- self.auto = auto
327
  for i, s in enumerate(sources): # index, source
328
  # Start thread to read frames from video stream
329
  st = f'{i + 1}/{n}: {s}... '
@@ -350,19 +323,20 @@ class LoadStreams:
350
  LOGGER.info('') # newline
351
 
352
  # check for common shapes
353
- s = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0].shape for x in self.imgs])
354
  self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
 
 
355
  if not self.rect:
356
  LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
357
 
358
  def update(self, i, cap, stream):
359
  # Read stream `i` frames in daemon thread
360
- n, f, read = 0, self.frames[i], 1 # frame number, frame array, inference every 'read' frame
361
  while cap.isOpened() and n < f:
362
  n += 1
363
- # _, self.imgs[index] = cap.read()
364
- cap.grab()
365
- if n % read == 0:
366
  success, im = cap.retrieve()
367
  if success:
368
  self.imgs[i] = im
@@ -382,18 +356,15 @@ class LoadStreams:
382
  cv2.destroyAllWindows()
383
  raise StopIteration
384
 
385
- # Letterbox
386
- img0 = self.imgs.copy()
387
- img = [letterbox(x, self.img_size, stride=self.stride, auto=self.rect and self.auto)[0] for x in img0]
388
-
389
- # Stack
390
- img = np.stack(img, 0)
391
-
392
- # Convert
393
- img = img[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
394
- img = np.ascontiguousarray(img)
395
 
396
- return self.sources, img, img0, None, ''
397
 
398
  def __len__(self):
399
  return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
@@ -453,7 +424,7 @@ class LoadImagesAndLabels(Dataset):
453
  # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
454
  assert self.im_files, f'{prefix}No images found'
455
  except Exception as e:
456
- raise Exception(f'{prefix}Error loading data from {path}: {e}\nSee {HELP_URL}')
457
 
458
  # Check cache
459
  self.label_files = img2label_paths(self.im_files) # labels
@@ -472,11 +443,13 @@ class LoadImagesAndLabels(Dataset):
472
  tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
473
  if cache['msgs']:
474
  LOGGER.info('\n'.join(cache['msgs'])) # display warnings
475
- assert nf > 0 or not augment, f'{prefix}No labels in {cache_path}. Can not train without labels. See {HELP_URL}'
476
 
477
  # Read cache
478
  [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
479
  labels, shapes, self.segments = zip(*cache.values())
 
 
480
  self.labels = list(labels)
481
  self.shapes = np.array(shapes)
482
  self.im_files = list(cache.keys()) # update
@@ -569,7 +542,7 @@ class LoadImagesAndLabels(Dataset):
569
  if msgs:
570
  LOGGER.info('\n'.join(msgs))
571
  if nf == 0:
572
- LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. See {HELP_URL}')
573
  x['hash'] = get_hash(self.label_files + self.im_files)
574
  x['results'] = nf, nm, ne, nc, len(self.im_files)
575
  x['msgs'] = msgs # warnings
@@ -831,7 +804,7 @@ class LoadImagesAndLabels(Dataset):
831
 
832
  @staticmethod
833
  def collate_fn4(batch):
834
- img, label, path, shapes = zip(*batch) # transposed
835
  n = len(shapes) // 4
836
  im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
837
 
@@ -841,13 +814,13 @@ class LoadImagesAndLabels(Dataset):
841
  for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
842
  i *= 4
843
  if random.random() < 0.5:
844
- im = F.interpolate(img[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
845
- align_corners=False)[0].type(img[i].type())
846
  lb = label[i]
847
  else:
848
- im = torch.cat((torch.cat((img[i], img[i + 1]), 1), torch.cat((img[i + 2], img[i + 3]), 1)), 2)
849
  lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
850
- im4.append(im)
851
  label4.append(lb)
852
 
853
  for i, lb in enumerate(label4):
@@ -870,7 +843,7 @@ def flatten_recursive(path=DATASETS_DIR / 'coco128'):
870
  def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.dataloaders import *; extract_boxes()
871
  # Convert detection dataset into classification dataset, with one directory per class
872
  path = Path(path) # images dir
873
- shutil.rmtree(path / 'classifier') if (path / 'classifier').is_dir() else None # remove existing
874
  files = list(path.rglob('*.*'))
875
  n = len(files) # number of files
876
  for im_file in tqdm(files, total=n):
@@ -916,7 +889,9 @@ def autosplit(path=DATASETS_DIR / 'coco128/images', weights=(0.9, 0.1, 0.0), ann
916
  indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
917
 
918
  txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
919
- [(path.parent / x).unlink(missing_ok=True) for x in txt] # remove existing
 
 
920
 
921
  print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
922
  for i, img in tqdm(zip(indices, files), total=n):
@@ -962,7 +937,7 @@ def verify_image_label(args):
962
  if len(i) < nl: # duplicate row check
963
  lb = lb[i] # remove duplicates
964
  if segments:
965
- segments = segments[i]
966
  msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
967
  else:
968
  ne = 1 # label empty
@@ -1002,7 +977,7 @@ class HUBDatasetStats():
1002
  self.hub_dir = Path(data['path'] + '-hub')
1003
  self.im_dir = self.hub_dir / 'images'
1004
  self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
1005
- self.stats = {'nc': data['nc'], 'names': data['names']} # statistics dictionary
1006
  self.data = data
1007
 
1008
  @staticmethod
@@ -1090,3 +1065,65 @@ class HUBDatasetStats():
1090
  pass
1091
  print(f'Done. All images saved to {self.im_dir}')
1092
  return self.im_dir
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  import numpy as np
23
  import torch
24
  import torch.nn.functional as F
25
+ import torchvision
26
  import yaml
27
  from PIL import ExifTags, Image, ImageOps
28
  from torch.utils.data import DataLoader, Dataset, dataloader, distributed
29
  from tqdm import tqdm
30
 
31
+ from utils.augmentations import (Albumentations, augment_hsv, classify_albumentations, classify_transforms, copy_paste,
32
+ letterbox, mixup, random_perspective)
33
  from utils.general import (DATASETS_DIR, LOGGER, NUM_THREADS, check_dataset, check_requirements, check_yaml, clean_str,
34
  cv2, is_colab, is_kaggle, segments2boxes, xyn2xy, xywh2xyxy, xywhn2xyxy, xyxy2xywhn)
35
  from utils.torch_utils import torch_distributed_zero_first
36
 
37
  # Parameters
38
+ HELP_URL = 'See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
39
+ IMG_FORMATS = 'bmp', 'dng', 'jpeg', 'jpg', 'mpo', 'png', 'tif', 'tiff', 'webp', 'pfm' # include image suffixes
40
  VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 'ts', 'wmv' # include video suffixes
41
  BAR_FORMAT = '{l_bar}{bar:10}{r_bar}{bar:-10b}' # tqdm bar format
42
  LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) # https://pytorch.org/docs/stable/elastic/run.html
43
+ PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders
44
 
45
  # Get orientation exif tag
46
  for orientation in ExifTags.TAGS.keys():
 
84
  5: Image.TRANSPOSE,
85
  6: Image.ROTATE_270,
86
  7: Image.TRANSVERSE,
87
+ 8: Image.ROTATE_90}.get(orientation)
88
  if method is not None:
89
  image = image.transpose(method)
90
  del exif[0x0112]
 
145
  shuffle=shuffle and sampler is None,
146
  num_workers=nw,
147
  sampler=sampler,
148
+ pin_memory=PIN_MEMORY,
149
  collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn,
150
  worker_init_fn=seed_worker,
151
  generator=generator), dataset
 
187
 
188
  class LoadImages:
189
  # YOLOv5 image/video dataloader, i.e. `python detect.py --source image.jpg/vid.mp4`
190
+ def __init__(self, path, img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
191
  files = []
192
  for p in sorted(path) if isinstance(path, (list, tuple)) else [path]:
193
  p = str(Path(p).resolve())
 
211
  self.video_flag = [False] * ni + [True] * nv
212
  self.mode = 'image'
213
  self.auto = auto
214
+ self.transforms = transforms # optional
215
+ self.vid_stride = vid_stride # video frame-rate stride
216
  if any(videos):
217
+ self._new_video(videos[0]) # new video
218
  else:
219
  self.cap = None
220
  assert self.nf > 0, f'No images or videos found in {p}. ' \
 
232
  if self.video_flag[self.count]:
233
  # Read video
234
  self.mode = 'video'
235
+ ret_val, im0 = self.cap.read()
236
+ self.cap.set(cv2.CAP_PROP_POS_FRAMES, self.vid_stride * (self.frame + 1)) # read at vid_stride
237
  while not ret_val:
238
  self.count += 1
239
  self.cap.release()
240
  if self.count == self.nf: # last video
241
  raise StopIteration
242
  path = self.files[self.count]
243
+ self._new_video(path)
244
+ ret_val, im0 = self.cap.read()
245
 
246
  self.frame += 1
247
+ # im0 = self._cv2_rotate(im0) # for use if cv2 autorotation is False
248
  s = f'video {self.count + 1}/{self.nf} ({self.frame}/{self.frames}) {path}: '
249
 
250
  else:
251
  # Read image
252
  self.count += 1
253
+ im0 = cv2.imread(path) # BGR
254
+ assert im0 is not None, f'Image Not Found {path}'
255
  s = f'image {self.count}/{self.nf} {path}: '
256
 
257
+ if self.transforms:
258
+ im = self.transforms(im0) # transforms
259
+ else:
260
+ im = letterbox(im0, self.img_size, stride=self.stride, auto=self.auto)[0] # padded resize
261
+ im = im.transpose((2, 0, 1))[::-1] # HWC to CHW, BGR to RGB
262
+ im = np.ascontiguousarray(im) # contiguous
263
 
264
+ return path, im, im0, self.cap, s
265
 
266
+ def _new_video(self, path):
267
+ # Create a new video capture object
268
  self.frame = 0
269
  self.cap = cv2.VideoCapture(path)
270
+ self.frames = int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT) / self.vid_stride)
271
+ self.orientation = int(self.cap.get(cv2.CAP_PROP_ORIENTATION_META)) # rotation degrees
272
+ # self.cap.set(cv2.CAP_PROP_ORIENTATION_AUTO, 0) # disable https://github.com/ultralytics/yolov5/issues/8493
273
+
274
+ def _cv2_rotate(self, im):
275
+ # Rotate a cv2 video manually
276
+ if self.orientation == 0:
277
+ return cv2.rotate(im, cv2.ROTATE_90_CLOCKWISE)
278
+ elif self.orientation == 180:
279
+ return cv2.rotate(im, cv2.ROTATE_90_COUNTERCLOCKWISE)
280
+ elif self.orientation == 90:
281
+ return cv2.rotate(im, cv2.ROTATE_180)
282
+ return im
283
 
284
  def __len__(self):
285
  return self.nf # number of files
286
 
287
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  class LoadStreams:
289
  # YOLOv5 streamloader, i.e. `python detect.py --source 'rtsp://example.com/media.mp4' # RTSP, RTMP, HTTP streams`
290
+ def __init__(self, sources='streams.txt', img_size=640, stride=32, auto=True, transforms=None, vid_stride=1):
291
+ torch.backends.cudnn.benchmark = True # faster for fixed-size inference
292
  self.mode = 'stream'
293
  self.img_size = img_size
294
  self.stride = stride
295
+ self.vid_stride = vid_stride # video frame-rate stride
296
+ sources = Path(sources).read_text().rsplit() if Path(sources).is_file() else [sources]
 
 
 
 
 
297
  n = len(sources)
 
298
  self.sources = [clean_str(x) for x in sources] # clean source names for later
299
+ self.imgs, self.fps, self.frames, self.threads = [None] * n, [0] * n, [0] * n, [None] * n
300
  for i, s in enumerate(sources): # index, source
301
  # Start thread to read frames from video stream
302
  st = f'{i + 1}/{n}: {s}... '
 
323
  LOGGER.info('') # newline
324
 
325
  # check for common shapes
326
+ s = np.stack([letterbox(x, img_size, stride=stride, auto=auto)[0].shape for x in self.imgs])
327
  self.rect = np.unique(s, axis=0).shape[0] == 1 # rect inference if all shapes equal
328
+ self.auto = auto and self.rect
329
+ self.transforms = transforms # optional
330
  if not self.rect:
331
  LOGGER.warning('WARNING: Stream shapes differ. For optimal performance supply similarly-shaped streams.')
332
 
333
  def update(self, i, cap, stream):
334
  # Read stream `i` frames in daemon thread
335
+ n, f = 0, self.frames[i] # frame number, frame array
336
  while cap.isOpened() and n < f:
337
  n += 1
338
+ cap.grab() # .read() = .grab() followed by .retrieve()
339
+ if n % self.vid_stride == 0:
 
340
  success, im = cap.retrieve()
341
  if success:
342
  self.imgs[i] = im
 
356
  cv2.destroyAllWindows()
357
  raise StopIteration
358
 
359
+ im0 = self.imgs.copy()
360
+ if self.transforms:
361
+ im = np.stack([self.transforms(x) for x in im0]) # transforms
362
+ else:
363
+ im = np.stack([letterbox(x, self.img_size, stride=self.stride, auto=self.auto)[0] for x in im0]) # resize
364
+ im = im[..., ::-1].transpose((0, 3, 1, 2)) # BGR to RGB, BHWC to BCHW
365
+ im = np.ascontiguousarray(im) # contiguous
 
 
 
366
 
367
+ return self.sources, im, im0, None, ''
368
 
369
  def __len__(self):
370
  return len(self.sources) # 1E12 frames = 32 streams at 30 FPS for 30 years
 
424
  # self.img_files = sorted([x for x in f if x.suffix[1:].lower() in IMG_FORMATS]) # pathlib
425
  assert self.im_files, f'{prefix}No images found'
426
  except Exception as e:
427
+ raise Exception(f'{prefix}Error loading data from {path}: {e}\n{HELP_URL}')
428
 
429
  # Check cache
430
  self.label_files = img2label_paths(self.im_files) # labels
 
443
  tqdm(None, desc=prefix + d, total=n, initial=n, bar_format=BAR_FORMAT) # display cache results
444
  if cache['msgs']:
445
  LOGGER.info('\n'.join(cache['msgs'])) # display warnings
446
+ assert nf > 0 or not augment, f'{prefix}No labels found in {cache_path}, can not start training. {HELP_URL}'
447
 
448
  # Read cache
449
  [cache.pop(k) for k in ('hash', 'version', 'msgs')] # remove items
450
  labels, shapes, self.segments = zip(*cache.values())
451
+ nl = len(np.concatenate(labels, 0)) # number of labels
452
+ assert nl > 0 or not augment, f'{prefix}All labels empty in {cache_path}, can not start training. {HELP_URL}'
453
  self.labels = list(labels)
454
  self.shapes = np.array(shapes)
455
  self.im_files = list(cache.keys()) # update
 
542
  if msgs:
543
  LOGGER.info('\n'.join(msgs))
544
  if nf == 0:
545
+ LOGGER.warning(f'{prefix}WARNING: No labels found in {path}. {HELP_URL}')
546
  x['hash'] = get_hash(self.label_files + self.im_files)
547
  x['results'] = nf, nm, ne, nc, len(self.im_files)
548
  x['msgs'] = msgs # warnings
 
804
 
805
  @staticmethod
806
  def collate_fn4(batch):
807
+ im, label, path, shapes = zip(*batch) # transposed
808
  n = len(shapes) // 4
809
  im4, label4, path4, shapes4 = [], [], path[:n], shapes[:n]
810
 
 
814
  for i in range(n): # zidane torch.zeros(16,3,720,1280) # BCHW
815
  i *= 4
816
  if random.random() < 0.5:
817
+ im1 = F.interpolate(im[i].unsqueeze(0).float(), scale_factor=2.0, mode='bilinear',
818
+ align_corners=False)[0].type(im[i].type())
819
  lb = label[i]
820
  else:
821
+ im1 = torch.cat((torch.cat((im[i], im[i + 1]), 1), torch.cat((im[i + 2], im[i + 3]), 1)), 2)
822
  lb = torch.cat((label[i], label[i + 1] + ho, label[i + 2] + wo, label[i + 3] + ho + wo), 0) * s
823
+ im4.append(im1)
824
  label4.append(lb)
825
 
826
  for i, lb in enumerate(label4):
 
843
  def extract_boxes(path=DATASETS_DIR / 'coco128'): # from utils.dataloaders import *; extract_boxes()
844
  # Convert detection dataset into classification dataset, with one directory per class
845
  path = Path(path) # images dir
846
+ shutil.rmtree(path / 'classification') if (path / 'classification').is_dir() else None # remove existing
847
  files = list(path.rglob('*.*'))
848
  n = len(files) # number of files
849
  for im_file in tqdm(files, total=n):
 
889
  indices = random.choices([0, 1, 2], weights=weights, k=n) # assign each image to a split
890
 
891
  txt = ['autosplit_train.txt', 'autosplit_val.txt', 'autosplit_test.txt'] # 3 txt files
892
+ for x in txt:
893
+ if (path.parent / x).exists():
894
+ (path.parent / x).unlink() # remove existing
895
 
896
  print(f'Autosplitting images from {path}' + ', using *.txt labeled images only' * annotated_only)
897
  for i, img in tqdm(zip(indices, files), total=n):
 
937
  if len(i) < nl: # duplicate row check
938
  lb = lb[i] # remove duplicates
939
  if segments:
940
+ segments = [segments[x] for x in i]
941
  msg = f'{prefix}WARNING: {im_file}: {nl - len(i)} duplicate labels removed'
942
  else:
943
  ne = 1 # label empty
 
977
  self.hub_dir = Path(data['path'] + '-hub')
978
  self.im_dir = self.hub_dir / 'images'
979
  self.im_dir.mkdir(parents=True, exist_ok=True) # makes /images
980
+ self.stats = {'nc': data['nc'], 'names': list(data['names'].values())} # statistics dictionary
981
  self.data = data
982
 
983
  @staticmethod
 
1065
  pass
1066
  print(f'Done. All images saved to {self.im_dir}')
1067
  return self.im_dir
1068
+
1069
+
1070
+ # Classification dataloaders -------------------------------------------------------------------------------------------
1071
+ class ClassificationDataset(torchvision.datasets.ImageFolder):
1072
+ """
1073
+ YOLOv5 Classification Dataset.
1074
+ Arguments
1075
+ root: Dataset path
1076
+ transform: torchvision transforms, used by default
1077
+ album_transform: Albumentations transforms, used if installed
1078
+ """
1079
+
1080
+ def __init__(self, root, augment, imgsz, cache=False):
1081
+ super().__init__(root=root)
1082
+ self.torch_transforms = classify_transforms(imgsz)
1083
+ self.album_transforms = classify_albumentations(augment, imgsz) if augment else None
1084
+ self.cache_ram = cache is True or cache == 'ram'
1085
+ self.cache_disk = cache == 'disk'
1086
+ self.samples = [list(x) + [Path(x[0]).with_suffix('.npy'), None] for x in self.samples] # file, index, npy, im
1087
+
1088
+ def __getitem__(self, i):
1089
+ f, j, fn, im = self.samples[i] # filename, index, filename.with_suffix('.npy'), image
1090
+ if self.cache_ram and im is None:
1091
+ im = self.samples[i][3] = cv2.imread(f)
1092
+ elif self.cache_disk:
1093
+ if not fn.exists(): # load npy
1094
+ np.save(fn.as_posix(), cv2.imread(f))
1095
+ im = np.load(fn)
1096
+ else: # read image
1097
+ im = cv2.imread(f) # BGR
1098
+ if self.album_transforms:
1099
+ sample = self.album_transforms(image=cv2.cvtColor(im, cv2.COLOR_BGR2RGB))["image"]
1100
+ else:
1101
+ sample = self.torch_transforms(im)
1102
+ return sample, j
1103
+
1104
+
1105
+ def create_classification_dataloader(path,
1106
+ imgsz=224,
1107
+ batch_size=16,
1108
+ augment=True,
1109
+ cache=False,
1110
+ rank=-1,
1111
+ workers=8,
1112
+ shuffle=True):
1113
+ # Returns Dataloader object to be used with YOLOv5 Classifier
1114
+ with torch_distributed_zero_first(rank): # init dataset *.cache only once if DDP
1115
+ dataset = ClassificationDataset(root=path, imgsz=imgsz, augment=augment, cache=cache)
1116
+ batch_size = min(batch_size, len(dataset))
1117
+ nd = torch.cuda.device_count()
1118
+ nw = min([os.cpu_count() // max(nd, 1), batch_size if batch_size > 1 else 0, workers])
1119
+ sampler = None if rank == -1 else distributed.DistributedSampler(dataset, shuffle=shuffle)
1120
+ generator = torch.Generator()
1121
+ generator.manual_seed(0)
1122
+ return InfiniteDataLoader(dataset,
1123
+ batch_size=batch_size,
1124
+ shuffle=shuffle and sampler is None,
1125
+ num_workers=nw,
1126
+ sampler=sampler,
1127
+ pin_memory=PIN_MEMORY,
1128
+ worker_init_fn=seed_worker,
1129
+ generator=generator) # or DataLoader(persistent_workers=True)
utils/downloads.py CHANGED
@@ -33,6 +33,12 @@ def gsutil_getsize(url=''):
33
  return eval(s.split(' ')[0]) if len(s) else 0 # bytes
34
 
35
 
 
 
 
 
 
 
36
  def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
37
  # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
38
  from utils.general import LOGGER
@@ -44,24 +50,26 @@ def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
44
  torch.hub.download_url_to_file(url, str(file), progress=LOGGER.level <= logging.INFO)
45
  assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
46
  except Exception as e: # url2
47
- file.unlink(missing_ok=True) # remove partial downloads
 
48
  LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
49
- os.system(f"curl -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
50
  finally:
51
  if not file.exists() or file.stat().st_size < min_bytes: # check
52
- file.unlink(missing_ok=True) # remove partial downloads
 
53
  LOGGER.info(f"ERROR: {assert_msg}\n{error_msg}")
54
  LOGGER.info('')
55
 
56
 
57
- def attempt_download(file, repo='ultralytics/yolov5', release='v6.1'):
58
- # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.1', etc.
59
  from utils.general import LOGGER
60
 
61
  def github_assets(repository, version='latest'):
62
- # Return GitHub repo tag (i.e. 'v6.1') and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
63
  if version != 'latest':
64
- version = f'tags/{version}' # i.e. tags/v6.1
65
  response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
66
  return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
67
 
@@ -112,8 +120,10 @@ def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
112
  file = Path(file)
113
  cookie = Path('cookie') # gdrive cookie
114
  print(f'Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ', end='')
115
- file.unlink(missing_ok=True) # remove existing file
116
- cookie.unlink(missing_ok=True) # remove existing cookie
 
 
117
 
118
  # Attempt file download
119
  out = "NUL" if platform.system() == "Windows" else "/dev/null"
@@ -123,11 +133,13 @@ def gdrive_download(id='16TiPfZj7htmTyhntwcZyEEAejOUxuT6m', file='tmp.zip'):
123
  else: # small file
124
  s = f'curl -s -L -o {file} "drive.google.com/uc?export=download&id={id}"'
125
  r = os.system(s) # execute, capture return
126
- cookie.unlink(missing_ok=True) # remove existing cookie
 
127
 
128
  # Error check
129
  if r != 0:
130
- file.unlink(missing_ok=True) # remove partial
 
131
  print('Download error ') # raise Exception('Download error')
132
  return r
133
 
 
33
  return eval(s.split(' ')[0]) if len(s) else 0 # bytes
34
 
35
 
36
+ def url_getsize(url='https://ultralytics.com/images/bus.jpg'):
37
+ # Return downloadable file size in bytes
38
+ response = requests.head(url, allow_redirects=True)
39
+ return int(response.headers.get('content-length', -1))
40
+
41
+
42
  def safe_download(file, url, url2=None, min_bytes=1E0, error_msg=''):
43
  # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes
44
  from utils.general import LOGGER
 
50
  torch.hub.download_url_to_file(url, str(file), progress=LOGGER.level <= logging.INFO)
51
  assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check
52
  except Exception as e: # url2
53
+ if file.exists():
54
+ file.unlink() # remove partial downloads
55
  LOGGER.info(f'ERROR: {e}\nRe-attempting {url2 or url} to {file}...')
56
+ os.system(f"curl -# -L '{url2 or url}' -o '{file}' --retry 3 -C -") # curl download, retry and resume on fail
57
  finally:
58
  if not file.exists() or file.stat().st_size < min_bytes: # check
59
+ if file.exists():
60
+ file.unlink() # remove partial downloads
61
  LOGGER.info(f"ERROR: {assert_msg}\n{error_msg}")
62
  LOGGER.info('')
63
 
64
 
65
+ def attempt_download(file, repo='ultralytics/yolov5', release='v6.2'):
66
+ # Attempt file download from GitHub release assets if not found locally. release = 'latest', 'v6.2', etc.
67
  from utils.general import LOGGER
68
 
69
  def github_assets(repository, version='latest'):
70
+ # Return GitHub repo tag (i.e. 'v6.2') and assets (i.e. ['yolov5s.pt', 'yolov5m.pt', ...])
71
  if version != 'latest':
72
+ version = f'tags/{version}' # i.e. tags/v6.2
73
  response = requests.get(f'https://api.github.com/repos/{repository}/releases/{version}').json() # github api
74
  return response['tag_name'], [x['name'] for x in response['assets']] # tag, assets
75
 
 
120
  file = Path(file)
121
  cookie = Path('cookie') # gdrive cookie
122
  print(f'Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ', end='')
123
+ if file.exists():
124
+ file.unlink() # remove existing file
125
+ if cookie.exists():
126
+ cookie.unlink() # remove existing cookie
127
 
128
  # Attempt file download
129
  out = "NUL" if platform.system() == "Windows" else "/dev/null"
 
133
  else: # small file
134
  s = f'curl -s -L -o {file} "drive.google.com/uc?export=download&id={id}"'
135
  r = os.system(s) # execute, capture return
136
+ if cookie.exists():
137
+ cookie.unlink() # remove existing cookie
138
 
139
  # Error check
140
  if r != 0:
141
+ if file.exists():
142
+ file.unlink() # remove partial
143
  print('Download error ') # raise Exception('Download error')
144
  return r
145
 
utils/general.py CHANGED
@@ -15,7 +15,6 @@ import re
15
  import shutil
16
  import signal
17
  import sys
18
- import threading
19
  import time
20
  import urllib
21
  from datetime import datetime
@@ -34,6 +33,7 @@ import torch
34
  import torchvision
35
  import yaml
36
 
 
37
  from utils.downloads import gsutil_getsize
38
  from utils.metrics import box_iou, fitness
39
 
@@ -56,13 +56,35 @@ os.environ['NUMEXPR_MAX_THREADS'] = str(NUM_THREADS) # NumExpr max threads
56
  os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
57
 
58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  def is_kaggle():
60
  # Is environment a Kaggle Notebook?
61
- try:
62
- assert os.environ.get('PWD') == '/kaggle/working'
63
- assert os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
 
 
 
64
  return True
65
- except AssertionError:
 
 
 
66
  return False
67
 
68
 
@@ -82,7 +104,7 @@ def is_writeable(dir, test=False):
82
 
83
  def set_logging(name=None, verbose=VERBOSE):
84
  # Sets level and returns logger
85
- if is_kaggle():
86
  for h in logging.root.handlers:
87
  logging.root.removeHandler(h) # remove all handlers associated with the root logger object
88
  rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
@@ -119,16 +141,27 @@ CONFIG_DIR = user_config_dir() # Ultralytics settings dir
119
 
120
 
121
  class Profile(contextlib.ContextDecorator):
122
- # Usage: @Profile() decorator or 'with Profile():' context manager
 
 
 
 
123
  def __enter__(self):
124
- self.start = time.time()
 
125
 
126
  def __exit__(self, type, value, traceback):
127
- print(f'Profile results: {time.time() - self.start:.5f}s')
 
 
 
 
 
 
128
 
129
 
130
  class Timeout(contextlib.ContextDecorator):
131
- # Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
132
  def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
133
  self.seconds = int(seconds)
134
  self.timeout_message = timeout_msg
@@ -162,64 +195,50 @@ class WorkingDirectory(contextlib.ContextDecorator):
162
  os.chdir(self.cwd)
163
 
164
 
165
- def try_except(func):
166
- # try-except function. Usage: @try_except decorator
167
- def handler(*args, **kwargs):
168
- try:
169
- func(*args, **kwargs)
170
- except Exception as e:
171
- print(e)
172
-
173
- return handler
174
-
175
-
176
- def threaded(func):
177
- # Multi-threads a target function and returns thread. Usage: @threaded decorator
178
- def wrapper(*args, **kwargs):
179
- thread = threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True)
180
- thread.start()
181
- return thread
182
-
183
- return wrapper
184
-
185
-
186
  def methods(instance):
187
  # Get class/instance methods
188
  return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
189
 
190
 
191
- def print_args(args: Optional[dict] = None, show_file=True, show_fcn=False):
192
  # Print function arguments (optional args dict)
193
  x = inspect.currentframe().f_back # previous frame
194
- file, _, fcn, _, _ = inspect.getframeinfo(x)
195
  if args is None: # get args automatically
196
  args, _, _, frm = inspect.getargvalues(x)
197
  args = {k: v for k, v in frm.items() if k in args}
198
- s = (f'{Path(file).stem}: ' if show_file else '') + (f'{fcn}: ' if show_fcn else '')
 
 
 
 
199
  LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
200
 
201
 
202
  def init_seeds(seed=0, deterministic=False):
203
  # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
204
- # cudnn seed 0 settings are slower and more reproducible, else faster and less reproducible
205
- import torch.backends.cudnn as cudnn
206
-
207
- if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
208
- torch.use_deterministic_algorithms(True)
209
- os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
210
- os.environ['PYTHONHASHSEED'] = str(seed)
211
-
212
  random.seed(seed)
213
  np.random.seed(seed)
214
  torch.manual_seed(seed)
215
- cudnn.benchmark, cudnn.deterministic = (False, True) if seed == 0 else (True, False)
216
  torch.cuda.manual_seed(seed)
217
  torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
 
 
 
 
 
 
218
 
219
 
220
  def intersect_dicts(da, db, exclude=()):
221
  # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
222
- return {k: v for k, v in da.items() if k in db and not any(x in k for x in exclude) and v.shape == db[k].shape}
 
 
 
 
 
 
223
 
224
 
225
  def get_latest_run(search_dir='.'):
@@ -228,42 +247,6 @@ def get_latest_run(search_dir='.'):
228
  return max(last_list, key=os.path.getctime) if last_list else ''
229
 
230
 
231
- def is_docker() -> bool:
232
- """Check if the process runs inside a docker container."""
233
- if Path("/.dockerenv").exists():
234
- return True
235
- try: # check if docker is in control groups
236
- with open("/proc/self/cgroup") as file:
237
- return any("docker" in line for line in file)
238
- except OSError:
239
- return False
240
-
241
-
242
- def is_colab():
243
- # Is environment a Google Colab instance?
244
- try:
245
- import google.colab
246
- return True
247
- except ImportError:
248
- return False
249
-
250
-
251
- def is_pip():
252
- # Is file in a pip package?
253
- return 'site-packages' in Path(__file__).resolve().parts
254
-
255
-
256
- def is_ascii(s=''):
257
- # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
258
- s = str(s) # convert list, tuple, None, etc. to str
259
- return len(s.encode().decode('ascii', 'ignore')) == len(s)
260
-
261
-
262
- def is_chinese(s='人工智能'):
263
- # Is string composed of any Chinese characters?
264
- return bool(re.search('[\u4e00-\u9fff]', str(s)))
265
-
266
-
267
  def emojis(str=''):
268
  # Return platform-dependent emoji-safe version of string
269
  return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
@@ -312,9 +295,9 @@ def git_describe(path=ROOT): # path must be a directory
312
  return ''
313
 
314
 
315
- @try_except
316
  @WorkingDirectory(ROOT)
317
- def check_git_status(repo='ultralytics/yolov5'):
318
  # YOLOv5 status check, recommend 'git pull' if code is out of date
319
  url = f'https://github.com/{repo}'
320
  msg = f', for updates see {url}'
@@ -330,10 +313,10 @@ def check_git_status(repo='ultralytics/yolov5'):
330
  remote = 'ultralytics'
331
  check_output(f'git remote add {remote} {url}', shell=True)
332
  check_output(f'git fetch {remote}', shell=True, timeout=5) # git fetch
333
- branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
334
- n = int(check_output(f'git rev-list {branch}..{remote}/master --count', shell=True)) # commits behind
335
  if n > 0:
336
- pull = 'git pull' if remote == 'origin' else f'git pull {remote} master'
337
  s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `{pull}` or `git clone {url}` to update."
338
  else:
339
  s += f'up to date with {url} ✅'
@@ -349,17 +332,17 @@ def check_version(current='0.0.0', minimum='0.0.0', name='version ', pinned=Fals
349
  # Check version vs. required version
350
  current, minimum = (pkg.parse_version(x) for x in (current, minimum))
351
  result = (current == minimum) if pinned else (current >= minimum) # bool
352
- s = f'{name}{minimum} required by YOLOv5, but {name}{current} is currently installed' # string
353
  if hard:
354
- assert result, s # assert min requirements met
355
  if verbose and not result:
356
  LOGGER.warning(s)
357
  return result
358
 
359
 
360
- @try_except
361
  def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
362
- # Check installed dependencies meet requirements (pass *.txt file or list of packages)
363
  prefix = colorstr('red', 'bold', 'requirements:')
364
  check_python() # check python version
365
  if isinstance(requirements, (str, Path)): # requirements.txt file
@@ -470,7 +453,7 @@ def check_font(font=FONT, progress=False):
470
  font = Path(font)
471
  file = CONFIG_DIR / font.name
472
  if not font.exists() and not file.exists():
473
- url = "https://ultralytics.com/assets/" + font.name
474
  LOGGER.info(f'Downloading {url} to {file}...')
475
  torch.hub.download_url_to_file(url, str(file), progress=progress)
476
 
@@ -491,11 +474,11 @@ def check_dataset(data, autodownload=True):
491
  data = yaml.safe_load(f) # dictionary
492
 
493
  # Checks
494
- for k in 'train', 'val', 'nc':
495
  assert k in data, f"data.yaml '{k}:' field missing ❌"
496
- if 'names' not in data:
497
- LOGGER.warning("data.yaml 'names:' field missing ⚠️, assigning default names 'class0', 'class1', etc.")
498
- data['names'] = [f'class{i}' for i in range(data['nc'])] # default names
499
 
500
  # Resolve paths
501
  path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
@@ -549,8 +532,8 @@ def check_amp(model):
549
 
550
  prefix = colorstr('AMP: ')
551
  device = next(model.parameters()).device # get model device
552
- if device.type == 'cpu':
553
- return False # AMP disabled on CPU
554
  f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
555
  im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
556
  try:
@@ -563,6 +546,18 @@ def check_amp(model):
563
  return False
564
 
565
 
 
 
 
 
 
 
 
 
 
 
 
 
566
  def url2file(url):
567
  # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
568
  url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
@@ -570,7 +565,7 @@ def url2file(url):
570
 
571
 
572
  def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
573
- # Multi-threaded file download and unzip function, used in data.yaml for autodownload
574
  def download_one(url, dir):
575
  # Download 1 file
576
  success = True
@@ -582,7 +577,8 @@ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry
582
  for i in range(retry + 1):
583
  if curl:
584
  s = 'sS' if threads > 1 else '' # silent
585
- r = os.system(f'curl -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
 
586
  success = r == 0
587
  else:
588
  torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
@@ -594,10 +590,12 @@ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry
594
  else:
595
  LOGGER.warning(f'Failed to download {url}...')
596
 
597
- if unzip and success and f.suffix in ('.zip', '.gz'):
598
  LOGGER.info(f'Unzipping {f}...')
599
  if f.suffix == '.zip':
600
  ZipFile(f).extractall(path=dir) # unzip
 
 
601
  elif f.suffix == '.gz':
602
  os.system(f'tar xfz {f} --directory {f.parent}') # unzip
603
  if delete:
@@ -607,7 +605,7 @@ def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry
607
  dir.mkdir(parents=True, exist_ok=True) # make directory
608
  if threads > 1:
609
  pool = ThreadPool(threads)
610
- pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multi-threaded
611
  pool.close()
612
  pool.join()
613
  else:
@@ -815,6 +813,9 @@ def non_max_suppression(prediction,
815
  list of detections, on (n,6) tensor per image [xyxy, conf, cls]
816
  """
817
 
 
 
 
818
  bs = prediction.shape[0] # batch size
819
  nc = prediction.shape[2] - 5 # number of classes
820
  xc = prediction[..., 4] > conf_thres # candidates
 
15
  import shutil
16
  import signal
17
  import sys
 
18
  import time
19
  import urllib
20
  from datetime import datetime
 
33
  import torchvision
34
  import yaml
35
 
36
+ from utils import TryExcept
37
  from utils.downloads import gsutil_getsize
38
  from utils.metrics import box_iou, fitness
39
 
 
56
  os.environ['OMP_NUM_THREADS'] = '1' if platform.system() == 'darwin' else str(NUM_THREADS) # OpenMP (PyTorch and SciPy)
57
 
58
 
59
+ def is_ascii(s=''):
60
+ # Is string composed of all ASCII (no UTF) characters? (note str().isascii() introduced in python 3.7)
61
+ s = str(s) # convert list, tuple, None, etc. to str
62
+ return len(s.encode().decode('ascii', 'ignore')) == len(s)
63
+
64
+
65
+ def is_chinese(s='人工智能'):
66
+ # Is string composed of any Chinese characters?
67
+ return bool(re.search('[\u4e00-\u9fff]', str(s)))
68
+
69
+
70
+ def is_colab():
71
+ # Is environment a Google Colab instance?
72
+ return 'COLAB_GPU' in os.environ
73
+
74
+
75
  def is_kaggle():
76
  # Is environment a Kaggle Notebook?
77
+ return os.environ.get('PWD') == '/kaggle/working' and os.environ.get('KAGGLE_URL_BASE') == 'https://www.kaggle.com'
78
+
79
+
80
+ def is_docker() -> bool:
81
+ """Check if the process runs inside a docker container."""
82
+ if Path("/.dockerenv").exists():
83
  return True
84
+ try: # check if docker is in control groups
85
+ with open("/proc/self/cgroup") as file:
86
+ return any("docker" in line for line in file)
87
+ except OSError:
88
  return False
89
 
90
 
 
104
 
105
  def set_logging(name=None, verbose=VERBOSE):
106
  # Sets level and returns logger
107
+ if is_kaggle() or is_colab():
108
  for h in logging.root.handlers:
109
  logging.root.removeHandler(h) # remove all handlers associated with the root logger object
110
  rank = int(os.getenv('RANK', -1)) # rank in world for Multi-GPU trainings
 
141
 
142
 
143
  class Profile(contextlib.ContextDecorator):
144
+ # YOLOv5 Profile class. Usage: @Profile() decorator or 'with Profile():' context manager
145
+ def __init__(self, t=0.0):
146
+ self.t = t
147
+ self.cuda = torch.cuda.is_available()
148
+
149
  def __enter__(self):
150
+ self.start = self.time()
151
+ return self
152
 
153
  def __exit__(self, type, value, traceback):
154
+ self.dt = self.time() - self.start # delta-time
155
+ self.t += self.dt # accumulate dt
156
+
157
+ def time(self):
158
+ if self.cuda:
159
+ torch.cuda.synchronize()
160
+ return time.time()
161
 
162
 
163
  class Timeout(contextlib.ContextDecorator):
164
+ # YOLOv5 Timeout class. Usage: @Timeout(seconds) decorator or 'with Timeout(seconds):' context manager
165
  def __init__(self, seconds, *, timeout_msg='', suppress_timeout_errors=True):
166
  self.seconds = int(seconds)
167
  self.timeout_message = timeout_msg
 
195
  os.chdir(self.cwd)
196
 
197
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  def methods(instance):
199
  # Get class/instance methods
200
  return [f for f in dir(instance) if callable(getattr(instance, f)) and not f.startswith("__")]
201
 
202
 
203
+ def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
204
  # Print function arguments (optional args dict)
205
  x = inspect.currentframe().f_back # previous frame
206
+ file, _, func, _, _ = inspect.getframeinfo(x)
207
  if args is None: # get args automatically
208
  args, _, _, frm = inspect.getargvalues(x)
209
  args = {k: v for k, v in frm.items() if k in args}
210
+ try:
211
+ file = Path(file).resolve().relative_to(ROOT).with_suffix('')
212
+ except ValueError:
213
+ file = Path(file).stem
214
+ s = (f'{file}: ' if show_file else '') + (f'{func}: ' if show_func else '')
215
  LOGGER.info(colorstr(s) + ', '.join(f'{k}={v}' for k, v in args.items()))
216
 
217
 
218
  def init_seeds(seed=0, deterministic=False):
219
  # Initialize random number generator (RNG) seeds https://pytorch.org/docs/stable/notes/randomness.html
 
 
 
 
 
 
 
 
220
  random.seed(seed)
221
  np.random.seed(seed)
222
  torch.manual_seed(seed)
 
223
  torch.cuda.manual_seed(seed)
224
  torch.cuda.manual_seed_all(seed) # for Multi-GPU, exception safe
225
+ torch.backends.cudnn.benchmark = True # for faster training
226
+ if deterministic and check_version(torch.__version__, '1.12.0'): # https://github.com/ultralytics/yolov5/pull/8213
227
+ torch.use_deterministic_algorithms(True)
228
+ torch.backends.cudnn.deterministic = True
229
+ os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
230
+ os.environ['PYTHONHASHSEED'] = str(seed)
231
 
232
 
233
  def intersect_dicts(da, db, exclude=()):
234
  # Dictionary intersection of matching keys and shapes, omitting 'exclude' keys, using da values
235
+ return {k: v for k, v in da.items() if k in db and all(x not in k for x in exclude) and v.shape == db[k].shape}
236
+
237
+
238
+ def get_default_args(func):
239
+ # Get func() default arguments
240
+ signature = inspect.signature(func)
241
+ return {k: v.default for k, v in signature.parameters.items() if v.default is not inspect.Parameter.empty}
242
 
243
 
244
  def get_latest_run(search_dir='.'):
 
247
  return max(last_list, key=os.path.getctime) if last_list else ''
248
 
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  def emojis(str=''):
251
  # Return platform-dependent emoji-safe version of string
252
  return str.encode().decode('ascii', 'ignore') if platform.system() == 'Windows' else str
 
295
  return ''
296
 
297
 
298
+ @TryExcept()
299
  @WorkingDirectory(ROOT)
300
+ def check_git_status(repo='ultralytics/yolov5', branch='master'):
301
  # YOLOv5 status check, recommend 'git pull' if code is out of date
302
  url = f'https://github.com/{repo}'
303
  msg = f', for updates see {url}'
 
313
  remote = 'ultralytics'
314
  check_output(f'git remote add {remote} {url}', shell=True)
315
  check_output(f'git fetch {remote}', shell=True, timeout=5) # git fetch
316
+ local_branch = check_output('git rev-parse --abbrev-ref HEAD', shell=True).decode().strip() # checked out
317
+ n = int(check_output(f'git rev-list {local_branch}..{remote}/{branch} --count', shell=True)) # commits behind
318
  if n > 0:
319
+ pull = 'git pull' if remote == 'origin' else f'git pull {remote} {branch}'
320
  s += f"⚠️ YOLOv5 is out of date by {n} commit{'s' * (n > 1)}. Use `{pull}` or `git clone {url}` to update."
321
  else:
322
  s += f'up to date with {url} ✅'
 
332
  # Check version vs. required version
333
  current, minimum = (pkg.parse_version(x) for x in (current, minimum))
334
  result = (current == minimum) if pinned else (current >= minimum) # bool
335
+ s = f'WARNING: ⚠️ {name}{minimum} is required by YOLOv5, but {name}{current} is currently installed' # string
336
  if hard:
337
+ assert result, emojis(s) # assert min requirements met
338
  if verbose and not result:
339
  LOGGER.warning(s)
340
  return result
341
 
342
 
343
+ @TryExcept()
344
  def check_requirements(requirements=ROOT / 'requirements.txt', exclude=(), install=True, cmds=()):
345
+ # Check installed dependencies meet YOLOv5 requirements (pass *.txt file or list of packages)
346
  prefix = colorstr('red', 'bold', 'requirements:')
347
  check_python() # check python version
348
  if isinstance(requirements, (str, Path)): # requirements.txt file
 
453
  font = Path(font)
454
  file = CONFIG_DIR / font.name
455
  if not font.exists() and not file.exists():
456
+ url = f'https://ultralytics.com/assets/{font.name}'
457
  LOGGER.info(f'Downloading {url} to {file}...')
458
  torch.hub.download_url_to_file(url, str(file), progress=progress)
459
 
 
474
  data = yaml.safe_load(f) # dictionary
475
 
476
  # Checks
477
+ for k in 'train', 'val', 'names':
478
  assert k in data, f"data.yaml '{k}:' field missing ❌"
479
+ if isinstance(data['names'], (list, tuple)): # old array format
480
+ data['names'] = dict(enumerate(data['names'])) # convert to dict
481
+ data['nc'] = len(data['names'])
482
 
483
  # Resolve paths
484
  path = Path(extract_dir or data.get('path') or '') # optional 'path' default to '.'
 
532
 
533
  prefix = colorstr('AMP: ')
534
  device = next(model.parameters()).device # get model device
535
+ if device.type in ('cpu', 'mps'):
536
+ return False # AMP only used on CUDA devices
537
  f = ROOT / 'data' / 'images' / 'bus.jpg' # image to check
538
  im = f if f.exists() else 'https://ultralytics.com/images/bus.jpg' if check_online() else np.ones((640, 640, 3))
539
  try:
 
546
  return False
547
 
548
 
549
+ def yaml_load(file='data.yaml'):
550
+ # Single-line safe yaml loading
551
+ with open(file, errors='ignore') as f:
552
+ return yaml.safe_load(f)
553
+
554
+
555
+ def yaml_save(file='data.yaml', data={}):
556
+ # Single-line safe yaml saving
557
+ with open(file, 'w') as f:
558
+ yaml.safe_dump({k: str(v) if isinstance(v, Path) else v for k, v in data.items()}, f, sort_keys=False)
559
+
560
+
561
  def url2file(url):
562
  # Convert URL to filename, i.e. https://url.com/file.txt?auth -> file.txt
563
  url = str(Path(url)).replace(':/', '://') # Pathlib turns :// -> :/
 
565
 
566
 
567
  def download(url, dir='.', unzip=True, delete=True, curl=False, threads=1, retry=3):
568
+ # Multithreaded file download and unzip function, used in data.yaml for autodownload
569
  def download_one(url, dir):
570
  # Download 1 file
571
  success = True
 
577
  for i in range(retry + 1):
578
  if curl:
579
  s = 'sS' if threads > 1 else '' # silent
580
+ r = os.system(
581
+ f'curl -# -{s}L "{url}" -o "{f}" --retry 9 -C -') # curl download with retry, continue
582
  success = r == 0
583
  else:
584
  torch.hub.download_url_to_file(url, f, progress=threads == 1) # torch download
 
590
  else:
591
  LOGGER.warning(f'Failed to download {url}...')
592
 
593
+ if unzip and success and f.suffix in ('.zip', '.tar', '.gz'):
594
  LOGGER.info(f'Unzipping {f}...')
595
  if f.suffix == '.zip':
596
  ZipFile(f).extractall(path=dir) # unzip
597
+ elif f.suffix == '.tar':
598
+ os.system(f'tar xf {f} --directory {f.parent}') # unzip
599
  elif f.suffix == '.gz':
600
  os.system(f'tar xfz {f} --directory {f.parent}') # unzip
601
  if delete:
 
605
  dir.mkdir(parents=True, exist_ok=True) # make directory
606
  if threads > 1:
607
  pool = ThreadPool(threads)
608
+ pool.imap(lambda x: download_one(*x), zip(url, repeat(dir))) # multithreaded
609
  pool.close()
610
  pool.join()
611
  else:
 
813
  list of detections, on (n,6) tensor per image [xyxy, conf, cls]
814
  """
815
 
816
+ if isinstance(prediction, (list, tuple)): # YOLOv5 model in validation model, output = (inference_out, loss_out)
817
+ prediction = prediction[0] # select only inference output
818
+
819
  bs = prediction.shape[0] # batch size
820
  nc = prediction.shape[2] - 5 # number of classes
821
  xc = prediction[..., 4] > conf_thres # candidates
utils/metrics.py CHANGED
@@ -11,6 +11,8 @@ import matplotlib.pyplot as plt
11
  import numpy as np
12
  import torch
13
 
 
 
14
 
15
  def fitness(x):
16
  # Model fitness as a weighted combination of metrics
@@ -141,7 +143,7 @@ class ConfusionMatrix:
141
  """
142
  if detections is None:
143
  gt_classes = labels.int()
144
- for i, gc in enumerate(gt_classes):
145
  self.matrix[self.nc, gc] += 1 # background FN
146
  return
147
 
@@ -184,36 +186,35 @@ class ConfusionMatrix:
184
  # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
185
  return tp[:-1], fp[:-1] # remove background class
186
 
 
187
  def plot(self, normalize=True, save_dir='', names=()):
188
- try:
189
- import seaborn as sn
190
-
191
- array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
192
- array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
193
-
194
- fig = plt.figure(figsize=(12, 9), tight_layout=True)
195
- nc, nn = self.nc, len(names) # number of classes, names
196
- sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
197
- labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
198
- with warnings.catch_warnings():
199
- warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
200
- sn.heatmap(array,
201
- annot=nc < 30,
202
- annot_kws={
203
- "size": 8},
204
- cmap='Blues',
205
- fmt='.2f',
206
- square=True,
207
- vmin=0.0,
208
- xticklabels=names + ['background FP'] if labels else "auto",
209
- yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
210
- fig.axes[0].set_xlabel('True')
211
- fig.axes[0].set_ylabel('Predicted')
212
- plt.title('Confusion Matrix')
213
- fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
214
- plt.close()
215
- except Exception as e:
216
- print(f'WARNING: ConfusionMatrix plot failure: {e}')
217
 
218
  def print(self):
219
  for i in range(self.nc + 1):
@@ -320,6 +321,7 @@ def wh_iou(wh1, wh2, eps=1e-7):
320
  # Plots ----------------------------------------------------------------------------------------------------------------
321
 
322
 
 
323
  def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
324
  # Precision-recall curve
325
  fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
@@ -336,12 +338,13 @@ def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
336
  ax.set_ylabel('Precision')
337
  ax.set_xlim(0, 1)
338
  ax.set_ylim(0, 1)
339
- plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
340
- plt.title('Precision-Recall Curve')
341
  fig.savefig(save_dir, dpi=250)
342
- plt.close()
343
 
344
 
 
345
  def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
346
  # Metric-confidence curve
347
  fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
@@ -358,7 +361,7 @@ def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confi
358
  ax.set_ylabel(ylabel)
359
  ax.set_xlim(0, 1)
360
  ax.set_ylim(0, 1)
361
- plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
362
- plt.title(f'{ylabel}-Confidence Curve')
363
  fig.savefig(save_dir, dpi=250)
364
- plt.close()
 
11
  import numpy as np
12
  import torch
13
 
14
+ from utils import TryExcept, threaded
15
+
16
 
17
  def fitness(x):
18
  # Model fitness as a weighted combination of metrics
 
143
  """
144
  if detections is None:
145
  gt_classes = labels.int()
146
+ for gc in gt_classes:
147
  self.matrix[self.nc, gc] += 1 # background FN
148
  return
149
 
 
186
  # fn = self.matrix.sum(0) - tp # false negatives (missed detections)
187
  return tp[:-1], fp[:-1] # remove background class
188
 
189
+ @TryExcept('WARNING: ConfusionMatrix plot failure: ')
190
  def plot(self, normalize=True, save_dir='', names=()):
191
+ import seaborn as sn
192
+
193
+ array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1E-9) if normalize else 1) # normalize columns
194
+ array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
195
+
196
+ fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
197
+ nc, nn = self.nc, len(names) # number of classes, names
198
+ sn.set(font_scale=1.0 if nc < 50 else 0.8) # for label size
199
+ labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
200
+ with warnings.catch_warnings():
201
+ warnings.simplefilter('ignore') # suppress empty matrix RuntimeWarning: All-NaN slice encountered
202
+ sn.heatmap(array,
203
+ ax=ax,
204
+ annot=nc < 30,
205
+ annot_kws={
206
+ "size": 8},
207
+ cmap='Blues',
208
+ fmt='.2f',
209
+ square=True,
210
+ vmin=0.0,
211
+ xticklabels=names + ['background FP'] if labels else "auto",
212
+ yticklabels=names + ['background FN'] if labels else "auto").set_facecolor((1, 1, 1))
213
+ ax.set_ylabel('True')
214
+ ax.set_ylabel('Predicted')
215
+ ax.set_title('Confusion Matrix')
216
+ fig.savefig(Path(save_dir) / 'confusion_matrix.png', dpi=250)
217
+ plt.close(fig)
 
 
218
 
219
  def print(self):
220
  for i in range(self.nc + 1):
 
321
  # Plots ----------------------------------------------------------------------------------------------------------------
322
 
323
 
324
+ @threaded
325
  def plot_pr_curve(px, py, ap, save_dir=Path('pr_curve.png'), names=()):
326
  # Precision-recall curve
327
  fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
 
338
  ax.set_ylabel('Precision')
339
  ax.set_xlim(0, 1)
340
  ax.set_ylim(0, 1)
341
+ ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
342
+ ax.set_title('Precision-Recall Curve')
343
  fig.savefig(save_dir, dpi=250)
344
+ plt.close(fig)
345
 
346
 
347
+ @threaded
348
  def plot_mc_curve(px, py, save_dir=Path('mc_curve.png'), names=(), xlabel='Confidence', ylabel='Metric'):
349
  # Metric-confidence curve
350
  fig, ax = plt.subplots(1, 1, figsize=(9, 6), tight_layout=True)
 
361
  ax.set_ylabel(ylabel)
362
  ax.set_xlim(0, 1)
363
  ax.set_ylim(0, 1)
364
+ ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
365
+ ax.set_title(f'{ylabel}-Confidence Curve')
366
  fig.savefig(save_dir, dpi=250)
367
+ plt.close(fig)
utils/plots.py CHANGED
@@ -3,6 +3,7 @@
3
  Plotting utils
4
  """
5
 
 
6
  import math
7
  import os
8
  from copy import copy
@@ -18,8 +19,9 @@ import seaborn as sn
18
  import torch
19
  from PIL import Image, ImageDraw, ImageFont
20
 
21
- from utils.general import (CONFIG_DIR, FONT, LOGGER, Timeout, check_font, check_requirements, clip_coords,
22
- increment_path, is_ascii, threaded, try_except, xywh2xyxy, xyxy2xywh)
 
23
  from utils.metrics import fitness
24
 
25
  # Settings
@@ -115,10 +117,12 @@ class Annotator:
115
  # Add rectangle to image (PIL-only)
116
  self.draw.rectangle(xy, fill, outline, width)
117
 
118
- def text(self, xy, text, txt_color=(255, 255, 255)):
119
  # Add text to image (PIL-only)
120
- w, h = self.font.getsize(text) # text width, height
121
- self.draw.text((xy[0], xy[1] - h + 1), text, fill=txt_color, font=self.font)
 
 
122
 
123
  def result(self):
124
  # Return annotated image as array
@@ -180,8 +184,7 @@ def output_to_target(output):
180
  # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
181
  targets = []
182
  for i, o in enumerate(output):
183
- for *box, conf, cls in o.cpu().numpy():
184
- targets.append([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf])
185
  return np.array(targets)
186
 
187
 
@@ -221,7 +224,7 @@ def plot_images(images, targets, paths=None, fname='images.jpg', names=None, max
221
  x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
222
  annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
223
  if paths:
224
- annotator.text((x + 5, y + 5 + h), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
225
  if len(targets) > 0:
226
  ti = targets[targets[:, 0] == i] # image targets
227
  boxes = xywh2xyxy(ti[:, 2:6]).T
@@ -339,8 +342,7 @@ def plot_val_study(file='', dir='', x=None): # from utils.plots import *; plot_
339
  plt.savefig(f, dpi=300)
340
 
341
 
342
- @try_except # known issue https://github.com/ultralytics/yolov5/issues/5395
343
- @Timeout(30) # known issue https://github.com/ultralytics/yolov5/issues/5611
344
  def plot_labels(labels, names=(), save_dir=Path('')):
345
  # plot dataset labels
346
  LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
@@ -357,10 +359,8 @@ def plot_labels(labels, names=(), save_dir=Path('')):
357
  matplotlib.use('svg') # faster
358
  ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
359
  y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
360
- try: # color histogram bars by class
361
  [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
362
- except Exception:
363
- pass
364
  ax[0].set_ylabel('instances')
365
  if 0 < len(names) < 30:
366
  ax[0].set_xticks(range(len(names)))
@@ -388,6 +388,35 @@ def plot_labels(labels, names=(), save_dir=Path('')):
388
  plt.close()
389
 
390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
  def plot_evolve(evolve_csv='path/to/evolve.csv'): # from utils.plots import *; plot_evolve()
392
  # Plot evolve.csv hyp evolution results
393
  evolve_csv = Path(evolve_csv)
 
3
  Plotting utils
4
  """
5
 
6
+ import contextlib
7
  import math
8
  import os
9
  from copy import copy
 
19
  import torch
20
  from PIL import Image, ImageDraw, ImageFont
21
 
22
+ from utils import TryExcept, threaded
23
+ from utils.general import (CONFIG_DIR, FONT, LOGGER, check_font, check_requirements, clip_coords, increment_path,
24
+ is_ascii, xywh2xyxy, xyxy2xywh)
25
  from utils.metrics import fitness
26
 
27
  # Settings
 
117
  # Add rectangle to image (PIL-only)
118
  self.draw.rectangle(xy, fill, outline, width)
119
 
120
+ def text(self, xy, text, txt_color=(255, 255, 255), anchor='top'):
121
  # Add text to image (PIL-only)
122
+ if anchor == 'bottom': # start y from font bottom
123
+ w, h = self.font.getsize(text) # text width, height
124
+ xy[1] += 1 - h
125
+ self.draw.text(xy, text, fill=txt_color, font=self.font)
126
 
127
  def result(self):
128
  # Return annotated image as array
 
184
  # Convert model output to target format [batch_id, class_id, x, y, w, h, conf]
185
  targets = []
186
  for i, o in enumerate(output):
187
+ targets.extend([i, cls, *list(*xyxy2xywh(np.array(box)[None])), conf] for *box, conf, cls in o.cpu().numpy())
 
188
  return np.array(targets)
189
 
190
 
 
224
  x, y = int(w * (i // ns)), int(h * (i % ns)) # block origin
225
  annotator.rectangle([x, y, x + w, y + h], None, (255, 255, 255), width=2) # borders
226
  if paths:
227
+ annotator.text((x + 5, y + 5), text=Path(paths[i]).name[:40], txt_color=(220, 220, 220)) # filenames
228
  if len(targets) > 0:
229
  ti = targets[targets[:, 0] == i] # image targets
230
  boxes = xywh2xyxy(ti[:, 2:6]).T
 
342
  plt.savefig(f, dpi=300)
343
 
344
 
345
+ @TryExcept() # known issue https://github.com/ultralytics/yolov5/issues/5395
 
346
  def plot_labels(labels, names=(), save_dir=Path('')):
347
  # plot dataset labels
348
  LOGGER.info(f"Plotting labels to {save_dir / 'labels.jpg'}... ")
 
359
  matplotlib.use('svg') # faster
360
  ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
361
  y = ax[0].hist(c, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
362
+ with contextlib.suppress(Exception): # color histogram bars by class
363
  [y[2].patches[i].set_color([x / 255 for x in colors(i)]) for i in range(nc)] # known issue #3195
 
 
364
  ax[0].set_ylabel('instances')
365
  if 0 < len(names) < 30:
366
  ax[0].set_xticks(range(len(names)))
 
388
  plt.close()
389
 
390
 
391
+ def imshow_cls(im, labels=None, pred=None, names=None, nmax=25, verbose=False, f=Path('images.jpg')):
392
+ # Show classification image grid with labels (optional) and predictions (optional)
393
+ from utils.augmentations import denormalize
394
+
395
+ names = names or [f'class{i}' for i in range(1000)]
396
+ blocks = torch.chunk(denormalize(im.clone()).cpu().float(), len(im),
397
+ dim=0) # select batch index 0, block by channels
398
+ n = min(len(blocks), nmax) # number of plots
399
+ m = min(8, round(n ** 0.5)) # 8 x 8 default
400
+ fig, ax = plt.subplots(math.ceil(n / m), m) # 8 rows x n/8 cols
401
+ ax = ax.ravel() if m > 1 else [ax]
402
+ # plt.subplots_adjust(wspace=0.05, hspace=0.05)
403
+ for i in range(n):
404
+ ax[i].imshow(blocks[i].squeeze().permute((1, 2, 0)).numpy().clip(0.0, 1.0))
405
+ ax[i].axis('off')
406
+ if labels is not None:
407
+ s = names[labels[i]] + (f'—{names[pred[i]]}' if pred is not None else '')
408
+ ax[i].set_title(s, fontsize=8, verticalalignment='top')
409
+ plt.savefig(f, dpi=300, bbox_inches='tight')
410
+ plt.close()
411
+ if verbose:
412
+ LOGGER.info(f"Saving {f}")
413
+ if labels is not None:
414
+ LOGGER.info('True: ' + ' '.join(f'{names[i]:3s}' for i in labels[:nmax]))
415
+ if pred is not None:
416
+ LOGGER.info('Predicted:' + ' '.join(f'{names[i]:3s}' for i in pred[:nmax]))
417
+ return f
418
+
419
+
420
  def plot_evolve(evolve_csv='path/to/evolve.csv'): # from utils.plots import *; plot_evolve()
421
  # Plot evolve.csv hyp evolution results
422
  evolve_csv = Path(evolve_csv)
utils/torch_utils.py CHANGED
@@ -42,6 +42,15 @@ def smart_inference_mode(torch_1_9=check_version(torch.__version__, '1.9.0')):
42
  return decorate
43
 
44
 
 
 
 
 
 
 
 
 
 
45
  def smart_DDP(model):
46
  # Model DDP creation with checks
47
  assert not check_version(torch.__version__, '1.12.0', pinned=True), \
@@ -53,6 +62,28 @@ def smart_DDP(model):
53
  return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
54
 
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  @contextmanager
57
  def torch_distributed_zero_first(local_rank: int):
58
  # Decorator to make all processes in distributed training wait for each local_master to do something
@@ -86,7 +117,7 @@ def select_device(device='', batch_size=0, newline=True):
86
  assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
87
  f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
88
 
89
- if not (cpu or mps) and torch.cuda.is_available(): # prefer GPU if available
90
  devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
91
  n = len(devices) # device count
92
  if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
@@ -117,14 +148,13 @@ def time_sync():
117
 
118
 
119
  def profile(input, ops, n=10, device=None):
120
- # YOLOv5 speed/memory/FLOPs profiler
121
- #
122
- # Usage:
123
- # input = torch.randn(16, 3, 640, 640)
124
- # m1 = lambda x: x * torch.sigmoid(x)
125
- # m2 = nn.SiLU()
126
- # profile(input, [m1, m2], n=100) # profile over 100 iterations
127
-
128
  results = []
129
  if not isinstance(device, torch.device):
130
  device = select_device(device)
@@ -251,7 +281,7 @@ def model_info(model, verbose=False, imgsz=640):
251
  try: # FLOPs
252
  p = next(model.parameters())
253
  stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
254
- im = torch.zeros((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
255
  flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
256
  imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
257
  fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs
@@ -313,6 +343,18 @@ def smart_optimizer(model, name='Adam', lr=0.001, momentum=0.9, decay=1e-5):
313
  return optimizer
314
 
315
 
 
 
 
 
 
 
 
 
 
 
 
 
316
  def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True):
317
  # Resume training from a partially trained checkpoint
318
  best_fitness = 0.0
@@ -365,14 +407,11 @@ class ModelEMA:
365
  def __init__(self, model, decay=0.9999, tau=2000, updates=0):
366
  # Create EMA
367
  self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
368
- # if next(model.parameters()).device.type != 'cpu':
369
- # self.ema.half() # FP16 EMA
370
  self.updates = updates # number of EMA updates
371
  self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
372
  for p in self.ema.parameters():
373
  p.requires_grad_(False)
374
 
375
- @smart_inference_mode()
376
  def update(self, model):
377
  # Update EMA parameters
378
  self.updates += 1
@@ -380,9 +419,10 @@ class ModelEMA:
380
 
381
  msd = de_parallel(model).state_dict() # model state_dict
382
  for k, v in self.ema.state_dict().items():
383
- if v.dtype.is_floating_point:
384
  v *= d
385
  v += (1 - d) * msd[k].detach()
 
386
 
387
  def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
388
  # Update EMA attributes
 
42
  return decorate
43
 
44
 
45
+ def smartCrossEntropyLoss(label_smoothing=0.0):
46
+ # Returns nn.CrossEntropyLoss with label smoothing enabled for torch>=1.10.0
47
+ if check_version(torch.__version__, '1.10.0'):
48
+ return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
49
+ if label_smoothing > 0:
50
+ LOGGER.warning(f'WARNING: label smoothing {label_smoothing} requires torch>=1.10.0')
51
+ return nn.CrossEntropyLoss()
52
+
53
+
54
  def smart_DDP(model):
55
  # Model DDP creation with checks
56
  assert not check_version(torch.__version__, '1.12.0', pinned=True), \
 
62
  return DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
63
 
64
 
65
+ def reshape_classifier_output(model, n=1000):
66
+ # Update a TorchVision classification model to class count 'n' if required
67
+ from models.common import Classify
68
+ name, m = list((model.model if hasattr(model, 'model') else model).named_children())[-1] # last module
69
+ if isinstance(m, Classify): # YOLOv5 Classify() head
70
+ if m.linear.out_features != n:
71
+ m.linear = nn.Linear(m.linear.in_features, n)
72
+ elif isinstance(m, nn.Linear): # ResNet, EfficientNet
73
+ if m.out_features != n:
74
+ setattr(model, name, nn.Linear(m.in_features, n))
75
+ elif isinstance(m, nn.Sequential):
76
+ types = [type(x) for x in m]
77
+ if nn.Linear in types:
78
+ i = types.index(nn.Linear) # nn.Linear index
79
+ if m[i].out_features != n:
80
+ m[i] = nn.Linear(m[i].in_features, n)
81
+ elif nn.Conv2d in types:
82
+ i = types.index(nn.Conv2d) # nn.Conv2d index
83
+ if m[i].out_channels != n:
84
+ m[i] = nn.Conv2d(m[i].in_channels, n, m[i].kernel_size, m[i].stride, bias=m[i].bias)
85
+
86
+
87
  @contextmanager
88
  def torch_distributed_zero_first(local_rank: int):
89
  # Decorator to make all processes in distributed training wait for each local_master to do something
 
117
  assert torch.cuda.is_available() and torch.cuda.device_count() >= len(device.replace(',', '')), \
118
  f"Invalid CUDA '--device {device}' requested, use '--device cpu' or pass valid CUDA device(s)"
119
 
120
+ if not cpu and not mps and torch.cuda.is_available(): # prefer GPU if available
121
  devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
122
  n = len(devices) # device count
123
  if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
 
148
 
149
 
150
  def profile(input, ops, n=10, device=None):
151
+ """ YOLOv5 speed/memory/FLOPs profiler
152
+ Usage:
153
+ input = torch.randn(16, 3, 640, 640)
154
+ m1 = lambda x: x * torch.sigmoid(x)
155
+ m2 = nn.SiLU()
156
+ profile(input, [m1, m2], n=100) # profile over 100 iterations
157
+ """
 
158
  results = []
159
  if not isinstance(device, torch.device):
160
  device = select_device(device)
 
281
  try: # FLOPs
282
  p = next(model.parameters())
283
  stride = max(int(model.stride.max()), 32) if hasattr(model, 'stride') else 32 # max stride
284
+ im = torch.empty((1, p.shape[1], stride, stride), device=p.device) # input image in BCHW format
285
  flops = thop.profile(deepcopy(model), inputs=(im,), verbose=False)[0] / 1E9 * 2 # stride GFLOPs
286
  imgsz = imgsz if isinstance(imgsz, list) else [imgsz, imgsz] # expand if int/float
287
  fs = f', {flops * imgsz[0] / stride * imgsz[1] / stride:.1f} GFLOPs' # 640x640 GFLOPs
 
343
  return optimizer
344
 
345
 
346
+ def smart_hub_load(repo='ultralytics/yolov5', model='yolov5s', **kwargs):
347
+ # YOLOv5 torch.hub.load() wrapper with smart error/issue handling
348
+ if check_version(torch.__version__, '1.9.1'):
349
+ kwargs['skip_validation'] = True # validation causes GitHub API rate limit errors
350
+ if check_version(torch.__version__, '1.12.0'):
351
+ kwargs['trust_repo'] = True # argument required starting in torch 0.12
352
+ try:
353
+ return torch.hub.load(repo, model, **kwargs)
354
+ except Exception:
355
+ return torch.hub.load(repo, model, force_reload=True, **kwargs)
356
+
357
+
358
  def smart_resume(ckpt, optimizer, ema=None, weights='yolov5s.pt', epochs=300, resume=True):
359
  # Resume training from a partially trained checkpoint
360
  best_fitness = 0.0
 
407
  def __init__(self, model, decay=0.9999, tau=2000, updates=0):
408
  # Create EMA
409
  self.ema = deepcopy(de_parallel(model)).eval() # FP32 EMA
 
 
410
  self.updates = updates # number of EMA updates
411
  self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
412
  for p in self.ema.parameters():
413
  p.requires_grad_(False)
414
 
 
415
  def update(self, model):
416
  # Update EMA parameters
417
  self.updates += 1
 
419
 
420
  msd = de_parallel(model).state_dict() # model state_dict
421
  for k, v in self.ema.state_dict().items():
422
+ if v.dtype.is_floating_point: # true for FP16 and FP32
423
  v *= d
424
  v += (1 - d) * msd[k].detach()
425
+ # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
426
 
427
  def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
428
  # Update EMA attributes
val.py CHANGED
@@ -1,21 +1,21 @@
1
  # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
  """
3
- Validate a trained YOLOv5 model accuracy on a custom dataset
4
 
5
  Usage:
6
- $ python path/to/val.py --weights yolov5s.pt --data coco128.yaml --img 640
7
 
8
  Usage - formats:
9
- $ python path/to/val.py --weights yolov5s.pt # PyTorch
10
- yolov5s.torchscript # TorchScript
11
- yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
12
- yolov5s.xml # OpenVINO
13
- yolov5s.engine # TensorRT
14
- yolov5s.mlmodel # CoreML (macOS-only)
15
- yolov5s_saved_model # TensorFlow SavedModel
16
- yolov5s.pb # TensorFlow GraphDef
17
- yolov5s.tflite # TensorFlow Lite
18
- yolov5s_edgetpu.tflite # TensorFlow Edge TPU
19
  """
20
 
21
  import argparse
@@ -37,12 +37,12 @@ ROOT = Path(os.path.relpath(ROOT, Path.cwd())) # relative
37
  from models.common import DetectMultiBackend
38
  from utils.callbacks import Callbacks
39
  from utils.dataloaders import create_dataloader
40
- from utils.general import (LOGGER, check_dataset, check_img_size, check_requirements, check_yaml,
41
  coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
42
  scale_coords, xywh2xyxy, xyxy2xywh)
43
  from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
44
  from utils.plots import output_to_target, plot_images, plot_val_study
45
- from utils.torch_utils import select_device, smart_inference_mode, time_sync
46
 
47
 
48
  def save_one_txt(predn, save_conf, shape, file):
@@ -182,40 +182,39 @@ def run(
182
 
183
  seen = 0
184
  confusion_matrix = ConfusionMatrix(nc=nc)
185
- names = dict(enumerate(model.names if hasattr(model, 'names') else model.module.names))
 
 
186
  class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
187
- s = ('%20s' + '%11s' * 6) % ('Class', 'Images', 'Labels', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
188
- dt, p, r, f1, mp, mr, map50, map = [0.0, 0.0, 0.0], 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
189
  loss = torch.zeros(3, device=device)
190
  jdict, stats, ap, ap_class = [], [], [], []
191
  callbacks.run('on_val_start')
192
  pbar = tqdm(dataloader, desc=s, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
193
  for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
194
  callbacks.run('on_val_batch_start')
195
- t1 = time_sync()
196
- if cuda:
197
- im = im.to(device, non_blocking=True)
198
- targets = targets.to(device)
199
- im = im.half() if half else im.float() # uint8 to fp16/32
200
- im /= 255 # 0 - 255 to 0.0 - 1.0
201
- nb, _, height, width = im.shape # batch size, channels, height, width
202
- t2 = time_sync()
203
- dt[0] += t2 - t1
204
 
205
  # Inference
206
- out, train_out = model(im) if training else model(im, augment=augment, val=True) # inference, loss outputs
207
- dt[1] += time_sync() - t2
208
 
209
  # Loss
210
  if compute_loss:
211
- loss += compute_loss([x.float() for x in train_out], targets)[1] # box, obj, cls
212
 
213
  # NMS
214
  targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels
215
  lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
216
- t3 = time_sync()
217
- out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
218
- dt[2] += time_sync() - t3
219
 
220
  # Metrics
221
  for si, pred in enumerate(out):
@@ -271,7 +270,7 @@ def run(
271
  nt = np.bincount(stats[3].astype(int), minlength=nc) # number of targets per class
272
 
273
  # Print results
274
- pf = '%20s' + '%11i' * 2 + '%11.3g' * 4 # print format
275
  LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
276
  if nt.sum() == 0:
277
  LOGGER.warning(f'WARNING: no labels found in {task} set, can not compute metrics without labels ⚠️')
@@ -282,7 +281,7 @@ def run(
282
  LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
283
 
284
  # Print speeds
285
- t = tuple(x / seen * 1E3 for x in dt) # speeds per image
286
  if not training:
287
  shape = (batch_size, 3, imgsz, imgsz)
288
  LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)
@@ -366,6 +365,8 @@ def main(opt):
366
  if opt.task in ('train', 'val', 'test'): # run normally
367
  if opt.conf_thres > 0.001: # https://github.com/ultralytics/yolov5/issues/1466
368
  LOGGER.info(f'WARNING: confidence threshold {opt.conf_thres} > 0.001 produces invalid results ⚠️')
 
 
369
  run(**vars(opt))
370
 
371
  else:
 
1
  # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
2
  """
3
+ Validate a trained YOLOv5 detection model on a detection dataset
4
 
5
  Usage:
6
+ $ python val.py --weights yolov5s.pt --data coco128.yaml --img 640
7
 
8
  Usage - formats:
9
+ $ python val.py --weights yolov5s.pt # PyTorch
10
+ yolov5s.torchscript # TorchScript
11
+ yolov5s.onnx # ONNX Runtime or OpenCV DNN with --dnn
12
+ yolov5s.xml # OpenVINO
13
+ yolov5s.engine # TensorRT
14
+ yolov5s.mlmodel # CoreML (macOS-only)
15
+ yolov5s_saved_model # TensorFlow SavedModel
16
+ yolov5s.pb # TensorFlow GraphDef
17
+ yolov5s.tflite # TensorFlow Lite
18
+ yolov5s_edgetpu.tflite # TensorFlow Edge TPU
19
  """
20
 
21
  import argparse
 
37
  from models.common import DetectMultiBackend
38
  from utils.callbacks import Callbacks
39
  from utils.dataloaders import create_dataloader
40
+ from utils.general import (LOGGER, Profile, check_dataset, check_img_size, check_requirements, check_yaml,
41
  coco80_to_coco91_class, colorstr, increment_path, non_max_suppression, print_args,
42
  scale_coords, xywh2xyxy, xyxy2xywh)
43
  from utils.metrics import ConfusionMatrix, ap_per_class, box_iou
44
  from utils.plots import output_to_target, plot_images, plot_val_study
45
+ from utils.torch_utils import select_device, smart_inference_mode
46
 
47
 
48
  def save_one_txt(predn, save_conf, shape, file):
 
182
 
183
  seen = 0
184
  confusion_matrix = ConfusionMatrix(nc=nc)
185
+ names = model.names if hasattr(model, 'names') else model.module.names # get class names
186
+ if isinstance(names, (list, tuple)): # old format
187
+ names = dict(enumerate(names))
188
  class_map = coco80_to_coco91_class() if is_coco else list(range(1000))
189
+ s = ('%22s' + '%11s' * 6) % ('Class', 'Images', 'Instances', 'P', 'R', 'mAP@.5', 'mAP@.5:.95')
190
+ dt, p, r, f1, mp, mr, map50, map = (Profile(), Profile(), Profile()), 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
191
  loss = torch.zeros(3, device=device)
192
  jdict, stats, ap, ap_class = [], [], [], []
193
  callbacks.run('on_val_start')
194
  pbar = tqdm(dataloader, desc=s, bar_format='{l_bar}{bar:10}{r_bar}{bar:-10b}') # progress bar
195
  for batch_i, (im, targets, paths, shapes) in enumerate(pbar):
196
  callbacks.run('on_val_batch_start')
197
+ with dt[0]:
198
+ if cuda:
199
+ im = im.to(device, non_blocking=True)
200
+ targets = targets.to(device)
201
+ im = im.half() if half else im.float() # uint8 to fp16/32
202
+ im /= 255 # 0 - 255 to 0.0 - 1.0
203
+ nb, _, height, width = im.shape # batch size, channels, height, width
 
 
204
 
205
  # Inference
206
+ with dt[1]:
207
+ out, train_out = model(im) if compute_loss else (model(im, augment=augment), None)
208
 
209
  # Loss
210
  if compute_loss:
211
+ loss += compute_loss(train_out, targets)[1] # box, obj, cls
212
 
213
  # NMS
214
  targets[:, 2:] *= torch.tensor((width, height, width, height), device=device) # to pixels
215
  lb = [targets[targets[:, 0] == i, 1:] for i in range(nb)] if save_hybrid else [] # for autolabelling
216
+ with dt[2]:
217
+ out = non_max_suppression(out, conf_thres, iou_thres, labels=lb, multi_label=True, agnostic=single_cls)
 
218
 
219
  # Metrics
220
  for si, pred in enumerate(out):
 
270
  nt = np.bincount(stats[3].astype(int), minlength=nc) # number of targets per class
271
 
272
  # Print results
273
+ pf = '%22s' + '%11i' * 2 + '%11.3g' * 4 # print format
274
  LOGGER.info(pf % ('all', seen, nt.sum(), mp, mr, map50, map))
275
  if nt.sum() == 0:
276
  LOGGER.warning(f'WARNING: no labels found in {task} set, can not compute metrics without labels ⚠️')
 
281
  LOGGER.info(pf % (names[c], seen, nt[c], p[i], r[i], ap50[i], ap[i]))
282
 
283
  # Print speeds
284
+ t = tuple(x.t / seen * 1E3 for x in dt) # speeds per image
285
  if not training:
286
  shape = (batch_size, 3, imgsz, imgsz)
287
  LOGGER.info(f'Speed: %.1fms pre-process, %.1fms inference, %.1fms NMS per image at shape {shape}' % t)
 
365
  if opt.task in ('train', 'val', 'test'): # run normally
366
  if opt.conf_thres > 0.001: # https://github.com/ultralytics/yolov5/issues/1466
367
  LOGGER.info(f'WARNING: confidence threshold {opt.conf_thres} > 0.001 produces invalid results ⚠️')
368
+ if opt.save_hybrid:
369
+ LOGGER.info('WARNING: --save-hybrid will return high mAP from hybrid labels, not from predictions alone ⚠️')
370
  run(**vars(opt))
371
 
372
  else: