File size: 3,062 Bytes
c10a93b
 
 
 
 
 
 
 
 
 
 
 
 
 
c09b02e
c10a93b
 
 
 
 
 
c09b02e
822e91a
c10a93b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c09b02e
 
 
 
c10a93b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
import SessionState
from prompts import PROMPT_LIST
from wit_index import WitIndex
import random
import time

# st.set_page_config(page_title="Image Search")

# vector_length = 128
wit_index_path = f"./models/wit_faiss.idx"
model_name = f"./models/distilbert-base-wit"
wit_dataset_path = "./models/wit_dataset.pkl"


@st.cache(suppress_st_warning=True, allow_output_mutation=True)
def get_wit_index():
    st.write("Loading the WIT index, dataset and the DistillBERT model..")
    wit_index = WitIndex(wit_index_path, model_name, wit_dataset_path, gpu=False)
    return wit_index

# st.cache is disabled temporarily because the inference could take forever using newer streamlit version.
# @st.cache(suppress_st_warning=True)
def process(text: str, top_k: int = 10):
    # st.write("Cache miss: process")
    distance, index, image_info = wit_index.search(text, top_k=top_k)
    return distance, index, image_info


st.title("Image Search")

st.markdown(
    """
    This application is a demo for sentence-based image search using 
    [WIT dataset](https://github.com/google-research-datasets/wit). We use DistillBert to encode the sentences 
    and Facebook's Faiss to search the vector embeddings.
    """
)
session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
ALL_PROMPTS = list(PROMPT_LIST.keys())+["Custom"]
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
# Update prompt
if session_state.prompt is None:
    session_state.prompt = prompt
elif session_state.prompt is not None and (prompt != session_state.prompt):
    session_state.prompt = prompt
    session_state.prompt_box = None
    session_state.text = None
else:
    session_state.prompt = prompt

# Update prompt box
if session_state.prompt == "Custom":
    session_state.prompt_box = "Enter your text here"
else:
    if session_state.prompt is not None and session_state.prompt_box is None:
        session_state.prompt_box = random.choice(PROMPT_LIST[session_state.prompt])

session_state.text = st.text_area("Enter text", session_state.prompt_box)

top_k = st.sidebar.number_input(
    "Top k",
    value=6,
    min_value=1,
    max_value=10
)

wit_index = get_wit_index()
if st.button("Run"):
    with st.spinner(text="Getting results..."):
        st.subheader("Result")
        time_start = time.time()
        distances, index, image_info = process(text=session_state.text, top_k=int(top_k))
        time_end = time.time()
        time_diff = time_end-time_start
        print(f"Search in {time_diff} seconds")
        st.markdown(f"*Search in {time_diff:.5f} seconds*")
        for i, distance in enumerate(distances):
            try:
                st.image(image_info[i][0].replace("http:", "https:"), width=400)
            except FileNotFoundError:
                st.write(f"{image_info[i][0]} can't be displayed")
            st.write(f"{image_info[i][1]}. (D: {distance:.2f})")

        # Reset state
        session_state.prompt = None
        session_state.prompt_box = None
        session_state.text = None