niks-salodkar's picture
update numbers display
05793f6
raw
history blame contribute delete
No virus
4.25 kB
import os
import streamlit as st
from PIL import Image
from inference import get_predictions, get_nearest_k
st.title('Fashion accessories prediction and search Demo')
tot_index = 44065
# multiple folders to handle 10,000 files folder limit in git
sample_path1 = './data/small_images_0_9999'
sample_path2 = './data/small_images_10000_19999'
sample_path3 = './data/small_images_20000_29999'
sample_path4 = './data/small_images_30000_39999'
sample_path5 = './data/small_images_40000_49999'
if 'image_index' not in st.session_state:
st.session_state['image_index'] = 0
if 'which_button' not in st.session_state:
st.session_state['which_button'] = 'sample_button'
sample_col, upload_col = st.tabs(['Select from sample images', 'Upload file'])
with upload_col:
use_uploaded_image = True
uploaded_file = st.file_uploader("Select a picture from your computer(png/jpg) :", type=['png', 'jpg', 'jpeg'])
if uploaded_file is not None:
img = Image.open(uploaded_file)
st.image(img, caption='Uploaded Image')
use_uploaded_image = st.button("Use uploaded image")
if use_uploaded_image is True:
st.session_state['which_button'] = 'upload_button'
with sample_col:
use_sample_image = True
st.write("Select one from these available samples: ")
current_index = st.session_state['image_index']
prev_button, next_button = st.columns(2)
with prev_button:
prev = st.button('prev_image')
with next_button:
next = st.button('next_image')
if prev:
current_index = (current_index - 1) % tot_index
if next:
current_index = (current_index + 1) % tot_index
st.session_state['image_index'] = current_index
if current_index < 9999:
sample_path = sample_path1
elif current_index < 19998:
sample_path = sample_path2
elif current_index < 29997:
sample_path = sample_path3
elif current_index < 39996:
sample_path = sample_path4
else:
sample_path = sample_path5
sample_image = Image.open(os.path.join(sample_path, str(current_index)+'.jpg'))
st.image(sample_image, caption='Chosen image')
use_sample_image = st.button("Use this Sample")
if use_sample_image is True:
st.session_state['which_button'] = 'sample_button'
classification_button, search_button = st.columns(2)
with classification_button:
predict_clicked = st.button("Get categories predictions")
with search_button:
search_clicked = st.button("Get similar looking products")
if predict_clicked:
which_button = st.session_state['which_button']
if which_button == 'sample_button':
predictions = get_predictions(sample_image)
elif which_button == 'upload_button':
predictions = get_predictions(img)
st.markdown('**The model predictions along with their probabilities are :**')
st.table(predictions)
elif search_clicked:
which_button = st.session_state['which_button']
if which_button == 'sample_button':
top_k_preds = get_nearest_k(sample_image)
elif which_button == 'upload_button':
top_k_preds = get_nearest_k(img)
all_distances = top_k_preds[0][0]
all_valid_distances = [dist for dist in all_distances if dist < 300]
pred_to_show = len(all_valid_distances)
st.markdown('**The top 5 similar product predictions are :**')
if pred_to_show == 0:
st.markdown('No similar visually looking similar products found in the database.')
else:
pred_cols = st.columns(pred_to_show)
for i in range(len(pred_cols)):
with pred_cols[i]:
nearest_index = top_k_preds[1][0][i]
if nearest_index < 9999:
the_path = sample_path1
elif nearest_index < 19998:
the_path = sample_path2
elif nearest_index < 29997:
the_path = sample_path3
elif nearest_index < 39996:
the_path = sample_path4
else:
the_path = sample_path5
temp_img = Image.open(os.path.join(the_path, str(nearest_index)+'.jpg'))
st.image(temp_img, caption=str(round(top_k_preds[0][0][i], 2))+' distance')