awacke1 commited on
Commit
91ba0dc
1 Parent(s): 8f848d2

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.txt +13 -0
  2. app.py +186 -0
  3. data.csv +0 -0
  4. data2.csv +0 -0
  5. embeddings.npy +3 -0
  6. embeddings2.npy +3 -0
  7. requirements.txt +5 -0
README.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: 🔎NLP Image Semantic Search SL🖼️
3
+ emoji: 🔎🖼️
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: 1.2.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
app.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from html import escape
2
+ import re
3
+ import streamlit as st
4
+ import pandas as pd, numpy as np
5
+ from transformers import CLIPProcessor, CLIPModel
6
+ from st_clickable_images import clickable_images
7
+
8
+ @st.cache(
9
+ show_spinner=False,
10
+ hash_funcs={
11
+ CLIPModel: lambda _: None,
12
+ CLIPProcessor: lambda _: None,
13
+ dict: lambda _: None,
14
+ },
15
+ )
16
+ def load():
17
+ model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
18
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
19
+ df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
20
+ embeddings = {0: np.load("embeddings.npy"), 1: np.load("embeddings2.npy")}
21
+ for k in [0, 1]:
22
+ embeddings[k] = embeddings[k] / np.linalg.norm(
23
+ embeddings[k], axis=1, keepdims=True
24
+ )
25
+ return model, processor, df, embeddings
26
+
27
+
28
+ model, processor, df, embeddings = load()
29
+ source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
30
+
31
+
32
+ def compute_text_embeddings(list_of_strings):
33
+ inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
34
+ result = model.get_text_features(**inputs).detach().numpy()
35
+ return result / np.linalg.norm(result, axis=1, keepdims=True)
36
+
37
+
38
+ def image_search(query, corpus, n_results=24):
39
+ positive_embeddings = None
40
+
41
+ def concatenate_embeddings(e1, e2):
42
+ if e1 is None:
43
+ return e2
44
+ else:
45
+ return np.concatenate((e1, e2), axis=0)
46
+
47
+ splitted_query = query.split("EXCLUDING ")
48
+ dot_product = 0
49
+ k = 0 if corpus == "Unsplash" else 1
50
+ if len(splitted_query[0]) > 0:
51
+ positive_queries = splitted_query[0].split(";")
52
+ for positive_query in positive_queries:
53
+ match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query)
54
+ if match:
55
+ corpus2, idx, remainder = match.groups()
56
+ idx, remainder = int(idx), remainder.strip()
57
+ k2 = 0 if corpus2 == "Unsplash" else 1
58
+ positive_embeddings = concatenate_embeddings(
59
+ positive_embeddings, embeddings[k2][idx : idx + 1, :]
60
+ )
61
+ if len(remainder) > 0:
62
+ positive_embeddings = concatenate_embeddings(
63
+ positive_embeddings, compute_text_embeddings([remainder])
64
+ )
65
+ else:
66
+ positive_embeddings = concatenate_embeddings(
67
+ positive_embeddings, compute_text_embeddings([positive_query])
68
+ )
69
+ dot_product = embeddings[k] @ positive_embeddings.T
70
+ dot_product = dot_product - np.median(dot_product, axis=0)
71
+ dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True)
72
+ dot_product = np.min(dot_product, axis=1)
73
+
74
+ if len(splitted_query) > 1:
75
+ negative_queries = (" ".join(splitted_query[1:])).split(";")
76
+ negative_embeddings = compute_text_embeddings(negative_queries)
77
+ dot_product2 = embeddings[k] @ negative_embeddings.T
78
+ dot_product2 = dot_product2 - np.median(dot_product2, axis=0)
79
+ dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True)
80
+ dot_product -= np.max(np.maximum(dot_product2, 0), axis=1)
81
+
82
+ results = np.argsort(dot_product)[-1 : -n_results - 1 : -1]
83
+ return [
84
+ (
85
+ df[k].iloc[i]["path"],
86
+ df[k].iloc[i]["tooltip"] + source[k],
87
+ i,
88
+ )
89
+ for i in results
90
+ ]
91
+
92
+
93
+ description = """
94
+ # Semantic image search
95
+ **Enter your query and hit enter**
96
+ """
97
+
98
+ howto = """
99
+ - Click image to find similar images
100
+ - Use "**;**" to combine multiple queries)
101
+ - Use "**EXCLUDING**", to exclude a query
102
+ """
103
+
104
+
105
+ def main():
106
+ st.markdown(
107
+ """
108
+ <style>
109
+ .block-container{
110
+ max-width: 1200px;
111
+ }
112
+ div.row-widget.stRadio > div{
113
+ flex-direction:row;
114
+ display: flex;
115
+ justify-content: center;
116
+ }
117
+ div.row-widget.stRadio > div > label{
118
+ margin-left: 5px;
119
+ margin-right: 5px;
120
+ }
121
+ section.main>div:first-child {
122
+ padding-top: 0px;
123
+ }
124
+ section:not(.main)>div:first-child {
125
+ padding-top: 30px;
126
+ }
127
+ div.reportview-container > section:first-child{
128
+ max-width: 320px;
129
+ }
130
+ #MainMenu {
131
+ visibility: hidden;
132
+ }
133
+ footer {
134
+ visibility: hidden;
135
+ }
136
+ </style>""",
137
+ unsafe_allow_html=True,
138
+ )
139
+ st.sidebar.markdown(description)
140
+ with st.sidebar.expander("Advanced use"):
141
+ st.markdown(howto)
142
+
143
+
144
+ st.sidebar.markdown(f"Try these test prompts: orange, blue, beach, lighthouse, mountain, sunset, parade")
145
+ st.sidebar.markdown(f"Unsplash has categories that match: backgrounds, photos, nature, iphone, etc")
146
+ st.sidebar.markdown(f"Unsplash images contain animals, apps, events, feelings, food, travel, nature, people, religion, sports, things, stock")
147
+ st.sidebar.markdown(f"Unsplash things include flag, tree, clock, money, tattoo, arrow, book, car, fireworks, ghost, health, kiss, dance, balloon, crown, eye, house, music, airplane, lighthouse, typewriter, toys")
148
+ st.sidebar.markdown(f"unsplash feelings include funny, heart, love, cool, congratulations, love, scary, cute, friendship, inspirational, hug, sad, cursed, beautiful, crazy, respect, transformation, peaceful, happy")
149
+ st.sidebar.markdown(f"unsplash people contain baby, life, women, family, girls, pregnancy, society, old people, musician, attractive, bohemian")
150
+ st.sidebar.markdown(f"imagenet queries include: photo of, photo of many, sculpture of, rendering of, graffiti of, tattoo of, embroidered, drawing of, plastic, black and white, painting, video game, doodle, origami, sketch, etc")
151
+
152
+
153
+ _, c, _ = st.columns((1, 3, 1))
154
+ if "query" in st.session_state:
155
+ query = c.text_input("", value=st.session_state["query"])
156
+ else:
157
+
158
+ query = c.text_input("", value="lighthouse")
159
+ corpus = st.radio("", ["Unsplash"])
160
+ #corpus = st.radio("", ["Unsplash", "Movies"])
161
+ if len(query) > 0:
162
+ results = image_search(query, corpus)
163
+ clicked = clickable_images(
164
+ [result[0] for result in results],
165
+ titles=[result[1] for result in results],
166
+ div_style={
167
+ "display": "flex",
168
+ "justify-content": "center",
169
+ "flex-wrap": "wrap",
170
+ },
171
+ img_style={"margin": "2px", "height": "200px"},
172
+ )
173
+ if clicked >= 0:
174
+ change_query = False
175
+ if "last_clicked" not in st.session_state:
176
+ change_query = True
177
+ else:
178
+ if clicked != st.session_state["last_clicked"]:
179
+ change_query = True
180
+ if change_query:
181
+ st.session_state["query"] = f"[{corpus}:{results[clicked][2]}]"
182
+ st.experimental_rerun()
183
+
184
+
185
+ if __name__ == "__main__":
186
+ main()
data.csv ADDED
The diff for this file is too large to render. See raw diff
 
data2.csv ADDED
The diff for this file is too large to render. See raw diff
 
embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64515f7d3d71137e2944f2c3d72c8df3e684b5d6a6ff7dcebb92370f7326ccfd
3
+ size 76800128
embeddings2.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d730b33e758c2648419a96ac86d39516c59795e613c35700d3a64079e5a9a27
3
+ size 25098368
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ numpy
4
+ pandas
5
+ st-clickable-images