Marc
commited on
max workers for dataloader (#722)
Browse files- train.py +4 -2
- utils/datasets.py +2 -2
train.py
CHANGED
@@ -159,7 +159,7 @@ def train(hyp, opt, device, tb_writer=None):
|
|
159 |
# Trainloader
|
160 |
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
|
161 |
cache=opt.cache_images, rect=opt.rect, rank=rank,
|
162 |
-
world_size=opt.world_size)
|
163 |
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
164 |
nb = len(dataloader) # number of batches
|
165 |
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
|
@@ -168,7 +168,8 @@ def train(hyp, opt, device, tb_writer=None):
|
|
168 |
if rank in [-1, 0]:
|
169 |
# local_rank is set to -1. Because only the first process is expected to do evaluation.
|
170 |
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
|
171 |
-
cache=opt.cache_images, rect=True, rank=-1, world_size=opt.world_size
|
|
|
172 |
|
173 |
# Model parameters
|
174 |
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
|
@@ -403,6 +404,7 @@ if __name__ == '__main__':
|
|
403 |
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
|
404 |
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
|
405 |
parser.add_argument('--logdir', type=str, default='runs/', help='logging directory')
|
|
|
406 |
opt = parser.parse_args()
|
407 |
|
408 |
# Set DDP variables
|
|
|
159 |
# Trainloader
|
160 |
dataloader, dataset = create_dataloader(train_path, imgsz, batch_size, gs, opt, hyp=hyp, augment=True,
|
161 |
cache=opt.cache_images, rect=opt.rect, rank=rank,
|
162 |
+
world_size=opt.world_size, workers=opt.workers)
|
163 |
mlc = np.concatenate(dataset.labels, 0)[:, 0].max() # max label class
|
164 |
nb = len(dataloader) # number of batches
|
165 |
assert mlc < nc, 'Label class %g exceeds nc=%g in %s. Possible class labels are 0-%g' % (mlc, nc, opt.data, nc - 1)
|
|
|
168 |
if rank in [-1, 0]:
|
169 |
# local_rank is set to -1. Because only the first process is expected to do evaluation.
|
170 |
testloader = create_dataloader(test_path, imgsz_test, total_batch_size, gs, opt, hyp=hyp, augment=False,
|
171 |
+
cache=opt.cache_images, rect=True, rank=-1, world_size=opt.world_size,
|
172 |
+
workers=opt.workers)[0]
|
173 |
|
174 |
# Model parameters
|
175 |
hyp['cls'] *= nc / 80. # scale coco-tuned hyp['cls'] to current dataset
|
|
|
404 |
parser.add_argument('--sync-bn', action='store_true', help='use SyncBatchNorm, only available in DDP mode')
|
405 |
parser.add_argument('--local_rank', type=int, default=-1, help='DDP parameter, do not modify')
|
406 |
parser.add_argument('--logdir', type=str, default='runs/', help='logging directory')
|
407 |
+
parser.add_argument('--workers', type=int, default=8, help='maximum number of workers for dataloader')
|
408 |
opt = parser.parse_args()
|
409 |
|
410 |
# Set DDP variables
|
utils/datasets.py
CHANGED
@@ -47,7 +47,7 @@ def exif_size(img):
|
|
47 |
|
48 |
|
49 |
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
|
50 |
-
rank=-1, world_size=1):
|
51 |
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache.
|
52 |
with torch_distributed_zero_first(rank):
|
53 |
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
@@ -61,7 +61,7 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
|
|
61 |
rank=rank)
|
62 |
|
63 |
batch_size = min(batch_size, len(dataset))
|
64 |
-
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0,
|
65 |
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
|
66 |
dataloader = torch.utils.data.DataLoader(dataset,
|
67 |
batch_size=batch_size,
|
|
|
47 |
|
48 |
|
49 |
def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=False, cache=False, pad=0.0, rect=False,
|
50 |
+
rank=-1, world_size=1, workers=8):
|
51 |
# Make sure only the first process in DDP process the dataset first, and the following others can use the cache.
|
52 |
with torch_distributed_zero_first(rank):
|
53 |
dataset = LoadImagesAndLabels(path, imgsz, batch_size,
|
|
|
61 |
rank=rank)
|
62 |
|
63 |
batch_size = min(batch_size, len(dataset))
|
64 |
+
nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
|
65 |
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
|
66 |
dataloader = torch.utils.data.DataLoader(dataset,
|
67 |
batch_size=batch_size,
|