Spaces:
Runtime error
Runtime error
import os | |
import streamlit as st | |
from PIL import Image | |
import requests | |
import io | |
import time | |
from model import ViTForImageClassification | |
st.set_page_config( | |
page_title="Grocery Classifier", | |
page_icon="interface/shopping-cart.png", | |
initial_sidebar_state="expanded" | |
) | |
def load_model(): | |
with st.spinner("Loading model"): | |
model = ViTForImageClassification('google/vit-base-patch16-224') | |
model.load('model/') | |
return model | |
model = load_model() | |
feedback_path = "feedback" | |
def predict(image): | |
print("Predicting...") | |
# Load using PIL | |
image = Image.open(image) | |
prediction, confidence = model.predict(image) | |
return {'prediction': prediction[0], 'confidence': round(confidence[0], 3)}, image | |
def submit_feedback(correct_label, image): | |
folder_path = feedback_path + "/" + correct_label + "/" | |
os.makedirs(folder_path, exist_ok=True) | |
image.save(folder_path + correct_label + "_" + str(int(time.time())) + ".png") | |
def retrain_from_feedback(): | |
model.retrain_from_path(feedback_path, remove_path=True) | |
def main(): | |
labels = set(list(model.label_encoder.classes_)) | |
st.title("π Grocery Classifier π₯") | |
if labels is None: | |
st.warning("Received error from server, labels could not be retrieved") | |
else: | |
st.write("Labels:", labels) | |
image_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
if image_file is not None: | |
st.image(image_file) | |
st.subheader("Classification") | |
if st.button("Predict"): | |
st.session_state['response_json'], st.session_state['image'] = predict(image_file) | |
if 'response_json' in st.session_state and st.session_state['response_json'] is not None: | |
# Show the result | |
st.markdown(f"**Prediction:** {st.session_state['response_json']['prediction']}") | |
st.markdown(f"**Confidence:** {st.session_state['response_json']['confidence']}") | |
# User feedback | |
st.subheader("User Feedback") | |
st.markdown("If this prediction was incorrect, please select below the correct label") | |
correct_labels = labels.copy() | |
correct_labels.remove(st.session_state['response_json']["prediction"]) | |
correct_label = st.selectbox("Correct label", correct_labels) | |
if st.button("Submit"): | |
# Save feedback | |
try: | |
submit_feedback(correct_label, st.session_state['image']) | |
st.success("Feedback submitted") | |
except Exception as e: | |
st.error("Feedback could not be submitted. Error: {}".format(e)) | |
# Retrain from feedback | |
if st.button("Retrain from feedback"): | |
try: | |
with st.spinner('Retraining...'): | |
retrain_from_feedback() | |
st.success("Model retrained") | |
st.balloons() | |
except Exception as e: | |
st.warning("Model could not be retrained. Error: {}".format(e)) | |
main() |