Spaces:
Running
Running
import streamlit as st | |
from transformers import ViTFeatureExtractor, ViTForImageClassification | |
from PIL import Image | |
import torch | |
# Load pre-trained model and feature extractor for CIFAR-10 | |
model_name = "aaraki/vit-base-patch16-224-in21k-finetuned-cifar10" | |
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) | |
model = ViTForImageClassification.from_pretrained(model_name) | |
# CIFAR-10 class names | |
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] | |
# Streamlit app | |
st.title("CIFAR-10 Image Classification with Pre-trained Vision Transformer") | |
# Prediction on uploaded image | |
st.subheader("Make Predictions") | |
uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
# Preprocess the uploaded image | |
image = Image.open(uploaded_file).convert("RGB") | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
inputs = feature_extractor(images=image, return_tensors="pt") | |
if st.button("Predict"): | |
with st.spinner("Classifying..."): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
# Check if the predicted_class_idx is within bounds | |
if 0 <= predicted_class_idx < len(class_names): | |
st.write(f"Predicted Class: {predicted_class_idx} ({class_names[predicted_class_idx]})") | |
else: | |
st.error("Prediction index out of range.") | |