Callbacks¶

Callbacks are objects that can customize the behavior of the training loop in the PyTorch Trainer (this feature is not yet implemented in TensorFlow) that can inspect the training loop state (for progress reporting, logging on TensorBoard or other ML platforms…) and take decisions (like early stopping).

Callbacks are “read only” pieces of code, apart from the TrainerControl object they return, they cannot change anything in the training loop. For customizations that require changes in the training loop, you should subclass Trainer and override the methods you need (see Trainer for examples).

By default a Trainer will use the following callbacks:

  • DefaultFlowCallback which handles the default behavior for logging, saving and evaluation.

  • PrinterCallback or ProgressCallback to display progress and print the logs (the first one is used if you deactivate tqdm through the TrainingArguments, otherwise it’s the second one).

  • TensorBoardCallback if tensorboard is accessible (either through PyTorch >= 1.4 or tensorboardX).

  • WandbCallback if wandb is installed.

  • CometCallback if comet_ml is installed.

  • MLflowCallback if mlflow is installed.

  • AzureMLCallback if azureml-sdk is installed.

The main class that implements callbacks is TrainerCallback. It gets the TrainingArguments used to instantiate the Trainer, can access that Trainer’s internal state via TrainerState, and can take some actions on the training loop via TrainerControl.

Available Callbacks¶

Here is the list of the available TrainerCallback in the library:

TrainerCallback¶

TrainerState¶

TrainerControl¶