uisikdag's picture
Upload app.py
9a87132 verified
raw
history blame
2.6 kB
import PIL
import streamlit as st
from transformers import AutoImageProcessor
from transformers import AutoModelForImageClassification
import torch
# Replace the relative path to your weight file
model_path = 'weights/yolov8n.pt'
model_dict={"vit":"uisikdag/weed_vit_balanced",
"deit":"uisikdag/weed_deit_balanced",
"swin":"uisikdag/weeds_swin_balanced",
"beit":"uisikdag/weed_beit_balanced",
"convnext":"uisikdag/weeds_convnext_balanced",
"resnet":"uisikdag/weed_resnet_balanced"
}
# Setting page layout
st.set_page_config(
page_title="Weed Classification", # Setting page title
page_icon="πŸ€–", # Setting page icon
layout="wide", # Setting layout to wide
initial_sidebar_state="expanded" # Expanding sidebar by default
)
# Creating sidebar
with st.sidebar:
st.header("Settings") # Adding header to sidebar
model_idx=st.selectbox("Select Base Classifier",{'vit','deit','swin','beit','convnext','resnet'})
model=model_dict[model_idx]
# Adding file uploader to sidebar for selecting images
source_img = st.file_uploader(
"Choose an image...", type=("jpg", "jpeg", "png", 'bmp', 'webp'))
with open('sample.zip', 'rb') as f:
st.download_button('Sample Images', f, file_name='images.zip')
# Creating main page heading
st.title("Weed Classification with \N{hugging face} Transformers")
# Creating two columns on the main page
col1, col2 = st.columns(2)
# Adding image to the first column if image is uploaded
with col1:
if source_img:
# Opening the uploaded image
uploaded_image = PIL.Image.open(source_img)
# Adding the uploaded image to the page with a caption
st.image(source_img,
caption="Uploaded Image",
use_column_width=True
)
else:
uploaded_image=None
st.write('Please upload an image')
with col2:
if st.sidebar.button('Classify'):
if uploaded_image is not None:
image_processor = AutoImageProcessor.from_pretrained(model)
inputs = image_processor(uploaded_image, return_tensors="pt")
model = AutoModelForImageClassification.from_pretrained(model)
with torch.no_grad():
logits = model(**inputs).logits
predicted_label = logits.argmax(-1).item()
out=model.config.id2label[predicted_label]
out='The predicted class for the image is: '+out
st.text(out)