TRL documentation

Callbacks

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Callbacks

SyncRefModelCallback

class trl.SyncRefModelCallback

< >

( ref_model: typing.Union[transformers.modeling_utils.PreTrainedModel, torch.nn.modules.module.Module] accelerator: typing.Optional[accelerate.accelerator.Accelerator] )

RichProgressCallback

class trl.RichProgressCallback

< >

( )

A TrainerCallback that displays the progress of training or evaluation using Rich.

WinRateCallback

class trl.WinRateCallback

< >

( judge: BasePairwiseJudge trainer: Trainer generation_config: typing.Optional[transformers.generation.configuration_utils.GenerationConfig] = None num_prompts: typing.Optional[int] = None shuffle_order: bool = True use_soft_judge: bool = False )

Parameters

  • judge (BasePairwiseJudge) — The judge to use for comparing completions.
  • trainer (Trainer) — Trainer to which the callback will be attached. The trainer’s evaluation dataset must include a "prompt" column containing the prompts for generating completions. If the Trainer has a reference model (via the ref_model attribute), it will use this reference model for generating the reference completions; otherwise, it defaults to using the initial model.
  • generation_config (GenerationConfig, optional) — The generation config to use for generating completions.
  • num_prompts (int or None, optional, defaults to None) — The number of prompts to generate completions for. If not provided, defaults to the number of examples in the evaluation dataset.
  • shuffle_order (bool, optional, defaults to True) — Whether to shuffle the order of the completions before judging.
  • use_soft_judge (bool, optional, defaults to False) — Whether to use a soft judge that returns a win probability between 0 and 1 for the first completion vs the second.

A TrainerCallback that computes the win rate of a model based on a reference.

It generates completions using prompts from the evaluation dataset and compares the trained model’s outputs against a reference. The reference is either the initial version of the model (before training) or the reference model, if available in the trainer. During each evaluation step, a judge determines how often the trained model’s completions win against the reference using a judge. The win rate is then logged in the trainer’s logs under the key "eval_win_rate".

Usage:

trainer = DPOTrainer(...)
judge = PairRMJudge()
win_rate_callback = WinRateCallback(judge=judge, trainer=trainer)
trainer.add_callback(win_rate_callback)

LogCompletionsCallback

class trl.LogCompletionsCallback

< >

( trainer: Trainer generation_config: typing.Optional[transformers.generation.configuration_utils.GenerationConfig] = None num_prompts: typing.Optional[int] = None freq: typing.Optional[int] = None )

Parameters

  • trainer (Trainer) — Trainer to which the callback will be attached. The trainer’s evaluation dataset must include a "prompt" column containing the prompts for generating completions.
  • generation_config (GenerationConfig, optional) — The generation config to use for generating completions.
  • num_prompts (int or None, optional) — The number of prompts to generate completions for. If not provided, defaults to the number of examples in the evaluation dataset.
  • freq (int or None, optional) — The frequency at which to log completions. If not provided, defaults to the trainer’s eval_steps.

A TrainerCallback that logs completions to Weights & Biases.

Usage:

trainer = DPOTrainer(...)
completions_callback = LogCompletionsCallback(trainer=trainer)
trainer.add_callback(completions_callback)

MergeModelCallback

class trl.MergeModelCallback

< >

( merge_config: typing.Optional[ForwardRef('MergeConfig')] = None merge_at_every_checkpoint: bool = False push_to_hub: bool = False )

Parameters

  • merge_config (MergeConfig, optional, defaults to None) — Configuration used for the merging process. If not provided, the default MergeConfig is used.
  • merge_at_every_checkpoint (bool, optional, defaults to False) — Whether to merge the model at every checkpoint.
  • push_to_hub (bool, optional, defaults to False) — Whether to push the merged model to the Hub after merging.

A TrainerCallback that merges the policy model (the model being trained) with another model based on a merge configuration.

Example:

!pip install trl[mergekit]

from trl.mergekit_utils import MergeConfig
from trl import MergeModelCallback

config = MergeConfig()
merge_callback = MergeModelCallback(config)
trainer = DPOTrainer(..., callbacks=[merge_callback])
< > Update on GitHub