Dean Mark glenn-jocher commited on
Commit
28bff22
1 Parent(s): c058a61

Use multi-threading in cache_labels (#3505)

Browse files

* Use multi threading in cache_labels

* PEP8 reformat

* Add num_threads

* changed ThreadPool.imap_unordered to Pool.imap_unordered

* Remove inplace additions

* Update datasets.py

refactor initial desc

Co-authored-by: Glenn Jocher <glenn.jocher@ultralytics.com>

Files changed (1) hide show
  1. utils/datasets.py +56 -43
utils/datasets.py CHANGED
@@ -9,7 +9,7 @@ import random
9
  import shutil
10
  import time
11
  from itertools import repeat
12
- from multiprocessing.pool import ThreadPool
13
  from pathlib import Path
14
  from threading import Thread
15
 
@@ -29,6 +29,7 @@ from utils.torch_utils import torch_distributed_zero_first
29
  help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
30
  img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
31
  vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
 
32
  logger = logging.getLogger(__name__)
33
 
34
  # Get orientation exif tag
@@ -447,7 +448,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
447
  if cache_images:
448
  gb = 0 # Gigabytes of cached images
449
  self.img_hw0, self.img_hw = [None] * n, [None] * n
450
- results = ThreadPool(8).imap(lambda x: load_image(*x), zip(repeat(self), range(n))) # 8 threads
451
  pbar = tqdm(enumerate(results), total=n)
452
  for i, x in pbar:
453
  self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
@@ -458,53 +459,24 @@ class LoadImagesAndLabels(Dataset): # for training/testing
458
  def cache_labels(self, path=Path('./labels.cache'), prefix=''):
459
  # Cache dataset labels, check images and read shapes
460
  x = {} # dict
461
- nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, duplicate
462
- pbar = tqdm(zip(self.img_files, self.label_files), desc='Scanning images', total=len(self.img_files))
463
- for i, (im_file, lb_file) in enumerate(pbar):
464
- try:
465
- # verify images
466
- im = Image.open(im_file)
467
- im.verify() # PIL verify
468
- shape = exif_size(im) # image size
469
- segments = [] # instance segments
470
- assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
471
- assert im.format.lower() in img_formats, f'invalid image format {im.format}'
472
-
473
- # verify labels
474
- if os.path.isfile(lb_file):
475
- nf += 1 # label found
476
- with open(lb_file, 'r') as f:
477
- l = [x.split() for x in f.read().strip().splitlines() if len(x)]
478
- if any([len(x) > 8 for x in l]): # is segment
479
- classes = np.array([x[0] for x in l], dtype=np.float32)
480
- segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
481
- l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
482
- l = np.array(l, dtype=np.float32)
483
- if len(l):
484
- assert l.shape[1] == 5, 'labels require 5 columns each'
485
- assert (l >= 0).all(), 'negative labels'
486
- assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
487
- assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
488
- else:
489
- ne += 1 # label empty
490
- l = np.zeros((0, 5), dtype=np.float32)
491
- else:
492
- nm += 1 # label missing
493
- l = np.zeros((0, 5), dtype=np.float32)
494
- x[im_file] = [l, shape, segments]
495
- except Exception as e:
496
- nc += 1
497
- logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
498
-
499
- pbar.desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels... " \
500
- f"{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
501
  pbar.close()
502
 
503
  if nf == 0:
504
  logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
505
 
506
  x['hash'] = get_hash(self.label_files + self.img_files)
507
- x['results'] = nf, nm, ne, nc, i + 1
508
  x['version'] = 0.2 # cache version
509
  try:
510
  torch.save(x, path) # save cache for next time
@@ -1069,3 +1041,44 @@ def autosplit(path='../coco128', weights=(0.9, 0.1, 0.0), annotated_only=False):
1069
  if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
1070
  with open(path / txt[i], 'a') as f:
1071
  f.write(str(img) + '\n') # add image to txt file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  import shutil
10
  import time
11
  from itertools import repeat
12
+ from multiprocessing.pool import ThreadPool, Pool
13
  from pathlib import Path
14
  from threading import Thread
15
 
 
29
  help_url = 'https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data'
30
  img_formats = ['bmp', 'jpg', 'jpeg', 'png', 'tif', 'tiff', 'dng', 'webp', 'mpo'] # acceptable image suffixes
31
  vid_formats = ['mov', 'avi', 'mp4', 'mpg', 'mpeg', 'm4v', 'wmv', 'mkv'] # acceptable video suffixes
32
+ num_threads = min(8, os.cpu_count()) # number of multiprocessing threads
33
  logger = logging.getLogger(__name__)
34
 
35
  # Get orientation exif tag
 
448
  if cache_images:
449
  gb = 0 # Gigabytes of cached images
450
  self.img_hw0, self.img_hw = [None] * n, [None] * n
451
+ results = ThreadPool(num_threads).imap(lambda x: load_image(*x), zip(repeat(self), range(n)))
452
  pbar = tqdm(enumerate(results), total=n)
453
  for i, x in pbar:
454
  self.imgs[i], self.img_hw0[i], self.img_hw[i] = x # img, hw_original, hw_resized = load_image(self, i)
 
459
  def cache_labels(self, path=Path('./labels.cache'), prefix=''):
460
  # Cache dataset labels, check images and read shapes
461
  x = {} # dict
462
+ nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, corrupt
463
+ desc = f"{prefix}Scanning '{path.parent / path.stem}' images and labels..."
464
+ with Pool(num_threads) as pool:
465
+ pbar = tqdm(pool.imap_unordered(verify_image_label,
466
+ zip(self.img_files, self.label_files, repeat(prefix))),
467
+ desc=desc, total=len(self.img_files))
468
+ for im_file, l, shape, segments, nm_f, nf_f, ne_f, nc_f in pbar:
469
+ if im_file:
470
+ x[im_file] = [l, shape, segments]
471
+ nm, nf, ne, nc = nm + nm_f, nf + nf_f, ne + ne_f, nc + nc_f
472
+ pbar.desc = f"{desc}{nf} found, {nm} missing, {ne} empty, {nc} corrupted"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
473
  pbar.close()
474
 
475
  if nf == 0:
476
  logging.info(f'{prefix}WARNING: No labels found in {path}. See {help_url}')
477
 
478
  x['hash'] = get_hash(self.label_files + self.img_files)
479
+ x['results'] = nf, nm, ne, nc, len(self.img_files)
480
  x['version'] = 0.2 # cache version
481
  try:
482
  torch.save(x, path) # save cache for next time
 
1041
  if not annotated_only or Path(img2label_paths([str(img)])[0]).exists(): # check label
1042
  with open(path / txt[i], 'a') as f:
1043
  f.write(str(img) + '\n') # add image to txt file
1044
+
1045
+
1046
+ def verify_image_label(params):
1047
+ # Verify one image-label pair
1048
+ im_file, lb_file, prefix = params
1049
+ nm, nf, ne, nc = 0, 0, 0, 0 # number missing, found, empty, corrupt
1050
+ try:
1051
+ # verify images
1052
+ im = Image.open(im_file)
1053
+ im.verify() # PIL verify
1054
+ shape = exif_size(im) # image size
1055
+ segments = [] # instance segments
1056
+ assert (shape[0] > 9) & (shape[1] > 9), f'image size {shape} <10 pixels'
1057
+ assert im.format.lower() in img_formats, f'invalid image format {im.format}'
1058
+
1059
+ # verify labels
1060
+ if os.path.isfile(lb_file):
1061
+ nf = 1 # label found
1062
+ with open(lb_file, 'r') as f:
1063
+ l = [x.split() for x in f.read().strip().splitlines() if len(x)]
1064
+ if any([len(x) > 8 for x in l]): # is segment
1065
+ classes = np.array([x[0] for x in l], dtype=np.float32)
1066
+ segments = [np.array(x[1:], dtype=np.float32).reshape(-1, 2) for x in l] # (cls, xy1...)
1067
+ l = np.concatenate((classes.reshape(-1, 1), segments2boxes(segments)), 1) # (cls, xywh)
1068
+ l = np.array(l, dtype=np.float32)
1069
+ if len(l):
1070
+ assert l.shape[1] == 5, 'labels require 5 columns each'
1071
+ assert (l >= 0).all(), 'negative labels'
1072
+ assert (l[:, 1:] <= 1).all(), 'non-normalized or out of bounds coordinate labels'
1073
+ assert np.unique(l, axis=0).shape[0] == l.shape[0], 'duplicate labels'
1074
+ else:
1075
+ ne = 1 # label empty
1076
+ l = np.zeros((0, 5), dtype=np.float32)
1077
+ else:
1078
+ nm = 1 # label missing
1079
+ l = np.zeros((0, 5), dtype=np.float32)
1080
+ return im_file, l, shape, segments, nm, nf, ne, nc
1081
+ except Exception as e:
1082
+ nc = 1
1083
+ logging.info(f'{prefix}WARNING: Ignoring corrupted image and/or label {im_file}: {e}')
1084
+ return [None] * 4 + [nm, nf, ne, nc]