glenn-jocher commited on
Commit
d3f9bf2
1 Parent(s): 901243c

Update datasets.py

Browse files
Files changed (1) hide show
  1. utils/datasets.py +18 -25
utils/datasets.py CHANGED
@@ -62,26 +62,25 @@ def create_dataloader(path, imgsz, batch_size, stride, opt, hyp=None, augment=Fa
62
 
63
  batch_size = min(batch_size, len(dataset))
64
  nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
65
- train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
66
- dataloader = InfiniteDataLoader (dataset,
67
  batch_size=batch_size,
68
  num_workers=nw,
69
- sampler=train_sampler,
70
  pin_memory=True,
71
  collate_fn=LoadImagesAndLabels.collate_fn)
72
  return dataloader, dataset
73
 
74
 
75
  class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
76
- '''
77
- Dataloader that reuses workers.
78
 
79
  Uses same syntax as vanilla DataLoader.
80
- '''
81
 
82
  def __init__(self, *args, **kwargs):
83
  super().__init__(*args, **kwargs)
84
- object.__setattr__(self, 'batch_sampler', _RepeatSampler(self.batch_sampler))
85
  self.iterator = super().__iter__()
86
 
87
  def __len__(self):
@@ -91,22 +90,20 @@ class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
91
  for i in range(len(self)):
92
  yield next(self.iterator)
93
 
 
 
94
 
95
- class _RepeatSampler(object):
96
- '''
97
- Sampler that repeats forever.
98
 
99
- Args:
100
- sampler (Sampler)
101
- '''
102
 
103
- def __init__(self, sampler):
104
- self.sampler = sampler
 
105
 
106
- def __iter__(self):
107
- while True:
108
- yield from iter(self.sampler)
109
-
110
 
111
  class LoadImages: # for inference
112
  def __init__(self, path, img_size=640):
@@ -684,14 +681,10 @@ def load_mosaic(self, index):
684
  # Concat/clip labels
685
  if len(labels4):
686
  labels4 = np.concatenate(labels4, 0)
687
- # np.clip(labels4[:, 1:] - s / 2, 0, s, out=labels4[:, 1:]) # use with center crop
688
- np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_affine
689
-
690
- # Replicate
691
- # img4, labels4 = replicate(img4, labels4)
692
 
693
  # Augment
694
- # img4 = img4[s // 2: int(s * 1.5), s // 2:int(s * 1.5)] # center crop (WARNING, requires box pruning)
695
  img4, labels4 = random_perspective(img4, labels4,
696
  degrees=self.hyp['degrees'],
697
  translate=self.hyp['translate'],
 
62
 
63
  batch_size = min(batch_size, len(dataset))
64
  nw = min([os.cpu_count() // world_size, batch_size if batch_size > 1 else 0, workers]) # number of workers
65
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset) if rank != -1 else None
66
+ dataloader = InfiniteDataLoader(dataset,
67
  batch_size=batch_size,
68
  num_workers=nw,
69
+ sampler=sampler,
70
  pin_memory=True,
71
  collate_fn=LoadImagesAndLabels.collate_fn)
72
  return dataloader, dataset
73
 
74
 
75
  class InfiniteDataLoader(torch.utils.data.dataloader.DataLoader):
76
+ """ Dataloader that reuses workers.
 
77
 
78
  Uses same syntax as vanilla DataLoader.
79
+ """
80
 
81
  def __init__(self, *args, **kwargs):
82
  super().__init__(*args, **kwargs)
83
+ object.__setattr__(self, 'batch_sampler', self._RepeatSampler(self.batch_sampler))
84
  self.iterator = super().__iter__()
85
 
86
  def __len__(self):
 
90
  for i in range(len(self)):
91
  yield next(self.iterator)
92
 
93
+ class _RepeatSampler(object):
94
+ """ Sampler that repeats forever.
95
 
96
+ Args:
97
+ sampler (Sampler)
98
+ """
99
 
100
+ def __init__(self, sampler):
101
+ self.sampler = sampler
 
102
 
103
+ def __iter__(self):
104
+ while True:
105
+ yield from iter(self.sampler)
106
 
 
 
 
 
107
 
108
  class LoadImages: # for inference
109
  def __init__(self, path, img_size=640):
 
681
  # Concat/clip labels
682
  if len(labels4):
683
  labels4 = np.concatenate(labels4, 0)
684
+ np.clip(labels4[:, 1:], 0, 2 * s, out=labels4[:, 1:]) # use with random_perspective
685
+ # img4, labels4 = replicate(img4, labels4) # replicate
 
 
 
686
 
687
  # Augment
 
688
  img4, labels4 = random_perspective(img4, labels4,
689
  degrees=self.hyp['degrees'],
690
  translate=self.hyp['translate'],