image-search / app.py
cahya's picture
add try/except to display image
c09b02e
import streamlit as st
import SessionState
from prompts import PROMPT_LIST
from wit_index import WitIndex
import random
import time
# st.set_page_config(page_title="Image Search")
# vector_length = 128
wit_index_path = f"./models/wit_faiss.idx"
model_name = f"./models/distilbert-base-wit"
wit_dataset_path = "./models/wit_dataset.pkl"
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_wit_index():
st.write("Loading the WIT index, dataset and the DistillBERT model..")
wit_index = WitIndex(wit_index_path, model_name, wit_dataset_path, gpu=False)
return wit_index
# st.cache is disabled temporarily because the inference could take forever using newer streamlit version.
# @st.cache(suppress_st_warning=True)
def process(text: str, top_k: int = 10):
# st.write("Cache miss: process")
distance, index, image_info = wit_index.search(text, top_k=top_k)
return distance, index, image_info
st.title("Image Search")
st.markdown(
"""
This application is a demo for sentence-based image search using
[WIT dataset](https://github.com/google-research-datasets/wit). We use DistillBert to encode the sentences
and Facebook's Faiss to search the vector embeddings.
"""
)
session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
ALL_PROMPTS = list(PROMPT_LIST.keys())+["Custom"]
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
# Update prompt
if session_state.prompt is None:
session_state.prompt = prompt
elif session_state.prompt is not None and (prompt != session_state.prompt):
session_state.prompt = prompt
session_state.prompt_box = None
session_state.text = None
else:
session_state.prompt = prompt
# Update prompt box
if session_state.prompt == "Custom":
session_state.prompt_box = "Enter your text here"
else:
if session_state.prompt is not None and session_state.prompt_box is None:
session_state.prompt_box = random.choice(PROMPT_LIST[session_state.prompt])
session_state.text = st.text_area("Enter text", session_state.prompt_box)
top_k = st.sidebar.number_input(
"Top k",
value=6,
min_value=1,
max_value=10
)
wit_index = get_wit_index()
if st.button("Run"):
with st.spinner(text="Getting results..."):
st.subheader("Result")
time_start = time.time()
distances, index, image_info = process(text=session_state.text, top_k=int(top_k))
time_end = time.time()
time_diff = time_end-time_start
print(f"Search in {time_diff} seconds")
st.markdown(f"*Search in {time_diff:.5f} seconds*")
for i, distance in enumerate(distances):
try:
st.image(image_info[i][0].replace("http:", "https:"), width=400)
except FileNotFoundError:
st.write(f"{image_info[i][0]} can't be displayed")
st.write(f"{image_info[i][1]}. (D: {distance:.2f})")
# Reset state
session_state.prompt = None
session_state.prompt_box = None
session_state.text = None