aapot
commited on
Commit
•
d2f0aae
1
Parent(s):
6c26cc3
fix
Browse files- EasyLM/data.py +4 -8
EasyLM/data.py
CHANGED
@@ -175,13 +175,14 @@ class HuggingfaceDataset(object):
|
|
175 |
self._index = 0
|
176 |
|
177 |
def __iter__(self):
|
|
|
|
|
178 |
chunk_size = self.config.batch_size * self.config.seq_length
|
179 |
-
total_tokens = 0
|
180 |
while True:
|
181 |
token_buffer = []
|
182 |
loss_mask_buffer = []
|
183 |
-
if not self._eval_dataset:
|
184 |
-
self.
|
185 |
for index, example in enumerate(self._dataset):
|
186 |
self._index = index
|
187 |
if not self._eval_dataset and self._dataset_loc > index:
|
@@ -217,12 +218,7 @@ class HuggingfaceDataset(object):
|
|
217 |
break
|
218 |
else:
|
219 |
self._dataset_loc = 0
|
220 |
-
self._shuffle()
|
221 |
self._train_epochs += 1
|
222 |
-
print(f"TRAIN {self._train_epochs} EPOCH DONE")
|
223 |
-
|
224 |
-
def _shuffle(self):
|
225 |
-
self._dataset = self._dataset.shuffle(buffer_size=100)
|
226 |
|
227 |
def get_state_dict(self):
|
228 |
return dict(
|
|
|
175 |
self._index = 0
|
176 |
|
177 |
def __iter__(self):
|
178 |
+
if not self._eval_dataset and self._train_epochs > 0:
|
179 |
+
self._dataset = self._dataset.shuffle(seed=42, buffer_size=10000)
|
180 |
chunk_size = self.config.batch_size * self.config.seq_length
|
|
|
181 |
while True:
|
182 |
token_buffer = []
|
183 |
loss_mask_buffer = []
|
184 |
+
if not self._eval_dataset and self._train_epochs > 0:
|
185 |
+
self._dataset.set_epoch(self._train_epochs)
|
186 |
for index, example in enumerate(self._dataset):
|
187 |
self._index = index
|
188 |
if not self._eval_dataset and self._dataset_loc > index:
|
|
|
218 |
break
|
219 |
else:
|
220 |
self._dataset_loc = 0
|
|
|
221 |
self._train_epochs += 1
|
|
|
|
|
|
|
|
|
222 |
|
223 |
def get_state_dict(self):
|
224 |
return dict(
|