cahya commited on
Commit
c10a93b
1 Parent(s): 1855e20

add the application files

Browse files
Files changed (7) hide show
  1. Dockerfile +7 -0
  2. README.md +3 -3
  3. SessionState.py +107 -0
  4. app.py +83 -0
  5. prompts.py +18 -0
  6. requirements.txt +11 -0
  7. wit_index.py +40 -0
Dockerfile ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ FROM python:3.8-slim-buster
2
+ COPY . /app
3
+ WORKDIR /app
4
+ RUN pip install -r requirements.txt
5
+ EXPOSE 8501
6
+ ENTRYPOINT ["streamlit","run"]
7
+ CMD ["app.py"]
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Image Search
3
- emoji: 🔥
4
- colorFrom: pink
5
- colorTo: indigo
6
  sdk: streamlit
7
  app_file: app.py
8
  pinned: false
1
  ---
2
  title: Image Search
3
+ emoji: 😻
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: streamlit
7
  app_file: app.py
8
  pinned: false
SessionState.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hack to add per-session state to Streamlit.
2
+ Usage
3
+ -----
4
+ >>> import SessionState
5
+ >>>
6
+ >>> session_state = SessionState.get(user_name='', favorite_color='black')
7
+ >>> session_state.user_name
8
+ ''
9
+ >>> session_state.user_name = 'Mary'
10
+ >>> session_state.favorite_color
11
+ 'black'
12
+ Since you set user_name above, next time your script runs this will be the
13
+ result:
14
+ >>> session_state = get(user_name='', favorite_color='black')
15
+ >>> session_state.user_name
16
+ 'Mary'
17
+ """
18
+ try:
19
+ import streamlit.ReportThread as ReportThread
20
+ from streamlit.server.Server import Server
21
+ except Exception:
22
+ # Streamlit >= 0.65.0
23
+ import streamlit.report_thread as ReportThread
24
+ from streamlit.server.server import Server
25
+
26
+
27
+ class SessionState(object):
28
+ def __init__(self, **kwargs):
29
+ """A new SessionState object.
30
+ Parameters
31
+ ----------
32
+ **kwargs : any
33
+ Default values for the session state.
34
+ Example
35
+ -------
36
+ >>> session_state = SessionState(user_name='', favorite_color='black')
37
+ >>> session_state.user_name = 'Mary'
38
+ ''
39
+ >>> session_state.favorite_color
40
+ 'black'
41
+ """
42
+ for key, val in kwargs.items():
43
+ setattr(self, key, val)
44
+
45
+
46
+ def get(**kwargs):
47
+ """Gets a SessionState object for the current session.
48
+ Creates a new object if necessary.
49
+ Parameters
50
+ ----------
51
+ **kwargs : any
52
+ Default values you want to add to the session state, if we're creating a
53
+ new one.
54
+ Example
55
+ -------
56
+ >>> session_state = get(user_name='', favorite_color='black')
57
+ >>> session_state.user_name
58
+ ''
59
+ >>> session_state.user_name = 'Mary'
60
+ >>> session_state.favorite_color
61
+ 'black'
62
+ Since you set user_name above, next time your script runs this will be the
63
+ result:
64
+ >>> session_state = get(user_name='', favorite_color='black')
65
+ >>> session_state.user_name
66
+ 'Mary'
67
+ """
68
+ # Hack to get the session object from Streamlit.
69
+
70
+ ctx = ReportThread.get_report_ctx()
71
+
72
+ this_session = None
73
+
74
+ current_server = Server.get_current()
75
+ if hasattr(current_server, '_session_infos'):
76
+ # Streamlit < 0.56
77
+ session_infos = Server.get_current()._session_infos.values()
78
+ else:
79
+ session_infos = Server.get_current()._session_info_by_id.values()
80
+
81
+ for session_info in session_infos:
82
+ s = session_info.session
83
+ if (
84
+ # Streamlit < 0.54.0
85
+ (hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
86
+ or
87
+ # Streamlit >= 0.54.0
88
+ (not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
89
+ or
90
+ # Streamlit >= 0.65.2
91
+ (not hasattr(s, '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr)
92
+ ):
93
+ this_session = s
94
+
95
+ if this_session is None:
96
+ raise RuntimeError(
97
+ "Oh noes. Couldn't get your Streamlit Session object. "
98
+ 'Are you doing something fancy with threads?')
99
+
100
+ # Got the session object! Now let's attach some state into it.
101
+
102
+ if not hasattr(this_session, '_custom_session_state'):
103
+ this_session._custom_session_state = SessionState(**kwargs)
104
+
105
+ return this_session._custom_session_state
106
+
107
+ __all__ = ['get']
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import SessionState
3
+ from prompts import PROMPT_LIST
4
+ from wit_index import WitIndex
5
+ import random
6
+ import time
7
+
8
+ # st.set_page_config(page_title="Image Search")
9
+
10
+ # vector_length = 128
11
+ wit_index_path = f"./models/wit_faiss.idx"
12
+ model_name = f"./models/distilbert-base-wit"
13
+ wit_dataset_path = "./models/wit_dataset.pkl"
14
+
15
+ @st.cache(suppress_st_warning=True, allow_output_mutation=True)
16
+ def get_wit_index():
17
+ st.write("Loading the WIT index, dataset and the DistillBERT model..")
18
+ wit_index = WitIndex(wit_index_path, model_name, wit_dataset_path, gpu=False)
19
+ return wit_index
20
+
21
+ @st.cache(suppress_st_warning=True)
22
+ def process(text: str, top_k: int = 10):
23
+ # st.write("Cache miss: process")
24
+ distance, index, image_info = wit_index.search(text, top_k=top_k)
25
+ return distance, index, image_info
26
+
27
+
28
+ st.title("Image Search")
29
+
30
+ st.markdown(
31
+ """
32
+ This application is a demo for sentence-based image search using
33
+ [WIT dataset](https://github.com/google-research-datasets/wit). We use DistillBert to encode the sentences
34
+ and Facebook's Faiss to search the vector embeddings.
35
+ """
36
+ )
37
+ session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
38
+ ALL_PROMPTS = list(PROMPT_LIST.keys())+["Custom"]
39
+ prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
40
+ # Update prompt
41
+ if session_state.prompt is None:
42
+ session_state.prompt = prompt
43
+ elif session_state.prompt is not None and (prompt != session_state.prompt):
44
+ session_state.prompt = prompt
45
+ session_state.prompt_box = None
46
+ session_state.text = None
47
+ else:
48
+ session_state.prompt = prompt
49
+
50
+ # Update prompt box
51
+ if session_state.prompt == "Custom":
52
+ session_state.prompt_box = "Enter your text here"
53
+ else:
54
+ if session_state.prompt is not None and session_state.prompt_box is None:
55
+ session_state.prompt_box = random.choice(PROMPT_LIST[session_state.prompt])
56
+
57
+ session_state.text = st.text_area("Enter text", session_state.prompt_box)
58
+
59
+ top_k = st.sidebar.number_input(
60
+ "Top k",
61
+ value=6,
62
+ min_value=1,
63
+ max_value=10
64
+ )
65
+
66
+ wit_index = get_wit_index()
67
+ if st.button("Run"):
68
+ with st.spinner(text="Getting results..."):
69
+ st.subheader("Result")
70
+ time_start = time.time()
71
+ distances, index, image_info = process(text=session_state.text, top_k=int(top_k))
72
+ time_end = time.time()
73
+ time_diff = time_end-time_start
74
+ print(f"Search in {time_diff} seconds")
75
+ st.markdown(f"*Search in {time_diff:.5f} seconds*")
76
+ for i, distance in enumerate(distances):
77
+ st.image(image_info[i][0].replace("http:", "https:"), width=400)
78
+ st.write(f"{image_info[i][1]}. (D: {distance:.2f})")
79
+
80
+ # Reset state
81
+ session_state.prompt = None
82
+ session_state.prompt_box = None
83
+ session_state.text = None
prompts.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ PROMPT_LIST = {
2
+ "City": [
3
+ "Vienna is the national capital, largest city, and one of nine states of Austria. Vienna is Austria's most populous city, with about 2 million inhabitants, and its cultural, economic, and political centre. It is the 6th-largest city by population within city limits in the European Union",
4
+ "Sydney is the capital city of the state of New South Wales, and the most populous city in Australia and Oceania.",
5
+ "Ubud is a town on the Indonesian island of Bali in Ubud District, located amongst rice paddies and steep ravines in the central foothills of the Gianyar regency. Promoted as an arts and culture centre, it has developed a large tourism industry.",
6
+ "Jakarta is the capital of Indonesia"
7
+ ],
8
+ "People": [
9
+ "Albert Einstein was a German-born theoretical physicist, widely acknowledged to be one of the greatest physicists of all time. Einstein is known for developing the theory of relativity, but he also made important contributions to the development of the theory of quantum mechanics.",
10
+ "Geoffrey Everest Hinton is a British-Canadian cognitive psychologist and computer scientist, most noted for his work on artificial neural networks.",
11
+ "Pramoedya Ananta Toer was an Indonesian author of novels, short stories, essays, polemics and histories of his homeland and its people."
12
+ ],
13
+ "Building": [
14
+ "Borobudur is a 7th-century Mahayana Buddhist temple in Indonesia. It is the world's largest Buddhist temple. The temple consists of nine stacked platforms, six square and three circular, topped by a central dome. It is decorated with 2,672 relief panels and 504 Buddha statues.",
15
+ "The Statue of Liberty is a colossal neoclassical sculpture on Liberty Island in New York Harbor within New York City, in the United States.",
16
+ "Machu Picchu is a 15th-century Inca citadel, located in the Eastern Cordillera of southern Peru, on a 2,430-meter mountain ridge."
17
+ ]
18
+ }
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy
2
+ torch
3
+ tensorflow
4
+ transformers
5
+ tensorflow
6
+ mtranslate
7
+ sentence-transformers
8
+ datasets
9
+ faiss-cpu
10
+ # streamlit version 0.67.1 is needed due to issue with caching
11
+ streamlit==0.67.1
wit_index.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+
3
+ # Used to create the dense document vectors.
4
+ import torch
5
+ from sentence_transformers import SentenceTransformer
6
+ import datasets
7
+
8
+ # Used to create and store the Faiss index.
9
+ import faiss
10
+ import numpy as np
11
+
12
+ class WitIndex:
13
+ """
14
+ WitIndex is a class to search the wiki snippets from the given text. It can also return link to the
15
+ wiki page or the image.
16
+ """
17
+ wit_dataset = None
18
+
19
+ def __init__(self, wit_index_path: str, model_name: str, wit_dataset_path: str, gpu=True):
20
+ self.index = faiss.read_index(wit_index_path)
21
+ self.model = SentenceTransformer(model_name)
22
+ if WitIndex.wit_dataset is None:
23
+ WitIndex.wit_dataset = pickle.load(open(wit_dataset_path, "rb"))
24
+ print(f"Gpu: {gpu}")
25
+ if gpu and torch.cuda.is_available():
26
+ print("Cuda is available")
27
+ self.model = self.model.to(torch.device("cuda"))
28
+ res = faiss.StandardGpuResources()
29
+ self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
30
+
31
+ def search(self, text, top_k=6):
32
+ print(f"> Search: {text}")
33
+ embedding = self.model.encode(text, convert_to_numpy=True, show_progress_bar=False)
34
+ # Retrieve the k nearest neighbours
35
+ distance, index = self.index.search(np.array([embedding]), k=top_k)
36
+ distance, index = distance.flatten().tolist(), index.flatten().tolist()
37
+ index_url = [WitIndex.wit_dataset['desc2image_map'][i] for i in index]
38
+ image_info = [WitIndex.wit_dataset['image_info'][i] for i in index_url]
39
+ print(f"> URL: {image_info[0]}")
40
+ return distance, index, image_info