Spaces:
Runtime error
Runtime error
File size: 2,826 Bytes
c10a93b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import streamlit as st
import SessionState
from prompts import PROMPT_LIST
from wit_index import WitIndex
import random
import time
# st.set_page_config(page_title="Image Search")
# vector_length = 128
wit_index_path = f"./models/wit_faiss.idx"
model_name = f"./models/distilbert-base-wit"
wit_dataset_path = "./models/wit_dataset.pkl"
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_wit_index():
st.write("Loading the WIT index, dataset and the DistillBERT model..")
wit_index = WitIndex(wit_index_path, model_name, wit_dataset_path, gpu=False)
return wit_index
@st.cache(suppress_st_warning=True)
def process(text: str, top_k: int = 10):
# st.write("Cache miss: process")
distance, index, image_info = wit_index.search(text, top_k=top_k)
return distance, index, image_info
st.title("Image Search")
st.markdown(
"""
This application is a demo for sentence-based image search using
[WIT dataset](https://github.com/google-research-datasets/wit). We use DistillBert to encode the sentences
and Facebook's Faiss to search the vector embeddings.
"""
)
session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
ALL_PROMPTS = list(PROMPT_LIST.keys())+["Custom"]
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
# Update prompt
if session_state.prompt is None:
session_state.prompt = prompt
elif session_state.prompt is not None and (prompt != session_state.prompt):
session_state.prompt = prompt
session_state.prompt_box = None
session_state.text = None
else:
session_state.prompt = prompt
# Update prompt box
if session_state.prompt == "Custom":
session_state.prompt_box = "Enter your text here"
else:
if session_state.prompt is not None and session_state.prompt_box is None:
session_state.prompt_box = random.choice(PROMPT_LIST[session_state.prompt])
session_state.text = st.text_area("Enter text", session_state.prompt_box)
top_k = st.sidebar.number_input(
"Top k",
value=6,
min_value=1,
max_value=10
)
wit_index = get_wit_index()
if st.button("Run"):
with st.spinner(text="Getting results..."):
st.subheader("Result")
time_start = time.time()
distances, index, image_info = process(text=session_state.text, top_k=int(top_k))
time_end = time.time()
time_diff = time_end-time_start
print(f"Search in {time_diff} seconds")
st.markdown(f"*Search in {time_diff:.5f} seconds*")
for i, distance in enumerate(distances):
st.image(image_info[i][0].replace("http:", "https:"), width=400)
st.write(f"{image_info[i][1]}. (D: {distance:.2f})")
# Reset state
session_state.prompt = None
session_state.prompt_box = None
session_state.text = None
|