seungduk winglian commited on
Commit
43bdc5d
1 Parent(s): b1e3e1b

Add a config not to shuffle merged dataset (#1394) [skip ci]

Browse files

* Add a config not to shuffle merged dataset

* Update README.md

* Update src/axolotl/utils/config/models/input/v0_4_1/__init__.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

* invert the condition name

* update README

* info -> debug

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

README.md CHANGED
@@ -678,6 +678,10 @@ datasets:
678
  # For `completion` datsets only, uses the provided field instead of `text` column
679
  field:
680
 
 
 
 
 
681
  # A list of one or more datasets to eval the model with.
682
  # You can use either test_datasets, or val_set_size, but not both.
683
  test_datasets:
 
678
  # For `completion` datsets only, uses the provided field instead of `text` column
679
  field:
680
 
681
+ # If false, the datasets will not be shuffled and will keep their original order in `datasets`.
682
+ # The same applies to the `test_datasets` option and the `pretraining_dataset` option. Default is true.
683
+ shuffle_merged_datasets: true
684
+
685
  # A list of one or more datasets to eval the model with.
686
  # You can use either test_datasets, or val_set_size, but not both.
687
  test_datasets:
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -416,6 +416,7 @@ class AxolotlInputConfig(
416
 
417
  datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
418
  test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
 
419
  dataset_prepared_path: Optional[str] = None
420
  dataset_shard_num: Optional[int] = None
421
  dataset_shard_idx: Optional[int] = None
 
416
 
417
  datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
418
  test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
419
+ shuffle_merged_datasets: Optional[bool] = True
420
  dataset_prepared_path: Optional[str] = None
421
  dataset_shard_num: Optional[int] = None
422
  dataset_shard_idx: Optional[int] = None
src/axolotl/utils/data.py CHANGED
@@ -415,8 +415,11 @@ def load_tokenized_prepared_datasets(
415
  dataset = concatenate_datasets(datasets)
416
 
417
  if len(datasets) > 1:
418
- LOG.info("shuffle merged datasets")
419
- dataset = dataset.shuffle(seed=seed)
 
 
 
420
 
421
  dataset, _ = process_datasets_for_packing(cfg, dataset, None)
422
 
@@ -819,7 +822,11 @@ def wrap_pretraining_dataset(
819
  else:
820
  encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
821
 
822
- dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
 
 
 
 
823
  dataset = dataset.map(
824
  encode,
825
  batched=True,
 
415
  dataset = concatenate_datasets(datasets)
416
 
417
  if len(datasets) > 1:
418
+ if cfg.shuffle_merged_datasets:
419
+ LOG.debug("shuffle merged datasets")
420
+ dataset = dataset.shuffle(seed=seed)
421
+ else:
422
+ LOG.debug("NOT shuffling merged datasets")
423
 
424
  dataset, _ = process_datasets_for_packing(cfg, dataset, None)
425
 
 
822
  else:
823
  encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
824
 
825
+ if cfg.shuffle_merged_datasets:
826
+ dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
827
+ else:
828
+ LOG.debug("NOT shuffling merged pretraining datasets")
829
+
830
  dataset = dataset.map(
831
  encode,
832
  batched=True,