Spaces:
Runtime error
Runtime error
import os | |
import streamlit as st | |
from PIL import Image | |
from inference import get_predictions, get_nearest_k | |
st.title('Fashion accessories prediction and search Demo') | |
tot_index = 44065 | |
# multiple folders to handle 10,000 files folder limit in git | |
sample_path1 = './data/small_images_0_9999' | |
sample_path2 = './data/small_images_10000_19999' | |
sample_path3 = './data/small_images_20000_29999' | |
sample_path4 = './data/small_images_30000_39999' | |
sample_path5 = './data/small_images_40000_49999' | |
if 'image_index' not in st.session_state: | |
st.session_state['image_index'] = 0 | |
if 'which_button' not in st.session_state: | |
st.session_state['which_button'] = 'sample_button' | |
sample_col, upload_col = st.tabs(['Select from sample images', 'Upload file']) | |
with upload_col: | |
use_uploaded_image = True | |
uploaded_file = st.file_uploader("Select a picture from your computer(png/jpg) :", type=['png', 'jpg', 'jpeg']) | |
if uploaded_file is not None: | |
img = Image.open(uploaded_file) | |
st.image(img, caption='Uploaded Image') | |
use_uploaded_image = st.button("Use uploaded image") | |
if use_uploaded_image is True: | |
st.session_state['which_button'] = 'upload_button' | |
with sample_col: | |
use_sample_image = True | |
st.write("Select one from these available samples: ") | |
current_index = st.session_state['image_index'] | |
prev_button, next_button = st.columns(2) | |
with prev_button: | |
prev = st.button('prev_image') | |
with next_button: | |
next = st.button('next_image') | |
if prev: | |
current_index = (current_index - 1) % tot_index | |
if next: | |
current_index = (current_index + 1) % tot_index | |
st.session_state['image_index'] = current_index | |
if current_index < 9999: | |
sample_path = sample_path1 | |
elif current_index < 19998: | |
sample_path = sample_path2 | |
elif current_index < 29997: | |
sample_path = sample_path3 | |
elif current_index < 39996: | |
sample_path = sample_path4 | |
else: | |
sample_path = sample_path5 | |
sample_image = Image.open(os.path.join(sample_path, str(current_index)+'.jpg')) | |
st.image(sample_image, caption='Chosen image') | |
use_sample_image = st.button("Use this Sample") | |
if use_sample_image is True: | |
st.session_state['which_button'] = 'sample_button' | |
classification_button, search_button = st.columns(2) | |
with classification_button: | |
predict_clicked = st.button("Get categories predictions") | |
with search_button: | |
search_clicked = st.button("Get similar looking products") | |
if predict_clicked: | |
which_button = st.session_state['which_button'] | |
if which_button == 'sample_button': | |
predictions = get_predictions(sample_image) | |
elif which_button == 'upload_button': | |
predictions = get_predictions(img) | |
st.markdown('**The model predictions along with their probabilities are :**') | |
st.table(predictions) | |
elif search_clicked: | |
which_button = st.session_state['which_button'] | |
if which_button == 'sample_button': | |
top_k_preds = get_nearest_k(sample_image) | |
elif which_button == 'upload_button': | |
top_k_preds = get_nearest_k(img) | |
all_distances = top_k_preds[0][0] | |
all_valid_distances = [dist for dist in all_distances if dist < 300] | |
pred_to_show = len(all_valid_distances) | |
st.markdown('**The top 5 similar product predictions are :**') | |
if pred_to_show == 0: | |
st.markdown('No similar visually looking similar products found in the database.') | |
else: | |
pred_cols = st.columns(pred_to_show) | |
for i in range(len(pred_cols)): | |
with pred_cols[i]: | |
nearest_index = top_k_preds[1][0][i] | |
if nearest_index < 9999: | |
the_path = sample_path1 | |
elif nearest_index < 19998: | |
the_path = sample_path2 | |
elif nearest_index < 29997: | |
the_path = sample_path3 | |
elif nearest_index < 39996: | |
the_path = sample_path4 | |
else: | |
the_path = sample_path5 | |
temp_img = Image.open(os.path.join(the_path, str(nearest_index)+'.jpg')) | |
st.image(temp_img, caption=str(round(top_k_preds[0][0][i], 2))+' distance') | |