Marc commited on
Commit
a925f28
·
unverified ·
1 Parent(s): 7875f4c

max workers for dataloader (#722)

Browse files
Files changed (2) hide show
  1. train.py +4 -2
  2. utils/datasets.py +2 -2
train.py CHANGED
@@ -159,7 +159,7 @@ def train(hyp, opt, device, tb_writer=None):
159
  # Trainloader
160
  dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
161
  cache=opt.cache_images, rect=opt.rect, rank=rank,
162
- world_size=opt.world_size)
163
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
164
  nb = len(dataloader) # number of batches
165
  assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
@@ -168,7 +168,8 @@ def train(hyp, opt, device, tb_writer=None):
168
  if rank in [-1, 0]:
169
  # local_rank is set to -1. Because only the first process is expected to do evaluation.
170
  testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
171
- cache=opt.cache_images, rect=True, rank=-1, world_size=opt.world_size)[0]
 
172
 
173
  # Model parameters
174
  hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
@@ -403,6 +404,7 @@ if __name__ == '__main__':
403
  parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
404
  parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
405
  parser.add_argument('--logdir', type=str, default='runs/', help='logging directory')
 
406
  opt = parser.parse_args()
407
 
408
  # Set DDP variables
 
159
  # Trainloader
160
  dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
161
  cache=opt.cache_images, rect=opt.rect, rank=rank,
162
+ world_size=opt.world_size, workers=opt.workers)
163
  mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
164
  nb = len(dataloader) # number of batches
165
  assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
 
168
  if rank in [-1, 0]:
169
  # local_rank is set to -1. Because only the first process is expected to do evaluation.
170
  testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
171
+ cache=opt.cache_images, rect=True, rank=-1, world_size=opt.world_size,
172
+ workers=opt.workers)[0]
173
 
174
  # Model parameters
175
  hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
 
404
  parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
405
  parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
406
  parser.add_argument('--logdir', type=str, default='runs/', help='logging directory')
407
+ parser.add_argument('--workers', type=int, default=8, help='maximum number of workers for dataloader')
408
  opt = parser.parse_args()
409
 
410
  # Set DDP variables
utils/datasets.py CHANGED
@@ -47,7 +47,7 @@ def exif_size(img):
47
 
48
 
49
  def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
50
- rank=-1, world_size=1):
51
  # Make sure only the first process in DDP process the dataset first, and the following others can use the cache.
52
  with torch_distributed_zero_first(rank):
53
  dataset = LoadImagesAndLabels(path, imgsz, batch_size,
@@ -61,7 +61,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
61
  rank=rank)
62
 
63
  batch_size = min(batch_size, len(dataset))
64
- nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, 8]) # number of workers
65
  train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
66
  dataloader = torch.utils.data.DataLoader(dataset,
67
  batch_size=batch_size,
 
47
 
48
 
49
  def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
50
+ rank=-1, world_size=1, workers=8):
51
  # Make sure only the first process in DDP process the dataset first, and the following others can use the cache.
52
  with torch_distributed_zero_first(rank):
53
  dataset = LoadImagesAndLabels(path, imgsz, batch_size,
 
61
  rank=rank)
62
 
63
  batch_size = min(batch_size, len(dataset))
64
+ nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
65
  train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
66
  dataloader = torch.utils.data.DataLoader(dataset,
67
  batch_size=batch_size,