import PIL import streamlit as st from transformers import AutoImageProcessor from transformers import AutoModelForImageClassification import torch # Replace the relative path to your weight file model_path = 'weights/yolov8n.pt' model_dict={"vit":"uisikdag/weed_vit_balanced", "deit":"uisikdag/weed_deit_balanced", "swin":"uisikdag/weeds_swin_balanced", "beit":"uisikdag/weed_beit_balanced", "convnext":"uisikdag/weeds_convnext_balanced", "resnet":"uisikdag/weed_resnet_balanced" } # Setting page layout st.set_page_config( page_title="Weed Classification", # Setting page title page_icon="🤖", # Setting page icon layout="wide", # Setting layout to wide initial_sidebar_state="expanded" # Expanding sidebar by default ) # Creating sidebar with st.sidebar: st.header("Settings") # Adding header to sidebar model_idx=st.selectbox("Select Base Classifier",{'vit','deit','swin','beit','convnext','resnet'}) model=model_dict[model_idx] # Adding file uploader to sidebar for selecting images source_img = st.file_uploader( "Choose an image...", type=("jpg", "jpeg", "png", 'bmp', 'webp')) with open('sample.zip', 'rb') as f: st.download_button('Sample Images', f, file_name='images.zip') # Creating main page heading st.title("Weed Classification with \N{hugging face} Transformers") # Creating two columns on the main page col1, col2 = st.columns(2) # Adding image to the first column if image is uploaded with col1: if source_img: # Opening the uploaded image uploaded_image = PIL.Image.open(source_img) # Adding the uploaded image to the page with a caption st.image(source_img, caption="Uploaded Image", use_column_width=True ) else: uploaded_image=None st.write('Please upload an image') with col2: if st.sidebar.button('Classify'): if uploaded_image is not None: image_processor = AutoImageProcessor.from_pretrained(model) inputs = image_processor(uploaded_image, return_tensors="pt") model = AutoModelForImageClassification.from_pretrained(model) with torch.no_grad(): logits = model(**inputs).logits predicted_label = logits.argmax(-1).item() out=model.config.id2label[predicted_label] out='The predicted class for the image is: '+out st.text(out)