glenn-jocher commited on
Commit
12499f1
1 Parent(s): 9728e2b

--image_weights bug fix (#1524)

Browse files
Files changed (1) hide show
  1. utils/datasets.py +8 -6
utils/datasets.py CHANGED
@@ -72,12 +72,14 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
72
  batch_size = min(batch_size, len(dataset))
73
  nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
74
  sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
75
- dataloader = InfiniteDataLoader(dataset,
76
- batch_size=batch_size,
77
- num_workers=nw,
78
- sampler=sampler,
79
- pin_memory=True,
80
- collate_fn=LoadImagesAndLabels.collate_fn) # torch.utils.data.DataLoader()
 
 
81
  return dataloader, dataset
82
 
83
 
 
72
  batch_size = min(batch_size, len(dataset))
73
  nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
74
  sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
75
+ loader = torch.utils.data.DataLoader if image_weights else InfiniteDataLoader
76
+ # Use torch.utils.data.DataLoader() if dataset.properties will update during training else InfiniteDataLoader()
77
+ dataloader = loader(dataset,
78
+ batch_size=batch_size,
79
+ num_workers=nw,
80
+ sampler=sampler,
81
+ pin_memory=True,
82
+ collate_fn=LoadImagesAndLabels.collate_fn)
83
  return dataloader, dataset
84
 
85