DPO Trainer
Overview
TRL supports the DPO Trainer for training language models from preference data, as described in the paper Direct Preference Optimization: Your Language Model is Secretly a Reward Model by Rafael Rafailov, Archit Sharma, Eric Mitchell, Stefano Ermon, Christopher D. Manning, Chelsea Finn.
The abstract from the paper is the following:
While large-scale unsupervised language models (LMs) learn broad world knowledge and some reasoning skills, achieving precise control of their behavior is difficult due to the completely unsupervised nature of their training. Existing methods for gaining such steerability collect human labels of the relative quality of model generations and fine-tune the unsupervised LM to align with these preferences, often with reinforcement learning from human feedback (RLHF). However, RLHF is a complex and often unstable procedure, first fitting a reward model that reflects the human preferences, and then fine-tuning the large unsupervised LM using reinforcement learning to maximize this estimated reward without drifting too far from the original model. In this paper we introduce a new parameterization of the reward model in RLHF that enables extraction of the corresponding optimal policy in closed form, allowing us to solve the standard RLHF problem with only a simple classification loss. The resulting algorithm, which we call Direct Preference Optimization (DPO), is stable, performant, and computationally lightweight, eliminating the need for sampling from the LM during fine-tuning or performing significant hyperparameter tuning. Our experiments show that DPO can fine-tune LMs to align with human preferences as well as or better than existing methods. Notably, fine-tuning with DPO exceeds PPO-based RLHF in ability to control sentiment of generations, and matches or improves response quality in summarization and single-turn dialogue while being substantially simpler to implement and train.
The first step is to train an SFT model, to ensure the data we train on is in-distribution for the DPO algorithm.
Then, fine-tuning a language model via DPO consists of two steps and is easier than PPO:
- Data collection: Gather a preference dataset with positive and negative selected pairs of generation, given a prompt.
- Optimization: Maximize the log-likelihood of the DPO loss directly.
This process is illustrated in the sketch below (from Figure 1 of the DPO paper):
Read more about DPO algorithm in the original paper.
Quick start
This example demonstrates how to train a model using the DPO method. We use the Qwen 0.5B model as the base model. We use the preference data from the UltraFeedback dataset. You can view the data in the dataset here:
Below is the script to train the model:
# train_dpo.py
from datasets import load_dataset
from trl import DPOConfig, DPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10)
trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
Execute the script using the following command:
accelerate launch train_dpo.py
Distributed across 8 GPUs, the training takes approximately 3 minutes. You can verify the training progress by checking the reward graph. An increasing trend in the reward margin indicates that the model is improving and generating better responses over time.
To see how the trained model performs, you can use the TRL Chat CLI.
$ trl chat --model_name_or_path trl-lib/Qwen2-0.5B-DPO
<quentin_gallouedec>:
What is the best programming language?
<trl-lib/Qwen2-0.5B-DPO>:
The best programming language for specific applications can vary depending on the use case and knowledge level of the programmer. Here are some general factors that can be used as input to choose the best programming language:
1 Ease of use: Some programming languages are more user-friendly than others, such as Python, Java, or Ruby. Python is popular due to its simplicity and great scalability.
2 Versatility: The ability to work with a wide range of data structures and frameworks can define the language as versatile.
3 Ease of learning: Different programming languages have different learning curves, so users must be willing to take some time to master one.
4 Community support: The broader community of developers and enthusiasts in the selected programming language can provide great support and resources.
5 Reusability: Languages that emphasize code reuse and can be easily modifiable can be more suitable for software development.
The best programming language based on these factors is subjective and depends on what the programmer intends to accomplish.
Expected dataset type
DPO requires a preference dataset. The DPOTrainer supports both conversational and standard dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset.
Although the DPOTrainer supports both explicit and implicit prompts, we recommend using explicit prompts. If provided with an implicit prompt dataset, the trainer will automatically extract the prompt from the "chosen"
and "rejected"
columns. For more information, refer to the preference style section.
Special considerations for vision-language models
The DPOTrainer supports fine-tuning vision-language models (VLMs). For these models, a vision dataset is required. To learn more about the specific format for vision datasets, refer to the Vision dataset format section.
Additionally, unlike standard text-based models where a tokenizer
is used, for VLMs, you should replace the tokenizer
with a processor
.
- model = AutoModelForCausalLM.from_pretrained(model_id)
+ model = AutoModelForVision2Seq.from_pretrained(model_id)
- tokenizer = AutoTokenizer.from_pretrained(model_id)
+ processor = AutoProcessor.from_pretrained(model_id)
trainer = DPOTrainer(
model,
args=training_args,
train_dataset=train_dataset,
- processing_class=tokenizer,
+ processing_class=processor,
)
For a complete example of fine-tuning a vision-language model, refer to the script in examples/scripts/dpo_vlm.py
.
Example script
We provide an example script to train a model using the DPO method. The script is available in trl/scripts/dpo.py
To test the DPO script with the Qwen2 0.5B model on the UltraFeedback dataset, run the following command:
accelerate launch trl/scripts/dpo.py \ --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ --dataset_name trl-lib/ultrafeedback_binarized \ --num_train_epochs 1 \ --logging_steps 25 \ --output_dir Qwen2-0.5B-DPO
Logged metrics
While training and evaluating we record the following reward metrics:
rewards/chosen
: the mean difference between the log probabilities of the policy model and the reference model for the chosen responses scaled by betarewards/rejected
: the mean difference between the log probabilities of the policy model and the reference model for the rejected responses scaled by betarewards/accuracies
: mean of how often the chosen rewards are > than the corresponding rejected rewardsrewards/margins
: the mean difference between the chosen and corresponding rejected rewards
Loss functions
The DPO algorithm supports several loss functions. The loss function can be set using the loss_type
parameter in the DPOConfig. The following loss functions are supported:
loss_type= | Description |
---|---|
"sigmoid" (default) | Given the preference data, we can fit a binary classifier according to the Bradley-Terry model and in fact the DPO authors propose the sigmoid loss on the normalized likelihood via the logsigmoid to fit a logistic regression. |
"hinge" | The RSO authors propose to use a hinge loss on the normalized likelihood from the SLiC paper. In this case, the beta is the reciprocal of the margin. |
"ipo" | The IPO authors provide a deeper theoretical understanding of the DPO algorithms and identify an issue with overfitting and propose an alternative loss. In this case, the beta is the reciprocal of the gap between the log-likelihood ratios of the chosen vs the rejected completion pair and thus the smaller the beta the larger this gaps is. As per the paper the loss is averaged over log-likelihoods of the completion (unlike DPO which is summed only). |
"exo_pair" | The EXO authors propose to minimize the reverse KL instead of the negative log-sigmoid loss of DPO which corresponds to forward KL. Setting non-zero label_smoothing (default 1e-3 ) leads to a simplified version of EXO on pair-wise preferences (see Eqn. (16) of the EXO paper). The full version of EXO uses K>2 completions generated by the SFT policy, which becomes an unbiased estimator of the PPO objective (up to a constant) when K is sufficiently large. |
"nca_pair" | The NCA authors shows that NCA optimizes the absolute likelihood for each response rather than the relative likelihood. |
"robust" | The Robust DPO authors propose an unbiased estimate of the DPO loss that is robust to preference noise in the data. Like in cDPO, it assumes that the preference labels are noisy with some probability. In this approach, the label_smoothing parameter in the DPOConfig is used to model the probability of existing label noise. To apply this conservative loss, set label_smoothing to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0) |
"bco_pair" | The BCO authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. For unpaired data, we recommend the dedicated BCOTrainer. |
"sppo_hard" | The SPPO authors claim that SPPO is capable of solving the Nash equilibrium iteratively by pushing the chosen rewards to be as large as 1/2 and the rejected rewards to be as small as -1/2 and can alleviate data sparsity issues. The implementation approximates this algorithm by employing hard label probabilities, assigning 1 to the winner and 0 to the loser. |
"aot" or loss_type="aot_pair" | The AOT authors propose to use Distributional Preference Alignment Via Optimal Transport. Traditionally, the alignment algorithms use paired preferences at a sample level, which does not ensure alignment on the distributional level. AOT, on the other hand, can align LLMs on paired or unpaired preference data by making the reward distribution of the positive samples stochastically dominant in the first order on the distribution of negative samples. Specifically, loss_type="aot" is appropriate for paired datasets, where each prompt has both chosen and rejected responses; loss_type="aot_pair" is for unpaired datasets. In a nutshell, loss_type="aot" ensures that the log-likelihood ratio of chosen to rejected of the aligned model has higher quantiles than that ratio for the reference model. loss_type="aot_pair" ensures that the chosen reward is higher on all quantiles than the rejected reward. Note that in both cases quantiles are obtained via sorting. To fully leverage the advantages of the AOT algorithm, it is important to maximize the per-GPU batch size. |
"apo_zero" or loss_type="apo_down" | The APO method introduces an “anchored” version of the alignment objective. There are two variants: apo_zero and apo_down . The apo_zero loss increases the likelihood of winning outputs while decreasing the likelihood of losing outputs, making it suitable when the model is less performant than the winning outputs. On the other hand, apo_down decreases the likelihood of both winning and losing outputs, but with a stronger emphasis on reducing the likelihood of losing outputs. This variant is more effective when the model is better than the winning outputs. |
"discopop" | The DiscoPOP paper uses LLMs to discover more efficient offline preference optimization losses. In the paper the proposed DiscoPOP loss (which is a log-ratio modulated loss) outperformed other optimization losses on different tasks (IMDb positive text generation, Reddit TLDR summarization, and Alpaca Eval 2.0). |
Label smoothing
The cDPO is a tweak on the DPO loss where we assume that the preference labels are noisy with some probability. In this approach, the label_smoothing
parameter in the DPOConfig is used to model the probability of existing label noise. To apply this conservative loss, set label_smoothing
to a value greater than 0.0 (between 0.0 and 0.5; the default is 0.0).
Syncing the reference model
The TR-DPO paper suggests syncing the reference model weights after every ref_model_sync_steps
steps of SGD with weight ref_model_mixup_alpha
during DPO training. To toggle this callback use the sync_ref_model=True
in the DPOConfig.
RPO loss
The RPO paper implements an iterative preference tuning algorithm using a loss related to the RPO loss in this paper that essentially consists of a weighted SFT loss on the chosen preferences together with the DPO loss. To use this loss, set the rpo_alpha
in the DPOConfig to an appropriate value. The paper suggests setting this weight to 1.0
.
WPO loss
The WPO paper adapts off-policy data to resemble on-policy data more closely by reweighting preference pairs according to their probability under the current policy. To use this method, set the use_weighting
flag to True
in the DPOConfig.
For Mixture of Experts Models: Enabling the auxiliary loss
MOEs are the most efficient if the load is about equally distributed between experts.
To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss.
This option is enabled by setting output_router_logits=True
in the model config (e.g. MixtralConfig).
To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter router_aux_loss_coef=...
(default: 0.001
) in the model config.
Accelerate DPO fine-tuning using unsloth
You can further accelerate QLoRA / LoRA (2x faster, 60% less memory) using the unsloth
library that is fully compatible with SFTTrainer
. Currently unsloth
supports only Llama (Yi, TinyLlama, Qwen, Deepseek etc) and Mistral architectures. Some benchmarks for DPO listed below:
GPU | Model | Dataset | 🤗 | 🤗 + Flash Attention 2 | 🦥 Unsloth | 🦥 VRAM saved |
---|---|---|---|---|---|---|
A100 40G | Zephyr 7b | Ultra Chat | 1x | 1.24x | 1.88x | -11.6% |
Tesla T4 | Zephyr 7b | Ultra Chat | 1x | 1.09x | 1.55x | -18.6% |
First install unsloth
according to the official documentation. Once installed, you can incorporate unsloth into your workflow in a very simple manner; instead of loading AutoModelForCausalLM
, you just need to load a FastLanguageModel
as follows:
from datasets import load_dataset
from trl import DPOConfig, DPOTrainer
- from transformers import AutoModelForCausalLM, AutoTokenizer
+ from unsloth import FastLanguageModel
- model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+ model, tokenizer = FastLanguageModel.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
+ model = FastLanguageModel.get_peft_model(model)
train_dataset = load_dataset("trl-lib/ultrafeedback_binarized", split="train")
- training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10)
+ training_args = DPOConfig(output_dir="Qwen2-0.5B-DPO", logging_steps=10, bf16=True)
trainer = DPOTrainer(model=model, args=training_args, processing_class=tokenizer, train_dataset=train_dataset)
trainer.train()
The saved model is fully compatible with Hugging Face’s transformers library. Learn more about unsloth in their official repository.
Reference model considerations with PEFT
You have three main options (plus several variants) for how the reference model works when using PEFT, assuming the model that you would like to further enhance with DPO was tuned using (Q)LoRA.
- Simply create two instances of the model, each loading your adapter - works fine but is very inefficient.
- Merge the adapter into the base model, create another adapter on top, then leave the
ref_model
param null, in which case DPOTrainer will unload the adapter for reference inference - efficient, but has potential downsides discussed below. - Load the adapter twice with different names, then use
set_adapter
during training to swap between the adapter being DPO’d and the reference adapter - slightly less efficient compared to 2 (~adapter size VRAM overhead), but avoids the pitfalls.
Downsides to merging QLoRA before DPO (approach 2)
As suggested by Benjamin Marie, the best option for merging QLoRA adapters is to first dequantize the base model, then merge the adapter. Something similar to this script.
However, after using this approach, you will have an unquantized base model. Therefore, to use QLoRA for DPO, you will need to re-quantize the merged model or use the unquantized merge (resulting in higher memory demand).
Using option 3 - load the adapter twice
To avoid the downsides with option 2, you can load your fine-tuned adapter into the model twice, with different names, and set the model/ref adapter names in DPOTrainer.
For example:
# Load the base model.
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
model = AutoModelForCausalLM.from_pretrained(
"mistralai/mixtral-8x7b-v0.1",
load_in_4bit=True,
quantization_config=bnb_config,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto",
)
model.config.use_cache = False
# Load the adapter.
model = PeftModel.from_pretrained(
model,
"/path/to/peft",
is_trainable=True,
adapter_name="train",
)
# Load the adapter a second time, with a different name, which will be our reference model.
model.load_adapter("/path/to/peft", adapter_name="reference")
# Initialize the trainer, without a ref_model param.
training_args = DPOConfig(
model_adapter_name="train",
ref_adapter_name="reference",
)
dpo_trainer = DPOTrainer(
model,
args=training_args,
...
)
DPOTrainer
class trl.DPOTrainer
< source >( model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, str, NoneType] = None ref_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module, str, NoneType] = None args: typing.Optional[trl.trainer.dpo_config.DPOConfig] = None data_collator: typing.Optional[transformers.data.data_collator.DataCollator] = None train_dataset: typing.Optional[datasets.arrow_dataset.Dataset] = None eval_dataset: typing.Union[datasets.arrow_dataset.Dataset, dict[str, datasets.arrow_dataset.Dataset], NoneType] = None processing_class: typing.Union[transformers.tokenization_utils_base.PreTrainedTokenizerBase, transformers.image_processing_utils.BaseImageProcessor, transformers.feature_extraction_utils.FeatureExtractionMixin, transformers.processing_utils.ProcessorMixin, NoneType] = None model_init: typing.Optional[typing.Callable[[], transformers.modeling_utils.PreTrainedModel]] = None compute_metrics: typing.Optional[typing.Callable[[transformers.trainer_utils.EvalLoopOutput], dict]] = None callbacks: typing.Optional[list[transformers.trainer_callback.TrainerCallback]] = None optimizers: tuple = (None, None) preprocess_logits_for_metrics: typing.Optional[typing.Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None peft_config: typing.Optional[dict] = None )
Parameters
- model (
transformers.PreTrainedModel
) — The model to train, preferably anAutoModelForSequenceClassification
. - ref_model (
PreTrainedModelWrapper
) — Hugging Face transformer model with a casual language modelling head. Used for implicit reward computation and loss. If no reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized. - args (
DPOConfig
) — The DPO config arguments to use for training. - data_collator (
transformers.DataCollator
) — The data collator to use for training. If None is specified, the default data collator (PreferenceCollator
) will be used which will pad the sequences to the maximum length of the sequences in the batch, given a dataset of paired sequences. - train_dataset (
datasets.Dataset
) — The dataset to use for training. - eval_dataset (
datasets.Dataset
) — The dataset to use for evaluation. - processing_class (
PreTrainedTokenizerBase
orBaseImageProcessor
orFeatureExtractionMixin
orProcessorMixin
, optional) — Processing class used to process the data. If provided, will be used to automatically process the inputs for the model, and it will be saved along the model to make it easier to rerun an interrupted training or reuse the fine-tuned model. This supercedes thetokenizer
argument, which is now deprecated. - model_init (
Callable[[], transformers.PreTrainedModel]
) — The model initializer to use for training. If None is specified, the default model initializer will be used. - compute_metrics (
Callable[[EvalPrediction], dict]
, optional) — The function to use to compute the metrics. Must take aEvalPrediction
and return a dictionary string to metric values. - callbacks (
list[transformers.TrainerCallback]
) — The callbacks to use for training. - optimizers (
tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
) — The optimizer and scheduler to use for training. - preprocess_logits_for_metrics (
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) — The function to use to preprocess the logits before computing the metrics. - peft_config (
dict
, defaults toNone
) — The PEFT configuration to use for training. If you pass a PEFT configuration, the model will be wrapped in a PEFT model.
Initialize DPOTrainer.
Computes log probabilities of the reference model for a single padded batch of a DPO specific dataset.
Run the given model on the given batch of inputs, concatenating the chosen and rejected inputs together.
We do this to avoid doing two forward passes, because it’s faster for FSDP.
concatenated_inputs
< source >( batch: dict padding_value: int ) → dict[str, torch.LongTensor]
Parameters
- batch (
dict[str, Union[list, torch.LongTensor]]
) — A batch of input data. The batch must contain the following keys:"prompt_input_ids"
: Tensor of shape(batch_size, prompt_length)
representing the prompt input IDs."chosen_input_ids"
: Tensor of shape(batch_size, chosen_length)
representing the chosen completion input IDs."rejected_input_ids"
: Tensor of shape(batch_size, rejected_length)
representing the rejected completion input IDs."prompt_pixel_values"
(optional): Tensor for pixel values, if available."prompt_pixel_attention_mask"
(optional): Tensor for pixel attention masks, if available.
- padding_value (
int
) — The padding value to use for the concatenated completion sequences (chosen_input_ids
andrejected_input_ids
).
Returns
dict[str, torch.LongTensor]
A dictionary containing:
"prompt_input_ids"
: Concatenated prompt input IDs of shape(2 * batch_size, prompt_length)
."completion_input_ids"
: Concatenated chosen and rejected completion input IDs of shape(2 * batch_size, max_completion_length)
."prompt_attention_mask"
: Concatenated prompt attention masks of shape(2 * batch_size, prompt_length)
."completion_attention_mask"
: Concatenated chosen and rejected attention masks of shape(2 * batch_size, max_completion_length)
."pixel_values"
(optional): Concatenated pixel values if"prompt_pixel_values"
are present."pixel_attention_mask"
(optional): Concatenated pixel attention masks if"prompt_pixel_attention_mask"
are present.
Concatenate the chosen
and rejected
inputs from the batch into a single tensor for both the prompt
and completion sequences.
Notes: The completion input IDs and attention masks are padded to the maximum completion length of the chosen or rejected sequences.
create_model_card
< source >( model_name: typing.Optional[str] = None dataset_name: typing.Optional[str] = None tags: typing.Union[str, list[str], NoneType] = None )
Creates a draft of a model card using the information available to the Trainer
.
dpo_loss
< source >( chosen_logps: FloatTensor rejected_logps: FloatTensor ref_chosen_logps: FloatTensor ref_rejected_logps: FloatTensor ) → A tuple of three tensors
Parameters
- chosen_logps (
torch.FloatTensor
) — Log probabilities of the model for the chosen responses. Shape:(batch_size,)
. - rejected_logps (
torch.FloatTensor
) — Log probabilities of the model for the rejected responses. Shape:(batch_size,)
. - ref_chosen_logps (
torch.FloatTensor
) — Log probabilities of the reference model for the chosen responses. Shape:(batch_size,)
. - ref_rejected_logps (
torch.FloatTensor
) — Log probabilities of the reference model for the rejected responses. Shape:(batch_size,)
.
Returns
A tuple of three tensors
(losses, chosen_rewards, rejected_rewards)
.
The losses tensor contains the DPO loss for each example in the batch.
The chosen_rewards
and rejected_rewards
tensors contain the rewards for the chosen and rejected
responses, respectively.
Compute the DPO loss for a batch of policy and reference model log probabilities.
evaluation_loop
< source >( dataloader: DataLoader description: str prediction_loss_only: typing.Optional[bool] = None ignore_keys: typing.Optional[list[str]] = None metric_key_prefix: str = 'eval' )
Overriding built-in evaluation loop to store metrics for each batch.
Prediction/evaluation loop, shared by Trainer.evaluate()
and Trainer.predict()
.
Works both with or without labels.
Generate samples from the model and reference model for the given batch of inputs.
get_batch_loss_metrics
< source >( model batch: dict train_eval: typing.Literal['train', 'eval'] = 'train' )
Compute the DPO loss and other metrics for the given batch of inputs for train or test.
get_eval_dataloader
< source >( eval_dataset: typing.Optional[datasets.arrow_dataset.Dataset] = None )
Parameters
- eval_dataset (
torch.utils.data.Dataset
, optional) — If provided, will overrideself.eval_dataset
. If it is a Dataset, columns not accepted by themodel.forward()
method are automatically removed. It must implement__len__
.
Returns the evaluation ~torch.utils.data.DataLoader
.
Subclass of transformers.src.transformers.trainer.get_eval_dataloader to precompute ref_log_probs
.
Returns the training ~torch.utils.data.DataLoader
.
Subclass of transformers.src.transformers.trainer.get_train_dataloader to precompute ref_log_probs
.
log
< source >( logs: dict start_time: typing.Optional[float] = None )
Log logs
on the various objects watching training, including stored metrics.
Context manager for handling null reference model (that is, peft adapter manipulation).
process_row
< source >( features processing_class max_prompt_length max_completion_length add_special_tokens )
Same as tokenize_row
but for vision models. Please refer to tokenize_row
for more information.
tokenize_row
< source >( features processing_class max_prompt_length max_completion_length add_special_tokens ) → dict[str, list[int]]
Parameters
- features (
dict[str, str]
) — Row of the dataset, should contain the keys"prompt"
,"chosen"
, and"rejected"
. - processing_class (
PreTrainedTokenizerBase
) — Processing class used to process the data. - max_prompt_length (
int
orNone
) — Maximum length of the prompt sequence. IfNone
, the prompt sequence is not truncated. - max_completion_length (
int
orNone
) — Maximum length of the completion sequences. IfNone
, the completion sequences are not truncated. - add_special_tokens (
bool
) — Whether to add special tokens to the sequences. Typically used for encoder-decoder models. IfTrue
, the prompt sequence will have a bos token prepended and an eos token appended. In any case, the completion sequences will have an eos token appended.
Returns
dict[str, list[int]]
Tokenized sequences with the keys "prompt_input_ids"
, "chosen_input_ids"
, and
`“rejected_input_ids”.
Tokenize a row of the dataset.
Example:
>>> from transformers import GPT2Tokenizer
>>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
>>> features = {"prompt": "The sky is", "chosen": " blue", "rejected": " green"}
>>> DPOTrainer.tokenize_row(features, tokenizer, max_prompt_length=3, max_completion_length=3, add_special_tokens=False)
{'prompt_input_ids': [464, 6766, 318], 'chosen_input_ids': [4171, 50256], 'rejected_input_ids': [4077, 50256]}
DPOConfig
class trl.DPOConfig
< source >( output_dir: str overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: typing.Optional[int] = None per_gpu_eval_batch_size: typing.Optional[int] = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: typing.Optional[int] = None eval_delay: typing.Optional[float] = 0 torch_empty_cache_steps: typing.Optional[int] = None learning_rate: float = 1e-06 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: typing.Union[transformers.trainer_utils.SchedulerType, str] = 'linear' lr_scheduler_kwargs: typing.Union[dict, str, NoneType] = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: typing.Optional[str] = 'passive' log_level_replica: typing.Optional[str] = 'warning' log_on_each_node: bool = True logging_dir: typing.Optional[str] = None logging_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: typing.Union[transformers.trainer_utils.SaveStrategy, str] = 'steps' save_steps: float = 500 save_total_limit: typing.Optional[int] = None save_safetensors: typing.Optional[bool] = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: typing.Optional[int] = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: typing.Optional[bool] = None local_rank: int = -1 ddp_backend: typing.Optional[str] = None tpu_num_cores: typing.Optional[int] = None tpu_metrics_debug: bool = False debug: typing.Union[str, typing.List[transformers.debug_utils.DebugOption]] = '' dataloader_drop_last: bool = False eval_steps: typing.Optional[float] = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: typing.Optional[int] = None past_index: int = -1 run_name: typing.Optional[str] = None disable_tqdm: typing.Optional[bool] = None remove_unused_columns: typing.Optional[bool] = True label_names: typing.Optional[typing.List[str]] = None load_best_model_at_end: typing.Optional[bool] = False metric_for_best_model: typing.Optional[str] = None greater_is_better: typing.Optional[bool] = None ignore_data_skip: bool = False fsdp: typing.Union[typing.List[transformers.trainer_utils.FSDPOption], str, NoneType] = '' fsdp_min_num_params: int = 0 fsdp_config: typing.Union[dict, str, NoneType] = None fsdp_transformer_layer_cls_to_wrap: typing.Optional[str] = None accelerator_config: typing.Union[dict, str, NoneType] = None deepspeed: typing.Union[dict, str, NoneType] = None label_smoothing_factor: float = 0.0 optim: typing.Union[transformers.training_args.OptimizerNames, str] = 'adamw_torch' optim_args: typing.Optional[str] = None adafactor: bool = False group_by_length: bool = False length_column_name: typing.Optional[str] = 'length' report_to: typing.Union[NoneType, str, typing.List[str]] = None ddp_find_unused_parameters: typing.Optional[bool] = None ddp_bucket_cap_mb: typing.Optional[int] = None ddp_broadcast_buffers: typing.Optional[bool] = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: typing.Optional[str] = None hub_model_id: typing.Optional[str] = None hub_strategy: typing.Union[transformers.trainer_utils.HubStrategy, str] = 'every_save' hub_token: typing.Optional[str] = None hub_private_repo: typing.Optional[bool] = None hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: typing.Union[dict, str, NoneType] = None include_inputs_for_metrics: bool = False include_for_metrics: typing.List[str] = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' evaluation_strategy: typing.Union[transformers.trainer_utils.IntervalStrategy, str] = None push_to_hub_model_id: typing.Optional[str] = None push_to_hub_organization: typing.Optional[str] = None push_to_hub_token: typing.Optional[str] = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: typing.Optional[str] = None ray_scope: typing.Optional[str] = 'last' ddp_timeout: typing.Optional[int] = 1800 torch_compile: bool = False torch_compile_backend: typing.Optional[str] = None torch_compile_mode: typing.Optional[str] = None dispatch_batches: typing.Optional[bool] = None split_batches: typing.Optional[bool] = None include_tokens_per_second: typing.Optional[bool] = False include_num_input_tokens_seen: typing.Optional[bool] = False neftune_noise_alpha: typing.Optional[float] = None optim_target_modules: typing.Union[NoneType, str, typing.List[str]] = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: typing.Optional[bool] = False eval_use_gather_object: typing.Optional[bool] = False average_tokens_across_devices: typing.Optional[bool] = False beta: float = 0.1 label_smoothing: float = 0.0 loss_type: typing.Literal['sigmoid', 'hinge', 'ipo', 'exo_pair', 'nca_pair', 'robust', 'bco_pair', 'sppo_hard', 'aot', 'aot_pair', 'discopop', 'apo_zero', 'apo_down'] = 'sigmoid' use_weighting: bool = False label_pad_token_id: int = -100 padding_value: typing.Optional[int] = None truncation_mode: str = 'keep_end' max_length: typing.Optional[int] = None max_prompt_length: typing.Optional[int] = None max_completion_length: typing.Optional[int] = None is_encoder_decoder: typing.Optional[bool] = None disable_dropout: bool = True generate_during_eval: bool = False precompute_ref_log_probs: bool = False precompute_ref_batch_size: typing.Optional[int] = None dataset_num_proc: typing.Optional[int] = None model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None ref_model_init_kwargs: typing.Optional[dict[str, typing.Any]] = None model_adapter_name: typing.Optional[str] = None ref_adapter_name: typing.Optional[str] = None reference_free: bool = False force_use_ref_model: bool = False f_divergence_type: FDivergenceType = <FDivergenceType.REVERSE_KL: 'reverse_kl'> f_alpha_divergence_coef: float = 1.0 sync_ref_model: bool = False ref_model_mixup_alpha: float = 0.9 ref_model_sync_steps: int = 64 rpo_alpha: typing.Optional[float] = None discopop_tau: float = 0.05 use_num_logits_to_keep: bool = False )
Parameters
- learning_rate (
float
, optional, defaults to1e-6
) — Initial learning rate forAdamW
optimizer. The default value replaces that of TrainingArguments. - beta (
float
, optional, defaults to0.1
) — Parameter controlling the deviation from the reference model. Higher β means less deviation from the reference model. For the IPO loss (loss_type="ipo"
), β is the regularization parameter denoted by τ in the paper. - label_smoothing (
float
, optional, defaults to0.0
) — Robust DPO label smoothing parameter from the cDPO report and Robust DPO paper that should be between0.0
and0.5
. - loss_type (
str
, optional, defaults to"sigmoid"
) — Type of loss to use. Possible values are:"sigmoid"
: sigmoid loss from the original DPO paper."hinge"
: hinge loss on the normalized likelihood from the SLiC paper."ipo"
: IPO loss from the IPO paper."exo_pair"
: pairwise EXO loss from the EXO paper."nca_pair"
: pairwise NCA loss from the NCA paper."robust"
: unbiased estimate of the DPO loss that is robust to preference noise from the Robust DPO paper."bco_pair"
: pairwise BCO loss from the BCO paper."sppo_hard"
: SPPO loss with hard label from the SPPO paper."aot"
: AOT loss for paired datasets from the AOT paper."aot_pair"
: AOT loss for unpaired datasets from the AOT paper."discopop"
: DiscoPOP (a.k.a Log-Ratio Modulated Loss, LRML) loss from the DiscoPOP paper."apo_zero"
: APO-zero loss from the APO paper."apo_down"
: APO-down loss from the APO paper.
- use_weighting (
bool
, optional, defaults toFalse
) — Whether or not to weight the loss as done in the WPO paper. - label_pad_token_id (
int
, optional, defaults to-100
) — Label pad token id. This argument is required if you want to use the default data collator. - padding_value (
Optional[int]
, optional, defaults toNone
) — Padding value to use. IfNone
, the padding value of the tokenizer is used. - truncation_mode (
str
, optional, defaults to"keep_end"
) — Truncation mode to use, eitherkeep_end
orkeep_start
. This argument is required if you want to use the default data collator. - max_length (
Optional[int]
, optional, defaults toNone
) — Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want to use the default data collator. - max_prompt_length (
Optional[int]
, optional, defaults toNone
) — Maximum length of the prompt. This argument is required if you want to use the default data collator. - max_completion_length (
Optional[int]
, optional, defaults toNone
) — Maximum length of the target. This argument is required if you want to use the default data collator and your model is an encoder-decoder. - is_encoder_decoder(
Optional[int]
, optional, defaults toNone
) — When using themodel_init
argument (callable) to instantiate the model instead of themodel
argument, you need to specify if the model returned by the callable is an encoder-decoder model. - disable_dropout (
bool
, optional, defaults toTrue
) — Whether to disable dropout in the model and reference model. - generate_during_eval (
bool
, optional, defaults toFalse
) — IfTrue
, generates and logs completions from both the model and the reference model to W&B during evaluation. - precompute_ref_log_probs (
bool
, optional, defaults toFalse
) — Whether to precompute reference model log probabilities for training and evaluation datasets. This is useful when training without the reference model to reduce the total GPU memory needed. - precompute_ref_batch_size (
Optional[int]
, optional, defaults toNone
) — Batch size to use when precomputing reference model log probabilities. This can be set higher than the training batch size to speed up preprocessing. IfNone
, defaults toper_device_train_batch_size
for training andper_device_eval_batch_size
for evaluation. - dataset_num_proc (
Optional[int]
, optional, defaults toNone
) — Number of processes to use for processing the dataset. - model_init_kwargs (
Optional[dict[str, Any]]
, optional, defaults toNone
) — Keyword arguments to pass toAutoModelForCausalLM.from_pretrained
when instantiating the model from a string. - ref_model_init_kwargs (
Optional[dict[str, Any]]
, optional, defaults toNone
) — Keyword arguments to pass toAutoModelForCausalLM.from_pretrained
when instantiating the reference model from a string. - model_adapter_name (
Optional[str]
, optional, defaults toNone
) — Name of the train target PEFT adapter, when using LoRA with multiple adapters. - ref_adapter_name (
Optional[str]
, optional, defaults toNone
) — Name of the reference PEFT adapter, when using LoRA with multiple adapters. - reference_free (
bool
, optional, defaults toFalse
) — IfTrue
, we ignore the provided reference model and implicitly use a reference model that assigns equal probability to all responses. - force_use_ref_model (
bool
, optional, defaults toFalse
) — In case one passes a PEFT model for the active model and you want to use a different model for the ref_model, set this flag toTrue
. - f_divergence_type (
str
, optional, defaults toFDivergenceType.REVERSE_KL
) — Type of f-divergence regularization function to compute divergence between policy and reference model. - f_alpha_divergence_coef (
float
, optional, defaults to1.0
) — α coefficient in the α-divergence u^-α regularization function for DPO loss. - sync_ref_model (
bool
, optional, defaults toFalse
) — When set toTrue
, the reference model is synchronized with the active model everyref_model_sync_steps
steps, using theref_model_mixup_alpha
parameter. This synchronization originites from the TR-DPO paper. - ref_model_mixup_alpha (
float
, optional, defaults to0.9
) — α parameter from the TR-DPO paper, which controls the mix between the current policy and the previous reference policy during updates. The reference policy is updated according to the equation:π_ref = α * π_θ + (1 - α) * π_ref_prev
To use this parameter, you must setsync_ref_model=True
. - ref_model_sync_steps (
int
, optional, defaults to64
) — τ parameter from the TR-DPO paper, which determines how frequently the current policy is synchronized with the reference policy. To use this parameter, you must setsync_ref_model=True
. - rpo_alpha (
float
, optional, defaults toNone
) — α parameter from the RPO paper (v3), which controls the weighting of the NLL term in the loss. IfNone
, no weighting is applied and the loss is the same as the DPO loss. The paper recommendsrpo_alpha=1.0
. - discopop_tau (
float
, optional, defaults to0.05
) — τ/temperature parameter from the DiscoPOP paper, which controls the shape of log ratio modulated loss. The paper recommends the default valuediscopop_tau=0.05
. - use_num_logits_to_keep (
bool
, optional, defaults toFalse
) — IfTrue
, only a specified number of logits are computed in the forward pass of CausalLM. This can be useful for saving memory and speeding up training by not computing the logits for all tokens, especially in scenarios when working with very long prompts where labels are -ignored (-100). Read more
Configuration class for the DPOTrainer.
Using HfArgumentParser we can turn this class into argparse arguments that can be specified on the command line.
PreferenceCollator
class trl.trainer.dpo_trainer.PreferenceCollator
< source >( pad_token_id: int return_tensors: str = 'pt' )
Data collator used for preference data. Inputs are dynamically padded to the maximum length of a batch if they are not all of the same length.
Examples:
>>> from trl import PreferenceCollator
>>> collator = PreferenceCollator(pad_token_id=0)
>>> examples = [
... {"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6]},
... {"prompt_input_ids": [7, 8], "chosen_input_ids": [9, 10], "rejected_input_ids": [11, 12, 13]}
... ]
>>> collator(examples)
{'prompt_input_ids': tensor([[1, 2, 3],
[0, 7, 8]]),
'prompt_attention_mask': tensor([[1, 1, 1],
[0, 1, 1]]),
'chosen_input_ids': tensor([[ 4, 5],
[ 9, 10]]),
'chosen_attention_mask': tensor([[1, 1],
[1, 1]]),
'rejected_input_ids': tensor([[ 6, 0, 0],
[11, 12, 13]]),
'rejected_attention_mask': tensor([[1, 0, 0],
[1, 1, 1]])
}