winglian commited on
Commit
5bce45f
·
unverified ·
1 Parent(s): d85d494

more dpo fixes for dataset loading and docs (#1185) [skip ci]

Browse files

* more dpo fixes for dataset loading and docs

* preprocess dpo datasets

docs/rlhf.md CHANGED
@@ -34,6 +34,16 @@ datasets:
34
  rl: ipo
35
  ```
36
 
 
 
 
 
 
 
 
 
 
 
37
  #### Trl autounwrap for peft
38
 
39
  Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
 
34
  rl: ipo
35
  ```
36
 
37
+ #### Using local dataset files
38
+ ```yaml
39
+ datasets:
40
+ - ds_type: json
41
+ data_files:
42
+ - orca_rlhf.jsonl
43
+ split: train
44
+ type: chatml.intel
45
+ ```
46
+
47
  #### Trl autounwrap for peft
48
 
49
  Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
src/axolotl/cli/preprocess.py CHANGED
@@ -13,6 +13,7 @@ from axolotl.cli import (
13
  check_user_token,
14
  load_cfg,
15
  load_datasets,
 
16
  print_axolotl_text_art,
17
  )
18
  from axolotl.common.cli import PreprocessCliArgs
@@ -43,7 +44,11 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
43
  LOG.warning(msg)
44
  parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
45
 
46
- _ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
 
 
 
 
47
  LOG.info(
48
  Fore.GREEN
49
  + f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
 
13
  check_user_token,
14
  load_cfg,
15
  load_datasets,
16
+ load_rl_datasets,
17
  print_axolotl_text_art,
18
  )
19
  from axolotl.common.cli import PreprocessCliArgs
 
44
  LOG.warning(msg)
45
  parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
46
 
47
+ if parsed_cfg.rl:
48
+ load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
49
+ else:
50
+ load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
51
+
52
  LOG.info(
53
  Fore.GREEN
54
  + f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
src/axolotl/core/trainer_builder.py CHANGED
@@ -996,6 +996,12 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
996
  training_args_kwargs["lr_scheduler_kwargs"] = (
997
  self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
998
  )
 
 
 
 
 
 
999
 
1000
  if self.cfg.dataloader_pin_memory is not None:
1001
  training_args_kwargs[
@@ -1013,7 +1019,6 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
1013
  training_args = TrainingArguments(
1014
  per_device_train_batch_size=self.cfg.micro_batch_size,
1015
  max_steps=self.cfg.max_steps or total_num_steps,
1016
- remove_unused_columns=False,
1017
  gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
1018
  learning_rate=self.cfg.learning_rate,
1019
  save_strategy="steps",
 
996
  training_args_kwargs["lr_scheduler_kwargs"] = (
997
  self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
998
  )
999
+ if self.cfg.remove_unused_columns is not None:
1000
+ training_args_kwargs[
1001
+ "remove_unused_columns"
1002
+ ] = self.cfg.remove_unused_columns
1003
+ else:
1004
+ training_args_kwargs["remove_unused_columns"] = False
1005
 
1006
  if self.cfg.dataloader_pin_memory is not None:
1007
  training_args_kwargs[
 
1019
  training_args = TrainingArguments(
1020
  per_device_train_batch_size=self.cfg.micro_batch_size,
1021
  max_steps=self.cfg.max_steps or total_num_steps,
 
1022
  gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
1023
  learning_rate=self.cfg.learning_rate,
1024
  save_strategy="steps",
src/axolotl/utils/data.py CHANGED
@@ -7,6 +7,7 @@ from pathlib import Path
7
  from typing import Any, Dict, List, Optional, Tuple, Union
8
 
9
  import torch
 
10
  from datasets import (
11
  Dataset,
12
  DatasetDict,
@@ -853,6 +854,41 @@ def encode_packed_pretraining(
853
  return chunked_data
854
 
855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
856
  def load_prepare_dpo_datasets(cfg):
857
  def load_split(dataset_cfgs, _cfg):
858
  split_datasets: List[Any] = []
@@ -889,12 +925,25 @@ def load_prepare_dpo_datasets(cfg):
889
  return concatenate_datasets(split_datasets)
890
 
891
  with zero_first(is_main_process()):
892
- train_dataset = load_split(cfg.datasets, cfg)
 
 
 
 
 
893
 
894
  eval_dataset = None
895
  if cfg.test_datasets:
896
- eval_dataset = load_split(cfg.test_datasets, cfg)
 
 
 
897
  if not eval_dataset:
898
  eval_dataset = None
899
 
 
 
 
 
 
900
  return train_dataset, eval_dataset
 
7
  from typing import Any, Dict, List, Optional, Tuple, Union
8
 
9
  import torch
10
+ import yaml
11
  from datasets import (
12
  Dataset,
13
  DatasetDict,
 
854
  return chunked_data
855
 
856
 
857
+ def _get_path(ds_hash, cfg):
858
+ prepared_ds_path = (
859
+ Path(cfg.dataset_prepared_path) / ds_hash
860
+ if cfg.dataset_prepared_path
861
+ else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
862
+ )
863
+
864
+ return prepared_ds_path
865
+
866
+
867
+ def _load_preprocessed_ds(cfg, sub_cfg):
868
+ ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
869
+ prepared_ds_path = _get_path(ds_hash, cfg)
870
+ dataset = None
871
+
872
+ if (
873
+ cfg.dataset_prepared_path
874
+ and any(prepared_ds_path.glob("*"))
875
+ and not cfg.is_preprocess
876
+ ):
877
+ LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
878
+ dataset = load_from_disk(str(prepared_ds_path))
879
+
880
+ return dataset
881
+
882
+
883
+ def _save_preprocessed_ds(cfg, sub_cfg, dataset):
884
+ ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
885
+ prepared_ds_path = _get_path(ds_hash, cfg)
886
+
887
+ if cfg.is_preprocess and is_main_process():
888
+ LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
889
+ dataset.save_to_disk(str(prepared_ds_path))
890
+
891
+
892
  def load_prepare_dpo_datasets(cfg):
893
  def load_split(dataset_cfgs, _cfg):
894
  split_datasets: List[Any] = []
 
925
  return concatenate_datasets(split_datasets)
926
 
927
  with zero_first(is_main_process()):
928
+ train_is_preprocessed = False
929
+ eval_is_preprocessed = False
930
+ if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
931
+ train_is_preprocessed = True
932
+ else:
933
+ train_dataset = load_split(cfg.datasets, cfg)
934
 
935
  eval_dataset = None
936
  if cfg.test_datasets:
937
+ if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
938
+ eval_is_preprocessed = True
939
+ else:
940
+ eval_dataset = load_split(cfg.test_datasets, cfg)
941
  if not eval_dataset:
942
  eval_dataset = None
943
 
944
+ if not train_is_preprocessed:
945
+ _save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
946
+ if eval_dataset and not eval_is_preprocessed:
947
+ _save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
948
+
949
  return train_dataset, eval_dataset