TRL documentation

RLOO Trainer

You are viewing v0.9.6 version. A newer version v0.13.0 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

RLOO Trainer

TRL supports training LLMs with REINFORCE Leave-One-Out (RLOO). The idea is that instead of using a value function, RLOO generates K completions for each prompt. For each completion, RLOO uses the mean scores from the other K-1 completions as a baseline to calculate the advantage. RLOO also models the entire completion as a single action, where as PPO models each token as an action. Note that REINFORCE / A2C is a special case of PPO, when the number of PPO epochs is 1 and the number of mini-batches is 1, which is how we implement RLOO in TRL.

References:

Get started

To just run a RLOO script to make sure the trainer can run, you can run the following command to train a RLOO model with a dummy reward model.

python examples/scripts/rloo/rloo.py \
    --learning_rate 3e-6 \
    --output_dir models/minimal/rloo \
    --per_device_train_batch_size 64 \
    --gradient_accumulation_steps 1 \
    --total_episodes 10000 \
    --model_name_or_path EleutherAI/pythia-14m \
    --reward_model_path EleutherAI/pythia-14m \
    --non_eos_penalty \

Explanation of the logged metrics

The logged metrics are as follows. Here is an example tracked run at Weights and Biases

  • eps: Tracks the number of episodes per second.
  • objective/kl: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
  • objective/entropy: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
  • objective/non_score_reward: The mean reward from non-score-related sources, basically beta * kl.sum(1), where beta is the KL penalty coefficient and kl is the per-token KL divergence.
  • objective/rlhf_reward: The mean RLHF reward, which is score - non_score_reward.
  • objective/scores: The mean scores returned by the reward model / environment.
  • policy/approxkl_avg: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as objective/kl.
  • policy/clipfrac_avg: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
  • loss/policy_avg: The average policy loss, indicating how well the policy is performing.
  • val/clipfrac_avg: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
  • policy/entropy_avg: The average entropy of the policy during training, indicating how diverse the policy’s actions are.
  • val/ratio: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
  • val/ratio_var: The variance of the val/ratio, indicating the variability in policy changes.
  • val/num_eos_tokens: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
  • lr: lr: The current learning rate used by the optimizer.
  • episode: episode: The current global step or episode count in the training process.

Cookbook

  • Debugging TIP: objective/rlhf_reward: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
  • Debugging TIP: val/ratio: this number should float around 1.0, and it gets clipped by --cliprange 0.2 with PPO’s surrogate loss. So if this ratio is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.
  • Memory TIP: If you are running out of memory, you can try to reduce the --per_device_train_batch_size or increase the --gradient_accumulation_steps to reduce the memory footprint.
  • Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml.
  • Usage TIP: We recommend to use the β€œEOS trick” via --non_eos_penalty --stop_token eos, which replaces the score of completions that do not end with an EOS token with a static scalar penalty --penalty_reward_value. This can help the model learn to generate more coherent completions.

What is my model doing exactly?

To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example tracked run at Weights and Biases, it looks like the following, allowing you to see the model’s response at different stages of training. By default we generate --num_sample_generations 10 during training, but you can customize the number of generations.

In the logs the sampled generations look like

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ query                           ┃ model response                  ┃ score    ┃
┑━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
β”‚  SUBREDDIT: r/AskReddit         β”‚  I'm in love with a friend, and β”‚ 3.921875 β”‚
β”‚                                 β”‚ I don't know how to get rid of  β”‚          β”‚
β”‚ TITLE: How do you get someone   β”‚ those feelings. I'm             β”‚          β”‚
β”‚ out of your head?               β”‚ desperate.<|endoftext|>[PAD][P… β”‚          β”‚
β”‚                                 β”‚                                 β”‚          β”‚
β”‚ POST: Hi,                       β”‚                                 β”‚          β”‚
β”‚ I'm 22, and I have been with my β”‚                                 β”‚          β”‚
β”‚ girlfriend for 5 years now. We  β”‚                                 β”‚          β”‚
β”‚ recently moved together. We've  β”‚                                 β”‚          β”‚
β”‚ always loved each other         β”‚                                 β”‚          β”‚
β”‚ intensely.                      β”‚                                 β”‚          β”‚
β”‚                                 β”‚                                 β”‚          β”‚
β”‚ Problem, I recently started to  β”‚                                 β”‚          β”‚
β”‚ have feelings for an other      β”‚                                 β”‚          β”‚
β”‚ person (a friend). This person  β”‚                                 β”‚          β”‚
β”‚ has had a boyfriend for now 3   β”‚                                 β”‚          β”‚
β”‚ years, and has absolutely no    β”‚                                 β”‚          β”‚
β”‚ ideas. Those feelings were so   β”‚                                 β”‚          β”‚
β”‚ strong, it was hard to hide     β”‚                                 β”‚          β”‚
β”‚ them. After 2 months of me      β”‚                                 β”‚          β”‚
β”‚ being distant and really sad,   β”‚                                 β”‚          β”‚
β”‚ my girlfriend forced me to say  β”‚                                 β”‚          β”‚
β”‚ what was bothering me. I'm not  β”‚                                 β”‚          β”‚
β”‚ a good liar, and now she knows. β”‚                                 β”‚          β”‚
β”‚                                 β”‚                                 β”‚          β”‚
β”‚ We decided to give us a week    β”‚                                 β”‚          β”‚
β”‚ alone, I went to my parents.    β”‚                                 β”‚          β”‚
β”‚                                 β”‚                                 β”‚          β”‚
β”‚ Now, I'm completely lost. I     β”‚                                 β”‚          β”‚
β”‚ keep on thinking about this     β”‚                                 β”‚          β”‚
β”‚ person, and I hate that. I      β”‚                                 β”‚          β”‚
β”‚ would like for those feelings   β”‚                                 β”‚          β”‚
β”‚ to go away, to leave me alone.  β”‚                                 β”‚          β”‚
β”‚ But I can't.                    β”‚                                 β”‚          β”‚
β”‚                                 β”‚                                 β”‚          β”‚
β”‚ What do I do? It's been 3       β”‚                                 β”‚          β”‚
β”‚ months now, and I'm just        β”‚                                 β”‚          β”‚
β”‚ desperate.                      β”‚                                 β”‚          β”‚
β”‚                                 β”‚                                 β”‚          β”‚
β”‚ TL;DR:                          β”‚                                 β”‚          β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚  SUBREDDIT: r/pettyrevenge      β”‚  My mom woke me up with a loud  β”‚ 6.84375  β”‚
β”‚                                 β”‚ TV. I blasted Gangnam Style on  β”‚          β”‚
β”‚ TITLE: So, my mom woke me up    β”‚ repeat, with the bass cranked   β”‚          β”‚
β”‚ with a loud TV.                 β”‚ up as high as it could          β”‚          β”‚
β”‚                                 β”‚ go.<|endoftext|>[PAD][PAD][PAD… β”‚          β”‚
β”‚ POST: She was in her living     β”‚                                 β”‚          β”‚
β”‚ room, watching TV. This was at  β”‚                                 β”‚          β”‚
β”‚ about 8:30 in the morning, and  β”‚                                 β”‚          β”‚
β”‚ she was exercising. She turned  β”‚                                 β”‚          β”‚
β”‚ the TV up extra loud to hear it β”‚                                 β”‚          β”‚
β”‚ over her excercycle, and woke   β”‚                                 β”‚          β”‚
β”‚ me up. I went in there asking   β”‚                                 β”‚          β”‚
β”‚ for her to turn it down. She    β”‚                                 β”‚          β”‚
β”‚ said she didn't have to; I      β”‚                                 β”‚          β”‚
β”‚ explained that I always used    β”‚                                 β”‚          β”‚
β”‚ headphones so she didn't have   β”‚                                 β”‚          β”‚
β”‚ to deal with my noise and that  β”‚                                 β”‚          β”‚
β”‚ she should give me a little     β”‚                                 β”‚          β”‚
β”‚ more respect, given that I paid β”‚                                 β”‚          β”‚
β”‚ rent at the time.               β”‚                                 β”‚          β”‚
β”‚                                 β”‚                                 β”‚          β”‚
β”‚ She disagreed. I went back to   β”‚                                 β”‚          β”‚
β”‚ my room, rather pissed off at   β”‚                                 β”‚          β”‚
β”‚ the lack of equality. I had no  β”‚                                 β”‚          β”‚
β”‚ lock on my door; but I had a    β”‚                                 β”‚          β”‚
β”‚ dresser right next to it, so I  β”‚                                 β”‚          β”‚
β”‚ pulled one of the drawers out   β”‚                                 β”‚          β”‚
β”‚ enough so that it caused the    β”‚                                 β”‚          β”‚
β”‚ door to not be openable. Then,  β”‚                                 β”‚          β”‚
β”‚ I turned my speakers up really  β”‚                                 β”‚          β”‚
β”‚ loud and blasted Gangnam Style  β”‚                                 β”‚          β”‚
β”‚ on repeat, with the bass        β”‚                                 β”‚          β”‚
β”‚ cranked up as high as it could  β”‚                                 β”‚          β”‚
β”‚ go.                             β”‚                                 β”‚          β”‚
β”‚                                 β”‚                                 β”‚          β”‚
β”‚ If you hate Gangnam Style for   β”‚                                 β”‚          β”‚
β”‚ being overplayed, you will see  β”‚                                 β”‚          β”‚
β”‚ why I chose that particular     β”‚                                 β”‚          β”‚
β”‚ song. I personally don't mind   β”‚                                 β”‚          β”‚
β”‚ it. But here's the thing about  β”‚                                 β”‚          β”‚
β”‚ my bass; it vibrates the walls, β”‚                                 β”‚          β”‚
β”‚ making one hell of a lot of     β”‚                                 β”‚          β”‚
β”‚ noise. Needless to say, my mom  β”‚                                 β”‚          β”‚
β”‚ was not pleased and shut off    β”‚                                 β”‚          β”‚
β”‚ the internet. But it was oh so  β”‚                                 β”‚          β”‚
β”‚ worth it.                       β”‚                                 β”‚          β”‚
β”‚                                 β”‚                                 β”‚          β”‚
β”‚ TL;DR:                          β”‚                                 β”‚          β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€

Implementation details

The bulk of RLOOTrainer is based on the PPO implementation, which is based on the The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization.

Below is a vectorized advantage calculation for RLOO:

def test_rloo_reward():
    local_batch_size = 3
    rloo_k = 4
    # fmt: off
    rlhf_reward = torch.tensor([
        1, 2, 3, # first rlhf reward for three prompts
        2, 3, 4, # second rlhf reward for three prompts
        5, 6, 7, # third rlhf reward for three prompts
        8, 9, 10, # fourth rlhf reward for three prompts
    ]).float() # here we have 3 prompts which have 4 completions each
    # fmt: on

    baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
    advantages = torch.zeros_like(rlhf_reward)
    for i in range(0, len(advantages), local_batch_size):
        other_response_rlhf_rewards = []
        for j in range(0, len(advantages), local_batch_size):
            if i != j:
                other_response_rlhf_rewards.append(rlhf_reward[j : j + local_batch_size])
        advantages[i : i + local_batch_size] = rlhf_reward[i : i + local_batch_size] - torch.stack(
            other_response_rlhf_rewards
        ).mean(0)
    assert (1 - (2 + 5 + 8) / 3 - advantages[0].item()) < 1e-6
    assert (6 - (3 + 2 + 9) / 3 - advantages[7].item()) < 1e-6

    # vectorized impl
    rlhf_reward = rlhf_reward.reshape(rloo_k, local_batch_size)
    baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
    vec_advantages = rlhf_reward - baseline
    torch.testing.assert_close(vec_advantages.flatten(), advantages)

Benchmark experiments

To validate the RLOO implementation works, we ran experiments on the 1B and 6.9B models. Here are the commands we used to run the experiments. We take the SFT / RM models directly from The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization.

# 1B RLOO experiment
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
    examples/scripts/rloo/rloo_tldr.py \
    --output_dir models/minimal/rloo_tldr \
    --num_ppo_epochs 2 \
    --num_mini_batches 2 \
    --learning_rate 3e-6 \
    --per_device_train_batch_size 16 \
    --gradient_accumulation_steps 16 \
    --total_episodes 1000000 \
    --model_name_or_path EleutherAI/pythia-1b-deduped \
    --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
    --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
    --local_rollout_forward_batch_size 16 \
    --non_eos_penalty \
    --stop_token eos \
    --kl_coef 0.03

# 6.9B RLOO experiment
accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
    examples/scripts/rloo/rloo_tldr.py \
    --output_dir models/minimal/rloo_tldr_6.9b \
    --num_ppo_epochs 2 \
    --num_mini_batches 2 \
    --learning_rate 3e-6 \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 256 \
    --total_episodes 1000000 \
    --model_name_or_path EleutherAI/pythia-6.9b-deduped \
    --sft_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__sft__tldr \
    --reward_model_path cleanrl/EleutherAI_pythia-6.9b-deduped__reward__tldr \
    --local_rollout_forward_batch_size 2 \
    --non_eos_penalty \
    --stop_token eos \
    --kl_coef 0.03

1B experiment can be found here:

To evaluate, we use vLLM to load the checkpoints and GPT3.5 as a judge model to evaluate the generated TL;DR against the reference TL;DR.

python -i examples/scripts/evals/generate_tldr.py \
    --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
    --output_path examples/scripts/minimal/evals/sft_tldr.csv \
    --n 1000
# preferred
# response1    656
# response0    344
# Name: count, dtype: int64
python -i examples/scripts/evals/generate_tldr.py \
    --model_name_or_path vwxyzjn/rloo_tldr \
    --output_path examples/scripts/minimal/evals/rloo_tldr.csv \
    --n 1000
# preferred
# response0    532
# response1    468
# Name: count, dtype: int64

The RLOO checkpoint gets a 53.2% preferred rate vs the 34.4% preference rate of the SFT checkpoint. This is a good sign that the RLOO training is working as intended.

Metrics:

# pip install openrlbenchmark==0.2.1a5
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
# to use it, change `?we=huggingface&wpn=trl` to your own project and `?tag=pr-1540` to your own tag
python -m openrlbenchmark.rlops_multi_metrics \
    --filters '?we=huggingface&wpn=trl&xaxis=train/episode&ceik=output_dir&cen=sft_model_path&metrics=train/objective/rlhf_reward&metrics=train/objective/scores&metrics=train/objective/kl&metrics=train/objective/non_score_reward&metrics=train/objective/entropy&metrics=train/policy/approxkl_avg&metrics=train/policy/clipfrac_avg&metrics=train/loss/policy_avg&metrics=train/policy/entropy_avg&metrics=train/val/ratio&metrics=train/val/ratio_var&metrics=train/val/num_eos_tokens&metrics=train/lr&metrics=train/eps' \
        "cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr?tag=pr-1540" \
    --env-ids models/minimal/rloo_tldr \
    --pc.ncols 4 \
    --pc.ncols-legend 1 \
    --pc.xlabel "Episode" \
    --output-filename benchmark/trl/pr-1540/rloo \
    --scan-history

6.9B experiment is still TBD (experiments got preempted due to resource constraints).

< > Update on GitHub