benredmond winglian commited on
Commit
22ae21a
1 Parent(s): ba45531

Add KTO support (#1640)

Browse files

* add kto support

* test cleanup

* fix outdated comment

* fix llama3 ultra

* chore: lint

* update to use rl_beta instead of dpo_beta

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>

src/axolotl/core/trainer_builder.py CHANGED
@@ -30,7 +30,7 @@ from transformers import (
30
  )
31
  from transformers.trainer_utils import seed_worker
32
  from transformers.utils import is_sagemaker_mp_enabled
33
- from trl import DPOTrainer, ORPOConfig, ORPOTrainer
34
  from trl.trainer.utils import pad_to_length
35
 
36
  from axolotl.loraplus import create_loraplus_optimizer
@@ -826,6 +826,14 @@ class AxolotlORPOTrainer(ORPOTrainer):
826
  tag_names = ["axolotl", "orpo"]
827
 
828
 
 
 
 
 
 
 
 
 
829
  class TrainerBuilderBase(abc.ABC):
830
  """
831
  Base class for trainer builder
@@ -1532,6 +1540,22 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
1532
  if self.cfg.max_prompt_len:
1533
  training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
1534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1535
  training_args = training_args_cls(
1536
  per_device_train_batch_size=self.cfg.micro_batch_size,
1537
  max_steps=self.cfg.max_steps or total_num_steps,
@@ -1567,7 +1591,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
1567
  ] = self.cfg.precompute_ref_log_probs
1568
  if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
1569
  trainer_cls = AxolotlDPOTrainer
1570
- dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
1571
  trainer_cls_args = [self.model, self.model_ref]
1572
 
1573
  # these aren't used for the ORPO trainer
@@ -1580,6 +1604,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
1580
  elif self.cfg.rl == "orpo":
1581
  trainer_cls = AxolotlORPOTrainer
1582
  trainer_cls_args = [self.model]
 
 
 
1583
  else:
1584
  raise ValueError(f"Unsupported RL: {self.cfg.rl}")
1585
  dpo_trainer = trainer_cls(
 
30
  )
31
  from transformers.trainer_utils import seed_worker
32
  from transformers.utils import is_sagemaker_mp_enabled
33
+ from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
34
  from trl.trainer.utils import pad_to_length
35
 
36
  from axolotl.loraplus import create_loraplus_optimizer
 
826
  tag_names = ["axolotl", "orpo"]
827
 
828
 
829
+ class AxolotlKTOTrainer(KTOTrainer):
830
+ """
831
+ Extend the base KTOTrainer for axolotl helpers
832
+ """
833
+
834
+ tag_names = ["axolotl", "kto"]
835
+
836
+
837
  class TrainerBuilderBase(abc.ABC):
838
  """
839
  Base class for trainer builder
 
1540
  if self.cfg.max_prompt_len:
1541
  training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
1542
 
1543
+ if self.cfg.rl == "kto":
1544
+ training_args_cls = KTOConfig
1545
+
1546
+ training_args_kwargs["beta"] = self.cfg.rl_beta or 0.1
1547
+ training_args_kwargs["desirable_weight"] = (
1548
+ self.cfg.kto_desirable_weight or 1.0
1549
+ )
1550
+ training_args_kwargs["undesirable_weight"] = (
1551
+ self.cfg.kto_undesirable_weight or 1.0
1552
+ )
1553
+
1554
+ training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
1555
+ training_args_kwargs["max_length"] = self.cfg.sequence_len
1556
+ if self.cfg.max_prompt_len:
1557
+ training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
1558
+
1559
  training_args = training_args_cls(
1560
  per_device_train_batch_size=self.cfg.micro_batch_size,
1561
  max_steps=self.cfg.max_steps or total_num_steps,
 
1591
  ] = self.cfg.precompute_ref_log_probs
1592
  if self.cfg.rl in ["dpo", "ipo", "kto_pair"]:
1593
  trainer_cls = AxolotlDPOTrainer
1594
+ dpo_trainer_kwargs["beta"] = self.cfg.rl_beta or 0.1
1595
  trainer_cls_args = [self.model, self.model_ref]
1596
 
1597
  # these aren't used for the ORPO trainer
 
1604
  elif self.cfg.rl == "orpo":
1605
  trainer_cls = AxolotlORPOTrainer
1606
  trainer_cls_args = [self.model]
1607
+ elif self.cfg.rl == "kto":
1608
+ trainer_cls = AxolotlKTOTrainer
1609
+ trainer_cls_args = [self.model]
1610
  else:
1611
  raise ValueError(f"Unsupported RL: {self.cfg.rl}")
1612
  dpo_trainer = trainer_cls(
src/axolotl/prompt_strategies/kto/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ module for KTO style dataset transform strategies
3
+ """
4
+
5
+ from functools import partial
6
+
7
+ from ..base import load as load_base
8
+
9
+ load = partial(load_base, module_base="axolotl.prompt_strategies.kto")
src/axolotl/prompt_strategies/kto/chatml.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ KTO strategies for chatml
3
+ """
4
+ # pylint: disable=duplicate-code
5
+
6
+
7
+ def argilla(
8
+ cfg,
9
+ **kwargs,
10
+ ): # pylint: disable=possibly-unused-variable,unused-argument
11
+ def transform_fn(sample):
12
+ if "system" in sample and sample["system"]:
13
+ sample["prompt"] = (
14
+ f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
15
+ f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
16
+ )
17
+ else:
18
+ sample[
19
+ "prompt"
20
+ ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
21
+ sample["completion"] = f"{sample['completion']}<|im_end|>"
22
+ return sample
23
+
24
+ return transform_fn
25
+
26
+
27
+ def argilla_chat(
28
+ cfg,
29
+ **kwargs,
30
+ ): # pylint: disable=possibly-unused-variable,unused-argument
31
+ """
32
+ for argilla/kto-mix-15k conversations
33
+ """
34
+
35
+ def transform_fn(sample):
36
+ sample[
37
+ "prompt"
38
+ ] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
39
+ sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>"
40
+ return sample
41
+
42
+ return transform_fn
43
+
44
+
45
+ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
46
+ """
47
+ For Intel Orca KTO
48
+ ex: argilla/distilabel-intel-orca-kto
49
+ """
50
+
51
+ def transform_fn(sample):
52
+ if "system" in sample and sample["system"]:
53
+ sample["prompt"] = (
54
+ f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
55
+ f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
56
+ )
57
+ else:
58
+ sample[
59
+ "prompt"
60
+ ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
61
+ sample["completion"] = f"{sample['completion']}<|im_end|>"
62
+ return sample
63
+
64
+ return transform_fn
65
+
66
+
67
+ def prompt_pairs(
68
+ cfg, **kwargs
69
+ ): # pylint: disable=possibly-unused-variable,unused-argument
70
+ def transform_fn(sample):
71
+ if "system" in sample and sample["system"]:
72
+ sample["prompt"] = (
73
+ f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
74
+ f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
75
+ )
76
+ else:
77
+ sample[
78
+ "prompt"
79
+ ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
80
+ sample["completion"] = f"{sample['completion']}<|im_end|>"
81
+ return sample
82
+
83
+ return transform_fn
84
+
85
+
86
+ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
87
+ """
88
+ for ultrafeedback binarized conversations
89
+ ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto
90
+ """
91
+
92
+ def transform_fn(sample):
93
+ if "system" in sample and sample["system"]:
94
+ sample["prompt"] = (
95
+ f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
96
+ f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
97
+ )
98
+ else:
99
+ sample[
100
+ "prompt"
101
+ ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
102
+ sample["completion"] = f"{sample['completion']}<|im_end|>"
103
+ return sample
104
+
105
+ return transform_fn
src/axolotl/prompt_strategies/kto/llama3.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ KTO strategies for llama-3 chat template
3
+ """
4
+ # pylint: disable=duplicate-code
5
+
6
+
7
+ def argilla(
8
+ cfg,
9
+ **kwargs,
10
+ ): # pylint: disable=possibly-unused-variable,unused-argument
11
+ def transform_fn(sample):
12
+ if "system" in sample and sample["system"]:
13
+ sample["prompt"] = (
14
+ f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
15
+ f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
16
+ )
17
+ else:
18
+ sample[
19
+ "prompt"
20
+ ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
21
+ sample["completion"] = f"{sample['completion']}<|eot_id|>"
22
+ return sample
23
+
24
+ return transform_fn
25
+
26
+
27
+ def argilla_chat(
28
+ cfg,
29
+ **kwargs,
30
+ ): # pylint: disable=possibly-unused-variable,unused-argument
31
+ """
32
+ for argilla/kto-mix-15k conversations
33
+ """
34
+
35
+ def transform_fn(sample):
36
+ sample[
37
+ "prompt"
38
+ ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
39
+ sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>"
40
+ return sample
41
+
42
+ return transform_fn
43
+
44
+
45
+ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
46
+ """
47
+ For Intel Orca KTO
48
+ ex: argilla/distilabel-intel-orca-kto
49
+ """
50
+
51
+ def transform_fn(sample):
52
+ if "system" in sample and sample["system"]:
53
+ sample["prompt"] = (
54
+ f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
55
+ f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
56
+ )
57
+ else:
58
+ sample[
59
+ "prompt"
60
+ ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
61
+ sample["completion"] = f"{sample['completion']}<|eot_id|>"
62
+ return sample
63
+
64
+ return transform_fn
65
+
66
+
67
+ def prompt_pairs(
68
+ cfg, **kwargs
69
+ ): # pylint: disable=possibly-unused-variable,unused-argument
70
+ def transform_fn(sample):
71
+ if "system" in sample and sample["system"]:
72
+ sample["prompt"] = (
73
+ f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
74
+ f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
75
+ )
76
+ else:
77
+ sample[
78
+ "prompt"
79
+ ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
80
+ sample["completion"] = f"{sample['completion']}<|eot_id|>"
81
+ return sample
82
+
83
+ return transform_fn
84
+
85
+
86
+ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-argument
87
+ """
88
+ for ultrafeedback binarized conversations
89
+ ex: argilla/ultrafeedback-binarized-preferences-cleaned-kto
90
+ """
91
+
92
+ def transform_fn(sample):
93
+ if "system" in sample and sample["system"]:
94
+ sample["prompt"] = (
95
+ f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
96
+ f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
97
+ )
98
+ else:
99
+ sample[
100
+ "prompt"
101
+ ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
102
+ sample["completion"] = f"{sample['completion']}<|eot_id|>"
103
+ return sample
104
+
105
+ return transform_fn
src/axolotl/prompt_strategies/kto/user_defined.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ User-defined KTO strategies
3
+ """
4
+ # pylint: disable=duplicate-code
5
+
6
+
7
+ def default(cfg, dataset_idx=0, **kwargs): # pylint: disable=unused-argument
8
+ ds_cfg = cfg["datasets"][dataset_idx]["type"]
9
+ if not isinstance(ds_cfg, dict):
10
+ raise ValueError(
11
+ f"User-defined dataset type must be a dictionary. Got: {ds_cfg}"
12
+ )
13
+ field_prompt = ds_cfg.get("field_prompt", "prompt")
14
+ field_system = ds_cfg.get("field_system", "system")
15
+ field_completion = ds_cfg.get("field_completion", "completion")
16
+ field_label = ds_cfg.get("field_label", "label")
17
+ prompt_format = ds_cfg.get("prompt_format")
18
+ if not prompt_format:
19
+ prompt_format = "{" + field_prompt + "}"
20
+ completion_format = ds_cfg.get("completion_format")
21
+ if not completion_format:
22
+ chosen_format = "{" + field_completion + "}"
23
+
24
+ def transform_fn(sample):
25
+ if (
26
+ "{" + field_system + "}" in prompt_format
27
+ and field_system in sample
28
+ and sample[field_system]
29
+ ):
30
+ sample["prompt"] = prompt_format.format(
31
+ system=sample[field_system], prompt=sample[field_prompt]
32
+ )
33
+ else:
34
+ sample["prompt"] = prompt_format.format(prompt=sample["prompt"])
35
+ sample["completion"] = chosen_format.format(chosen=sample[field_completion])
36
+ sample["label"] = sample[field_label]
37
+ return sample
38
+
39
+ return transform_fn
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -24,6 +24,7 @@ class DeprecatedParameters(BaseModel):
24
  max_packed_sequence_len: Optional[int] = None
25
  rope_scaling: Optional[Any] = None
26
  noisy_embedding_alpha: Optional[float] = None
 
27
 
28
  @field_validator("max_packed_sequence_len")
29
  @classmethod
@@ -48,6 +49,13 @@ class DeprecatedParameters(BaseModel):
48
  LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
49
  return noisy_embedding_alpha
50
 
 
 
 
 
 
 
 
51
 
52
  class RemappedParameters(BaseModel):
53
  """parameters that have been remapped to other names"""
@@ -126,6 +134,26 @@ class DPODataset(BaseModel):
126
  data_files: Optional[List[str]] = None
127
 
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  class RLType(str, Enum):
130
  """RL trainer type configuration subset"""
131
 
@@ -133,6 +161,7 @@ class RLType(str, Enum):
133
  ipo = "ipo" # pylint: disable=invalid-name
134
  kto_pair = "kto_pair" # pylint: disable=invalid-name
135
  orpo = "orpo" # pylint: disable=invalid-name
 
136
 
137
 
138
  class ChatTemplate(str, Enum):
@@ -450,8 +479,8 @@ class AxolotlInputConfig(
450
 
451
  rl: Optional[RLType] = None
452
 
453
- datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
454
- test_datasets: Optional[conlist(Union[SFTDataset, DPODataset], min_length=1)] = None # type: ignore
455
  shuffle_merged_datasets: Optional[bool] = True
456
  dataset_prepared_path: Optional[str] = None
457
  dataset_shard_num: Optional[int] = None
@@ -585,6 +614,10 @@ class AxolotlInputConfig(
585
 
586
  orpo_alpha: Optional[float] = None
587
 
 
 
 
 
588
  max_memory: Optional[
589
  Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
590
  ] = None
@@ -884,6 +917,13 @@ class AxolotlInputConfig(
884
  raise ValueError("neftune_noise_alpha must be > 0.0")
885
  return neftune_noise_alpha
886
 
 
 
 
 
 
 
 
887
  @model_validator(mode="before")
888
  @classmethod
889
  def check_frozen(cls, data):
 
24
  max_packed_sequence_len: Optional[int] = None
25
  rope_scaling: Optional[Any] = None
26
  noisy_embedding_alpha: Optional[float] = None
27
+ dpo_beta: Optional[float] = None
28
 
29
  @field_validator("max_packed_sequence_len")
30
  @classmethod
 
49
  LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
50
  return noisy_embedding_alpha
51
 
52
+ @field_validator("dpo_beta")
53
+ @classmethod
54
+ def validate_dpo_beta(cls, dpo_beta):
55
+ if dpo_beta is not None:
56
+ LOG.warning("dpo_beta is deprecated, use rl_beta instead")
57
+ return dpo_beta
58
+
59
 
60
  class RemappedParameters(BaseModel):
61
  """parameters that have been remapped to other names"""
 
134
  data_files: Optional[List[str]] = None
135
 
136
 
137
+ class UserDefinedKTOType(BaseModel):
138
+ """User defined typing for KTO"""
139
+
140
+ field_system: Optional[str] = None
141
+ field_prompt: Optional[str] = None
142
+ field_completion: Optional[str] = None
143
+ field_label: Optional[bool] = None
144
+ prompt_format: Optional[str] = None
145
+ completion_format: Optional[str] = None
146
+
147
+
148
+ class KTODataset(BaseModel):
149
+ """KTO configuration subset"""
150
+
151
+ path: Optional[str] = None
152
+ split: Optional[str] = None
153
+ type: Optional[Union[UserDefinedKTOType, str]] = None
154
+ data_files: Optional[List[str]] = None
155
+
156
+
157
  class RLType(str, Enum):
158
  """RL trainer type configuration subset"""
159
 
 
161
  ipo = "ipo" # pylint: disable=invalid-name
162
  kto_pair = "kto_pair" # pylint: disable=invalid-name
163
  orpo = "orpo" # pylint: disable=invalid-name
164
+ kto = "kto" # pylint: disable=invalid-name
165
 
166
 
167
  class ChatTemplate(str, Enum):
 
479
 
480
  rl: Optional[RLType] = None
481
 
482
+ datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
483
+ test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
484
  shuffle_merged_datasets: Optional[bool] = True
485
  dataset_prepared_path: Optional[str] = None
486
  dataset_shard_num: Optional[int] = None
 
614
 
615
  orpo_alpha: Optional[float] = None
616
 
617
+ kto_desirable_weight: Optional[float] = None
618
+ kto_undesirable_weight: Optional[float] = None
619
+ rl_beta: Optional[float] = None
620
+
621
  max_memory: Optional[
622
  Dict[Union[int, Literal["cpu", "disk"]], Union[int, str]]
623
  ] = None
 
917
  raise ValueError("neftune_noise_alpha must be > 0.0")
918
  return neftune_noise_alpha
919
 
920
+ @model_validator(mode="after")
921
+ def check(self):
922
+ if self.dpo_beta and not self.rl_beta:
923
+ self.rl_beta = self.dpo_beta
924
+ del self.dpo_beta
925
+ return self
926
+
927
  @model_validator(mode="before")
928
  @classmethod
929
  def check_frozen(cls, data):
src/axolotl/utils/data/rl.py CHANGED
@@ -10,6 +10,7 @@ from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_
10
 
11
  from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
12
  from axolotl.prompt_strategies.dpo import load as load_dpo
 
13
  from axolotl.prompt_strategies.orpo import load as load_orpo
14
  from axolotl.utils.data.utils import md5
15
  from axolotl.utils.dict import DictDefault
@@ -55,6 +56,22 @@ def _save_preprocessed_ds(cfg, sub_cfg, dataset):
55
  dataset.save_to_disk(str(prepared_ds_path))
56
 
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  def load_prepare_dpo_datasets(cfg):
59
  def load_split(dataset_cfgs, _cfg):
60
  split_datasets: List[Any] = []
@@ -76,6 +93,7 @@ def load_prepare_dpo_datasets(cfg):
76
  split_datasets.insert(i, ds)
77
 
78
  tokenizer = None
 
79
  for i, data_set in enumerate(split_datasets):
80
  _type = dataset_cfgs[i]["type"]
81
  if _type:
@@ -83,21 +101,19 @@ def load_prepare_dpo_datasets(cfg):
83
  _type = "user_defined.default"
84
  if _cfg.rl == "orpo":
85
  ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
 
 
86
  else:
87
  ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
88
- sig = inspect.signature(ds_transform_fn)
89
- if "tokenizer" in sig.parameters:
90
- if not tokenizer:
91
- tokenizer = load_tokenizer(_cfg)
92
- ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
93
-
94
- data_set = data_set.map(
95
- ds_transform_fn,
96
- desc="Mapping RL Dataset",
97
  )
98
- if isinstance(data_set, DatasetDict):
99
- data_set = data_set["train"]
100
- split_datasets[i] = data_set
101
  else:
102
  # If no `type` is provided, assume the dataset is already in the expected format with
103
  # "prompt", "chosen" and "rejected" already preprocessed
 
10
 
11
  from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
12
  from axolotl.prompt_strategies.dpo import load as load_dpo
13
+ from axolotl.prompt_strategies.kto import load as load_kto
14
  from axolotl.prompt_strategies.orpo import load as load_orpo
15
  from axolotl.utils.data.utils import md5
16
  from axolotl.utils.dict import DictDefault
 
56
  dataset.save_to_disk(str(prepared_ds_path))
57
 
58
 
59
+ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
60
+ sig = inspect.signature(ds_transform_fn)
61
+ if "tokenizer" in sig.parameters:
62
+ if not tokenizer:
63
+ tokenizer = load_tokenizer(cfg)
64
+ ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
65
+
66
+ data_set = data_set.map(
67
+ ds_transform_fn,
68
+ desc="Mapping RL Dataset",
69
+ )
70
+ if isinstance(data_set, DatasetDict):
71
+ data_set = data_set["train"]
72
+ return data_set
73
+
74
+
75
  def load_prepare_dpo_datasets(cfg):
76
  def load_split(dataset_cfgs, _cfg):
77
  split_datasets: List[Any] = []
 
93
  split_datasets.insert(i, ds)
94
 
95
  tokenizer = None
96
+
97
  for i, data_set in enumerate(split_datasets):
98
  _type = dataset_cfgs[i]["type"]
99
  if _type:
 
101
  _type = "user_defined.default"
102
  if _cfg.rl == "orpo":
103
  ds_transform_fn = load_orpo(_type, _cfg, dataset_idx=i)
104
+ elif _cfg.rl == "kto":
105
+ ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
106
  else:
107
  ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
108
+
109
+ split_datasets[i] = map_dataset(
110
+ cfg, data_set, ds_transform_fn, tokenizer
111
+ )
112
+ elif _cfg.rl == "kto":
113
+ ds_transform_fn = load_kto(_type, _cfg, dataset_idx=i)
114
+ split_datasets[i] = map_dataset(
115
+ cfg, data_set, ds_transform_fn, tokenizer
 
116
  )
 
 
 
117
  else:
118
  # If no `type` is provided, assume the dataset is already in the expected format with
119
  # "prompt", "chosen" and "rejected" already preprocessed
src/axolotl/utils/models.py CHANGED
@@ -803,7 +803,11 @@ def load_model(
803
  if not reference_model or cfg.lora_model_dir:
804
  # if we're not loading the reference model, then we're loading the model for training
805
  # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
806
- if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
 
 
 
 
807
  _, lora_config = load_lora(model, cfg, inference=False, config_only=True)
808
  else:
809
  model, lora_config = load_adapter(model, cfg, cfg.adapter)
 
803
  if not reference_model or cfg.lora_model_dir:
804
  # if we're not loading the reference model, then we're loading the model for training
805
  # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
806
+ if (
807
+ cfg.adapter
808
+ and cfg.rl in ["dpo", "ipo", "kto_pair", "kto"]
809
+ and not cfg.merge_lora
810
+ ):
811
  _, lora_config = load_lora(model, cfg, inference=False, config_only=True)
812
  else:
813
  model, lora_config = load_adapter(model, cfg, cfg.adapter)
src/axolotl/utils/trainer.py CHANGED
@@ -428,7 +428,7 @@ def prepare_optim_env(cfg):
428
 
429
 
430
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
431
- if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo"]:
432
  trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
433
  trainer_builder.model_ref = model[1]
434
  trainer_builder.peft_config = model[2]
 
428
 
429
 
430
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
431
+ if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "kto"]:
432
  trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
433
  trainer_builder.model_ref = model[1]
434
  trainer_builder.peft_config = model[2]
tests/e2e/test_dpo.py CHANGED
@@ -205,3 +205,66 @@ class TestDPOLlamaLora(unittest.TestCase):
205
 
206
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
207
  assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
 
206
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
207
  assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
208
+
209
+ @with_temp_dir
210
+ def test_kto_lora(self, temp_dir):
211
+ # pylint: disable=duplicate-code
212
+ cfg = DictDefault(
213
+ {
214
+ "base_model": "JackFram/llama-68m",
215
+ "tokenizer_type": "LlamaTokenizer",
216
+ "sequence_len": 1024,
217
+ "load_in_8bit": True,
218
+ "adapter": "lora",
219
+ "lora_r": 64,
220
+ "lora_alpha": 32,
221
+ "lora_dropout": 0.1,
222
+ "lora_target_linear": True,
223
+ "special_tokens": {},
224
+ "rl": "kto",
225
+ "rl_beta": 0.5,
226
+ "kto_desirable_weight": 1.0,
227
+ "kto_undesirable_weight": 1.0,
228
+ "remove_unused_columns": False,
229
+ "datasets": [
230
+ # {
231
+ # "path": "argilla/kto-mix-15k",
232
+ # "type": "chatml.argilla_chat",
233
+ # "split": "train",
234
+ # },
235
+ {
236
+ "path": "argilla/ultrafeedback-binarized-preferences-cleaned-kto",
237
+ "type": "chatml.ultra",
238
+ "split": "train",
239
+ },
240
+ # {
241
+ # "path": "argilla/kto-mix-15k",
242
+ # "type": "llama3.argilla_chat",
243
+ # "split": "train",
244
+ # },
245
+ {
246
+ "path": "argilla/ultrafeedback-binarized-preferences-cleaned-kto",
247
+ "type": "llama3.ultra",
248
+ "split": "train",
249
+ },
250
+ ],
251
+ "num_epochs": 1,
252
+ "micro_batch_size": 4,
253
+ "gradient_accumulation_steps": 1,
254
+ "output_dir": temp_dir,
255
+ "learning_rate": 0.00001,
256
+ "optimizer": "paged_adamw_8bit",
257
+ "lr_scheduler": "cosine",
258
+ "max_steps": 20,
259
+ "save_steps": 10,
260
+ "warmup_steps": 5,
261
+ "gradient_checkpointing": True,
262
+ "gradient_checkpointing_kwargs": {"use_reentrant": True},
263
+ }
264
+ )
265
+ normalize_config(cfg)
266
+ cli_args = TrainerCliArgs()
267
+ dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
268
+
269
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
270
+ assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
tests/test_validation.py CHANGED
@@ -1117,6 +1117,15 @@ class TestValidation(BaseValidation):
1117
  validate_config(cfg)
1118
  assert len(self._caplog.records) == 0
1119
 
 
 
 
 
 
 
 
 
 
1120
 
1121
  class TestValidationCheckModelConfig(BaseValidation):
1122
  """
 
1117
  validate_config(cfg)
1118
  assert len(self._caplog.records) == 0
1119
 
1120
+ def test_dpo_beta_deprecation(self, minimal_cfg):
1121
+ cfg = DictDefault({"dpo_beta": 0.2}) | minimal_cfg
1122
+
1123
+ with self._caplog.at_level(logging.WARNING):
1124
+ new_cfg = validate_config(cfg)
1125
+ assert new_cfg["rl_beta"] == 0.2
1126
+ assert new_cfg["dpo_beta"] is None
1127
+ assert len(self._caplog.records) == 1
1128
+
1129
 
1130
  class TestValidationCheckModelConfig(BaseValidation):
1131
  """