winglian commited on
Commit
c996881
·
unverified ·
1 Parent(s): 1f151c0

add support for rpo_alpha (#1681)

Browse files

* add support for rpo_alpha

* Add smoke test for dpo + nll loss

requirements.txt CHANGED
@@ -39,6 +39,6 @@ s3fs
39
  gcsfs
40
  # adlfs
41
 
42
- trl==0.8.6
43
  zstandard==0.22.0
44
  fastcore
 
39
  gcsfs
40
  # adlfs
41
 
42
+ trl @ git+https://github.com/huggingface/trl.git@f18253bf2d747f68acc9cd89da95c85ebf59dbb9
43
  zstandard==0.22.0
44
  fastcore
src/axolotl/core/trainer_builder.py CHANGED
@@ -30,7 +30,7 @@ from transformers import (
30
  )
31
  from transformers.trainer_utils import seed_worker
32
  from transformers.utils import is_sagemaker_mp_enabled
33
- from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
34
  from trl.trainer.utils import pad_to_length
35
 
36
  from axolotl.loraplus import create_loraplus_optimizer
@@ -238,6 +238,13 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
238
  """
239
 
240
 
 
 
 
 
 
 
 
241
  @dataclass
242
  class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
243
  """
@@ -1608,7 +1615,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
1608
  # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
1609
  training_args_kwargs["beta"] = self.cfg.orpo_alpha
1610
 
1611
- training_args_cls = AxolotlTrainingArguments
 
 
1612
  if self.cfg.rl == "orpo":
1613
  training_args_cls = AxolotlORPOConfig
1614
  training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
 
30
  )
31
  from transformers.trainer_utils import seed_worker
32
  from transformers.utils import is_sagemaker_mp_enabled
33
+ from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer
34
  from trl.trainer.utils import pad_to_length
35
 
36
  from axolotl.loraplus import create_loraplus_optimizer
 
238
  """
239
 
240
 
241
+ @dataclass
242
+ class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
243
+ """
244
+ DPO config for DPO training
245
+ """
246
+
247
+
248
  @dataclass
249
  class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
250
  """
 
1615
  # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
1616
  training_args_kwargs["beta"] = self.cfg.orpo_alpha
1617
 
1618
+ training_args_cls = AxolotlDPOConfig
1619
+ if self.cfg.rpo_alpha is not None:
1620
+ training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
1621
  if self.cfg.rl == "orpo":
1622
  training_args_cls = AxolotlORPOConfig
1623
  training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
src/axolotl/utils/config/models/input/v0_4_1/__init__.py CHANGED
@@ -619,6 +619,7 @@ class AxolotlInputConfig(
619
  neftune_noise_alpha: Optional[float] = None
620
 
621
  orpo_alpha: Optional[float] = None
 
622
 
623
  kto_desirable_weight: Optional[float] = None
624
  kto_undesirable_weight: Optional[float] = None
 
619
  neftune_noise_alpha: Optional[float] = None
620
 
621
  orpo_alpha: Optional[float] = None
622
+ rpo_alpha: Optional[float] = None
623
 
624
  kto_desirable_weight: Optional[float] = None
625
  kto_undesirable_weight: Optional[float] = None
tests/e2e/test_dpo.py CHANGED
@@ -70,6 +70,51 @@ class TestDPOLlamaLora(unittest.TestCase):
70
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
71
  assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  @with_temp_dir
74
  def test_kto_pair_lora(self, temp_dir):
75
  # pylint: disable=duplicate-code
 
70
  train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
71
  assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
72
 
73
+ @with_temp_dir
74
+ def test_dpo_nll_lora(self, temp_dir):
75
+ # pylint: disable=duplicate-code
76
+ cfg = DictDefault(
77
+ {
78
+ "base_model": "JackFram/llama-68m",
79
+ "tokenizer_type": "LlamaTokenizer",
80
+ "sequence_len": 1024,
81
+ "load_in_8bit": True,
82
+ "adapter": "lora",
83
+ "lora_r": 64,
84
+ "lora_alpha": 32,
85
+ "lora_dropout": 0.1,
86
+ "lora_target_linear": True,
87
+ "special_tokens": {},
88
+ "rl": "dpo",
89
+ "rpo_alpha": 0.5,
90
+ "datasets": [
91
+ {
92
+ "path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
93
+ "type": "chatml.ultra",
94
+ "split": "train",
95
+ },
96
+ ],
97
+ "num_epochs": 1,
98
+ "micro_batch_size": 4,
99
+ "gradient_accumulation_steps": 1,
100
+ "output_dir": temp_dir,
101
+ "learning_rate": 0.00001,
102
+ "optimizer": "paged_adamw_8bit",
103
+ "lr_scheduler": "cosine",
104
+ "max_steps": 20,
105
+ "save_steps": 10,
106
+ "warmup_steps": 5,
107
+ "gradient_checkpointing": True,
108
+ "gradient_checkpointing_kwargs": {"use_reentrant": True},
109
+ }
110
+ )
111
+ normalize_config(cfg)
112
+ cli_args = TrainerCliArgs()
113
+ dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
114
+
115
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
116
+ assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
117
+
118
  @with_temp_dir
119
  def test_kto_pair_lora(self, temp_dir):
120
  # pylint: disable=duplicate-code