Transformers documentation

Fine-tune a pretrained model

You are viewing v4.18.0 version. A newer version v4.41.3 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Fine-tune a pretrained model

There are significant benefits to using a pretrained model. It reduces computation costs, your carbon footprint, and allows you to use state-of-the-art models without having to train one from scratch. 🤗 Transformers provides access to thousands of pretrained models for a wide range of tasks. When you use a pretrained model, you train it on a dataset specific to your task. This is known as fine-tuning, an incredibly powerful training technique. In this tutorial, you will fine-tune a pretrained model with a deep learning framework of your choice:

  • Fine-tune a pretrained model with 🤗 Transformers Trainer.
  • Fine-tune a pretrained model in TensorFlow with Keras.
  • Fine-tune a pretrained model in native PyTorch.

Prepare a dataset

Before you can fine-tune a pretrained model, download a dataset and prepare it for training. The previous tutorial showed you how to process data for training, and now you get an opportunity to put those skills to the test!

Begin by loading the Yelp Reviews dataset:

>>> from datasets import load_dataset

>>> dataset = load_dataset("yelp_review_full")
>>> dataset["train"][100]
{'label': 0,
 'text': 'My expectations for McDonalds are t rarely high. But for one to still fail so spectacularly...that takes something special!\\nThe cashier took my friends\'s order, then promptly ignored me. I had to force myself in front of a cashier who opened his register to wait on the person BEHIND me. I waited over five minutes for a gigantic order that included precisely one kid\'s meal. After watching two people who ordered after me be handed their food, I asked where mine was. The manager started yelling at the cashiers for \\"serving off their orders\\" when they didn\'t have their food. But neither cashier was anywhere near those controls, and the manager was the one serving food to customers and clearing the boards.\\nThe manager was rude when giving me my order. She didn\'t make sure that I had everything ON MY RECEIPT, and never even had the decency to apologize that I felt I was getting poor service.\\nI\'ve eaten at various McDonalds restaurants for over 30 years. I\'ve worked at more than one location. I expect bad days, bad moods, and the occasional mistake. But I have yet to have a decent experience at this store. It will remain a place I avoid unless someone in my party needs to avoid illness from low blood sugar. Perhaps I should go back to the racially biased service of Steak n Shake instead!'}

As you now know, you need a tokenizer to process the text and include a padding and truncation strategy to handle any variable sequence lengths. To process your dataset in one step, use 🤗 Datasets map method to apply a preprocessing function over the entire dataset:

>>> from transformers import AutoTokenizer

>>> tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

>>> def tokenize_function(examples):
...     return tokenizer(examples["text"], padding="max_length", truncation=True)

>>> tokenized_datasets =, batched=True)

If you like, you can create a smaller subset of the full dataset to fine-tune on to reduce the time it takes:

>>> small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
>>> small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))


Hide Pytorch content

🤗 Transformers provides a Trainer class optimized for training 🤗 Transformers models, making it easier to start training without manually writing your own training loop. The Trainer API supports a wide range of training options and features such as logging, gradient accumulation, and mixed precision.

Start by loading your model and specify the number of expected labels. From the Yelp Review dataset card, you know there are five labels:

>>> from transformers import AutoModelForSequenceClassification

>>> model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

You will see a warning about some of the pretrained weights not being used and some weights being randomly initialized. Don’t worry, this is completely normal! The pretrained head of the BERT model is discarded, and replaced with a randomly initialized classification head. You will fine-tune this new model head on your sequence classification task, transferring the knowledge of the pretrained model to it.

Training hyperparameters

Next, create a TrainingArguments class which contains all the hyperparameters you can tune as well as flags for activating different training options. For this tutorial you can start with the default training hyperparameters, but feel free to experiment with these to find your optimal settings.

Specify where to save the checkpoints from your training:

>>> from transformers import TrainingArguments

>>> training_args = TrainingArguments(output_dir="test_trainer")


Trainer does not automatically evaluate model performance during training. You will need to pass Trainer a function to compute and report metrics. The 🤗 Datasets library provides a simple accuracy function you can load with the load_metric (see this tutorial for more information) function:

>>> import numpy as np
>>> from datasets import load_metric

>>> metric = load_metric("accuracy")

Call compute on metric to calculate the accuracy of your predictions. Before passing your predictions to compute, you need to convert the predictions to logits (remember all 🤗 Transformers models return logits):

>>> def compute_metrics(eval_pred):
...     logits, labels = eval_pred
...     predictions = np.argmax(logits, axis=-1)
...     return metric.compute(predictions=predictions, references=labels)

If you’d like to monitor your evaluation metrics during fine-tuning, specify the evaluation_strategy parameter in your training arguments to report the evaluation metric at the end of each epoch:

>>> from transformers import TrainingArguments, Trainer

>>> training_args = TrainingArguments(output_dir="test_trainer", evaluation_strategy="epoch")


Create a Trainer object with your model, training arguments, training and test datasets, and evaluation function:

>>> trainer = Trainer(
...     model=model,
...     args=training_args,
...     train_dataset=small_train_dataset,
...     eval_dataset=small_eval_dataset,
...     compute_metrics=compute_metrics,
... )

Then fine-tune your model by calling train():

>>> trainer.train()
Hide TensorFlow content

🤗 Transformers models also supports training in TensorFlow with the Keras API.

Convert dataset to TensorFlow format

The DefaultDataCollator assembles tensors into a batch for the model to train on. Make sure you specify return_tensors to return TensorFlow tensors:

>>> from transformers import DefaultDataCollator

>>> data_collator = DefaultDataCollator(return_tensors="tf")

Trainer uses DataCollatorWithPadding by default so you don’t need to explicitly specify a data collator.

Next, convert the tokenized datasets to TensorFlow datasets with the to_tf_dataset method. Specify your inputs in columns, and your label in label_cols:

>>> tf_train_dataset = small_train_dataset.to_tf_dataset(
...     columns=["attention_mask", "input_ids", "token_type_ids"],
...     label_cols=["labels"],
...     shuffle=True,
...     collate_fn=data_collator,
...     batch_size=8,
... )

>>> tf_validation_dataset = small_eval_dataset.to_tf_dataset(
...     columns=["attention_mask", "input_ids", "token_type_ids"],
...     label_cols=["labels"],
...     shuffle=False,
...     collate_fn=data_collator,
...     batch_size=8,
... )

Compile and fit

Load a TensorFlow model with the expected number of labels:

>>> import tensorflow as tf
>>> from transformers import TFAutoModelForSequenceClassification

>>> model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

Then compile and fine-tune your model with fit as you would with any other Keras model:

>>> model.compile(
...     optimizer=tf.keras.optimizers.Adam(learning_rate=5e-5),
...     loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
...     metrics=tf.metrics.SparseCategoricalAccuracy(),
... )

>>>, validation_data=tf_validation_dataset, epochs=3)

Train in native PyTorch

Hide Pytorch content

Trainer takes care of the training loop and allows you to fine-tune a model in a single line of code. For users who prefer to write their own training loop, you can also fine-tune a 🤗 Transformers model in native PyTorch.

At this point, you may need to restart your notebook or execute the following code to free some memory:

del model
del pytorch_model
del trainer

Next, manually postprocess tokenized_dataset to prepare it for training.

  1. Remove the text column because the model does not accept raw text as an input:

    >>> tokenized_datasets = tokenized_datasets.remove_columns(["text"])
  2. Rename the label column to labels because the model expects the argument to be named labels:

    >>> tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
  3. Set the format of the dataset to return PyTorch tensors instead of lists:

    >>> tokenized_datasets.set_format("torch")

Then create a smaller subset of the dataset as previously shown to speed up the fine-tuning:

>>> small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
>>> small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))


Create a DataLoader for your training and test datasets so you can iterate over batches of data:

>>> from import DataLoader

>>> train_dataloader = DataLoader(small_train_dataset, shuffle=True, batch_size=8)
>>> eval_dataloader = DataLoader(small_eval_dataset, batch_size=8)

Load your model with the number of expected labels:

>>> from transformers import AutoModelForSequenceClassification

>>> model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

Optimizer and learning rate scheduler

Create an optimizer and learning rate scheduler to fine-tune the model. Let’s use the AdamW optimizer from PyTorch:

>>> from torch.optim import AdamW

>>> optimizer = AdamW(model.parameters(), lr=5e-5)

Create the default learning rate scheduler from Trainer:

>>> from transformers import get_scheduler

>>> num_epochs = 3
>>> num_training_steps = num_epochs * len(train_dataloader)
>>> lr_scheduler = get_scheduler(
...     name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
... )

Lastly, specify device to use a GPU if you have access to one. Otherwise, training on a CPU may take several hours instead of a couple of minutes.

>>> import torch

>>> device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

Get free access to a cloud GPU if you don’t have one with a hosted notebook like Colaboratory or SageMaker StudioLab.

Great, now you are ready to train! 🥳

Training loop

To keep track of your training progress, use the tqdm library to add a progress bar over the number of training steps:

>>> from import tqdm

>>> progress_bar = tqdm(range(num_training_steps))

>>> model.train()
>>> for epoch in range(num_epochs):
...     for batch in train_dataloader:
...         batch = {k: for k, v in batch.items()}
...         outputs = model(**batch)
...         loss = outputs.loss
...         loss.backward()

...         optimizer.step()
...         lr_scheduler.step()
...         optimizer.zero_grad()
...         progress_bar.update(1)


Just like how you need to add an evaluation function to Trainer, you need to do the same when you write your own training loop. But instead of calculating and reporting the metric at the end of each epoch, this time you will accumulate all the batches with add_batch and calculate the metric at the very end.

>>> metric = load_metric("accuracy")
>>> model.eval()
>>> for batch in eval_dataloader:
...     batch = {k: for k, v in batch.items()}
...     with torch.no_grad():
...         outputs = model(**batch)

...     logits = outputs.logits
...     predictions = torch.argmax(logits, dim=-1)
...     metric.add_batch(predictions=predictions, references=batch["labels"])

>>> metric.compute()

Additional resources

For more fine-tuning examples, refer to: