TRL documentation

RLOO Trainer

You are viewing v0.11.2 version. A newer version v0.13.0 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

RLOO Trainer

TRL supports training LLMs with REINFORCE Leave-One-Out (RLOO). The idea is that instead of using a value function, RLOO generates K completions for each prompt. For each completion, RLOO uses the mean scores from the other K-1 completions as a baseline to calculate the advantage. RLOO also models the entire completion as a single action, where as PPO models each token as an action. Note that REINFORCE / A2C is a special case of PPO, when the number of PPO epochs is 1 and the number of mini-batches is 1, which is how we implement RLOO in TRL.

References:

Get started

To just run a RLOO script to make sure the trainer can run, you can run the following command to train a RLOO model with a dummy reward model.

python examples/scripts/rloo/rloo.py \
    --learning_rate 3e-6 \
    --output_dir models/minimal/rloo \
    --per_device_train_batch_size 64 \
    --gradient_accumulation_steps 1 \
    --total_episodes 10000 \
    --model_name_or_path EleutherAI/pythia-14m \
    --reward_model_path EleutherAI/pythia-14m \
    --missing_eos_penalty 1.0

Explanation of the logged metrics

The logged metrics are as follows. Here is an example tracked run at Weights and Biases

  • eps: Tracks the number of episodes per second.
  • objective/kl: The mean Kullback-Leibler (KL) divergence between the current policy and reference policy.
  • objective/entropy: The mean entropy of the policy, indicating the randomness of the actions chosen by the policy.
  • objective/non_score_reward: The mean reward from non-score-related sources, basically beta * kl.sum(1), where beta is the KL penalty coefficient and kl is the per-token KL divergence.
  • objective/rlhf_reward: The mean RLHF reward, which is score - non_score_reward.
  • objective/scores: The mean scores returned by the reward model / environment.
  • policy/approxkl_avg: The average approximate KL divergence between consecutive PPO policies. Note that this is not the same as objective/kl.
  • policy/clipfrac_avg: The average fraction of policy updates that are clipped, indicating how often the policy updates are constrained to prevent large changes.
  • loss/policy_avg: The average policy loss, indicating how well the policy is performing.
  • val/clipfrac_avg: The average fraction of value function updates that are clipped, similar to policy/clipfrac_avg but for the value function.
  • policy/entropy_avg: The average entropy of the policy during training, indicating how diverse the policy’s actions are.
  • val/ratio: The mean ratio of the current policy probability to the old policy probability, providing a measure of how much the policy has changed.
  • val/ratio_var: The variance of the val/ratio, indicating the variability in policy changes.
  • val/num_eos_tokens: The number of end-of-sequence (EOS) tokens generated, which can indicate the number of complete responses.
  • lr: lr: The current learning rate used by the optimizer.
  • episode: episode: The current global step or episode count in the training process.

Cookbook

  • Debugging TIP: objective/rlhf_reward: this is the ultimate objective of the RLHF training. If training works as intended, this metric should keep going up.
  • Debugging TIP: val/ratio: this number should float around 1.0, and it gets clipped by --cliprange 0.2 with PPO’s surrogate loss. So if this ratio is too high like 2.0 or 1000.0 or too small like 0.1, it means the updates between consecutive policies are too drastic. You should try undertand why this is happening and try to fix it.
  • Memory TIP: If you are running out of memory, you can try to reduce the --per_device_train_batch_size or increase the --gradient_accumulation_steps to reduce the memory footprint.
  • Memory TIP: If you have multiple GPUs, you can also run training with DeepSpeed stage 3 to reduce the memory footprint accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml.
  • Usage TIP: We recommend to use the “EOS trick” via --missing_eos_penalty, which subtracts a static scalar penalty from the score of completions that do not end with an EOS token. This can help the model learn to generate more coherent completions.

What is my model doing exactly?

To help you understand what your model is doing, we periodically log some sample completions from the model. Here is an example of a completion. In an example tracked run at Weights and Biases, it looks like the following, allowing you to see the model’s response at different stages of training. By default we generate --num_sample_generations 10 during training, but you can customize the number of generations.

In the logs the sampled generations look like

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┓
┃ query                           ┃ model response                  ┃ score    ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━┩
│  SUBREDDIT: r/AskReddit         │  I'm in love with a friend, and3.921875 │
│                                 │ I don't know how to get rid of  │          │
│ TITLE: How do you get someone   │ those feelings. I'm             │          │
│ out of your head?               │ desperate.<|endoftext|>[PAD][P… │          │
│                                 │                                 │          │
│ POST: Hi,                       │                                 │          │
│ I'm 22, and I have been with my │                                 │          │
│ girlfriend for 5 years now. We  │                                 │          │
│ recently moved together. We've  │                                 │          │
│ always loved each other         │                                 │          │
│ intensely.                      │                                 │          │
│                                 │                                 │          │
│ Problem, I recently started to  │                                 │          │
│ have feelings for an other      │                                 │          │
│ person (a friend). This person  │                                 │          │
│ has had a boyfriend for now 3   │                                 │          │
│ years, and has absolutely no    │                                 │          │
│ ideas. Those feelings were so   │                                 │          │
│ strong, it was hard to hide     │                                 │          │
│ them. After 2 months of me      │                                 │          │
│ being distant and really sad,   │                                 │          │
│ my girlfriend forced me to say  │                                 │          │
│ what was bothering me. I'm not  │                                 │          │
│ a good liar, and now she knows. │                                 │          │
│                                 │                                 │          │
│ We decided to give us a week    │                                 │          │
│ alone, I went to my parents.    │                                 │          │
│                                 │                                 │          │
│ Now, I'm completely lost. I     │                                 │          │
│ keep on thinking about this     │                                 │          │
│ person, and I hate that. I      │                                 │          │
│ would like for those feelings   │                                 │          │
│ to go away, to leave me alone.  │                                 │          │
│ But I can't.                    │                                 │          │
│                                 │                                 │          │
│ What do I do? It's been 3       │                                 │          │
│ months now, and I'm just        │                                 │          │
│ desperate.                      │                                 │          │
│                                 │                                 │          │
│ TL;DR:                          │                                 │          │
├─────────────────────────────────┼─────────────────────────────────┼──────────┤
│  SUBREDDIT: r/pettyrevenge      │  My mom woke me up with a loud  │ 6.84375  │
│                                 │ TV. I blasted Gangnam Style on  │          │
│ TITLE: So, my mom woke me up    │ repeat, with the bass cranked   │          │
│ with a loud TV.                 │ up as high as it could          │          │
│                                 │ go.<|endoftext|>[PAD][PAD][PAD… │          │
│ POST: She was in her living     │                                 │          │
│ room, watching TV. This was at  │                                 │          │
│ about 8:30 in the morning, and  │                                 │          │
│ she was exercising. She turned  │                                 │          │
│ the TV up extra loud to hear it │                                 │          │
│ over her excercycle, and woke   │                                 │          │
│ me up. I went in there asking   │                                 │          │
│ for her to turn it down. She    │                                 │          │
│ said she didn't have to; I      │                                 │          │
│ explained that I always used    │                                 │          │
│ headphones so she didn't have   │                                 │          │
│ to deal with my noise and that  │                                 │          │
│ she should give me a little     │                                 │          │
│ more respect, given that I paid │                                 │          │
│ rent at the time.               │                                 │          │
│                                 │                                 │          │
│ She disagreed. I went back to   │                                 │          │
│ my room, rather pissed off at   │                                 │          │
│ the lack of equality. I had no  │                                 │          │
│ lock on my door; but I had a    │                                 │          │
│ dresser right next to it, so I  │                                 │          │
│ pulled one of the drawers out   │                                 │          │
│ enough so that it caused the    │                                 │          │
│ door to not be openable. Then,  │                                 │          │
│ I turned my speakers up really  │                                 │          │
│ loud and blasted Gangnam Style  │                                 │          │
│ on repeat, with the bass        │                                 │          │
│ cranked up as high as it could  │                                 │          │
│ go.                             │                                 │          │
│                                 │                                 │          │
│ If you hate Gangnam Style for   │                                 │          │
│ being overplayed, you will see  │                                 │          │
│ why I chose that particular     │                                 │          │
│ song. I personally don't mind   │                                 │          │
│ it. But here's the thing about  │                                 │          │
│ my bass; it vibrates the walls, │                                 │          │
│ making one hell of a lot of     │                                 │          │
│ noise. Needless to say, my mom  │                                 │          │
│ was not pleased and shut off    │                                 │          │
│ the internet. But it was oh so  │                                 │          │
│ worth it.                       │                                 │          │
│                                 │                                 │          │
│ TL;DR:                          │                                 │          │
└─────────────────────────────────┴─────────────────────────────────┴──────────┘

Implementation details

The bulk of RLOOTrainer is based on the PPO implementation, which is based on the The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization.

Below is a vectorized advantage calculation for RLOO:

def test_rloo_reward():
    local_batch_size = 3
    rloo_k = 4
    rlhf_reward = torch.tensor([
        1, 2, 3, # first rlhf reward for three prompts
        2, 3, 4, # second rlhf reward for three prompts
        5, 6, 7, # third rlhf reward for three prompts
        8, 9, 10, # fourth rlhf reward for three prompts
    ]).float() # here we have 3 prompts which have 4 completions each

    baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
    advantages = torch.zeros_like(rlhf_reward)
    for i in range(0, len(advantages), local_batch_size):
        other_response_rlhf_rewards = []
        for j in range(0, len(advantages), local_batch_size):
            if i != j:
                other_response_rlhf_rewards.append(rlhf_reward[j : j + local_batch_size])
        advantages[i : i + local_batch_size] = rlhf_reward[i : i + local_batch_size] - torch.stack(other_response_rlhf_rewards).mean(0)
    
    assert (1 - (2 + 5 + 8) / 3 - advantages[0].item()) < 1e-6  # First rlhf reward for the first prompt
    assert (6 - (3 + 2 + 9) / 3 - advantages[7].item()) < 1e-6  # Third rlhf reward for the second prompt

    # Vectorized implementation
    rlhf_reward = rlhf_reward.reshape(rloo_k, local_batch_size)
    baseline = (rlhf_reward.sum(0) - rlhf_reward) / (rloo_k - 1)
    vec_advantages = rlhf_reward - baseline
    torch.testing.assert_close(vec_advantages.flatten(), advantages)

Benchmark experiments

To validate the RLOO implementation works, we ran experiment on the 1B model. Here are the command we used to run the experiment. We take the SFT / RM models directly from The N+ Implementation Details of RLHF with PPO: A Case Study on TL;DR Summarization.

accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml \
    examples/scripts/rloo/rloo_tldr.py \
    --output_dir models/minimal/rloo_tldr \
    --num_ppo_epochs 2 \
    --num_mini_batches 2 \
    --learning_rate 3e-6 \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 8 \
    --total_episodes 1000000 \
    --model_name_or_path EleutherAI/pythia-1b-deduped \
    --sft_model_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr \
    --reward_model_path cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr \
    --local_rollout_forward_batch_size 16 \
    --missing_eos_penalty 1.0 \
    --stop_token eos \
    --kl_coef 0.03

Checkpoints and experiment tracking are available at:

To evaluate, we use vLLM to load the checkpoints and GPT-4o mini as a judge model to evaluate the generated TL;DR against the reference TL;DR. For more information on how to use judges, see Judges.

$ python examples/scripts/evals/judge_tldr.py --model_name_or_path cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 33.00%
$ python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-4o-mini --num_examples 1000
Model win rate: 51.20%

The RLOO checkpoint gets a 51.2% preferred rate vs the 33.0% preference rate of the SFT checkpoint. This is a good sign that the RLOO training is working as intended.

Metrics:

# pip install openrlbenchmark==0.2.1a5
# see https://github.com/openrlbenchmark/openrlbenchmark#get-started for documentation
# to use it, change `?we=huggingface&wpn=trl` to your own project and `?tag=pr-1540` to your own tag
python -m openrlbenchmark.rlops_multi_metrics \
    --filters '?we=huggingface&wpn=trl&xaxis=train/episode&ceik=output_dir&cen=sft_model_path&metrics=train/objective/rlhf_reward&metrics=train/objective/scores&metrics=train/objective/kl&metrics=train/objective/non_score_reward&metrics=train/objective/entropy&metrics=train/policy/approxkl_avg&metrics=train/policy/clipfrac_avg&metrics=train/loss/policy_avg&metrics=train/policy/entropy_avg&metrics=train/val/ratio&metrics=train/val/ratio_var&metrics=train/val/num_eos_tokens&metrics=train/lr&metrics=train/eps' \
        "cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr?tag=pr-1540" \
    --env-ids models/minimal/rloo_tldr \
    --pc.ncols 4 \
    --pc.ncols-legend 1 \
    --pc.xlabel "Episode" \
    --output-filename benchmark/trl/pr-1540/rloo \
    --scan-history

RLOOTrainer

class trl.RLOOTrainer

< >

( config: RLOOConfig tokenizer: PreTrainedTokenizer policy: Module ref_policy: Module reward_model: Module train_dataset: Dataset data_collator: Optional = None eval_dataset: Union = None optimizers: Tuple = (None, None) callbacks: Optional = None )

RLOOConfig

class trl.RLOOConfig

< >

( output_dir: str overwrite_output_dir: bool = False do_train: bool = False do_eval: bool = False do_predict: bool = False eval_strategy: Union = 'no' prediction_loss_only: bool = False per_device_train_batch_size: int = 8 per_device_eval_batch_size: int = 8 per_gpu_train_batch_size: Optional = None per_gpu_eval_batch_size: Optional = None gradient_accumulation_steps: int = 1 eval_accumulation_steps: Optional = None eval_delay: Optional = 0 torch_empty_cache_steps: Optional = None learning_rate: float = 5e-05 weight_decay: float = 0.0 adam_beta1: float = 0.9 adam_beta2: float = 0.999 adam_epsilon: float = 1e-08 max_grad_norm: float = 1.0 num_train_epochs: float = 3.0 max_steps: int = -1 lr_scheduler_type: Union = 'linear' lr_scheduler_kwargs: Union = <factory> warmup_ratio: float = 0.0 warmup_steps: int = 0 log_level: Optional = 'passive' log_level_replica: Optional = 'warning' log_on_each_node: bool = True logging_dir: Optional = None logging_strategy: Union = 'steps' logging_first_step: bool = False logging_steps: float = 500 logging_nan_inf_filter: bool = True save_strategy: Union = 'steps' save_steps: float = 500 save_total_limit: Optional = None save_safetensors: Optional = True save_on_each_node: bool = False save_only_model: bool = False restore_callback_states_from_checkpoint: bool = False no_cuda: bool = False use_cpu: bool = False use_mps_device: bool = False seed: int = 42 data_seed: Optional = None jit_mode_eval: bool = False use_ipex: bool = False bf16: bool = False fp16: bool = False fp16_opt_level: str = 'O1' half_precision_backend: str = 'auto' bf16_full_eval: bool = False fp16_full_eval: bool = False tf32: Optional = None local_rank: int = -1 ddp_backend: Optional = None tpu_num_cores: Optional = None tpu_metrics_debug: bool = False debug: Union = '' dataloader_drop_last: bool = False eval_steps: Optional = None dataloader_num_workers: int = 0 dataloader_prefetch_factor: Optional = None past_index: int = -1 run_name: Optional = None disable_tqdm: Optional = None remove_unused_columns: Optional = True label_names: Optional = None load_best_model_at_end: Optional = False metric_for_best_model: Optional = None greater_is_better: Optional = None ignore_data_skip: bool = False fsdp: Union = '' fsdp_min_num_params: int = 0 fsdp_config: Union = None fsdp_transformer_layer_cls_to_wrap: Optional = None accelerator_config: Union = None deepspeed: Union = None label_smoothing_factor: float = 0.0 optim: Union = 'adamw_torch' optim_args: Optional = None adafactor: bool = False group_by_length: bool = False length_column_name: Optional = 'length' report_to: Union = None ddp_find_unused_parameters: Optional = None ddp_bucket_cap_mb: Optional = None ddp_broadcast_buffers: Optional = None dataloader_pin_memory: bool = True dataloader_persistent_workers: bool = False skip_memory_metrics: bool = True use_legacy_prediction_loop: bool = False push_to_hub: bool = False resume_from_checkpoint: Optional = None hub_model_id: Optional = None hub_strategy: Union = 'every_save' hub_token: Optional = None hub_private_repo: bool = False hub_always_push: bool = False gradient_checkpointing: bool = False gradient_checkpointing_kwargs: Union = None include_inputs_for_metrics: bool = False include_for_metrics: List = <factory> eval_do_concat_batches: bool = True fp16_backend: str = 'auto' evaluation_strategy: Union = None push_to_hub_model_id: Optional = None push_to_hub_organization: Optional = None push_to_hub_token: Optional = None mp_parameters: str = '' auto_find_batch_size: bool = False full_determinism: bool = False torchdynamo: Optional = None ray_scope: Optional = 'last' ddp_timeout: Optional = 1800 torch_compile: bool = False torch_compile_backend: Optional = None torch_compile_mode: Optional = None dispatch_batches: Optional = None split_batches: Optional = None include_tokens_per_second: Optional = False include_num_input_tokens_seen: Optional = False neftune_noise_alpha: Optional = None optim_target_modules: Union = None batch_eval_metrics: bool = False eval_on_start: bool = False use_liger_kernel: Optional = False eval_use_gather_object: Optional = False dataset_num_proc: Optional = None num_mini_batches: int = 1 total_episodes: Optional = None local_rollout_forward_batch_size: int = 64 num_sample_generations: int = 10 response_length: int = 53 stop_token: Optional = None stop_token_id: Optional = None temperature: float = 0.7 missing_eos_penalty: Optional = None sft_model_path: str = 'EleutherAI/pythia-160m' world_size: Optional = None num_total_batches: Optional = None micro_batch_size: Optional = None local_batch_size: Optional = None batch_size: Optional = None local_mini_batch_size: Optional = None mini_batch_size: Optional = None exp_name: str = 'rloo_config' reward_model_path: str = 'EleutherAI/pythia-160m' num_ppo_epochs: int = 4 whiten_rewards: bool = False kl_coef: float = 0.05 cliprange: float = 0.2 rloo_k: int = 2 )

Parameters

  • exp_name (str, optional, defaults to os.path.basename(__file__)[ -- -len(".py")]): Name of this experiment.
  • reward_model_path (str, optional, defaults to "EleutherAI/pythia-160m") — Path to the reward model.
  • num_ppo_epochs (int, optional, defaults to 4) — Number of epochs to train.
  • whiten_rewards (bool, optional, defaults to False) — Whether to whiten the rewards.
  • kl_coef (float, optional, defaults to 0.05) — KL coefficient.
  • cliprange (float, optional, defaults to 0.2) — Clip range.
  • rloo_k (int, optional, defaults to 2) — REINFORCE Leave-One-Out (RLOO) number of online samples per prompt.

Configuration class for the RLOOTrainer.

Using HfArgumentParser we can turn this class into argparse arguments that can be specified on the command line.

< > Update on GitHub