File size: 994 Bytes
9119567
 
 
 
218a3ef
 
9119567
218a3ef
9119567
218a3ef
9119567
218a3ef
 
9119567
 
218a3ef
9119567
218a3ef
9119567
218a3ef
9119567
 
218a3ef
9119567
 
1a49255
9119567
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
# Load model directly
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image # Import the Image module
import torch # Import the torch module
import streamlit as st

st.title("Image Classification")

uploaded_file = st.file_uploader("Choose an image file", type=["jpg", "png"])

#image_path = "/content/cm5_2.jpg" # Store the path as a string


processor = AutoImageProcessor.from_pretrained("mateoluksenberg/dit-base-Classifier_CM05")
model = AutoModelForImageClassification.from_pretrained("mateoluksenberg/dit-base-Classifier_CM05")

image = Image.open(uploaded_file) # Load the image from the file path

inputs = processor(image, return_tensors="pt") # Pass the image object to the processor

with torch.no_grad(): # Use torch.no_grad() to disable gradient calculations
    logits = model(**inputs).logits

# model predicts one of the 1000 ImageNet classes
predicted_label = logits.argmax(-1).item()

print(model.config.id2label[predicted_label])