Callbacks¶

Callbacks are objects that can customize the behavior of the training loop in the PyTorch Trainer (this feature is not yet implemented in TensorFlow) that can inspect the training loop state (for progress reporting, logging on TensorBoard or other ML platforms…) and take decisions (like early stopping).

Callbacks are “read only” pieces of code, apart from the TrainerControl object they return, they cannot change anything in the training loop. For customizations that require changes in the training loop, you should subclass Trainer and override the methods you need (see Trainer for examples).

By default a Trainer will use the following callbacks:

The main class that implements callbacks is TrainerCallback. It gets the TrainingArguments used to instantiate the Trainer, can access that Trainer’s internal state via TrainerState, and can take some actions on the training loop via TrainerControl.

Available Callbacks¶

Here is the list of the available TrainerCallback in the library:

class transformers.integrations.CometCallback[source]¶

A TrainerCallback that sends the logs to Comet ML.

setup(args, state, model)[source]¶

Setup the optional Comet.ml integration.

Environment:
COMET_MODE (str, optional):

“OFFLINE”, “ONLINE”, or “DISABLED”

COMET_PROJECT_NAME (str, optional):

Comet.ml project name for experiments

COMET_OFFLINE_DIRECTORY (str, optional):

Folder to use for saving offline experiments when COMET_MODE is “OFFLINE”

For a number of configurable items in the environment, see here.

class transformers.DefaultFlowCallback[source]¶

A TrainerCallback that handles the default flow of the training loop for logs, evaluation and checkpoints.

class transformers.PrinterCallback[source]¶

A bare TrainerCallback that just prints the logs.

class transformers.ProgressCallback[source]¶

A TrainerCallback that displays the progress of training or evaluation.

class transformers.EarlyStoppingCallback(early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0)[source]¶

A TrainerCallback that handles early stopping.

Parameters
  • early_stopping_patience (int) – Use with metric_for_best_model to stop training when the specified metric worsens for early_stopping_patience evaluation calls.

  • early_stopping_threshold (float, optional) – Use with TrainingArguments metric_for_best_model and early_stopping_patience to denote how much the specified metric must improve to satisfy early stopping conditions. `

This callback depends on TrainingArguments argument load_best_model_at_end functionality to set best_metric in TrainerState.

class transformers.integrations.TensorBoardCallback(tb_writer=None)[source]¶

A TrainerCallback that sends the logs to TensorBoard.

Parameters

tb_writer (SummaryWriter, optional) – The writer to use. Will instantiate one if not set.

class transformers.integrations.WandbCallback[source]¶

A TrainerCallback that sends the logs to Weight and Biases.

setup(args, state, model, **kwargs)[source]¶

Setup the optional Weights & Biases (wandb) integration.

One can subclass and override this method to customize the setup if needed. Find more information here. You can also override the following environment variables:

Environment:
WANDB_LOG_MODEL (bool, optional, defaults to False):

Whether or not to log model as artifact at the end of training. Use along with TrainingArguments.load_best_model_at_end to upload best model.

WANDB_WATCH (str, optional defaults to "gradients"):

Can be "gradients", "all" or "false". Set to "false" to disable gradient logging or "all" to log gradients and parameters.

WANDB_PROJECT (str, optional, defaults to "huggingface"):

Set this to a custom string to store results in a different project.

WANDB_DISABLED (bool, optional, defaults to False):

Whether or not to disable wandb entirely. Set WANDB_DISABLED=true to disable.

class transformers.integrations.MLflowCallback[source]¶

A TrainerCallback that sends the logs to MLflow.

setup(args, state, model)[source]¶

Setup the optional MLflow integration.

Environment:
HF_MLFLOW_LOG_ARTIFACTS (str, optional):

Whether to use MLflow .log_artifact() facility to log artifacts.

This only makes sense if logging to a remote server, e.g. s3 or GCS. If set to True or 1, will copy whatever is in TrainingArguments’s output_dir to the local or remote artifact storage. Using it without a remote storage will just copy the files to your artifact location.

class transformers.integrations.AzureMLCallback(azureml_run=None)[source]¶

A TrainerCallback that sends the logs to AzureML.

TrainerCallback¶

class transformers.TrainerCallback[source]¶

A class for objects that will inspect the state of the training loop at some events and take some decisions. At each of those events the following arguments are available:

Parameters
  • args (TrainingArguments) – The training arguments used to instantiate the Trainer.

  • state (TrainerState) – The current state of the Trainer.

  • control (TrainerControl) – The object that is returned to the Trainer and can be used to make some decisions.

  • model (PreTrainedModel or torch.nn.Module) – The model being trained.

  • tokenizer (PreTrainedTokenizer) – The tokenizer used for encoding the data.

  • optimizer (torch.optim.Optimizer) – The optimizer used for the training steps.

  • lr_scheduler (torch.optim.lr_scheduler.LambdaLR) – The scheduler used for setting the learning rate.

  • train_dataloader (torch.utils.data.DataLoader, optional) – The current dataloader used for training.

  • eval_dataloader (torch.utils.data.DataLoader, optional) – The current dataloader used for training.

  • metrics (Dict[str, float]) –

    The metrics computed by the last evaluation phase.

    Those are only accessible in the event on_evaluate.

  • logs (Dict[str, float]) –

    The values to log.

    Those are only accessible in the event on_log.

The control object is the only one that can be changed by the callback, in which case the event that changes it should return the modified version.

The argument args, state and control are positionals for all events, all the others are grouped in kwargs. You can unpack the ones you need in the signature of the event using them. As an example, see the code of the simple PrinterCallback.

Example:

class PrinterCallback(TrainerCallback):

    def on_log(self, args, state, control, logs=None, **kwargs):
        _ = logs.pop("total_flos", None)
        if state.is_local_process_zero:
            print(logs)
on_epoch_begin(args: transformers.training_args.TrainingArguments, state: transformers.trainer_callback.TrainerState, control: transformers.trainer_callback.TrainerControl, **kwargs)[source]¶

Event called at the beginning of an epoch.

on_epoch_end(args: transformers.training_args.TrainingArguments, state: transformers.trainer_callback.TrainerState, control: transformers.trainer_callback.TrainerControl, **kwargs)[source]¶

Event called at the end of an epoch.

on_evaluate(args: transformers.training_args.TrainingArguments, state: transformers.trainer_callback.TrainerState, control: transformers.trainer_callback.TrainerControl, **kwargs)[source]¶

Event called after an evaluation phase.

on_init_end(args: transformers.training_args.TrainingArguments, state: transformers.trainer_callback.TrainerState, control: transformers.trainer_callback.TrainerControl, **kwargs)[source]¶

Event called at the end of the initialization of the Trainer.

on_log(args: transformers.training_args.TrainingArguments, state: transformers.trainer_callback.TrainerState, control: transformers.trainer_callback.TrainerControl, **kwargs)[source]¶

Event called after logging the last logs.

on_prediction_step(args: transformers.training_args.TrainingArguments, state: transformers.trainer_callback.TrainerState, control: transformers.trainer_callback.TrainerControl, **kwargs)[source]¶

Event called after a prediction step.

on_save(args: transformers.training_args.TrainingArguments, state: transformers.trainer_callback.TrainerState, control: transformers.trainer_callback.TrainerControl, **kwargs)[source]¶

Event called after a checkpoint save.

on_step_begin(args: transformers.training_args.TrainingArguments, state: transformers.trainer_callback.TrainerState, control: transformers.trainer_callback.TrainerControl, **kwargs)[source]¶

Event called at the beginning of a training step. If using gradient accumulation, one training step might take several inputs.

on_step_end(args: transformers.training_args.TrainingArguments, state: transformers.trainer_callback.TrainerState, control: transformers.trainer_callback.TrainerControl, **kwargs)[source]¶

Event called at the end of a training step. If using gradient accumulation, one training step might take several inputs.

on_substep_end(args: transformers.training_args.TrainingArguments, state: transformers.trainer_callback.TrainerState, control: transformers.trainer_callback.TrainerControl, **kwargs)[source]¶

Event called at the end of an substep during gradient accumulation.

on_train_begin(args: transformers.training_args.TrainingArguments, state: transformers.trainer_callback.TrainerState, control: transformers.trainer_callback.TrainerControl, **kwargs)[source]¶

Event called at the beginning of training.

on_train_end(args: transformers.training_args.TrainingArguments, state: transformers.trainer_callback.TrainerState, control: transformers.trainer_callback.TrainerControl, **kwargs)[source]¶

Event called at the end of training.

Here is an example of how to register a custom callback with the PyTorch Trainer:

class MyCallback(TrainerCallback):
    "A callback that prints a message at the beginning of training"

    def on_train_begin(self, args, state, control, **kwargs):
        print("Starting training")

trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    callbacks=[MyCallback]  # We can either pass the callback class this way or an instance of it (MyCallback())
)

Another way to register a callback is to call trainer.add_callback() as follows:

trainer = Trainer(...)
trainer.add_callback(MyCallback)
# Alternatively, we can pass an instance of the callback class
trainer.add_callback(MyCallback())

TrainerState¶

class transformers.TrainerState(epoch: Optional[float] = None, global_step: int = 0, max_steps: int = 0, num_train_epochs: int = 0, total_flos: float = 0, log_history: List[Dict[str, float]] = None, best_metric: Optional[float] = None, best_model_checkpoint: Optional[str] = None, is_local_process_zero: bool = True, is_world_process_zero: bool = True, is_hyper_param_search: bool = False, trial_name: str = None, trial_params: Dict[str, Union[str, float, int]] = None)[source]¶

A class containing the Trainer inner state that will be saved along the model and optimizer when checkpointing and passed to the TrainerCallback.

Note

In all this class, one step is to be understood as one update step. When using gradient accumulation, one update step may require several forward and backward passes: if you use gradient_accumulation_steps=n, then one update step requires going through n batches.

Parameters
  • epoch (float, optional) – Only set during training, will represent the epoch the training is at (the decimal part being the percentage of the current epoch completed).

  • global_step (int, optional, defaults to 0) – During training, represents the number of update steps completed.

  • max_steps (int, optional, defaults to 0) – The number of update steps to do during the current training.

  • total_flos (float, optional, defaults to 0) – The total number of floating operations done by the model since the beginning of training (stored as floats to avoid overflow).

  • log_history (List[Dict[str, float]], optional) – The list of logs done since the beginning of training.

  • best_metric (float, optional) – When tracking the best model, the value of the best metric encountered so far.

  • best_model_checkpoint (str, optional) – When tracking the best model, the value of the name of the checkpoint for the best model encountered so far.

  • is_local_process_zero (bool, optional, defaults to True) – Whether or not this process is the local (e.g., on one machine if training in a distributed fashion on several machines) main process.

  • is_world_process_zero (bool, optional, defaults to True) – Whether or not this process is the global main process (when training in a distributed fashion on several machines, this is only going to be True for one process).

  • is_hyper_param_search (bool, optional, defaults to False) – Whether we are in the process of a hyper parameter search using Trainer.hyperparameter_search. This will impact the way data will be logged in TensorBoard.

classmethod load_from_json(json_path: str)[source]¶

Create an instance from the content of json_path.

save_to_json(json_path: str)[source]¶

Save the content of this instance in JSON format inside json_path.

TrainerControl¶

class transformers.TrainerControl(should_training_stop: bool = False, should_epoch_stop: bool = False, should_save: bool = False, should_evaluate: bool = False, should_log: bool = False)[source]¶

A class that handles the Trainer control flow. This class is used by the TrainerCallback to activate some switches in the training loop.

Parameters
  • should_training_stop (bool, optional, defaults to False) –

    Whether or not the training should be interrupted.

    If True, this variable will not be set back to False. The training will just stop.

  • should_epoch_stop (bool, optional, defaults to False) –

    Whether or not the current epoch should be interrupted.

    If True, this variable will be set back to False at the beginning of the next epoch.

  • should_save (bool, optional, defaults to False) –

    Whether or not the model should be saved at this step.

    If True, this variable will be set back to False at the beginning of the next step.

  • should_evaluate (bool, optional, defaults to False) –

    Whether or not the model should be evaluated at this step.

    If True, this variable will be set back to False at the beginning of the next step.

  • should_log (bool, optional, defaults to False) –

    Whether or not the logs should be reported at this step.

    If True, this variable will be set back to False at the beginning of the next step.