Image classification
Image classification assigns a label or class to an image. Unlike text or audio classification, the inputs are the pixel values that represent an image. There are many uses for image classification, like detecting damage after a disaster, monitoring crop health, or helping screen medical images for signs of disease.
This guide will show you how to fine-tune ViT on the Food-101 dataset to classify a food item in an image.
See the image classification task page for more information about its associated models, datasets, and metrics.
Load Food-101 dataset
Load only the first 5000 images of the Food-101 dataset from the 🤗 Datasets library since it is pretty large:
>>> from datasets import load_dataset
>>> food = load_dataset("food101", split="train[:5000]")
Split this dataset into a train and test set:
>>> food = food.train_test_split(test_size=0.2)
Then take a look at an example:
>>> food["train"][0]
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=512x512 at 0x7F52AFC8AC50>,
'label': 79}
The image
field contains a PIL image, and each label
is an integer that represents a class. Create a dictionary that maps a label name to an integer and vice versa. The mapping will help the model recover the label name from the label number:
>>> labels = food["train"].features["label"].names
>>> label2id, id2label = dict(), dict()
>>> for i, label in enumerate(labels):
... label2id[label] = str(i)
... id2label[str(i)] = label
Now you can convert the label number to a label name for more information:
>>> id2label[str(79)]
'prime_rib'
Each food class - or label - corresponds to a number; 79
indicates a prime rib in the example above.
Preprocess
Load the ViT feature extractor to process the image into a tensor:
>>> from transformers import AutoFeatureExtractor
>>> feature_extractor = AutoFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
Apply several image transformations to the dataset to make the model more robust against overfitting. Here you’ll use torchvision’s transforms
module. Crop a random part of the image, resize it, and normalize it with the image mean and standard deviation:
>>> from torchvision.transforms import RandomResizedCrop, Compose, Normalize, ToTensor
>>> normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
>>> _transforms = Compose([RandomResizedCrop(feature_extractor.size), ToTensor(), normalize])
Create a preprocessing function that will apply the transforms and return the pixel_values
- the inputs to the model - of the image:
>>> def transforms(examples):
... examples["pixel_values"] = [_transforms(img.convert("RGB")) for img in examples["image"]]
... del examples["image"]
... return examples
Use 🤗 Dataset’s with_transform method to apply the transforms over the entire dataset. The transforms are applied on-the-fly when you load an element of the dataset:
>>> food = food.with_transform(transforms)
Use DefaultDataCollator to create a batch of examples. Unlike other data collators in 🤗 Transformers, the DefaultDataCollator does not apply additional preprocessing such as padding.
>>> from transformers import DefaultDataCollator
>>> data_collator = DefaultDataCollator()
Train
Load ViT with AutoModelForImageClassification. Specify the number of labels, and pass the model the mapping between label number and label class:
>>> from transformers import AutoModelForImageClassification, TrainingArguments, Trainer
>>> model = AutoModelForImageClassification.from_pretrained(
... "google/vit-base-patch16-224-in21k",
... num_labels=len(labels),
... id2label=id2label,
... label2id=label2id,
... )
If you aren’t familiar with fine-tuning a model with the Trainer, take a look at the basic tutorial here!
At this point, only three steps remain:
- Define your training hyperparameters in TrainingArguments. It is important you don’t remove unused columns because this will drop the
image
column. Without theimage
column, you can’t createpixel_values
. Setremove_unused_columns=False
to prevent this behavior! - Pass the training arguments to Trainer along with the model, datasets, tokenizer, and data collator.
- Call train() to fine-tune your model.
>>> training_args = TrainingArguments(
... output_dir="./results",
... per_device_train_batch_size=16,
... evaluation_strategy="steps",
... num_train_epochs=4,
... fp16=True,
... save_steps=100,
... eval_steps=100,
... logging_steps=10,
... learning_rate=2e-4,
... save_total_limit=2,
... remove_unused_columns=False,
... )
>>> trainer = Trainer(
... model=model,
... args=training_args,
... data_collator=data_collator,
... train_dataset=food["train"],
... eval_dataset=food["test"],
... tokenizer=feature_extractor,
... )
>>> trainer.train()
For a more in-depth example of how to fine-tune a model for image classification, take a look at the corresponding PyTorch notebook.