Spaces:
Runtime error
Runtime error
add the application files
Browse files- Dockerfile +7 -0
- README.md +3 -3
- SessionState.py +107 -0
- app.py +83 -0
- prompts.py +18 -0
- requirements.txt +11 -0
- wit_index.py +40 -0
Dockerfile
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.8-slim-buster
|
2 |
+
COPY . /app
|
3 |
+
WORKDIR /app
|
4 |
+
RUN pip install -r requirements.txt
|
5 |
+
EXPOSE 8501
|
6 |
+
ENTRYPOINT ["streamlit","run"]
|
7 |
+
CMD ["app.py"]
|
README.md
CHANGED
@@ -1,8 +1,8 @@
|
|
1 |
---
|
2 |
title: Image Search
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: streamlit
|
7 |
app_file: app.py
|
8 |
pinned: false
|
1 |
---
|
2 |
title: Image Search
|
3 |
+
emoji: 😻
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: gray
|
6 |
sdk: streamlit
|
7 |
app_file: app.py
|
8 |
pinned: false
|
SessionState.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Hack to add per-session state to Streamlit.
|
2 |
+
Usage
|
3 |
+
-----
|
4 |
+
>>> import SessionState
|
5 |
+
>>>
|
6 |
+
>>> session_state = SessionState.get(user_name='', favorite_color='black')
|
7 |
+
>>> session_state.user_name
|
8 |
+
''
|
9 |
+
>>> session_state.user_name = 'Mary'
|
10 |
+
>>> session_state.favorite_color
|
11 |
+
'black'
|
12 |
+
Since you set user_name above, next time your script runs this will be the
|
13 |
+
result:
|
14 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
15 |
+
>>> session_state.user_name
|
16 |
+
'Mary'
|
17 |
+
"""
|
18 |
+
try:
|
19 |
+
import streamlit.ReportThread as ReportThread
|
20 |
+
from streamlit.server.Server import Server
|
21 |
+
except Exception:
|
22 |
+
# Streamlit >= 0.65.0
|
23 |
+
import streamlit.report_thread as ReportThread
|
24 |
+
from streamlit.server.server import Server
|
25 |
+
|
26 |
+
|
27 |
+
class SessionState(object):
|
28 |
+
def __init__(self, **kwargs):
|
29 |
+
"""A new SessionState object.
|
30 |
+
Parameters
|
31 |
+
----------
|
32 |
+
**kwargs : any
|
33 |
+
Default values for the session state.
|
34 |
+
Example
|
35 |
+
-------
|
36 |
+
>>> session_state = SessionState(user_name='', favorite_color='black')
|
37 |
+
>>> session_state.user_name = 'Mary'
|
38 |
+
''
|
39 |
+
>>> session_state.favorite_color
|
40 |
+
'black'
|
41 |
+
"""
|
42 |
+
for key, val in kwargs.items():
|
43 |
+
setattr(self, key, val)
|
44 |
+
|
45 |
+
|
46 |
+
def get(**kwargs):
|
47 |
+
"""Gets a SessionState object for the current session.
|
48 |
+
Creates a new object if necessary.
|
49 |
+
Parameters
|
50 |
+
----------
|
51 |
+
**kwargs : any
|
52 |
+
Default values you want to add to the session state, if we're creating a
|
53 |
+
new one.
|
54 |
+
Example
|
55 |
+
-------
|
56 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
57 |
+
>>> session_state.user_name
|
58 |
+
''
|
59 |
+
>>> session_state.user_name = 'Mary'
|
60 |
+
>>> session_state.favorite_color
|
61 |
+
'black'
|
62 |
+
Since you set user_name above, next time your script runs this will be the
|
63 |
+
result:
|
64 |
+
>>> session_state = get(user_name='', favorite_color='black')
|
65 |
+
>>> session_state.user_name
|
66 |
+
'Mary'
|
67 |
+
"""
|
68 |
+
# Hack to get the session object from Streamlit.
|
69 |
+
|
70 |
+
ctx = ReportThread.get_report_ctx()
|
71 |
+
|
72 |
+
this_session = None
|
73 |
+
|
74 |
+
current_server = Server.get_current()
|
75 |
+
if hasattr(current_server, '_session_infos'):
|
76 |
+
# Streamlit < 0.56
|
77 |
+
session_infos = Server.get_current()._session_infos.values()
|
78 |
+
else:
|
79 |
+
session_infos = Server.get_current()._session_info_by_id.values()
|
80 |
+
|
81 |
+
for session_info in session_infos:
|
82 |
+
s = session_info.session
|
83 |
+
if (
|
84 |
+
# Streamlit < 0.54.0
|
85 |
+
(hasattr(s, '_main_dg') and s._main_dg == ctx.main_dg)
|
86 |
+
or
|
87 |
+
# Streamlit >= 0.54.0
|
88 |
+
(not hasattr(s, '_main_dg') and s.enqueue == ctx.enqueue)
|
89 |
+
or
|
90 |
+
# Streamlit >= 0.65.2
|
91 |
+
(not hasattr(s, '_main_dg') and s._uploaded_file_mgr == ctx.uploaded_file_mgr)
|
92 |
+
):
|
93 |
+
this_session = s
|
94 |
+
|
95 |
+
if this_session is None:
|
96 |
+
raise RuntimeError(
|
97 |
+
"Oh noes. Couldn't get your Streamlit Session object. "
|
98 |
+
'Are you doing something fancy with threads?')
|
99 |
+
|
100 |
+
# Got the session object! Now let's attach some state into it.
|
101 |
+
|
102 |
+
if not hasattr(this_session, '_custom_session_state'):
|
103 |
+
this_session._custom_session_state = SessionState(**kwargs)
|
104 |
+
|
105 |
+
return this_session._custom_session_state
|
106 |
+
|
107 |
+
__all__ = ['get']
|
app.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import SessionState
|
3 |
+
from prompts import PROMPT_LIST
|
4 |
+
from wit_index import WitIndex
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
|
8 |
+
# st.set_page_config(page_title="Image Search")
|
9 |
+
|
10 |
+
# vector_length = 128
|
11 |
+
wit_index_path = f"./models/wit_faiss.idx"
|
12 |
+
model_name = f"./models/distilbert-base-wit"
|
13 |
+
wit_dataset_path = "./models/wit_dataset.pkl"
|
14 |
+
|
15 |
+
@st.cache(suppress_st_warning=True, allow_output_mutation=True)
|
16 |
+
def get_wit_index():
|
17 |
+
st.write("Loading the WIT index, dataset and the DistillBERT model..")
|
18 |
+
wit_index = WitIndex(wit_index_path, model_name, wit_dataset_path, gpu=False)
|
19 |
+
return wit_index
|
20 |
+
|
21 |
+
@st.cache(suppress_st_warning=True)
|
22 |
+
def process(text: str, top_k: int = 10):
|
23 |
+
# st.write("Cache miss: process")
|
24 |
+
distance, index, image_info = wit_index.search(text, top_k=top_k)
|
25 |
+
return distance, index, image_info
|
26 |
+
|
27 |
+
|
28 |
+
st.title("Image Search")
|
29 |
+
|
30 |
+
st.markdown(
|
31 |
+
"""
|
32 |
+
This application is a demo for sentence-based image search using
|
33 |
+
[WIT dataset](https://github.com/google-research-datasets/wit). We use DistillBert to encode the sentences
|
34 |
+
and Facebook's Faiss to search the vector embeddings.
|
35 |
+
"""
|
36 |
+
)
|
37 |
+
session_state = SessionState.get(prompt=None, prompt_box=None, text=None)
|
38 |
+
ALL_PROMPTS = list(PROMPT_LIST.keys())+["Custom"]
|
39 |
+
prompt = st.selectbox('Prompt', ALL_PROMPTS, index=len(ALL_PROMPTS)-1)
|
40 |
+
# Update prompt
|
41 |
+
if session_state.prompt is None:
|
42 |
+
session_state.prompt = prompt
|
43 |
+
elif session_state.prompt is not None and (prompt != session_state.prompt):
|
44 |
+
session_state.prompt = prompt
|
45 |
+
session_state.prompt_box = None
|
46 |
+
session_state.text = None
|
47 |
+
else:
|
48 |
+
session_state.prompt = prompt
|
49 |
+
|
50 |
+
# Update prompt box
|
51 |
+
if session_state.prompt == "Custom":
|
52 |
+
session_state.prompt_box = "Enter your text here"
|
53 |
+
else:
|
54 |
+
if session_state.prompt is not None and session_state.prompt_box is None:
|
55 |
+
session_state.prompt_box = random.choice(PROMPT_LIST[session_state.prompt])
|
56 |
+
|
57 |
+
session_state.text = st.text_area("Enter text", session_state.prompt_box)
|
58 |
+
|
59 |
+
top_k = st.sidebar.number_input(
|
60 |
+
"Top k",
|
61 |
+
value=6,
|
62 |
+
min_value=1,
|
63 |
+
max_value=10
|
64 |
+
)
|
65 |
+
|
66 |
+
wit_index = get_wit_index()
|
67 |
+
if st.button("Run"):
|
68 |
+
with st.spinner(text="Getting results..."):
|
69 |
+
st.subheader("Result")
|
70 |
+
time_start = time.time()
|
71 |
+
distances, index, image_info = process(text=session_state.text, top_k=int(top_k))
|
72 |
+
time_end = time.time()
|
73 |
+
time_diff = time_end-time_start
|
74 |
+
print(f"Search in {time_diff} seconds")
|
75 |
+
st.markdown(f"*Search in {time_diff:.5f} seconds*")
|
76 |
+
for i, distance in enumerate(distances):
|
77 |
+
st.image(image_info[i][0].replace("http:", "https:"), width=400)
|
78 |
+
st.write(f"{image_info[i][1]}. (D: {distance:.2f})")
|
79 |
+
|
80 |
+
# Reset state
|
81 |
+
session_state.prompt = None
|
82 |
+
session_state.prompt_box = None
|
83 |
+
session_state.text = None
|
prompts.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
PROMPT_LIST = {
|
2 |
+
"City": [
|
3 |
+
"Vienna is the national capital, largest city, and one of nine states of Austria. Vienna is Austria's most populous city, with about 2 million inhabitants, and its cultural, economic, and political centre. It is the 6th-largest city by population within city limits in the European Union",
|
4 |
+
"Sydney is the capital city of the state of New South Wales, and the most populous city in Australia and Oceania.",
|
5 |
+
"Ubud is a town on the Indonesian island of Bali in Ubud District, located amongst rice paddies and steep ravines in the central foothills of the Gianyar regency. Promoted as an arts and culture centre, it has developed a large tourism industry.",
|
6 |
+
"Jakarta is the capital of Indonesia"
|
7 |
+
],
|
8 |
+
"People": [
|
9 |
+
"Albert Einstein was a German-born theoretical physicist, widely acknowledged to be one of the greatest physicists of all time. Einstein is known for developing the theory of relativity, but he also made important contributions to the development of the theory of quantum mechanics.",
|
10 |
+
"Geoffrey Everest Hinton is a British-Canadian cognitive psychologist and computer scientist, most noted for his work on artificial neural networks.",
|
11 |
+
"Pramoedya Ananta Toer was an Indonesian author of novels, short stories, essays, polemics and histories of his homeland and its people."
|
12 |
+
],
|
13 |
+
"Building": [
|
14 |
+
"Borobudur is a 7th-century Mahayana Buddhist temple in Indonesia. It is the world's largest Buddhist temple. The temple consists of nine stacked platforms, six square and three circular, topped by a central dome. It is decorated with 2,672 relief panels and 504 Buddha statues.",
|
15 |
+
"The Statue of Liberty is a colossal neoclassical sculpture on Liberty Island in New York Harbor within New York City, in the United States.",
|
16 |
+
"Machu Picchu is a 15th-century Inca citadel, located in the Eastern Cordillera of southern Peru, on a 2,430-meter mountain ridge."
|
17 |
+
]
|
18 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
torch
|
3 |
+
tensorflow
|
4 |
+
transformers
|
5 |
+
tensorflow
|
6 |
+
mtranslate
|
7 |
+
sentence-transformers
|
8 |
+
datasets
|
9 |
+
faiss-cpu
|
10 |
+
# streamlit version 0.67.1 is needed due to issue with caching
|
11 |
+
streamlit==0.67.1
|
wit_index.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pickle
|
2 |
+
|
3 |
+
# Used to create the dense document vectors.
|
4 |
+
import torch
|
5 |
+
from sentence_transformers import SentenceTransformer
|
6 |
+
import datasets
|
7 |
+
|
8 |
+
# Used to create and store the Faiss index.
|
9 |
+
import faiss
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
class WitIndex:
|
13 |
+
"""
|
14 |
+
WitIndex is a class to search the wiki snippets from the given text. It can also return link to the
|
15 |
+
wiki page or the image.
|
16 |
+
"""
|
17 |
+
wit_dataset = None
|
18 |
+
|
19 |
+
def __init__(self, wit_index_path: str, model_name: str, wit_dataset_path: str, gpu=True):
|
20 |
+
self.index = faiss.read_index(wit_index_path)
|
21 |
+
self.model = SentenceTransformer(model_name)
|
22 |
+
if WitIndex.wit_dataset is None:
|
23 |
+
WitIndex.wit_dataset = pickle.load(open(wit_dataset_path, "rb"))
|
24 |
+
print(f"Gpu: {gpu}")
|
25 |
+
if gpu and torch.cuda.is_available():
|
26 |
+
print("Cuda is available")
|
27 |
+
self.model = self.model.to(torch.device("cuda"))
|
28 |
+
res = faiss.StandardGpuResources()
|
29 |
+
self.index = faiss.index_cpu_to_gpu(res, 0, self.index)
|
30 |
+
|
31 |
+
def search(self, text, top_k=6):
|
32 |
+
print(f"> Search: {text}")
|
33 |
+
embedding = self.model.encode(text, convert_to_numpy=True, show_progress_bar=False)
|
34 |
+
# Retrieve the k nearest neighbours
|
35 |
+
distance, index = self.index.search(np.array([embedding]), k=top_k)
|
36 |
+
distance, index = distance.flatten().tolist(), index.flatten().tolist()
|
37 |
+
index_url = [WitIndex.wit_dataset['desc2image_map'][i] for i in index]
|
38 |
+
image_info = [WitIndex.wit_dataset['image_info'][i] for i in index_url]
|
39 |
+
print(f"> URL: {image_info[0]}")
|
40 |
+
return distance, index, image_info
|