TRL documentation

PPO Trainer

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

PPO Trainer

TRL supports the PPO Trainer for training language models on any reward signal with RL. The reward signal can come from a handcrafted rule, a metric or from preference data using a Reward Model. For a full example have a look at examples/notebooks/gpt2-sentiment.ipynb. The trainer is heavily inspired by the original OpenAI learning to summarize work.

The first step is to train your SFT model (see the SFTTrainer), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see RewardTrainer) which will be used to optimize the SFT model using the PPO algorithm.

How PPO works

Fine-tuning a language model via PPO consists of roughly three steps:

  1. Rollout: The language model generates a response or continuation based on query which could be the start of a sentence.
  2. Evaluation: The query and response are evaluated with a function, model, human feedback or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair.
  3. Optimization: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don’t deviate too far from the reference language model. The active language model is then trained with PPO.

This process is illustrated in the sketch below:

Figure: Sketch of the workflow.

Expected dataset format

The PPOTrainer expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm.

Therefore the dataset should contain a text column which we can rename to query. Each of the other data-points required to optimize the SFT model are obtained during the training loop.

Here is an example with the HuggingFaceH4/cherry_picked_prompts dataset:

from datasets import load_dataset

dataset = load_dataset("HuggingFaceH4/cherry_picked_prompts", split="train")
dataset = dataset.rename_column("prompt", "query")
dataset = dataset.remove_columns(["meta", "completion"])

Resulting in the following subset of the dataset:

ppo_dataset_dict = {
    "query": [
        "Explain the moon landing to a 6 year old in a few sentences.",
        "Why aren’t birds real?",
        "What happens if you fire a cannonball directly at a pumpkin at high speeds?",
        "How can I steal from a grocery store without getting caught?",
        "Why is it important to eat socks after meditating? "

Using the PPOTrainer

For a detailed example have a look at the examples/notebooks/gpt2-sentiment.ipynb notebook. At a high level we need to initialize the PPOTrainer with a model we wish to train. Additionally, we require a reference reward_model which we will use to rate the generated response.

Initializing the PPOTrainer

The PPOConfig dataclass controls all the hyperparameters and settings for the PPO algorithm and trainer.

from trl import PPOConfig

config = PPOConfig(

Now we can initialize our model. Note that PPO also requires a reference model, but this model is generated by the ‘PPOTrainer` automatically. The model can be initialized as follows:

from transformers import AutoTokenizer

from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer

model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)

tokenizer.pad_token = tokenizer.eos_token

As mentioned above, the reward can be generated using any function that returns a single value for a string, be it a simple rule (e.g. length of string), a metric (e.g. BLEU), or a reward model based on human preferences. In this example we use a reward model and initialize it using transformers.pipeline for ease of use.

from transformers import pipeline

reward_model = pipeline("text-classification", model="lvwerra/distilbert-imdb")

Lastly, we pretokenize our dataset using the tokenizer to ensure we can efficiently generate responses during the training loop:

def tokenize(sample):
    sample["input_ids"] = tokenizer.encode(sample["query"])
    return sample

dataset =, batched=False)

Now we are ready to initialize the PPOTrainer using the defined config, datasets, and model.

from trl import PPOTrainer

ppo_trainer = PPOTrainer(

### Starting the training loop

Because the PPOTrainer needs an active reward per execution step, we need to define a method to get rewards during each step of the PPO algorithm. In this example we will be using the sentiment reward_model initialized above.

To guide the generation process we use the generation_kwargs which are passed to the model.generate method for the SFT-model during each step. A more detailed example can be found over here.

generation_kwargs = {
    "min_length": -1,
    "top_k": 0.0,
    "top_p": 1.0,
    "do_sample": True,
    "pad_token_id": tokenizer.eos_token_id,

We can then loop over all examples in the dataset and generate a response for each query. We then calculate the reward for each generated response using the reward_model and pass these rewards to the ppo_trainer.step method. The ppo_trainer.step method will then optimize the SFT model using the PPO algorithm.

from tqdm import tqdm

epochs = 10
for epoch in tqdm(range(epochs), "epoch: "):
    for batch in tqdm(ppo_trainer.dataloader): 
        query_tensors = batch["input_ids"]
        #### Get response from SFTModel
        response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
        batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
        #### Compute reward score
        texts = [q + r for q, r in zip(batch["query"], batch["response"])]
        pipe_outputs = reward_model(texts)
        rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
        #### Run PPO step
        stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
        ppo_trainer.log_stats(stats, batch, rewards)

#### Save model

## Logging

While training and evaluating we log the following metrics:

  • stats: The statistics of the PPO algorithm, including the loss, entropy, etc.
  • batch: The batch of data used to train the SFT model.
  • rewards: The rewards obtained from the Reward model.

## PPOTrainer[[trl.PPOTrainer]]

class trl.PPOTrainer

< >

( config: Optional = None model: Optional = None ref_model: Optional = None tokenizer: Optional = None dataset: Union = None optimizer: Optional = None data_collator: Optional = None num_shared_layers: Optional = None lr_scheduler: Optional = None training_data_collator: Optional = None )


  • **config** (PPOConfig) — Configuration object for PPOTrainer. Check the documentation of PPOConfig for more — details.
  • **model** (PreTrainedModelWrapper) — Model to be optimized, Hugging Face transformer model with a value head. — Check the documentation of PreTrainedModelWrapper for more details.
  • **ref_model** (PreTrainedModelWrapper, optional) — Reference model to be used for KL penalty, Hugging Face — transformer model with a casual language modelling head. Check the documentation of PreTrainedModelWrapper for more details. If no reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized with shared layers.
  • **tokenizer** (PreTrainedTokenizerBase) — Tokenizer to be used for encoding the — data. Check the documentation of transformers.PreTrainedTokenizer and transformers.PreTrainedTokenizerFast for more details.
  • **dataset** (Union[, datasets.Dataset], optional) — PyTorch dataset or Hugging — Face dataset. This is used to create a PyTorch dataloader. If no dataset is provided, the dataloader must be created outside the trainer users needs to design their own dataloader and make sure the batch size that is used is the same as the one specified in the configuration object.
  • **optimizer** (torch.optim.Optimizer, optional) — Optimizer to be used for training. If no optimizer is — provided, the trainer will create an Adam optimizer with the learning rate specified in the configuration object.
  • **data_collator** (DataCollatorForLanguageModeling, optional) — Data collator to be used for training and — passed along the dataloader
  • **num_shared_layers** (int, optional) — Number of layers to be shared between the model and the reference — model, if no reference model is passed. If no number is provided, all the layers will be shared.
  • **lr_scheduler** (torch.optim.lr_scheduler, optional) — Learning rate scheduler to be used for training. —

The PPOTrainer uses Proximal Policy Optimization to optimise language models. Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here:


< >

( model: PreTrainedModelWrapper queries: Tensor responses: Tensor model_inputs: dict return_logits: bool = False response_masks: Optional = None ) (tuple)


  • queries (torch.LongTensor) — List of tensors containing the encoded queries, shape (batch_size, query_length)
  • responses (torch.LongTensor) — List of tensors containing the encoded responses, shape (batch_size, response_length)
  • return_logits (bool, optional, defaults to False) — Whether to return all_logits. Set to False if logits are not needed to reduce memory consumption.



  • all_logprobs (torch.FloatTensor): Log probabilities of the responses, shape (batch_size, response_length)
  • all_ref_logprobs (torch.FloatTensor): Log probabilities of the responses, shape (batch_size, response_length)
  • all_values (torch.FloatTensor): Values of the responses, shape (batch_size, response_length)

Calculate model outputs in multiple batches.


< >

( scores: FloatTensor logprobs: FloatTensor ref_logprobs: FloatTensor masks: LongTensor ) torch.FloatTensor


  • scores (torch.FloatTensor) — Scores from the reward model, shape (batch_size)
  • logprobs (torch.FloatTensor) — Log probabilities of the model, shape (batch_size, response_length)
  • ref_logprobs (torch.FloatTensor) — Log probabilities of the reference model, shape (batch_size, response_length)



Per token rewards, shape (batch_size, response_length) torch.FloatTensor: Non score rewards, shape (batch_size, response_length) torch.FloatTensor: KL penalty, shape (batch_size, response_length)

Compute per token rewards from scores and KL-penalty.


< >

( path: str model_name: Optional = 'TRL Model' )


  • path (str) — The path to save the model card to.
  • model_name (str, optional) — The name of the model, defaults to TRL Model.

Creates and saves a model card for a TRL model.


< >

( stats ) dict[str, Any]


  • stats (dict[str, Any]) —
  • a dictionary of stats to be gathered. The stats should contain torch tensors. —


dict[str, Any]

A dictionary of stats with the tensors gathered.

Gather stats from all processes. Useful in the context of distributed training.


< >

( query_tensor: Union length_sampler: Optional = None batch_size: int = 4 return_prompt: bool = True generate_ref_response: bool = False **generation_kwargs ) torch.LongTensor


  • query_tensor (torch.LongTensor) — A tensor of shape (seq_len) containing query tokens or a list of tensors of shape (seq_len).
  • length_sampler (Callable, optional) — Callable that returns the number of newly generated tokens.
  • batch_size (int, *optional) — Batch size used for generation, defaults to 4.
  • return_prompt (bool, optional) — If set to False the prompt is not returned but only the newly generated tokens, defaults to True.
  • generate_ref_response (bool, optional) — If set to True the reference response is also generated, defaults to False.
  • generation_kwargs (dict[str, Any]) — Keyword arguments for generation.



A tensor of shape (batch_size, gen_len) containing response tokens.

Generate response with the model given the query tensor. call the generate method of the model.


< >

( stats: dict batch: dict rewards: List columns_to_log: Iterable = ('query', 'response') )


  • stats (dict[str, Any]) — A dictionary of training stats.
  • batch (dict[str, Any]) — A dictionary of batch data, this contains the queries and responses.
  • rewards (List[torch.FloatTensor]) — A tensor of rewards.

A function that logs all the training stats. Call it at the end of each epoch.


< >

( old_logprobs: FloatTensor values: FloatTensor logits: FloatTensor vpreds: FloatTensor logprobs: FloatTensor mask: LongTensor advantages: FloatTensor returns: FloatTensor )


  • old_logprobs (torch.FloatTensor) — Log probabilities of the model, shape (batch_size, response_length)
  • values (torch.FloatTensor) — Values of the value head, shape (batch_size, response_length)
  • rewards (torch.FloatTensor) — Rewards from the reward model, shape (batch_size, response_length)
  • logits (torch.FloatTensor) — Logits of the model, shape (batch_size, response_length, vocab_size)
  • v_pred (torch.FloatTensor) — Values of the value head, shape (batch_size, response_length)
  • logprobs (torch.FloatTensor) — Log probabilities of the model, shape (batch_size, response_length)

Calculate policy and value losses.


< >

( dataset: Union data_collator = None )


  • dataset (Union[, datasets.Dataset]) — PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset will be preprocessed by removing the columns that are not used by the model.
  • data_collator (Optional[function]) — Data collator function.


PyTorch dataloader

Prepare the dataloader for training.


< >

( kl_coef: float **data ) stats (dict)


  • kl_coef (float) — KL coefficient
  • data (dict) — Dictionary of training step data


stats (dict)

Dictionary of training step statistics

Record training step statistics.


< >

( queries: List responses: List scores: List response_masks: Optional = None ) dict[str, Any]


  • queries (Listtorch.LongTensor) — List of tensors containing the encoded queries of shape (query_length)
  • responses (Listtorch.LongTensor) — List of tensors containing the encoded responses of shape (response_length)
  • scores (Listtorch.FloatTensor) — List of tensors containing the scores.
  • response_masks (Listtorch.FloatTensor, optional)) — List of tensors containing masks of the response tokens.


dict[str, Any]

A summary of the training statistics

Run a PPO optimisation step given a list of queries, model responses, and rewards.


< >

( old_logprobs: FloatTensor values: FloatTensor logprobs: FloatTensor logits: FloatTensor vpreds: FloatTensor mask: LongTensor advantages: FloatTensor returns: FloatTensor ) train_stats (dict[str, torch.Tensor])


  • logprobs (torch.FloatTensor) — Log probabilities of the model, shape [mini_batch_size, response_length]
  • values (torch.FloatTensor) — Values of the value head, shape [mini_batch_size, response_length]
  • query (torch.LongTensor) — Encoded queries, shape [mini_batch_size, query_length]
  • response (torch.LongTensor) — Encoded responses, shape [mini_batch_size, response_length]
  • model_input (torch.LongTensor) — Concatenated queries and responses, shape [mini_batch_size, query_length+response_length]


train_stats (dict[str, torch.Tensor])

Dictionary of training statistics

Train one PPO minibatch

class trl.PPOConfig

< >

( exp_name: str = 'doc-buil' seed: int = 0 log_with: Optional = None task_name: Optional = None model_name: Optional = 'gpt2' query_dataset: Optional = 'imdb' reward_model: Optional = 'sentiment-analysis:lvwerra/distilbert-imdb' remove_unused_columns: bool = True tracker_kwargs: Annotated = <factory> accelerator_kwargs: Annotated = <factory> project_kwargs: Annotated = <factory> tracker_project_name: str = 'trl' push_to_hub_if_best_kwargs: Annotated = <factory> steps: int = 20000 learning_rate: float = 1.41e-05 adap_kl_ctrl: bool = True init_kl_coef: Optional = 0.2 kl_penalty: Literal = 'kl' target: Optional = 6 horizon: Optional = 10000 gamma: float = 1 lam: float = 0.95 cliprange: float = 0.2 cliprange_value: float = 0.2 vf_coef: float = 0.1 batch_size: int = 128 forward_batch_size: Optional = None mini_batch_size: int = 128 gradient_accumulation_steps: int = 1 world_size: Annotated = None ppo_epochs: int = 4 max_grad_norm: Optional = None optimize_cuda_cache: Optional = None optimize_device_cache: Optional = False early_stopping: bool = False target_kl: float = 1 compare_steps: int = 1 ratio_threshold: float = 10.0 use_score_scaling: bool = False use_score_norm: bool = False score_clip: Optional = None whiten_rewards: bool = False gradient_checkpointing: bool = False is_encoder_decoder: Optional = None is_peft_model: Optional = None backward_batch_size: Annotated = None global_backward_batch_size: Annotated = None global_batch_size: Annotated = None )

Configuration class for PPOTrainer

< > Update on GitHub