winglian plaguss HF staff commited on
Commit
7523d1f
1 Parent(s): 5439707

DPO cleanup (#1126)

Browse files

* cleanup dpo to be a little more extensible, add zephyr/nectar strategy

* fix eos slash

* support for eval split

* fix kwargs

* handle empty evals

* don't load peft model for dpo

* ensure dpo traning args gets bf16 for peft if applicable

* fix duplicate kwargs for bf16

* make sure to respect the configured lr scheduler

* supprt trainer callback to push config to wandb

* set dataloader preload args

* ensure that we are loading the lora when merging

* Update src/axolotl/utils/data.py

Co-authored-by: Agus <agustin.piqueres@gmail.com>

* support local datasets for dpo

Co-authored-by: Agus <agustin.piqueres@gmail.com>

* chore: lint

* dpo/kto/ipo smoke tests w lora, simplify dpo dataset type names

* add split to dpo tests

* fix rebase/merging error

* handle edge case w logging

* use accelerator for dpo datasets so it doesn't break the logger

* missing args

* validate checkpoint is an adapter for now

* log warning when dataset strategy is not loadable

---------

Co-authored-by: Agus <agustin.piqueres@gmail.com>

src/axolotl/cli/__init__.py CHANGED
@@ -17,7 +17,6 @@ import yaml
17
  # add src to the pythonpath so we don't need to pip install this
18
  from accelerate.commands.config import config_args
19
  from art import text2art
20
- from datasets import concatenate_datasets, load_dataset
21
  from huggingface_hub import HfApi
22
  from huggingface_hub.utils import LocalTokenNotFoundError
23
  from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
@@ -30,7 +29,7 @@ from axolotl.utils.config import (
30
  normalize_config,
31
  validate_config,
32
  )
33
- from axolotl.utils.data import prepare_dataset
34
  from axolotl.utils.dict import DictDefault
35
  from axolotl.utils.distributed import is_main_process
36
  from axolotl.utils.mlflow_ import setup_mlflow_env_vars
@@ -343,81 +342,7 @@ def load_rl_datasets(
343
  cfg: DictDefault,
344
  cli_args: TrainerCliArgs, # pylint: disable=unused-argument
345
  ) -> TrainDatasetMeta:
346
- train_datasets: List[Any] = []
347
- for i, ds_cfg in enumerate(cfg.datasets):
348
- train_datasets.insert(i, load_dataset(ds_cfg["path"], split=ds_cfg["split"]))
349
- # eval_dataset = load_dataset(
350
- # cfg.test_datasets[0]["path"], split=cfg.test_datasets[0]["split"]
351
- # )
352
- eval_dataset = None
353
-
354
- def argilla_apply_chatml(sample): # pylint: disable=possibly-unused-variable
355
- if "system" in sample and sample["system"]:
356
- sample["prompt"] = (
357
- f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
358
- f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
359
- )
360
- else:
361
- sample[
362
- "prompt"
363
- ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
364
- sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
365
- sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
366
- return sample
367
-
368
- def intel_apply_chatml(sample): # pylint: disable=possibly-unused-variable
369
- if "system" in sample and sample["system"]:
370
- sample["prompt"] = (
371
- f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
372
- f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
373
- )
374
- else:
375
- sample[
376
- "prompt"
377
- ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
378
- sample["chosen"] = f"{sample['chosen']}<|im_end|>"
379
- sample["rejected"] = f"{sample['rejected']}<|im_end|>"
380
- return sample
381
-
382
- def apply_chatml(sample): # pylint: disable=possibly-unused-variable
383
- if "system" in sample and sample["system"]:
384
- sample["prompt"] = (
385
- f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
386
- f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
387
- )
388
- else:
389
- sample[
390
- "prompt"
391
- ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
392
- sample["chosen"] = f"{sample['chosen']}<|im_end|>"
393
- sample["rejected"] = f"{sample['rejected']}<|im_end|>"
394
- return sample
395
-
396
- def ultra_apply_chatml(sample): # pylint: disable=possibly-unused-variable
397
- if "system" in sample and sample["system"]:
398
- sample["prompt"] = (
399
- f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
400
- f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
401
- )
402
- else:
403
- sample[
404
- "prompt"
405
- ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
406
- sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
407
- sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
408
- return sample
409
-
410
- for i, data_set in enumerate(train_datasets):
411
- _type = cfg.datasets[i]["type"]
412
- ds_type_fn = locals()[_type]
413
- train_datasets[i] = data_set.map(
414
- ds_type_fn,
415
- desc="Mapping RL Dataset",
416
- )
417
- train_dataset = concatenate_datasets(train_datasets)
418
-
419
- # eval_dataset = eval_dataset.map(intel_apply_chatml)
420
-
421
  total_num_steps = int(
422
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
423
  )
 
17
  # add src to the pythonpath so we don't need to pip install this
18
  from accelerate.commands.config import config_args
19
  from art import text2art
 
20
  from huggingface_hub import HfApi
21
  from huggingface_hub.utils import LocalTokenNotFoundError
22
  from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
 
29
  normalize_config,
30
  validate_config,
31
  )
32
+ from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
33
  from axolotl.utils.dict import DictDefault
34
  from axolotl.utils.distributed import is_main_process
35
  from axolotl.utils.mlflow_ import setup_mlflow_env_vars
 
342
  cfg: DictDefault,
343
  cli_args: TrainerCliArgs, # pylint: disable=unused-argument
344
  ) -> TrainDatasetMeta:
345
+ train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
346
  total_num_steps = int(
347
  math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
348
  )
src/axolotl/core/trainer_builder.py CHANGED
@@ -12,14 +12,19 @@ from abc import abstractmethod
12
  from dataclasses import dataclass, field
13
  from functools import wraps
14
  from pathlib import Path
15
- from typing import Optional, Type, Union
16
 
17
  import torch
18
  import transformers
19
  from datasets import Dataset
20
  from torch.optim.lr_scheduler import OneCycleLR
21
  from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
22
- from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
 
 
 
 
 
23
  from transformers.trainer_utils import seed_worker
24
  from trl import DPOTrainer
25
 
@@ -460,6 +465,7 @@ class TrainerBuilderBase(abc.ABC):
460
  _train_dataset = None
461
  _eval_dataset = None
462
  _model_ref = None
 
463
 
464
  def __init__(self, cfg, model, tokenizer):
465
  self.cfg = cfg
@@ -490,13 +496,26 @@ class TrainerBuilderBase(abc.ABC):
490
  def eval_dataset(self, dataset):
491
  self._eval_dataset = dataset
492
 
 
 
 
 
 
 
 
 
493
  @abstractmethod
494
  def build(self, total_num_steps):
495
  pass
496
 
497
- @abstractmethod
498
- def get_callbacks(self):
499
- pass
 
 
 
 
 
500
 
501
  @abstractmethod
502
  def get_post_trainer_create_callbacks(self, trainer):
@@ -504,12 +523,6 @@ class TrainerBuilderBase(abc.ABC):
504
  Callbacks added after the trainer is created, usually b/c these need access to the trainer
505
  """
506
 
507
-
508
- class HFCausalTrainerBuilder(TrainerBuilderBase):
509
- """
510
- Build the HuggingFace training args/trainer for Causal models
511
- """
512
-
513
  def hook_pre_create_training_args(self, training_arguments_kwargs):
514
  # TODO
515
  return training_arguments_kwargs
@@ -526,10 +539,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
526
  # TODO
527
  return trainer
528
 
 
 
 
 
 
 
529
  def get_callbacks(self):
530
- callbacks = []
531
  callbacks.append(GPUStatsCallback(self.cfg))
532
- callbacks.append(EvalFirstStepCallback)
533
 
534
  if self.cfg.relora_steps:
535
  callbacks.append(ReLoRACallback(self.cfg))
@@ -538,7 +557,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
538
  hasattr(self.model, "use_bettertransformer")
539
  and self.model.use_bettertransformer is True
540
  ):
541
- callbacks.append(SaveBetterTransformerModelCallback)
542
 
543
  if self.cfg.use_wandb:
544
  callbacks.append(
@@ -931,7 +950,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
931
  """
932
 
933
  def get_callbacks(self):
934
- callbacks = []
935
  return callbacks
936
 
937
  def get_post_trainer_create_callbacks(self, trainer):
@@ -949,21 +968,60 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
949
  ]:
950
  if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
951
  training_args_kwargs[arg] = getattr(self.cfg, arg)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
952
  training_args = TrainingArguments(
953
  per_device_train_batch_size=self.cfg.micro_batch_size,
954
- max_steps=total_num_steps,
955
  remove_unused_columns=False,
956
  gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
957
  learning_rate=self.cfg.learning_rate,
958
- evaluation_strategy="no",
959
- # eval_steps=self.cfg.eval_steps,
960
  save_strategy="steps",
961
  save_steps=self.cfg.save_steps,
962
  output_dir=self.cfg.output_dir,
963
  warmup_steps=self.cfg.warmup_steps,
964
- bf16=True,
965
  gradient_checkpointing=self.cfg.gradient_checkpointing,
966
- gradient_checkpointing_kwargs={"use_reentrant": False},
 
967
  logging_first_step=True,
968
  logging_steps=1,
969
  optim=self.cfg.optimizer,
@@ -982,22 +1040,27 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
982
  dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
983
  elif self.cfg.rl == "kto_pair":
984
  dpo_trainer_kwargs["loss_type"] = "kto_pair"
985
-
 
 
 
986
  dpo_trainer = DPOTrainer(
987
  self.model,
988
  self.model_ref,
989
  args=training_args,
990
  beta=self.cfg.dpo_beta or 0.1,
991
  train_dataset=self.train_dataset,
992
- # eval_dataset=self.eval_dataset,
993
- eval_dataset=None,
994
  tokenizer=self.tokenizer,
995
  max_length=self.cfg.sequence_len,
996
  max_target_length=None,
997
  max_prompt_length=self.cfg.sequence_len,
998
  generate_during_eval=True,
 
999
  **dpo_trainer_kwargs,
1000
  )
 
 
 
1001
 
1002
  return dpo_trainer
1003
 
 
12
  from dataclasses import dataclass, field
13
  from functools import wraps
14
  from pathlib import Path
15
+ from typing import List, Optional, Type, Union
16
 
17
  import torch
18
  import transformers
19
  from datasets import Dataset
20
  from torch.optim.lr_scheduler import OneCycleLR
21
  from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
22
+ from transformers import (
23
+ EarlyStoppingCallback,
24
+ Trainer,
25
+ TrainerCallback,
26
+ TrainingArguments,
27
+ )
28
  from transformers.trainer_utils import seed_worker
29
  from trl import DPOTrainer
30
 
 
465
  _train_dataset = None
466
  _eval_dataset = None
467
  _model_ref = None
468
+ _peft_config = None
469
 
470
  def __init__(self, cfg, model, tokenizer):
471
  self.cfg = cfg
 
496
  def eval_dataset(self, dataset):
497
  self._eval_dataset = dataset
498
 
499
+ @property
500
+ def peft_config(self):
501
+ return self._peft_config
502
+
503
+ @peft_config.setter
504
+ def peft_config(self, peft_config):
505
+ self._peft_config = peft_config
506
+
507
  @abstractmethod
508
  def build(self, total_num_steps):
509
  pass
510
 
511
+ def get_callbacks(self) -> List[TrainerCallback]:
512
+ callbacks = []
513
+ if self.cfg.use_wandb:
514
+ callbacks.append(
515
+ SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
516
+ )
517
+
518
+ return callbacks
519
 
520
  @abstractmethod
521
  def get_post_trainer_create_callbacks(self, trainer):
 
523
  Callbacks added after the trainer is created, usually b/c these need access to the trainer
524
  """
525
 
 
 
 
 
 
 
526
  def hook_pre_create_training_args(self, training_arguments_kwargs):
527
  # TODO
528
  return training_arguments_kwargs
 
539
  # TODO
540
  return trainer
541
 
542
+
543
+ class HFCausalTrainerBuilder(TrainerBuilderBase):
544
+ """
545
+ Build the HuggingFace training args/trainer for Causal models
546
+ """
547
+
548
  def get_callbacks(self):
549
+ callbacks = super().get_callbacks()
550
  callbacks.append(GPUStatsCallback(self.cfg))
551
+ callbacks.append(EvalFirstStepCallback())
552
 
553
  if self.cfg.relora_steps:
554
  callbacks.append(ReLoRACallback(self.cfg))
 
557
  hasattr(self.model, "use_bettertransformer")
558
  and self.model.use_bettertransformer is True
559
  ):
560
+ callbacks.append(SaveBetterTransformerModelCallback())
561
 
562
  if self.cfg.use_wandb:
563
  callbacks.append(
 
950
  """
951
 
952
  def get_callbacks(self):
953
+ callbacks = super().get_callbacks()
954
  return callbacks
955
 
956
  def get_post_trainer_create_callbacks(self, trainer):
 
968
  ]:
969
  if hasattr(self.cfg, arg) and getattr(self.cfg, arg) is not None:
970
  training_args_kwargs[arg] = getattr(self.cfg, arg)
971
+
972
+ if self.cfg.hub_model_id:
973
+ training_args_kwargs["hub_model_id"] = self.cfg.hub_model_id
974
+ training_args_kwargs["push_to_hub"] = True
975
+ training_args_kwargs["hub_private_repo"] = True
976
+ training_args_kwargs["hub_always_push"] = True
977
+
978
+ if self.cfg.hub_strategy:
979
+ training_args_kwargs["hub_strategy"] = self.cfg.hub_strategy
980
+
981
+ if self.cfg.save_safetensors is not None:
982
+ training_args_kwargs["save_safetensors"] = self.cfg.save_safetensors
983
+
984
+ if self.eval_dataset:
985
+ training_args_kwargs["evaluation_strategy"] = "steps"
986
+ training_args_kwargs["eval_steps"] = self.cfg.eval_steps
987
+ else:
988
+ training_args_kwargs["evaluation_strategy"] = "no"
989
+ if self.cfg.bf16 or self.cfg.bfloat16:
990
+ training_args_kwargs["bf16"] = True
991
+
992
+ training_args_kwargs["lr_scheduler_type"] = (
993
+ self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
994
+ )
995
+ training_args_kwargs["lr_scheduler_kwargs"] = (
996
+ self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
997
+ )
998
+
999
+ if self.cfg.dataloader_pin_memory is not None:
1000
+ training_args_kwargs[
1001
+ "dataloader_pin_memory"
1002
+ ] = self.cfg.dataloader_pin_memory
1003
+ if self.cfg.dataloader_num_workers is not None:
1004
+ training_args_kwargs[
1005
+ "dataloader_num_workers"
1006
+ ] = self.cfg.dataloader_num_workers
1007
+ if self.cfg.dataloader_prefetch_factor is not None:
1008
+ training_args_kwargs[
1009
+ "dataloader_prefetch_factor"
1010
+ ] = self.cfg.dataloader_prefetch_factor
1011
+
1012
  training_args = TrainingArguments(
1013
  per_device_train_batch_size=self.cfg.micro_batch_size,
1014
+ max_steps=self.cfg.max_steps or total_num_steps,
1015
  remove_unused_columns=False,
1016
  gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
1017
  learning_rate=self.cfg.learning_rate,
 
 
1018
  save_strategy="steps",
1019
  save_steps=self.cfg.save_steps,
1020
  output_dir=self.cfg.output_dir,
1021
  warmup_steps=self.cfg.warmup_steps,
 
1022
  gradient_checkpointing=self.cfg.gradient_checkpointing,
1023
+ gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs
1024
+ or {"use_reentrant": False},
1025
  logging_first_step=True,
1026
  logging_steps=1,
1027
  optim=self.cfg.optimizer,
 
1040
  dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
1041
  elif self.cfg.rl == "kto_pair":
1042
  dpo_trainer_kwargs["loss_type"] = "kto_pair"
1043
+ if self.eval_dataset:
1044
+ dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
1045
+ if self.cfg.adapter and self.peft_config:
1046
+ dpo_trainer_kwargs["peft_config"] = self.peft_config
1047
  dpo_trainer = DPOTrainer(
1048
  self.model,
1049
  self.model_ref,
1050
  args=training_args,
1051
  beta=self.cfg.dpo_beta or 0.1,
1052
  train_dataset=self.train_dataset,
 
 
1053
  tokenizer=self.tokenizer,
1054
  max_length=self.cfg.sequence_len,
1055
  max_target_length=None,
1056
  max_prompt_length=self.cfg.sequence_len,
1057
  generate_during_eval=True,
1058
+ callbacks=self.get_callbacks(),
1059
  **dpo_trainer_kwargs,
1060
  )
1061
+ dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
1062
+ for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
1063
+ dpo_trainer.add_callback(callback)
1064
 
1065
  return dpo_trainer
1066
 
src/axolotl/prompt_strategies/dpo/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ module for DPO style dataset transform strategies
3
+ """
4
+
5
+ import importlib
6
+ import logging
7
+
8
+ LOG = logging.getLogger("axolotl")
9
+
10
+
11
+ def load(strategy, cfg):
12
+ try:
13
+ load_fn = strategy.split(".")[-1]
14
+ strategy = ".".join(strategy.split(".")[:-1])
15
+ mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies.dpo")
16
+ func = getattr(mod, load_fn)
17
+ load_kwargs = {}
18
+ return func(cfg, **load_kwargs)
19
+ except Exception: # pylint: disable=broad-exception-caught
20
+ LOG.warning(f"unable to load strategy {strategy}")
21
+ return None
src/axolotl/prompt_strategies/dpo/chatml.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DPO strategies for chatml
3
+ """
4
+
5
+
6
+ def argilla(
7
+ cfg,
8
+ ): # pylint: disable=possibly-unused-variable,unused-argument
9
+ def transform_fn(sample):
10
+ if "system" in sample and sample["system"]:
11
+ sample["prompt"] = (
12
+ f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
13
+ f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
14
+ )
15
+ else:
16
+ sample[
17
+ "prompt"
18
+ ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
19
+ sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
20
+ sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
21
+ return sample
22
+
23
+ return transform_fn
24
+
25
+
26
+ def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
27
+ """
28
+ For Intel Orca DPO Pairs
29
+ """
30
+
31
+ def transform_fn(sample):
32
+ if "system" in sample and sample["system"]:
33
+ sample["prompt"] = (
34
+ f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
35
+ f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
36
+ )
37
+ else:
38
+ sample[
39
+ "prompt"
40
+ ] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
41
+ sample["chosen"] = f"{sample['chosen']}<|im_end|>"
42
+ sample["rejected"] = f"{sample['rejected']}<|im_end|>"
43
+ return sample
44
+
45
+ return transform_fn
46
+
47
+
48
+ def prompt_pairs(cfg): # pylint: disable=possibly-unused-variable,unused-argument
49
+ def transform_fn(sample):
50
+ if "system" in sample and sample["system"]:
51
+ sample["prompt"] = (
52
+ f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
53
+ f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
54
+ )
55
+ else:
56
+ sample[
57
+ "prompt"
58
+ ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
59
+ sample["chosen"] = f"{sample['chosen']}<|im_end|>"
60
+ sample["rejected"] = f"{sample['rejected']}<|im_end|>"
61
+ return sample
62
+
63
+ return transform_fn
64
+
65
+
66
+ def ultra(cfg): # pylint: disable=possibly-unused-variable,unused-argument
67
+ """
68
+ for ultrafeedback binarized conversations
69
+ """
70
+
71
+ def transform_fn(sample):
72
+ if "system" in sample and sample["system"]:
73
+ sample["prompt"] = (
74
+ f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
75
+ f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
76
+ )
77
+ else:
78
+ sample[
79
+ "prompt"
80
+ ] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
81
+ sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
82
+ sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
83
+ return sample
84
+
85
+ return transform_fn
src/axolotl/prompt_strategies/dpo/zephyr.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DPO strategies for zephyr
3
+ """
4
+
5
+
6
+ def nectar(cfg): # pylint: disable=possibly-unused-variable,unused-argument
7
+ def transform_fn(sample):
8
+ data = {}
9
+ data["prompt"] = (
10
+ "<|system|>\n</s>\n"
11
+ "<|user|>\n"
12
+ f"{sample['prompt']}</s>\n"
13
+ "<|assistant|>\n"
14
+ )
15
+ answers = sorted(sample["answers"], key=lambda x: x["rank"])
16
+ data["chosen"] = answers[-1]["answer"]
17
+ data["rejected"] = answers[-2]["answer"]
18
+
19
+ return data
20
+
21
+ return transform_fn
src/axolotl/train.py CHANGED
@@ -96,7 +96,12 @@ def train(
96
  freeze_parameters_except(model, cfg.unfrozen_parameters)
97
 
98
  trainer = setup_trainer(
99
- cfg, train_dataset, eval_dataset, (model, model_ref), tokenizer, total_num_steps
 
 
 
 
 
100
  )
101
 
102
  if hasattr(model, "config"):
 
96
  freeze_parameters_except(model, cfg.unfrozen_parameters)
97
 
98
  trainer = setup_trainer(
99
+ cfg,
100
+ train_dataset,
101
+ eval_dataset,
102
+ (model, model_ref, peft_config),
103
+ tokenizer,
104
+ total_num_steps,
105
  )
106
 
107
  if hasattr(model, "config"):
src/axolotl/utils/data.py CHANGED
@@ -4,7 +4,7 @@ import hashlib
4
  import logging
5
  from collections import defaultdict
6
  from pathlib import Path
7
- from typing import Dict, List, Optional, Tuple, Union
8
 
9
  import torch
10
  from datasets import (
@@ -21,6 +21,7 @@ from transformers import PreTrainedTokenizerBase
21
  from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
22
  from axolotl.datasets import TokenizedPromptDataset
23
  from axolotl.prompt_strategies import load
 
24
  from axolotl.prompt_tokenizers import (
25
  AlpacaMultipleChoicePromptTokenizingStrategy,
26
  AlpacaPromptTokenizingStrategy,
@@ -850,3 +851,50 @@ def encode_packed_pretraining(
850
  chunked_data[feature].append(collated_features[feature].squeeze(0))
851
 
852
  return chunked_data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import logging
5
  from collections import defaultdict
6
  from pathlib import Path
7
+ from typing import Any, Dict, List, Optional, Tuple, Union
8
 
9
  import torch
10
  from datasets import (
 
21
  from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
22
  from axolotl.datasets import TokenizedPromptDataset
23
  from axolotl.prompt_strategies import load
24
+ from axolotl.prompt_strategies.dpo import load as load_dpo
25
  from axolotl.prompt_tokenizers import (
26
  AlpacaMultipleChoicePromptTokenizingStrategy,
27
  AlpacaPromptTokenizingStrategy,
 
851
  chunked_data[feature].append(collated_features[feature].squeeze(0))
852
 
853
  return chunked_data
854
+
855
+
856
+ def load_prepare_dpo_datasets(cfg):
857
+ def load_split(dataset_cfgs, _cfg):
858
+ split_datasets: List[Any] = []
859
+ for i, ds_cfg in enumerate(dataset_cfgs):
860
+ if ds_cfg["ds_type"] == "json":
861
+ for data_file in ds_cfg["data_files"]:
862
+ data_files = {ds_cfg["split"]: data_file}
863
+ ds = load_dataset( # pylint: disable=invalid-name
864
+ "json",
865
+ data_files=data_files,
866
+ split=ds_cfg["split"],
867
+ )
868
+ split_datasets.insert(i, ds)
869
+ else:
870
+ ds = load_dataset( # pylint: disable=invalid-name
871
+ ds_cfg["path"],
872
+ split=ds_cfg["split"],
873
+ )
874
+ split_datasets.insert(i, ds)
875
+
876
+ for i, data_set in enumerate(split_datasets):
877
+ _type = dataset_cfgs[i]["type"]
878
+ if _type:
879
+ ds_transform_fn = load_dpo(_type, _cfg)
880
+ split_datasets[i] = data_set.map(
881
+ ds_transform_fn,
882
+ desc="Mapping RL Dataset",
883
+ )
884
+ else:
885
+ # If no `type` is provided, assume the dataset is already in the expected format with
886
+ # "prompt", "chosen" and "rejected" already preprocessed
887
+ split_datasets[i] = data_set
888
+
889
+ return concatenate_datasets(split_datasets)
890
+
891
+ with zero_first(is_main_process()):
892
+ train_dataset = load_split(cfg.datasets, cfg)
893
+
894
+ eval_dataset = None
895
+ if cfg.test_datasets:
896
+ eval_dataset = load_split(cfg.test_datasets, cfg)
897
+ if not eval_dataset:
898
+ eval_dataset = None
899
+
900
+ return train_dataset, eval_dataset
src/axolotl/utils/models.py CHANGED
@@ -682,7 +682,12 @@ def load_model(
682
 
683
  lora_config = None
684
  if not reference_model or cfg.lora_model_dir:
685
- model, lora_config = load_adapter(model, cfg, cfg.adapter)
 
 
 
 
 
686
 
687
  if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
688
  model.to(f"cuda:{cfg.local_rank}")
@@ -770,8 +775,8 @@ def find_all_linear_names(model):
770
  return list(lora_module_names)
771
 
772
 
773
- def load_lora(model, cfg, inference=False):
774
- # type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
775
 
776
  from peft import LoraConfig, PeftModel, get_peft_model
777
 
@@ -794,6 +799,9 @@ def load_lora(model, cfg, inference=False):
794
  task_type="CAUSAL_LM",
795
  )
796
 
 
 
 
797
  if cfg.lora_model_dir:
798
  LOG.debug("Loading pretained PEFT - LoRA")
799
  model_kwargs: Any = {}
 
682
 
683
  lora_config = None
684
  if not reference_model or cfg.lora_model_dir:
685
+ # if we're not loading the reference model, then we're loading the model for training
686
+ # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
687
+ if cfg.adapter and cfg.rl in ["dpo", "ipo", "kto_pair"] and not cfg.merge_lora:
688
+ _, lora_config = load_lora(model, cfg, inference=False, config_only=True)
689
+ else:
690
+ model, lora_config = load_adapter(model, cfg, cfg.adapter)
691
 
692
  if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
693
  model.to(f"cuda:{cfg.local_rank}")
 
775
  return list(lora_module_names)
776
 
777
 
778
+ def load_lora(model, cfg, inference=False, config_only=False):
779
+ # type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
780
 
781
  from peft import LoraConfig, PeftModel, get_peft_model
782
 
 
799
  task_type="CAUSAL_LM",
800
  )
801
 
802
+ if config_only:
803
+ return None, lora_config
804
+
805
  if cfg.lora_model_dir:
806
  LOG.debug("Loading pretained PEFT - LoRA")
807
  model_kwargs: Any = {}
src/axolotl/utils/trainer.py CHANGED
@@ -316,9 +316,10 @@ def prepare_optim_env(cfg):
316
 
317
 
318
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
319
- if cfg.rl:
320
  trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
321
  trainer_builder.model_ref = model[1]
 
322
  else:
323
  trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer)
324
 
 
316
 
317
 
318
  def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
319
+ if cfg.rl in ["dpo", "ipo", "kto_pair"]:
320
  trainer_builder = HFDPOTrainerBuilder(cfg, model[0], tokenizer)
321
  trainer_builder.model_ref = model[1]
322
+ trainer_builder.peft_config = model[2]
323
  else:
324
  trainer_builder = HFCausalTrainerBuilder(cfg, model[0], tokenizer)
325
 
tests/e2e/test_dpo.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ E2E tests for lora llama
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import unittest
8
+ from pathlib import Path
9
+
10
+ from axolotl.cli import load_rl_datasets
11
+ from axolotl.common.cli import TrainerCliArgs
12
+ from axolotl.train import train
13
+ from axolotl.utils.config import normalize_config
14
+ from axolotl.utils.dict import DictDefault
15
+
16
+ from .utils import with_temp_dir
17
+
18
+ LOG = logging.getLogger("axolotl.tests.e2e")
19
+ os.environ["WANDB_DISABLED"] = "true"
20
+
21
+
22
+ class TestDPOLlamaLora(unittest.TestCase):
23
+ """
24
+ Test case for DPO Llama models using LoRA
25
+ """
26
+
27
+ @with_temp_dir
28
+ def test_dpo_lora(self, temp_dir):
29
+ # pylint: disable=duplicate-code
30
+ cfg = DictDefault(
31
+ {
32
+ "base_model": "JackFram/llama-68m",
33
+ "tokenizer_type": "LlamaTokenizer",
34
+ "sequence_len": 1024,
35
+ "load_in_8bit": True,
36
+ "adapter": "lora",
37
+ "lora_r": 64,
38
+ "lora_alpha": 32,
39
+ "lora_dropout": 0.1,
40
+ "lora_target_linear": True,
41
+ "special_tokens": {},
42
+ "rl": "dpo",
43
+ "datasets": [
44
+ {
45
+ "path": "Intel/orca_dpo_pairs",
46
+ "type": "chatml.intel",
47
+ "split": "train",
48
+ },
49
+ ],
50
+ "num_epochs": 1,
51
+ "micro_batch_size": 4,
52
+ "gradient_accumulation_steps": 1,
53
+ "output_dir": temp_dir,
54
+ "learning_rate": 0.00001,
55
+ "optimizer": "paged_adamw_8bit",
56
+ "lr_scheduler": "cosine",
57
+ "max_steps": 20,
58
+ "save_steps": 10,
59
+ "warmup_steps": 5,
60
+ "gradient_checkpointing": True,
61
+ "gradient_checkpointing_kwargs": {"use_reentrant": True},
62
+ }
63
+ )
64
+ normalize_config(cfg)
65
+ cli_args = TrainerCliArgs()
66
+ dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
67
+
68
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
69
+ assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
70
+
71
+ @with_temp_dir
72
+ def test_kto_pair_lora(self, temp_dir):
73
+ # pylint: disable=duplicate-code
74
+ cfg = DictDefault(
75
+ {
76
+ "base_model": "JackFram/llama-68m",
77
+ "tokenizer_type": "LlamaTokenizer",
78
+ "sequence_len": 1024,
79
+ "load_in_8bit": True,
80
+ "adapter": "lora",
81
+ "lora_r": 64,
82
+ "lora_alpha": 32,
83
+ "lora_dropout": 0.1,
84
+ "lora_target_linear": True,
85
+ "special_tokens": {},
86
+ "rl": "kto_pair",
87
+ "datasets": [
88
+ {
89
+ "path": "Intel/orca_dpo_pairs",
90
+ "type": "chatml.intel",
91
+ "split": "train",
92
+ },
93
+ ],
94
+ "num_epochs": 1,
95
+ "micro_batch_size": 4,
96
+ "gradient_accumulation_steps": 1,
97
+ "output_dir": temp_dir,
98
+ "learning_rate": 0.00001,
99
+ "optimizer": "paged_adamw_8bit",
100
+ "lr_scheduler": "cosine",
101
+ "max_steps": 20,
102
+ "save_steps": 10,
103
+ "warmup_steps": 5,
104
+ "gradient_checkpointing": True,
105
+ "gradient_checkpointing_kwargs": {"use_reentrant": True},
106
+ }
107
+ )
108
+ normalize_config(cfg)
109
+ cli_args = TrainerCliArgs()
110
+ dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
111
+
112
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
113
+ assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
114
+
115
+ @with_temp_dir
116
+ def test_ipo_lora(self, temp_dir):
117
+ # pylint: disable=duplicate-code
118
+ cfg = DictDefault(
119
+ {
120
+ "base_model": "JackFram/llama-68m",
121
+ "tokenizer_type": "LlamaTokenizer",
122
+ "sequence_len": 1024,
123
+ "load_in_8bit": True,
124
+ "adapter": "lora",
125
+ "lora_r": 64,
126
+ "lora_alpha": 32,
127
+ "lora_dropout": 0.1,
128
+ "lora_target_linear": True,
129
+ "special_tokens": {},
130
+ "rl": "ipo",
131
+ "datasets": [
132
+ {
133
+ "path": "Intel/orca_dpo_pairs",
134
+ "type": "chatml.intel",
135
+ "split": "train",
136
+ },
137
+ ],
138
+ "num_epochs": 1,
139
+ "micro_batch_size": 4,
140
+ "gradient_accumulation_steps": 1,
141
+ "output_dir": temp_dir,
142
+ "learning_rate": 0.00001,
143
+ "optimizer": "paged_adamw_8bit",
144
+ "lr_scheduler": "cosine",
145
+ "max_steps": 20,
146
+ "save_steps": 10,
147
+ "warmup_steps": 5,
148
+ "gradient_checkpointing": True,
149
+ "gradient_checkpointing_kwargs": {"use_reentrant": True},
150
+ }
151
+ )
152
+ normalize_config(cfg)
153
+ cli_args = TrainerCliArgs()
154
+ dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
155
+
156
+ train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
157
+ assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()