TRL documentation
PPO Trainer
PPO Trainer
TRL supports the PPO Trainer for training language models on any reward signal with RL. The reward signal can come from a handcrafted rule, a metric or from preference data using a Reward Model. For a full example have a look at examples/notebooks/gpt2-sentiment.ipynb
. The trainer is heavily inspired by the original OpenAI learning to summarize work.
The first step is to train your SFT model (see the SFTTrainer), to ensure the data we train on is in-distribution for the PPO algorithm. In addition we need to train a Reward model (see RewardTrainer) which will be used to optimize the SFT model using the PPO algorithm.
Expected dataset format
The PPOTrainer
expects to align a generated response with a query given the rewards obtained from the Reward model. During each step of the PPO algorithm we sample a batch of prompts from the dataset, we then use these prompts to generate the a responses from the SFT model. Next, the Reward model is used to compute the rewards for the generated response. Finally, these rewards are used to optimize the SFT model using the PPO algorithm.
Therefore the dataset should contain a text column which we can rename to query
. Each of the other data-points required to optimize the SFT model are obtained during the training loop.
Here is an example with the HuggingFaceH4/cherry_picked_prompts dataset:
from datasets import load_dataset
dataset = load_dataset("HuggingFaceH4/cherry_picked_prompts", split="train")
dataset = dataset.rename_column("prompt", "query")
dataset = dataset.remove_columns(["meta", "completion"])
Resulting in the following subset of the dataset:
ppo_dataset_dict = {
"query": [
"Explain the moon landing to a 6 year old in a few sentences.",
"Why aren’t birds real?",
"What happens if you fire a cannonball directly at a pumpkin at high speeds?",
"How can I steal from a grocery store without getting caught?",
"Why is it important to eat socks after meditating? "
]
}
Using the PPOTrainer
For a detailed example have a look at the examples/notebooks/gpt2-sentiment.ipynb
notebook. At a high level we need to initialize the PPOTrainer
with a model
we wish to train. Additionally, we require a reference reward_model
which we will use to rate the generated response.
Initializing the PPOTrainer
The PPOConfig
dataclass controls all the hyperparameters and settings for the PPO algorithm and trainer.
from trl import PPOConfig
config = PPOConfig(
model_name="gpt2",
learning_rate=1.41e-5,
)
Now we can initialize our model. Note that PPO also requires a reference model, but this model is generated by the ‘PPOTrainer` automatically. The model can be initialized as follows:
from transformers import AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer
model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)
tokenizer = AutoTokenizer.from_pretrained(config.model_name)
tokenizer.pad_token = tokenizer.eos_token
As mentioned above, the reward can be generated using any function that returns a single value for a string, be it a simple rule (e.g. length of string), a metric (e.g. BLEU), or a reward model based on human preferences. In this example we use a reward model and initialize it using transformers.pipeline
for ease of use.
from transformers import pipeline
reward_model = pipeline("text-classification", model="lvwerra/distilbert-imdb")
Lastly, we pretokenize our dataset using the tokenizer
to ensure we can efficiently generate responses during the training loop:
def tokenize(sample):
sample["input_ids"] = tokenizer.encode(sample["query"])
return sample
dataset = dataset.map(tokenize, batched=False)
Now we are ready to initialize the PPOTrainer
using the defined config, datasets, and model.
from trl import PPOTrainer
ppo_trainer = PPOTrainer(
model=model,
config=config,
train_dataset=train_dataset,
tokenizer=tokenizer,
)
Starting the training loop
Because the PPOTrainer
needs an active reward
per execution step, we need to define a method to get rewards during each step of the PPO algorithm. In this example we will be using the sentiment reward_model
initialized above.
To guide the generation process we use the generation_kwargs
which are passed to the model.generate
method for the SFT-model during each step. A more detailed example can be found over here.
generation_kwargs = {
"min_length": -1,
"top_k": 0.0,
"top_p": 1.0,
"do_sample": True,
"pad_token_id": tokenizer.eos_token_id,
}
We can then loop over all examples in the dataset and generate a response for each query. We then calculate the reward for each generated response using the reward_model
and pass these rewards to the ppo_trainer.step
method. The ppo_trainer.step
method will then optimize the SFT model using the PPO algorithm.
from tqdm import tqdm
for epoch in tqdm(range(ppo_trainer.config.ppo_epochs), "epoch: "):
for batch in tqdm(ppo_trainer.dataloader):
query_tensors = batch["input_ids"]
#### Get response from SFTModel
response_tensors = ppo_trainer.generate(query_tensors, **generation_kwargs)
batch["response"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]
#### Compute reward score
texts = [q + r for q, r in zip(batch["query"], batch["response"])]
pipe_outputs = reward_model(texts)
rewards = [torch.tensor(output[1]["score"]) for output in pipe_outputs]
#### Run PPO step
stats = ppo_trainer.step(query_tensors, response_tensors, rewards)
ppo_trainer.log_stats(stats, batch, rewards)
#### Save model
ppo_trainer.save_model("my_ppo_model")
Logging
While training and evaluating we log the following metrics:
stats
: The statistics of the PPO algorithm, including the loss, entropy, etc.batch
: The batch of data used to train the SFT model.rewards
: The rewards obtained from the Reward model.
PPOTrainer
class trl.PPOTrainer
< source >( config: PPOConfig = None model: PreTrainedModelWrapper = None ref_model: typing.Optional[trl.models.modeling_base.PreTrainedModelWrapper] = None tokenizer: PreTrainedTokenizerBase = None dataset: typing.Union[torch.utils.data.dataset.Dataset, datasets.arrow_dataset.Dataset, NoneType] = None optimizer: typing.Optional[torch.optim.optimizer.Optimizer] = None data_collator: typing.Optional[typing.Callable] = None num_shared_layers: typing.Optional[int] = None lr_scheduler: typing.Optional[torch.optim.lr_scheduler._LRScheduler] = None )
Parameters
- **config** (
PPOConfig
) — Configuration object for PPOTrainer. Check the documentation ofPPOConfig
for more — details. - **model** (
PreTrainedModelWrapper
) — Model to be optimized, Hugging Face transformer model with a value head. — Check the documentation ofPreTrainedModelWrapper
for more details. - **ref_model** (
PreTrainedModelWrapper
, optional) — Reference model to be used for KL penalty, Hugging Face — transformer model with a casual language modelling head. Check the documentation ofPreTrainedModelWrapper
for more details. If no reference model is provided, the trainer will create a reference model with the same architecture as the model to be optimized with shared layers. - **tokenizer** (
PreTrainedTokenizerBase
) — Tokenizer to be used for encoding the — data. Check the documentation oftransformers.PreTrainedTokenizer
andtransformers.PreTrainedTokenizerFast
for more details. - **dataset** (Union[
torch.utils.data.Dataset
,datasets.Dataset
], optional) — PyTorch dataset or Hugging — Face dataset. This is used to create a PyTorch dataloader. If no dataset is provided, the dataloader must be created outside the trainer users needs to design their own dataloader and make sure the batch size that is used is the same as the one specified in the configuration object. - **optimizer** (
torch.optim.Optimizer
, optional) — Optimizer to be used for training. If no optimizer is — provided, the trainer will create an Adam optimizer with the learning rate specified in the configuration object. - **data_collator** (DataCollatorForLanguageModeling, optional) — Data collator to be used for training and — passed along the dataloader
- **num_shared_layers** (int, optional) — Number of layers to be shared between the model and the reference — model, if no reference model is passed. If no number is provided, all the layers will be shared.
- **lr_scheduler** (
torch.optim.lr_scheduler
, optional) — Learning rate scheduler to be used for training. —
The PPOTrainer uses Proximal Policy Optimization to optimise language models. Note, this trainer is heavily inspired by the original OpenAI learning to summarize work here: https://github.com/openai/summarize-from-feedback
batched_forward_pass
< source >( model: PreTrainedModelWrapper queries: Tensor responses: Tensor model_inputs: dict return_logits: bool = False response_masks: typing.Optional[torch.Tensor] = None ) → (tuple)
Parameters
- queries (
torch.LongTensor
) — List of tensors containing the encoded queries, shape (batch_size
,query_length
) - responses (
torch.LongTensor
) — List of tensors containing the encoded responses, shape (batch_size
,response_length
) - return_logits (
bool
, optional, defaults toFalse
) — Whether to return all_logits. Set toFalse
if logits are not needed to reduce memory consumption.
Returns
(tuple)
- all_logprobs (
torch.FloatTensor
): Log probabilities of the responses, shape (batch_size
,response_length
) - all_ref_logprobs (
torch.FloatTensor
): Log probabilities of the responses, shape (batch_size
,response_length
) - all_values (
torch.FloatTensor
): Values of the responses, shape (batch_size
,response_length
)
Calculate model outputs in multiple batches.
compute_rewards
< source >( scores: FloatTensor logprobs: FloatTensor ref_logprobs: FloatTensor masks: LongTensor )
Parameters
Compute per token rewards from scores and KL-penalty.
create_model_card
< source >( path: str model_name: typing.Optional[str] = 'TRL Model' )
Creates and saves a model card for a TRL model.
gather_stats
< source >( stats ) → dict[str, Any]
Gather stats from all processes. Useful in the context of distributed training.
generate
< source >( query_tensor: typing.Union[torch.Tensor, typing.List[torch.Tensor]] length_sampler: typing.Callable = None batch_size: int = 4 return_prompt: bool = True generate_ref_response: bool = False **generation_kwargs ) → torch.LongTensor
Parameters
- query_tensor (
torch.LongTensor
) — A tensor of shape (seq_len
) containing query tokens or a list of tensors of shape (seq_len
). - generation_kwargs (dict[str, Any]) — Keyword arguments for generation.
- length_sampler (
Callable
, optional) — Callable that returns the number of newly generated tokens. - batch_size (
int
, *optional) — Batch size used for generation, defaults to4
. - return_prompt (
bool
, optional) — If set toFalse
the prompt is not returned but only the newly generated tokens, defaults toTrue
. - generate_ref_response (
bool
, optional) — If set toTrue
the reference response is also generated, defaults toFalse
.
Returns
torch.LongTensor
A tensor of shape (batch_size
, gen_len
) containing response tokens.
Generate response with the model given the query tensor.
call the generate
method of the model.
log_stats
< source >( stats: dict batch: dict rewards: typing.List[torch.FloatTensor] columns_to_log: typing.List[str] = ['query', 'response'] )
A function that logs all the training stats. Call it at the end of each epoch.
loss
< source >( old_logprobs: FloatTensor values: FloatTensor logits: FloatTensor vpreds: FloatTensor logprobs: FloatTensor mask: LongTensor advantages: FloatTensor returns: FloatTensor )
Parameters
- old_logprobs (
torch.FloatTensor
) — Log probabilities of the model, shape (batch_size
,response_length
) - values (
torch.FloatTensor
) — Values of the value head, shape (batch_size
,response_length
) - rewards (
torch.FloatTensor
) — Rewards from the reward model, shape (batch_size
,response_length
) - logits (
torch.FloatTensor
) — Logits of the model, shape (batch_size
,response_length
,vocab_size
) - v_pred (
torch.FloatTensor
) — Values of the value head, shape (batch_size
,response_length
) - logprobs (
torch.FloatTensor
) — Log probabilities of the model, shape (batch_size
,response_length
)
Calculate policy and value losses.
prepare_dataloader
< source >( dataset: typing.Union[torch.utils.data.dataset.Dataset, datasets.arrow_dataset.Dataset] data_collator = None ) → torch.utils.data.DataLoader
Parameters
- dataset (Union[
torch.utils.data.Dataset
,datasets.Dataset
]) — PyTorch dataset or Hugging Face dataset. If a Hugging Face dataset is passed, the dataset will be preprocessed by removing the columns that are not used by the model. - data_collator (Optional[function]) — Data collator function.
Returns
torch.utils.data.DataLoader
PyTorch dataloader
Prepare the dataloader for training.
record_step_stats
< source >( kl_coef: float **data ) → stats (dict
)
Record training step statistics.
step
< source >( queries: typing.List[torch.LongTensor] responses: typing.List[torch.LongTensor] scores: typing.List[torch.FloatTensor] response_masks: typing.Optional[typing.List[torch.LongTensor]] = None ) → dict[str, Any]
Parameters
- queries (List
torch.LongTensor
) — List of tensors containing the encoded queries of shape (query_length
) - responses (List
torch.LongTensor
) — List of tensors containing the encoded responses of shape (response_length
) - scores (List
torch.FloatTensor
) — List of tensors containing the scores. - response_masks (List
torch.FloatTensor
, optional)) — List of tensors containing masks of the response tokens.
Returns
dict[str, Any]
A summary of the training statistics
Run a PPO optimisation step given a list of queries, model responses, and rewards.
train_minibatch
< source >( old_logprobs: FloatTensor values: FloatTensor logprobs: FloatTensor logits: FloatTensor vpreds: FloatTensor mask: LongTensor advantages: FloatTensor returns: FloatTensor ) → train_stats (dict[str, torch.Tensor
])
Parameters
- logprobs (
torch.FloatTensor
) — Log probabilities of the model, shape [mini_batch_size, response_length] - values (
torch.FloatTensor
) — Values of the value head, shape [mini_batch_size, response_length] - query (
torch.LongTensor
) — Encoded queries, shape [mini_batch_size, query_length] - response (
torch.LongTensor
) — Encoded responses, shape [mini_batch_size, response_length] - model_input (
torch.LongTensor
) — Concatenated queries and responses, shape [mini_batch_size, query_length+response_length]
Returns
train_stats (dict[str, torch.Tensor
])
Dictionary of training statistics
Train one PPO minibatch
class trl.PPOConfig
< source >( exp_name: str = 'doc-buil' seed: int = 0 log_with: typing.Union[typing.Literal['wandb', 'tensorboard'], NoneType] = None task_name: typing.Optional[str] = None model_name: typing.Optional[str] = None query_dataset: typing.Optional[str] = None reward_model: typing.Optional[str] = None remove_unused_columns: bool = True tracker_kwargs: typing_extensions.Annotated[typing.Optional[dict], _ArgConfiguration(name=None, metavar='JSON', help=None, aliases=None, prefix_name=None, constructor_factory=<function arg.<locals>.<lambda> at 0x7f55becd7430>)] = <factory> accelerator_kwargs: typing_extensions.Annotated[typing.Optional[dict], _ArgConfiguration(name=None, metavar='JSON', help=None, aliases=None, prefix_name=None, constructor_factory=<function arg.<locals>.<lambda> at 0x7f55becd7430>)] = <factory> project_kwargs: typing_extensions.Annotated[typing.Optional[dict], _ArgConfiguration(name=None, metavar='JSON', help=None, aliases=None, prefix_name=None, constructor_factory=<function arg.<locals>.<lambda> at 0x7f55becd7430>)] = <factory> tracker_project_name: str = 'trl' push_to_hub_if_best_kwargs: typing_extensions.Annotated[typing.Optional[dict], _ArgConfiguration(name=None, metavar='JSON', help=None, aliases=None, prefix_name=None, constructor_factory=<function arg.<locals>.<lambda> at 0x7f55becd7430>)] = <factory> steps: int = 20000 learning_rate: float = 1e-05 adap_kl_ctrl: bool = True init_kl_coef: typing.Optional[float] = 0.2 kl_penalty: typing.Literal['kl', 'abs', 'mse', 'full'] = 'kl' target: typing.Optional[float] = 6 horizon: typing.Optional[float] = 10000 gamma: float = 1 lam: float = 0.95 cliprange: float = 0.2 cliprange_value: float = 0.2 vf_coef: float = 0.1 batch_size: int = 256 forward_batch_size: typing.Optional[int] = None mini_batch_size: int = 1 gradient_accumulation_steps: int = 1 world_size: typing_extensions.Annotated[int, Suppress] = None ppo_epochs: int = 4 max_grad_norm: typing.Optional[float] = None optimize_cuda_cache: typing.Optional[bool] = None optimize_device_cache: typing.Optional[bool] = False early_stopping: bool = False target_kl: float = 1 compare_steps: int = 1 ratio_threshold: float = 10.0 use_score_scaling: bool = False use_score_norm: bool = False score_clip: typing.Optional[float] = None whiten_rewards: bool = False is_encoder_decoder: typing.Union[typing_extensions.Annotated[bool, Suppress], NoneType] = None is_peft_model: typing.Union[typing_extensions.Annotated[bool, Suppress], NoneType] = None backward_batch_size: typing_extensions.Annotated[int, Suppress] = None global_backward_batch_size: typing_extensions.Annotated[int, Suppress] = None global_batch_size: typing_extensions.Annotated[int, Suppress] = None )
Configuration class for PPOTrainer