Spaces:
Runtime error
Runtime error
import streamlit as st | |
import SessionState | |
from prompts import PROMPT_LIST | |
from wit_index import WitIndex | |
import random | |
import time | |
# st.set_page_config(page_title="Image Search") | |
# vector_length = 128 | |
wit_index_path = f"./models/wit_faiss.idx" | |
model_name = f"./models/distilbert-base-wit" | |
wit_dataset_path = "./models/wit_dataset.pkl" | |
def get_wit_index(): | |
st.write("Loading the WIT index, dataset and the DistillBERT model..") | |
wit_index = WitIndex(wit_index_path, model_name, wit_dataset_path, gpu=False) | |
return wit_index | |
# st.cache is disabled temporarily because the inference could take forever using newer streamlit version. | |
# @st.cache(suppress_st_warning=True) | |
def process(text: str, top_k: int = 10): | |
# st.write("Cache miss: process") | |
distance, index, image_info = wit_index.search(text, top_k=top_k) | |
return distance, index, image_info | |
st.title("Image Search") | |
st.markdown( | |
""" | |
This application is a demo for sentence-based image search using | |
[WIT dataset](https://github.com/google-research-datasets/wit). We use DistillBert to encode the sentences | |
and Facebook's Faiss to search the vector embeddings. | |
""" | |
) | |
session_state = SessionState.get(prompt=None, prompt_box=None, text=None) | |
ALL_PROMPTS = list(PROMPT_LIST.keys())+["Custom"] | |
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1) | |
# Update prompt | |
if session_state.prompt is None: | |
session_state.prompt = prompt | |
elif session_state.prompt is not None and (prompt != session_state.prompt): | |
session_state.prompt = prompt | |
session_state.prompt_box = None | |
session_state.text = None | |
else: | |
session_state.prompt = prompt | |
# Update prompt box | |
if session_state.prompt == "Custom": | |
session_state.prompt_box = "Enter your text here" | |
else: | |
if session_state.prompt is not None and session_state.prompt_box is None: | |
session_state.prompt_box = random.choice(PROMPT_LIST[session_state.prompt]) | |
session_state.text = st.text_area("Enter text", session_state.prompt_box) | |
top_k = st.sidebar.number_input( | |
"Top k", | |
value=6, | |
min_value=1, | |
max_value=10 | |
) | |
wit_index = get_wit_index() | |
if st.button("Run"): | |
with st.spinner(text="Getting results..."): | |
st.subheader("Result") | |
time_start = time.time() | |
distances, index, image_info = process(text=session_state.text, top_k=int(top_k)) | |
time_end = time.time() | |
time_diff = time_end-time_start | |
print(f"Search in {time_diff} seconds") | |
st.markdown(f"*Search in {time_diff:.5f} seconds*") | |
for i, distance in enumerate(distances): | |
try: | |
st.image(image_info[i][0].replace("http:", "https:"), width=400) | |
except FileNotFoundError: | |
st.write(f"{image_info[i][0]} can't be displayed") | |
st.write(f"{image_info[i][1]}. (D: {distance:.2f})") | |
# Reset state | |
session_state.prompt = None | |
session_state.prompt_box = None | |
session_state.text = None | |