awacke1 commited on
Commit
5daa26e
1 Parent(s): 833b736

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -0
app.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from html import escape
2
+ import re
3
+ import streamlit as st
4
+ import pandas as pd, numpy as np
5
+ from transformers import CLIPProcessor, CLIPModel
6
+ from st_clickable_images import clickable_images
7
+
8
+
9
+ @st.cache(
10
+ show_spinner=False,
11
+ hash_funcs={
12
+ CLIPModel: lambda _: None,
13
+ CLIPProcessor: lambda _: None,
14
+ dict: lambda _: None,
15
+ },
16
+ )
17
+ def load():
18
+ model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
19
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
20
+ df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
21
+ embeddings = {0: np.load("embeddings.npy"), 1: np.load("embeddings2.npy")}
22
+ for k in [0, 1]:
23
+ embeddings[k] = embeddings[k] / np.linalg.norm(
24
+ embeddings[k], axis=1, keepdims=True
25
+ )
26
+ return model, processor, df, embeddings
27
+
28
+
29
+ model, processor, df, embeddings = load()
30
+ source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
31
+
32
+
33
+ def compute_text_embeddings(list_of_strings):
34
+ inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
35
+ result = model.get_text_features(**inputs).detach().numpy()
36
+ return result / np.linalg.norm(result, axis=1, keepdims=True)
37
+
38
+
39
+ def image_search(query, corpus, n_results=24):
40
+ positive_embeddings = None
41
+
42
+ def concatenate_embeddings(e1, e2):
43
+ if e1 is None:
44
+ return e2
45
+ else:
46
+ return np.concatenate((e1, e2), axis=0)
47
+
48
+ splitted_query = query.split("EXCLUDING ")
49
+ dot_product = 0
50
+ k = 0 if corpus == "Unsplash" else 1
51
+ if len(splitted_query[0]) > 0:
52
+ positive_queries = splitted_query[0].split(";")
53
+ for positive_query in positive_queries:
54
+ match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query)
55
+ if match:
56
+ corpus2, idx, remainder = match.groups()
57
+ idx, remainder = int(idx), remainder.strip()
58
+ k2 = 0 if corpus2 == "Unsplash" else 1
59
+ positive_embeddings = concatenate_embeddings(
60
+ positive_embeddings, embeddings[k2][idx : idx + 1, :]
61
+ )
62
+ if len(remainder) > 0:
63
+ positive_embeddings = concatenate_embeddings(
64
+ positive_embeddings, compute_text_embeddings([remainder])
65
+ )
66
+ else:
67
+ positive_embeddings = concatenate_embeddings(
68
+ positive_embeddings, compute_text_embeddings([positive_query])
69
+ )
70
+ dot_product = embeddings[k] @ positive_embeddings.T
71
+ dot_product = dot_product - np.median(dot_product, axis=0)
72
+ dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True)
73
+ dot_product = np.min(dot_product, axis=1)
74
+
75
+ if len(splitted_query) > 1:
76
+ negative_queries = (" ".join(splitted_query[1:])).split(";")
77
+ negative_embeddings = compute_text_embeddings(negative_queries)
78
+ dot_product2 = embeddings[k] @ negative_embeddings.T
79
+ dot_product2 = dot_product2 - np.median(dot_product2, axis=0)
80
+ dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True)
81
+ dot_product -= np.max(np.maximum(dot_product2, 0), axis=1)
82
+
83
+ results = np.argsort(dot_product)[-1 : -n_results - 1 : -1]
84
+ return [
85
+ (
86
+ df[k].iloc[i]["path"],
87
+ df[k].iloc[i]["tooltip"] + source[k],
88
+ i,
89
+ )
90
+ for i in results
91
+ ]
92
+
93
+
94
+ description = """
95
+ # Semantic image search
96
+ **Enter your query and hit enter**
97
+ *Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
98
+ *Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe*
99
+ """
100
+
101
+ howto = """
102
+ - Click on an image to use it as a query and find similar images
103
+ - Several queries, including one based on an image, can be combined (use "**;**" as a separator)
104
+ - If the input includes "**EXCLUDING**", the part right of it will be used as a negative query
105
+ """
106
+
107
+
108
+ def main():
109
+ st.markdown(
110
+ """
111
+ <style>
112
+ .block-container{
113
+ max-width: 1200px;
114
+ }
115
+ div.row-widget.stRadio > div{
116
+ flex-direction:row;
117
+ display: flex;
118
+ justify-content: center;
119
+ }
120
+ div.row-widget.stRadio > div > label{
121
+ margin-left: 5px;
122
+ margin-right: 5px;
123
+ }
124
+ section.main>div:first-child {
125
+ padding-top: 0px;
126
+ }
127
+ section:not(.main)>div:first-child {
128
+ padding-top: 30px;
129
+ }
130
+ div.reportview-container > section:first-child{
131
+ max-width: 320px;
132
+ }
133
+ #MainMenu {
134
+ visibility: hidden;
135
+ }
136
+ footer {
137
+ visibility: hidden;
138
+ }
139
+ </style>""",
140
+ unsafe_allow_html=True,
141
+ )
142
+ st.sidebar.markdown(description)
143
+ with st.sidebar.expander("Advanced use"):
144
+ st.markdown(howto)
145
+
146
+ _, c, _ = st.columns((1, 3, 1))
147
+ if "query" in st.session_state:
148
+ query = c.text_input("", value=st.session_state["query"])
149
+ else:
150
+ query = c.text_input("", value="clouds at sunset")
151
+ corpus = st.radio("", ["Unsplash", "Movies"])
152
+ if len(query) > 0:
153
+ results = image_search(query, corpus)
154
+ clicked = clickable_images(
155
+ [result[0] for result in results],
156
+ titles=[result[1] for result in results],
157
+ div_style={
158
+ "display": "flex",
159
+ "justify-content": "center",
160
+ "flex-wrap": "wrap",
161
+ },
162
+ img_style={"margin": "2px", "height": "200px"},
163
+ )
164
+ if clicked >= 0:
165
+ change_query = False
166
+ if "last_clicked" not in st.session_state:
167
+ change_query = True
168
+ else:
169
+ if clicked != st.session_state["last_clicked"]:
170
+ change_query = True
171
+ if change_query:
172
+ st.session_state["query"] = f"[{corpus}:{results[clicked][2]}]"
173
+ st.experimental_rerun()
174
+
175
+
176
+ if __name__ == "__main__":
177
+ main()