boris commited on
Commit
8b72ed8
1 Parent(s): adbdff9

fix(train): handle seed_dataset

Browse files
Files changed (2) hide show
  1. src/dalle_mini/data.py +3 -8
  2. tools/train/train.py +2 -2
src/dalle_mini/data.py CHANGED
@@ -161,7 +161,7 @@ class Dataset:
161
  ):
162
  """
163
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
164
- Shuffle batches if `shuffle` is `True`.
165
  """
166
  steps_per_epoch = len(dataset) // batch_size
167
 
@@ -184,17 +184,13 @@ class Dataset:
184
  def _dataloader_datasets_streaming(
185
  dataset: Dataset, batch_size: int, epoch: int
186
  ):
187
- # epoch is only use for multi-host
188
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
189
  batch = {k: [] for k in keys}
190
  first_loop = True
191
  while self.multi_hosts or first_loop:
192
  # in multi-host, we run forever (no epoch) as hosts need to stop
193
  # at the same time and we don't know how much data is on each host
194
- if not first_loop:
195
- # multi-host setting, we reshuffle shards
196
- epoch += 1
197
- dataset.set_epoch(epoch)
198
  for item in dataset:
199
  for k, v in item.items():
200
  batch[k].append(v)
@@ -203,6 +199,7 @@ class Dataset:
203
  batch = shard(batch)
204
  yield batch
205
  batch = {k: [] for k in keys}
 
206
  first_loop = False
207
 
208
  if split == "train":
@@ -213,8 +210,6 @@ class Dataset:
213
  raise ValueError(f'split must be "train" or "eval", got {split}')
214
 
215
  if self.streaming:
216
- if split == "train":
217
- ds.set_epoch(epoch)
218
  return _dataloader_datasets_streaming(ds, batch_size, epoch)
219
  else:
220
  if split == "train":
 
161
  ):
162
  """
163
  Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
164
+ Shuffle batches if rng is set.
165
  """
166
  steps_per_epoch = len(dataset) // batch_size
167
 
 
184
  def _dataloader_datasets_streaming(
185
  dataset: Dataset, batch_size: int, epoch: int
186
  ):
 
187
  keys = ["input_ids", "attention_mask", "labels", "decoder_input_ids"]
188
  batch = {k: [] for k in keys}
189
  first_loop = True
190
  while self.multi_hosts or first_loop:
191
  # in multi-host, we run forever (no epoch) as hosts need to stop
192
  # at the same time and we don't know how much data is on each host
193
+ dataset.set_epoch(epoch) # reshuffle data at each epoch
 
 
 
194
  for item in dataset:
195
  for k, v in item.items():
196
  batch[k].append(v)
 
199
  batch = shard(batch)
200
  yield batch
201
  batch = {k: [] for k in keys}
202
+ epoch += 1
203
  first_loop = False
204
 
205
  if split == "train":
 
210
  raise ValueError(f'split must be "train" or "eval", got {split}')
211
 
212
  if self.streaming:
 
 
213
  return _dataloader_datasets_streaming(ds, batch_size, epoch)
214
  else:
215
  if split == "train":
tools/train/train.py CHANGED
@@ -241,7 +241,7 @@ class TrainingArguments:
241
  )
242
  optim_quantized: bool = field(
243
  default=False,
244
- metadat={
245
  "help": "Whether to quantize optimizer (only supported with distributed_shampoo)."
246
  },
247
  )
@@ -845,7 +845,7 @@ def main():
845
  metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
846
 
847
  # Generate an epoch by shuffling sampling indices from the train dataset
848
- train_loader = dataset.dataloader("train", train_batch_size)
849
  # train
850
  for batch in tqdm(
851
  train_loader,
 
241
  )
242
  optim_quantized: bool = field(
243
  default=False,
244
+ metadata={
245
  "help": "Whether to quantize optimizer (only supported with distributed_shampoo)."
246
  },
247
  )
 
845
  metrics_logger.log({"train/epoch": epoch}, step=unreplicate(state.step))
846
 
847
  # Generate an epoch by shuffling sampling indices from the train dataset
848
+ train_loader = dataset.dataloader("train", train_batch_size, epoch)
849
  # train
850
  for batch in tqdm(
851
  train_loader,