license: mit
metrics:
- accuracy
- f1
pipeline_tag: image-classification
widget:
- src: >-
https://upload.wikimedia.org/wikipedia/commons/thumb/f/fb/Welchcorgipembroke.JPG/1200px-Welchcorgipembroke.JPG
example_title: Pembroke Corgi
- src: >-
https://upload.wikimedia.org/wikipedia/commons/d/df/Shihtzu_%28cropped%29.jpg
example_title: Shih Tzu
- src: https://upload.wikimedia.org/wikipedia/commons/5/55/Beagle_600.jpg
example_title: Beagle
Model Motivation
Recently, someone asked me if you can classify dog images into their respective dog breeds instead just differentiating from cats vs dogs like my last notebook. I say YES!
Due to the complexity of the problem, we will be using the most advanced computer vision architecture released in the 2020 Google paper, the Vision Transformer.
The difference between the Vision Transformer and the traditional Convolutional Neural Network (CNN) is how it treats an image. In Vision Transformers, we take the input as a patch of the original image, say 16 x 16, and feed in into the Transformer as a sequence with positional embeddings and self-attention, while in the Convolutional Neural Network (CNN), we use the same patch of original image as an input, but use convolutions and pooling layers as inductive biases. What this means is that Vision Transformer can use it's judgement to attend any particular patch of the image in a global fashion using it's self-attention mechanism without having us to guide the neural network like a CNN with local centering/cropping/bounding box our images to help its convolutions.
This allows the Vision Transformer architecture to be more flexible and scalable in nature, allowing us to create foundation models in computer vision, similar to the NLP foundational models like BERT and GPT, with pre-training self-supervised/supervised on massive amount of image data that would generalize to different computer vision tasks such as image classification, recognition, segmentation, etc. This cross-pollination helps us move closer towards the goal of Artificial General Intelligence.
One thing about Vision Transformers are it has weaker inductive biases compared to Convolutional Neural Networks that enables it's scalability and flexibility. This feature/bug depending on who you ask will require most well-performing pre-trained models to require more data despite having less parameters compared to it's CNN counterparts.
Luckily, in this model, we will used a Vision Transformer from Google hosted at HuggingFace pre-trained on the ImageNet-21k dataset (14 million images, 21k classes) with 16x16 patches, 224x224 resolution to bypass that data limitation. We will be fine-tuning this model to our "small" dog breeds dataset of around 20 thousand images from the Stanford Dogs dataset imported by Jessica Li into Kaggle to classify dog images into 120 types of dog breeds!
Model Description
This model is finetuned using the Google Vision Transformer (vit-base-patch16-224-in21k) on the Stanford Dogs dataset in Kaggle to classify dog images into 150 dog breeds.
Intended Uses & Limitations
You can use this finetuned model to classify dog images to 150 dog breeds limited to those that are in the dataset.
How to Use
from transformers import AutoImageProcesssor, AutoModelForImageClassification
import Image
import requests
url = "https://upload.wikimedia.org/wikipedia/commons/8/8b/Husky_L.jpg"
image = PIL.Image.open(requests.get(url, stream=True).raw)
image_processor = AutoImageProcesssor.from_pretrained("wesleyacheng/dog-breeds-multiclass-image-classification-with-vit")
model = AutoModelForImageClassification.from_pretrained("wesleyacheng/dog-breeds-multiclass-image-classification-with-vit")
inputs = image_processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# model predicts one of the 150 Stanford dog breeds classes
predicted_class_idx = logits.argmax(-1).item()
print("Predicted class:", model.config.id2label[predicted_class_idx])
Model Training Metrics
Epoch | Top-1 Accuracy | Top-3 Accuracy | Top-5 Accuracy | Macro F1 |
---|---|---|---|---|
1 | 79.8% | 95.1% | 97.5% | 77.2% |
2 | 83.8% | 96.7% | 98.2% | 81.9% |
3 | 84.8% | 96.7% | 98.3% | 83.4% |
Model Evaluation Metrics
Top-1 Accuracy | Top-3 Accuracy | Top-5 Accuracy | Macro F1 |
---|---|---|---|
84.0% | 97.1% | 98.7% | 83.0% |