more dpo fixes for dataset loading and docs (#1185) [skip ci]
Browse files* more dpo fixes for dataset loading and docs
* preprocess dpo datasets
- docs/rlhf.md +10 -0
- src/axolotl/cli/preprocess.py +6 -1
- src/axolotl/core/trainer_builder.py +6 -1
- src/axolotl/utils/data.py +51 -2
docs/rlhf.md
CHANGED
@@ -34,6 +34,16 @@ datasets:
|
|
34 |
rl: ipo
|
35 |
```
|
36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
37 |
#### Trl autounwrap for peft
|
38 |
|
39 |
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
|
|
|
34 |
rl: ipo
|
35 |
```
|
36 |
|
37 |
+
#### Using local dataset files
|
38 |
+
```yaml
|
39 |
+
datasets:
|
40 |
+
- ds_type: json
|
41 |
+
data_files:
|
42 |
+
- orca_rlhf.jsonl
|
43 |
+
split: train
|
44 |
+
type: chatml.intel
|
45 |
+
```
|
46 |
+
|
47 |
#### Trl autounwrap for peft
|
48 |
|
49 |
Trl supports autounwrapping peft models, so that a ref model does not need to be additionally loaded, leading to less VRAM needed. This is on by default. To turn it off, pass the following config.
|
src/axolotl/cli/preprocess.py
CHANGED
@@ -13,6 +13,7 @@ from axolotl.cli import (
|
|
13 |
check_user_token,
|
14 |
load_cfg,
|
15 |
load_datasets,
|
|
|
16 |
print_axolotl_text_art,
|
17 |
)
|
18 |
from axolotl.common.cli import PreprocessCliArgs
|
@@ -43,7 +44,11 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|
43 |
LOG.warning(msg)
|
44 |
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
45 |
|
46 |
-
|
|
|
|
|
|
|
|
|
47 |
LOG.info(
|
48 |
Fore.GREEN
|
49 |
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
|
|
13 |
check_user_token,
|
14 |
load_cfg,
|
15 |
load_datasets,
|
16 |
+
load_rl_datasets,
|
17 |
print_axolotl_text_art,
|
18 |
)
|
19 |
from axolotl.common.cli import PreprocessCliArgs
|
|
|
44 |
LOG.warning(msg)
|
45 |
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
46 |
|
47 |
+
if parsed_cfg.rl:
|
48 |
+
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
49 |
+
else:
|
50 |
+
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
51 |
+
|
52 |
LOG.info(
|
53 |
Fore.GREEN
|
54 |
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
src/axolotl/core/trainer_builder.py
CHANGED
@@ -996,6 +996,12 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|
996 |
training_args_kwargs["lr_scheduler_kwargs"] = (
|
997 |
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
998 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
999 |
|
1000 |
if self.cfg.dataloader_pin_memory is not None:
|
1001 |
training_args_kwargs[
|
@@ -1013,7 +1019,6 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|
1013 |
training_args = TrainingArguments(
|
1014 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
1015 |
max_steps=self.cfg.max_steps or total_num_steps,
|
1016 |
-
remove_unused_columns=False,
|
1017 |
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
1018 |
learning_rate=self.cfg.learning_rate,
|
1019 |
save_strategy="steps",
|
|
|
996 |
training_args_kwargs["lr_scheduler_kwargs"] = (
|
997 |
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
998 |
)
|
999 |
+
if self.cfg.remove_unused_columns is not None:
|
1000 |
+
training_args_kwargs[
|
1001 |
+
"remove_unused_columns"
|
1002 |
+
] = self.cfg.remove_unused_columns
|
1003 |
+
else:
|
1004 |
+
training_args_kwargs["remove_unused_columns"] = False
|
1005 |
|
1006 |
if self.cfg.dataloader_pin_memory is not None:
|
1007 |
training_args_kwargs[
|
|
|
1019 |
training_args = TrainingArguments(
|
1020 |
per_device_train_batch_size=self.cfg.micro_batch_size,
|
1021 |
max_steps=self.cfg.max_steps or total_num_steps,
|
|
|
1022 |
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
1023 |
learning_rate=self.cfg.learning_rate,
|
1024 |
save_strategy="steps",
|
src/axolotl/utils/data.py
CHANGED
@@ -7,6 +7,7 @@ from pathlib import Path
|
|
7 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
8 |
|
9 |
import torch
|
|
|
10 |
from datasets import (
|
11 |
Dataset,
|
12 |
DatasetDict,
|
@@ -853,6 +854,41 @@ def encode_packed_pretraining(
|
|
853 |
return chunked_data
|
854 |
|
855 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
856 |
def load_prepare_dpo_datasets(cfg):
|
857 |
def load_split(dataset_cfgs, _cfg):
|
858 |
split_datasets: List[Any] = []
|
@@ -889,12 +925,25 @@ def load_prepare_dpo_datasets(cfg):
|
|
889 |
return concatenate_datasets(split_datasets)
|
890 |
|
891 |
with zero_first(is_main_process()):
|
892 |
-
|
|
|
|
|
|
|
|
|
|
|
893 |
|
894 |
eval_dataset = None
|
895 |
if cfg.test_datasets:
|
896 |
-
eval_dataset
|
|
|
|
|
|
|
897 |
if not eval_dataset:
|
898 |
eval_dataset = None
|
899 |
|
|
|
|
|
|
|
|
|
|
|
900 |
return train_dataset, eval_dataset
|
|
|
7 |
from typing import Any, Dict, List, Optional, Tuple, Union
|
8 |
|
9 |
import torch
|
10 |
+
import yaml
|
11 |
from datasets import (
|
12 |
Dataset,
|
13 |
DatasetDict,
|
|
|
854 |
return chunked_data
|
855 |
|
856 |
|
857 |
+
def _get_path(ds_hash, cfg):
|
858 |
+
prepared_ds_path = (
|
859 |
+
Path(cfg.dataset_prepared_path) / ds_hash
|
860 |
+
if cfg.dataset_prepared_path
|
861 |
+
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
862 |
+
)
|
863 |
+
|
864 |
+
return prepared_ds_path
|
865 |
+
|
866 |
+
|
867 |
+
def _load_preprocessed_ds(cfg, sub_cfg):
|
868 |
+
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
|
869 |
+
prepared_ds_path = _get_path(ds_hash, cfg)
|
870 |
+
dataset = None
|
871 |
+
|
872 |
+
if (
|
873 |
+
cfg.dataset_prepared_path
|
874 |
+
and any(prepared_ds_path.glob("*"))
|
875 |
+
and not cfg.is_preprocess
|
876 |
+
):
|
877 |
+
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
878 |
+
dataset = load_from_disk(str(prepared_ds_path))
|
879 |
+
|
880 |
+
return dataset
|
881 |
+
|
882 |
+
|
883 |
+
def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
884 |
+
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
|
885 |
+
prepared_ds_path = _get_path(ds_hash, cfg)
|
886 |
+
|
887 |
+
if cfg.is_preprocess and is_main_process():
|
888 |
+
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
889 |
+
dataset.save_to_disk(str(prepared_ds_path))
|
890 |
+
|
891 |
+
|
892 |
def load_prepare_dpo_datasets(cfg):
|
893 |
def load_split(dataset_cfgs, _cfg):
|
894 |
split_datasets: List[Any] = []
|
|
|
925 |
return concatenate_datasets(split_datasets)
|
926 |
|
927 |
with zero_first(is_main_process()):
|
928 |
+
train_is_preprocessed = False
|
929 |
+
eval_is_preprocessed = False
|
930 |
+
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
|
931 |
+
train_is_preprocessed = True
|
932 |
+
else:
|
933 |
+
train_dataset = load_split(cfg.datasets, cfg)
|
934 |
|
935 |
eval_dataset = None
|
936 |
if cfg.test_datasets:
|
937 |
+
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
|
938 |
+
eval_is_preprocessed = True
|
939 |
+
else:
|
940 |
+
eval_dataset = load_split(cfg.test_datasets, cfg)
|
941 |
if not eval_dataset:
|
942 |
eval_dataset = None
|
943 |
|
944 |
+
if not train_is_preprocessed:
|
945 |
+
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
|
946 |
+
if eval_dataset and not eval_is_preprocessed:
|
947 |
+
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
|
948 |
+
|
949 |
return train_dataset, eval_dataset
|