ugmSorcero
Updates retraining streamlit procedure
5f05cbf
raw history blame
No virus
3.19 kB
import os
import streamlit as st
from PIL import Image
import requests
import io
import time
from model import ViTForImageClassification
st.set_page_config(
page_title="Grocery Classifier",
page_icon="interface/shopping-cart.png",
initial_sidebar_state="expanded"
)
@st.cache()
def load_model():
with st.spinner("Loading model"):
model = ViTForImageClassification('google/vit-base-patch16-224')
model.load('model/')
return model
model = load_model()
feedback_path = "feedback"
def predict(image):
print("Predicting...")
# Load using PIL
image = Image.open(image)
prediction, confidence = model.predict(image)
return {'prediction': prediction[0], 'confidence': round(confidence[0], 3)}, image
def submit_feedback(correct_label, image):
folder_path = feedback_path + "/" + correct_label + "/"
os.makedirs(folder_path, exist_ok=True)
image.save(folder_path + correct_label + "_" + str(int(time.time())) + ".png")
def retrain_from_feedback():
model.retrain_from_path(feedback_path, remove_path=True)
def main():
labels = set(list(model.label_encoder.classes_))
st.title("πŸ‡ Grocery Classifier πŸ₯‘")
if labels is None:
st.warning("Received error from server, labels could not be retrieved")
else:
st.write("Labels:", labels)
image_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if image_file is not None:
st.image(image_file)
st.subheader("Classification")
if st.button("Predict"):
st.session_state['response_json'], st.session_state['image'] = predict(image_file)
if 'response_json' in st.session_state and st.session_state['response_json'] is not None:
# Show the result
st.markdown(f"**Prediction:** {st.session_state['response_json']['prediction']}")
st.markdown(f"**Confidence:** {st.session_state['response_json']['confidence']}")
# User feedback
st.subheader("User Feedback")
st.markdown("If this prediction was incorrect, please select below the correct label")
correct_labels = labels.copy()
correct_labels.remove(st.session_state['response_json']["prediction"])
correct_label = st.selectbox("Correct label", correct_labels)
if st.button("Submit"):
# Save feedback
try:
submit_feedback(correct_label, st.session_state['image'])
st.success("Feedback submitted")
except Exception as e:
st.error("Feedback could not be submitted. Error: {}".format(e))
# Retrain from feedback
if st.button("Retrain from feedback"):
try:
with st.spinner('Retraining...'):
retrain_from_feedback()
st.success("Model retrained")
st.balloons()
except Exception as e:
st.warning("Model could not be retrained. Error: {}".format(e))
main()