You need to agree to share your contact information to access this model

This repository is publicly accessible, but you have to accept the conditions to access its files and content.

Log in or Sign Up to review the conditions and access this model content.

A lightweight Vision Transformer (ViT) model for image classification on small 32ร—32

This is a custom Vision Transformer (ViT) model trained for image classification. It works well with low-resolution (32ร—32) images.


Test Images & Model Predictions

Here are the test images included in the repository along with the model's predicted output for each:

test_image_1.png

test_image_1.png
Prediction: Man

test_image_2.png

test_image_2.png
Prediction: Cat

test_image_3.png

test_image_3.png
Prediction: Bird

test_image_4.png

test_image_4.png
Prediction: Man

test_image 5.png!

[test_image_5.png] (No prediction provided)

test_image_6.png

test_image_6.png
Prediction: Dog

test_image_7.png

test_image_7.png
Prediction: Cow

test_image_8.png

test_image_8.png
Prediction: Automobile

test_image_9.png

test_image_9.png
Prediction: Car

test_image_10.png

test_image_10.png
Prediction: Cat

Summary of Predictions

Image File Predicted Label
test_image_1.png Man
test_image_2.png Cat
test_image_3.png Bird
test_image_4.png Man
test_image_6.png Dog
test_image_7.png Cow
test_image_8.png Automobile
test_image_9.png Car
test_image_10.png Cat

Quick Start โ€“ Inference Code

from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
import torch
import os

model_name = "sKT-AI-LABS/SKT-VIT"
processor = AutoImageProcessor.from_pretrained(model_name)
model = AutoModelForImageClassification.from_pretrained(model_name)

test_dir = "test_images"
for img_file in sorted(os.listdir(test_dir)):
    if img_file.endswith((".png", ".jpg", ".jpeg")):
        image = Image.open(os.path.join(test_dir, img_file)).convert("RGB")
        inputs = processor(images=image, return_tensors="pt")
        
        with torch.no_grad():
            outputs = model(**inputs)
            predicted_idx = outputs.logits.argmax(-1).item()
            label = model.config.id2label.get(predicted_idx, f"Class {predicted_idx}")
        
        print(f"{img_file:20} โ†’ {label}")
Downloads last month
3
Safetensors
Model size
3.2M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support