Spaces:
Runtime error
Runtime error
import PIL | |
import streamlit as st | |
from transformers import AutoImageProcessor | |
from transformers import AutoModelForImageClassification | |
import torch | |
# Replace the relative path to your weight file | |
model_path = 'weights/yolov8n.pt' | |
model_dict={"vit":"uisikdag/weed_vit_balanced", | |
"deit":"uisikdag/weed_deit_balanced", | |
"swin":"uisikdag/weeds_swin_balanced", | |
"beit":"uisikdag/weed_beit_balanced", | |
"convnext":"uisikdag/weeds_convnext_balanced", | |
"resnet":"uisikdag/weed_resnet_balanced" | |
} | |
# Setting page layout | |
st.set_page_config( | |
page_title="Weed Classification", # Setting page title | |
page_icon="π€", # Setting page icon | |
layout="wide", # Setting layout to wide | |
initial_sidebar_state="expanded" # Expanding sidebar by default | |
) | |
# Creating sidebar | |
with st.sidebar: | |
st.header("Settings") # Adding header to sidebar | |
model_idx=st.selectbox("Select Base Classifier",{'vit','deit','swin','beit','convnext','resnet'}) | |
model=model_dict[model_idx] | |
# Adding file uploader to sidebar for selecting images | |
source_img = st.file_uploader( | |
"Choose an image...", type=("jpg", "jpeg", "png", 'bmp', 'webp')) | |
with open('sample.zip', 'rb') as f: | |
st.download_button('Sample Images', f, file_name='images.zip') | |
# Creating main page heading | |
st.title("Weed Classification with \N{hugging face} Transformers") | |
# Creating two columns on the main page | |
col1, col2 = st.columns(2) | |
# Adding image to the first column if image is uploaded | |
with col1: | |
if source_img: | |
# Opening the uploaded image | |
uploaded_image = PIL.Image.open(source_img) | |
# Adding the uploaded image to the page with a caption | |
st.image(source_img, | |
caption="Uploaded Image", | |
use_column_width=True | |
) | |
else: | |
uploaded_image=None | |
st.write('Please upload an image') | |
with col2: | |
if st.sidebar.button('Classify'): | |
if uploaded_image is not None: | |
image_processor = AutoImageProcessor.from_pretrained(model) | |
inputs = image_processor(uploaded_image, return_tensors="pt") | |
model = AutoModelForImageClassification.from_pretrained(model) | |
with torch.no_grad(): | |
logits = model(**inputs).logits | |
predicted_label = logits.argmax(-1).item() | |
out=model.config.id2label[predicted_label] | |
out='The predicted class for the image is: '+out | |
st.text(out) | |