aryadytm's picture
Add application file
1716a7d
import streamlit as st
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image
import torch
import torch.nn.functional as F
# Load the feature extractor and model
model_name_or_path = 'google/vit-base-patch16-224-in21k' # Replace with your actual model path
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name_or_path)
model = ViTForImageClassification.from_pretrained("./aryadytm-vit-vehicle-classifier")
def predict_image(image):
inputs = feature_extractor(images=image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = F.softmax(logits, dim=-1)
predicted_label_id = probs.argmax(-1).item()
predicted_label = model.config.id2label[predicted_label_id]
confidence = probs.max().item()
return predicted_label, confidence
# Streamlit UI
st.markdown("## Vehicle Image Classification")
st.image('./assets/spotlight.png')
st.markdown("""
### Group 2
- Arya Adyatma - 2501985836
- Aldre Muhammad Keyzar - 2502006543
- Devin Eldrian Wijaya - 2501961363
- Rollando Marcellino Himmel Madison - 2502006575
This app lets you classify vehicle images using a pre-trained ViT model. You need to upload your own image.
- Kaggle dataset: https://www.kaggle.com/code/rydytm/vehicle-classification/edit.
- Colab Notebook: https://colab.research.google.com/drive/1El7RhY69KvE9Nj9vAxUPGjg42NwuNcPu?usp=sharing
""")
st.image('./assets/vit.png')
st.markdown("### Upload Your Image Here")
uploaded_file = st.file_uploader("Choose an image...", type=['png', 'jpg', 'jpeg'])
if uploaded_file is not None:
image = (
Image.open(uploaded_file)
.convert("RGB")
.resize((512, 512))
)
st.image(image, caption='Uploaded Image', use_column_width=True)
st.write("")
predicted_label, confidence = predict_image(image)
st.write("### Prediction Result")
st.write(f"Predicted label: **{predicted_label}**")
st.write(f"Confidence: **{confidence:.2f}**")