Add a config not to shuffle merged dataset (#1394) [skip ci]
Browse files* Add a config not to shuffle merged dataset
* Update README.md
* Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Co-authored-by: Wing Lian <wing.lian@gmail.com>
* invert the condition name
* update README
* info -> debug
---------
Co-authored-by: Wing Lian <wing.lian@gmail.com>
README.md
CHANGED
@@ -678,6 +678,10 @@ datasets:
|
|
678 |
# For `completion` datsets only, uses the provided field instead of `text` column
|
679 |
field:
|
680 |
|
|
|
|
|
|
|
|
|
681 |
# A list of one or more datasets to eval the model with.
|
682 |
# You can use either test_datasets, or val_set_size, but not both.
|
683 |
test_datasets:
|
|
|
678 |
# For `completion` datsets only, uses the provided field instead of `text` column
|
679 |
field:
|
680 |
|
681 |
+
# If false, the datasets will not be shuffled and will keep their original order in `datasets`.
|
682 |
+
# The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
|
683 |
+
shuffle_merged_datasets: true
|
684 |
+
|
685 |
# A list of one or more datasets to eval the model with.
|
686 |
# You can use either test_datasets, or val_set_size, but not both.
|
687 |
test_datasets:
|
src/axolotl/utils/config/models/input/v0_4_1/__init__.py
CHANGED
@@ -416,6 +416,7 @@ class AxolotlInputConfig(
|
|
416 |
|
417 |
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
418 |
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
|
|
419 |
dataset_prepared_path: Optional[str] = None
|
420 |
dataset_shard_num: Optional[int] = None
|
421 |
dataset_shard_idx: Optional[int] = None
|
|
|
416 |
|
417 |
datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
418 |
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
|
419 |
+
shuffle_merged_datasets: Optional[bool] = True
|
420 |
dataset_prepared_path: Optional[str] = None
|
421 |
dataset_shard_num: Optional[int] = None
|
422 |
dataset_shard_idx: Optional[int] = None
|
src/axolotl/utils/data.py
CHANGED
@@ -415,8 +415,11 @@ def load_tokenized_prepared_datasets(
|
|
415 |
dataset = concatenate_datasets(datasets)
|
416 |
|
417 |
if len(datasets) > 1:
|
418 |
-
|
419 |
-
|
|
|
|
|
|
|
420 |
|
421 |
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
422 |
|
@@ -819,7 +822,11 @@ def wrap_pretraining_dataset(
|
|
819 |
else:
|
820 |
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
821 |
|
822 |
-
|
|
|
|
|
|
|
|
|
823 |
dataset = dataset.map(
|
824 |
encode,
|
825 |
batched=True,
|
|
|
415 |
dataset = concatenate_datasets(datasets)
|
416 |
|
417 |
if len(datasets) > 1:
|
418 |
+
if cfg.shuffle_merged_datasets:
|
419 |
+
LOG.debug("shuffle merged datasets")
|
420 |
+
dataset = dataset.shuffle(seed=seed)
|
421 |
+
else:
|
422 |
+
LOG.debug("NOT shuffling merged datasets")
|
423 |
|
424 |
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
425 |
|
|
|
822 |
else:
|
823 |
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
824 |
|
825 |
+
if cfg.shuffle_merged_datasets:
|
826 |
+
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
827 |
+
else:
|
828 |
+
LOG.debug("NOT shuffling merged pretraining datasets")
|
829 |
+
|
830 |
dataset = dataset.map(
|
831 |
encode,
|
832 |
batched=True,
|