TRL documentation

Callbacks

You are viewing v0.10.1 version. A newer version v0.12.0 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Callbacks

SyncRefModelCallback

class trl.SyncRefModelCallback

< >

( ref_model: Union accelerator: Optional )

RichProgressCallback

class trl.RichProgressCallback

< >

( )

A TrainerCallback that displays the progress of training or evaluation using Rich.

WinRateCallback

class trl.WinRateCallback

< >

( prompts: List judge: BaseRankJudge trainer: Trainer generation_config: Optional = None batch_size: int = 4 )

Parameters

  • prompts (List[str]) — The prompts to generate completions for.
  • judge (BaseRankJudge) — The judge to use for comparing completions.
  • trainer (Trainer) — The trainer.
  • generation_config (GenerationConfig, optional) — The generation config to use for generating completions.
  • batch_size (int, optional) — The batch size to use for generating completions. Defaults to 4.

A TrainerCallback that computes the win rate of a model based on a reference.

Usage:

trainer = DPOTrainer(...)
win_rate_callback = WinRateCallback(..., trainer=trainer)
trainer.add_callback(win_rate_callback)
< > Update on GitHub