Transformers documentation

Keras callbacks

You are viewing v4.27.2 version. A newer version v4.42.0 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Keras callbacks

When training a Transformers model with Keras, there are some library-specific callbacks available to automate common tasks:


class transformers.KerasMetricCallback

< >

( metric_fn: typing.Callable eval_dataset: typing.Union[, numpy.ndarray, tensorflow.python.framework.ops.Tensor, tuple, dict] output_cols: typing.Optional[typing.List[str]] = None label_cols: typing.Optional[typing.List[str]] = None batch_size: typing.Optional[int] = None predict_with_generate: bool = False use_xla_generation: bool = False generate_kwargs: typing.Optional[dict] = None )


  • metric_fn (Callable) — Metric function provided by the user. It will be called with two arguments - predictions and labels. These contain the model’s outputs and matching labels from the dataset. It should return a dict mapping metric names to numerical values.
  • eval_dataset ( or dict or tuple or np.ndarray or tf.Tensor) — Validation data to be used to generate predictions for the metric_fn.
  • output_cols (`List[str], optional) — A list of columns to be retained from the model output as the predictions. Defaults to all.
  • label_cols (’List[str], optional’) — A list of columns to be retained from the input dataset as the labels. Will be autodetected if this is not supplied.
  • batch_size (int, optional) — Batch size. Only used when the data is not a pre-batched
  • predict_with_generate (bool, optional, defaults to False) — Whether we should use model.generate() to get outputs for the model.
  • use_xla_generation (bool, optional, defaults to False) — If we’re generating, whether to compile model generation with XLA. This can massively increase the speed of generation (up to 100X speedup) but will require a new XLA compilation for each input shape. When using XLA generation, it’s a good idea to pad your inputs to the same size, or to use the pad_to_multiple_of argument in your tokenizer or DataCollator, which will reduce the number of unique input shapes and save a lot of compilation time. This option has no effect is predict_with_generate is False.
  • generate_kwargs (dict, optional) — Keyword arguments to pass to model.generate() when generating. Has no effect if predict_with_generate is False.

Callback to compute metrics at the end of every epoch. Unlike normal Keras metrics, these do not need to be compilable by TF. It is particularly useful for common NLP metrics like BLEU and ROUGE that require string operations or generation loops that cannot be compiled. Predictions (or generations) will be computed on the eval_dataset before being passed to the metric_fn in np.ndarray format. The metric_fn should compute metrics and return a dict mapping metric names to metric values.

We provide an example of a suitable metric_fn that computes ROUGE scores for a summarization model below. Note that this example skips some post-processing for readability and simplicity, and should probably not be used as-is!

from datasets import load_metric

rouge_metric = load_metric("rouge")

def rouge_fn(predictions, labels):
    decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge_metric.compute(predictions=decoded_predictions, references=decoded_labels)
    return {key: value.mid.fmeasure * 100 for key, value in result.items()}

The above function will return a dict containing values which will be logged like any other Keras metric:

{'rouge1': 37.4199, 'rouge2': 13.9768, 'rougeL': 34.361, 'rougeLsum': 35.0781


class transformers.PushToHubCallback

< >

( output_dir: typing.Union[str, pathlib.Path] save_strategy: typing.Union[str, transformers.trainer_utils.IntervalStrategy] = 'epoch' save_steps: typing.Optional[int] = None tokenizer: typing.Optional[transformers.tokenization_utils_base.PreTrainedTokenizerBase] = None hub_model_id: typing.Optional[str] = None hub_token: typing.Optional[str] = None checkpoint: bool = False **model_card_args )


  • output_dir (str) — The output directory where the model predictions and checkpoints will be written and synced with the repository on the Hub.
  • save_strategy (str or IntervalStrategy, optional, defaults to "epoch") — The checkpoint save strategy to adopt during training. Possible values are:

    • "no": Save is done at the end of training.
    • "epoch": Save is done at the end of each epoch.
    • "steps": Save is done every save_steps
  • save_steps (int, optional) — The number of steps between saves when using the “steps” save_strategy.
  • tokenizer (PreTrainedTokenizerBase, optional) — The tokenizer used by the model. If supplied, will be uploaded to the repo alongside the weights.
  • hub_model_id (str, optional) — The name of the repository to keep in sync with the local output_dir. It can be a simple model ID in which case the model will be pushed in your namespace. Otherwise it should be the whole repository name, for instance "user_name/model", which allows you to push to an organization you are a member of with "organization_name/model".

    Will default to the name of output_dir.

  • hub_token (str, optional) — The token to use to push the model to the Hub. Will default to the token in the cache folder obtained with huggingface-cli login.
  • checkpoint (bool, optional, defaults to False) — Whether to save full training checkpoints (including epoch and optimizer state) to allow training to be resumed. Only usable when save_strategy is "epoch".

Callback that will save and push the model to the Hub regularly. By default, it pushes once per epoch, but this can be changed with the save_strategy argument. Pushed models can be accessed like any other model on the hub, such as with the from_pretrained method.

from transformers.keras_callbacks import PushToHubCallback

push_to_hub_callback = PushToHubCallback(
), callbacks=[push_to_hub_callback])