glenn-jocher commited on
Commit
481d46c
1 Parent(s): d61930e

Improved corruption handling during scan and cache (#999)

Browse files
Files changed (1) hide show
  1. utils/datasets.py +16 -17
utils/datasets.py CHANGED
@@ -328,6 +328,14 @@ class LoadStreams: # multiple IP or RTSP cameras
328
  class LoadImagesAndLabels(Dataset): # for training/testing
329
  def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
330
  cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
 
 
 
 
 
 
 
 
331
 
332
  def img2label_paths(img_paths):
333
  # Define label paths as a function of image paths
@@ -349,25 +357,10 @@ class LoadImagesAndLabels(Dataset): # for training/testing
349
  raise Exception('%s does not exist' % p)
350
  self.img_files = sorted(
351
  [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats])
 
352
  except Exception as e:
353
  raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
354
 
355
- n = len(self.img_files)
356
- assert n > 0, 'No images found in %s. See %s' % (path, help_url)
357
- bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
358
- nb = bi[-1] + 1 # number of batches
359
-
360
- self.n = n # number of images
361
- self.batch = bi # batch index of image
362
- self.img_size = img_size
363
- self.augment = augment
364
- self.hyp = hyp
365
- self.image_weights = image_weights
366
- self.rect = False if image_weights else rect
367
- self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
368
- self.mosaic_border = [-img_size // 2, -img_size // 2]
369
- self.stride = stride
370
-
371
  # Check cache
372
  self.label_files = img2label_paths(self.img_files) # labels
373
  cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
@@ -386,6 +379,12 @@ class LoadImagesAndLabels(Dataset): # for training/testing
386
  self.img_files = list(cache.keys()) # update
387
  self.label_files = img2label_paths(cache.keys()) # update
388
 
 
 
 
 
 
 
389
  # Rectangular Training
390
  if self.rect:
391
  # Sort by aspect ratio
@@ -500,7 +499,7 @@ class LoadImagesAndLabels(Dataset): # for training/testing
500
  l = np.zeros((0, 5), dtype=np.float32)
501
  x[img] = [l, shape]
502
  except Exception as e:
503
- print('WARNING: Ignoring corrupted image and/or label:%s: %s' % (img, e))
504
 
505
  x['hash'] = get_hash(self.label_files + self.img_files)
506
  torch.save(x, path) # save for next time
 
328
  class LoadImagesAndLabels(Dataset): # for training/testing
329
  def __init__(self, path, img_size=640, batch_size=16, augment=False, hyp=None, rect=False, image_weights=False,
330
  cache_images=False, single_cls=False, stride=32, pad=0.0, rank=-1):
331
+ self.img_size = img_size
332
+ self.augment = augment
333
+ self.hyp = hyp
334
+ self.image_weights = image_weights
335
+ self.rect = False if image_weights else rect
336
+ self.mosaic = self.augment and not self.rect # load 4 images at a time into a mosaic (only during training)
337
+ self.mosaic_border = [-img_size // 2, -img_size // 2]
338
+ self.stride = stride
339
 
340
  def img2label_paths(img_paths):
341
  # Define label paths as a function of image paths
 
357
  raise Exception('%s does not exist' % p)
358
  self.img_files = sorted(
359
  [x.replace('/', os.sep) for x in f if os.path.splitext(x)[-1].lower() in img_formats])
360
+ assert len(self.img_files) > 0, 'No images found'
361
  except Exception as e:
362
  raise Exception('Error loading data from %s: %s\nSee %s' % (path, e, help_url))
363
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  # Check cache
365
  self.label_files = img2label_paths(self.img_files) # labels
366
  cache_path = str(Path(self.label_files[0]).parent) + '.cache' # cached labels
 
379
  self.img_files = list(cache.keys()) # update
380
  self.label_files = img2label_paths(cache.keys()) # update
381
 
382
+ n = len(shapes) # number of images
383
+ bi = np.floor(np.arange(n) / batch_size).astype(np.int) # batch index
384
+ nb = bi[-1] + 1 # number of batches
385
+ self.batch = bi # batch index of image
386
+ self.n = n
387
+
388
  # Rectangular Training
389
  if self.rect:
390
  # Sort by aspect ratio
 
499
  l = np.zeros((0, 5), dtype=np.float32)
500
  x[img] = [l, shape]
501
  except Exception as e:
502
+ print('WARNING: Ignoring corrupted image and/or label %s: %s' % (img, e))
503
 
504
  x['hash'] = get_hash(self.label_files + self.img_files)
505
  torch.save(x, path) # save for next time