classifyStream / app.py
mateoluksenberg's picture
Update app.py
9119567 verified
raw
history blame
No virus
994 Bytes
# Load model directly
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image # Import the Image module
import torch # Import the torch module
import streamlit as st
st.title("Image Classification")
uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "png"])
#image_path = "/content/cm5_2.jpg" # Store the path as a string
processor = AutoImageProcessor.from_pretrained("mateoluksenberg/dit-base-Classifier_CM05")
model = AutoModelForImageClassification.from_pretrained("mateoluksenberg/dit-base-Classifier_CM05")
image = Image.open(uploaded_file) # Load the image from the file path
inputs = processor(image, return_tensors="pt") # Pass the image object to the processor
with torch.no_grad(): # Use torch.no_grad() to disable gradient calculations
logits = model(**inputs).logits
# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()
print(model.config.id2label[predicted_label])