aapot commited on
Commit
d2f0aae
1 Parent(s): 6c26cc3
Files changed (1) hide show
  1. EasyLM/data.py +4 -8
EasyLM/data.py CHANGED
@@ -175,13 +175,14 @@ class HuggingfaceDataset(object):
175
  self._index = 0
176
 
177
  def __iter__(self):
 
 
178
  chunk_size = self.config.batch_size * self.config.seq_length
179
- total_tokens = 0
180
  while True:
181
  token_buffer = []
182
  loss_mask_buffer = []
183
- if not self._eval_dataset:
184
- self._shuffle()
185
  for index, example in enumerate(self._dataset):
186
  self._index = index
187
  if not self._eval_dataset and self._dataset_loc > index:
@@ -217,12 +218,7 @@ class HuggingfaceDataset(object):
217
  break
218
  else:
219
  self._dataset_loc = 0
220
- self._shuffle()
221
  self._train_epochs += 1
222
- print(f"TRAIN {self._train_epochs} EPOCH DONE")
223
-
224
- def _shuffle(self):
225
- self._dataset = self._dataset.shuffle(buffer_size=100)
226
 
227
  def get_state_dict(self):
228
  return dict(
 
175
  self._index = 0
176
 
177
  def __iter__(self):
178
+ if not self._eval_dataset and self._train_epochs > 0:
179
+ self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
180
  chunk_size = self.config.batch_size * self.config.seq_length
 
181
  while True:
182
  token_buffer = []
183
  loss_mask_buffer = []
184
+ if not self._eval_dataset and self._train_epochs > 0:
185
+ self._dataset.set_epoch(self._train_epochs)
186
  for index, example in enumerate(self._dataset):
187
  self._index = index
188
  if not self._eval_dataset and self._dataset_loc > index:
 
218
  break
219
  else:
220
  self._dataset_loc = 0
 
221
  self._train_epochs += 1
 
 
 
 
222
 
223
  def get_state_dict(self):
224
  return dict(