boris commited on
Commit
ddcbc6a
1 Parent(s): e2781bc

fix(data): no shuffling of validation data

Browse files
Files changed (1) hide show
  1. src/dalle_mini/data.py +11 -7
src/dalle_mini/data.py CHANGED
@@ -182,15 +182,20 @@ class Dataset:
182
  yield batch
183
 
184
  def _dataloader_datasets_streaming(
185
- dataset: Dataset, batch_size: int, epoch: int
186
  ):
187
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
188
  batch = {k: [] for k in keys}
189
- first_loop = True
190
- while self.multi_hosts or first_loop:
191
  # in multi-host, we run forever (no epoch) as hosts need to stop
192
- # at the same time and we don't know how much data is on each host
193
- dataset.set_epoch(epoch) # reshuffle data at each epoch
 
 
 
 
 
194
  for item in dataset:
195
  for k, v in item.items():
196
  batch[k].append(v)
@@ -199,7 +204,6 @@ class Dataset:
199
  batch = shard(batch)
200
  yield batch
201
  batch = {k: [] for k in keys}
202
- epoch += 1
203
  first_loop = False
204
 
205
  if split == "train":
@@ -210,7 +214,7 @@ class Dataset:
210
  raise ValueError(f'split must be "train" or "eval", got {split}')
211
 
212
  if self.streaming:
213
- return _dataloader_datasets_streaming(ds, batch_size, epoch)
214
  else:
215
  if split == "train":
216
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
 
182
  yield batch
183
 
184
  def _dataloader_datasets_streaming(
185
+ dataset: Dataset, split: str, batch_size: int, epoch: int
186
  ):
187
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
188
  batch = {k: [] for k in keys}
189
+ first_loop = True # stop after one loop in some cases
190
+ while (self.multi_hosts and split == "train") or first_loop:
191
  # in multi-host, we run forever (no epoch) as hosts need to stop
192
+ # at the same time and training data may not be split equally
193
+ # For validation data we put the entire set on each host as we could lose
194
+ # too many samples on pods
195
+ if epoch is not None:
196
+ # reshuffle training data at each epoch (not applicable with validation set)
197
+ dataset.set_epoch(epoch)
198
+ epoch += 1
199
  for item in dataset:
200
  for k, v in item.items():
201
  batch[k].append(v)
 
204
  batch = shard(batch)
205
  yield batch
206
  batch = {k: [] for k in keys}
 
207
  first_loop = False
208
 
209
  if split == "train":
 
214
  raise ValueError(f'split must be "train" or "eval", got {split}')
215
 
216
  if self.streaming:
217
+ return _dataloader_datasets_streaming(ds, split, batch_size, epoch)
218
  else:
219
  if split == "train":
220
  self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)