Perceiver IO image classifier (MNIST)
This model is a small Perceiver IO image classifier (907K parameters) trained from scratch on the MNIST dataset. It is a training example of the perceiver-io library.
Model description
Like krasserm/perceiver-io-img-clf this model also uses 2D Fourier features for position encoding and cross-attends to individual pixels of an input image but uses repeated cross-attention, a configuration that was described in the original Perceiver paper which has been dropped in the follow-up Perceiver IO paper (see building blocks for more details).
Model training
The model was trained with randomly initialized weights on the MNIST handwritten digits dataset. Images were normalized, data augmentations were turned off. Training was done with PyTorch Lightning and the resulting checkpoint was converted to this 🤗 model with a library-specific conversion utility.
Intended use and limitations
The model can be used for MNIST handwritten digit classification.
Usage examples
To use this model you first need to install
the perceiver-io
library with extension vision
.
pip install perceiver-io[vision]
Then the model can be used with PyTorch. Either use the model and image processor directly
from datasets import load_dataset
from transformers import AutoModelForImageClassification, AutoImageProcessor
from perceiver.model.vision import image_classifier # auto-class registration
repo_id = "krasserm/perceiver-io-img-clf-mnist"
mnist_dataset = load_dataset("mnist", split="test")[:9]
images = mnist_dataset["image"]
labels = mnist_dataset["label"]
model = AutoModelForImageClassification.from_pretrained(repo_id)
processor = AutoImageProcessor.from_pretrained(repo_id)
inputs = processor(images, return_tensors="pt")
logits = model(**inputs).logits
print(f"Labels: {labels}")
print(f"Predictions: {logits.argmax(dim=-1).numpy().tolist()}")
Labels: [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
or use an image-classification
pipeline:
from datasets import load_dataset
from transformers import pipeline
from perceiver.model.vision import image_classifier # auto-class registration
repo_id = "krasserm/perceiver-io-img-clf-mnist"
mnist_dataset = load_dataset("mnist", split="test")[:9]
images = mnist_dataset["image"]
labels = mnist_dataset["label"]
classifier = pipeline("image-classification", model=repo_id)
predictions = [pred[0]["label"] for pred in classifier(images)]
print(f"Labels: {labels}")
print(f"Predictions: {predictions}")
Labels: [7, 2, 1, 0, 4, 1, 4, 9, 5]
Predictions: [7, 2, 1, 0, 4, 1, 4, 9, 5]
Checkpoint conversion
The krasserm/perceiver-io-img-clf-mnist
model has been created from a training checkpoint with:
from perceiver.model.vision.image_classifier import convert_mnist_classifier_checkpoint
convert_mnist_classifier_checkpoint(
save_dir="krasserm/perceiver-io-img-clf-mnist",
ckpt_url="https://martin-krasser.com/perceiver/logs-0.8.0/img_clf/version_0/checkpoints/epoch=025-val_loss=0.065.ckpt",
push_to_hub=True,
)
- Downloads last month
- 50