Image Classification with Vision Transformer (ViT)
This repository contains a Python script for training an image classification model using the Vision Transformer (ViT) architecture. We use the transformers and datasets libraries from Hugging Face along with PyTorch and TensorFlow for the implementation.
Functions and Usage
convert_to_tf_tensor(image: Image):
This function converts an image to a Tensorflow tensor with a size of 224x224 and three color channels.
preprocess(batch):
Preprocesses the images in a batch, using the feature extractor to convert them to pixel values. It also adds the labels to the batch.
collate_fn(batch):
This function prepares the batch for training or evaluation. It stacks the pixel values and labels.
compute_metrics(p):
Computes the metrics (accuracy) for the predictions.