santit96's picture
Stop versioning the model checkpoints, now they are downloaded from huggingface. Add env vars
dd14920
"""
Streamlit app
"""
import sys
import streamlit as st
from PIL import Image
from constants import CLASSES, OUTPUT_IMG_FILEPATH
sys.path.append("./efficientdet")
from efficientdet.efficientdet import plot_results
from trash_detector import detect_trash
def initial_config():
"""
Initial configuration of streamlit page
"""
st.set_page_config(
page_title="Waste Classifier",
page_icon="♻️",
)
def render():
"""
Render the streamlit app
"""
st.title("Waste classifier")
st.markdown("""Classify your waste into different classes""")
# Image loader and button
uploaded_file = st.file_uploader(
"Upload image with trash", type=["jpg", "jpeg", "png", "gif", "bmp"]
)
classify_button = st.button("Classify trash")
if classify_button:
if not uploaded_file:
st.error("Upload an image")
else:
# Create two columns
col1, col2 = st.columns(2)
# Column 1: Uploaded image
with col1:
st.write("Uploaded image")
st.image(
uploaded_file, caption="Uploaded Image.", use_column_width=True
)
# Column 2: Classified image
with col2:
with st.spinner(text="Classifying the trash..."):
img = Image.open(uploaded_file).convert("RGB")
cls_prob, bboxes_final = detect_trash(img)
# plot and save demo image
plot_results(
img, cls_prob, bboxes_final, CLASSES, OUTPUT_IMG_FILEPATH
)
output_img = Image.open(OUTPUT_IMG_FILEPATH)
st.write("Classified image")
st.image(
output_img, caption="Classified Image.", use_column_width=True
)
if __name__ == "__main__":
initial_config()
render()