Vivien commited on
Commit
563e3ef
1 Parent(s): 8d4b675

Initial commit

Browse files
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
- title: Clip Owlvit
3
- emoji: 🐨
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: streamlit
1
  ---
2
+ title: Search and Detect (CLIP/Owl-ViT)
3
+ emoji: 🦉
4
  colorFrom: indigo
5
  colorTo: red
6
  sdk: streamlit
app.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from html import escape
2
+ import requests
3
+ from io import BytesIO
4
+ import base64
5
+ from multiprocessing.dummy import Pool
6
+ from PIL import Image, ImageDraw
7
+ import streamlit as st
8
+ import pandas as pd, numpy as np
9
+ import torch
10
+ from transformers import CLIPProcessor, CLIPModel
11
+ from transformers import OwlViTProcessor, OwlViTForObjectDetection
12
+ from transformers.image_utils import ImageFeatureExtractionMixin
13
+ import tokenizers
14
+
15
+ DEBUG = True
16
+ if DEBUG:
17
+ MODEL = "vit-base-patch32"
18
+ OWL_MODEL = f"google/owlvit-base-patch32"
19
+ else:
20
+ MODEL = "vit-large-patch14-336"
21
+ OWL_MODEL = f"google/owlvit-large-path14"
22
+ CLIP_MODEL = f"openai/clip-{MODEL}"
23
+
24
+ if not DEBUG and torch.cuda.is_available():
25
+ device = torch.device("cuda")
26
+ else:
27
+ device = torch.device("cpu")
28
+
29
+ HEIGHT = 200
30
+ N_RESULTS = 6
31
+
32
+ color = st.get_option("theme.primaryColor")
33
+ if color is None:
34
+ color = (255, 75, 75)
35
+ else:
36
+ color = tuple(int(color.lstrip("#")[i : i + 2], 16) for i in (0, 2, 4))
37
+
38
+
39
+ @st.cache(allow_output_mutation=True)
40
+ def load():
41
+ df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
42
+ clip_model = CLIPModel.from_pretrained(CLIP_MODEL)
43
+ clip_model.to(device)
44
+ clip_model.eval()
45
+ clip_processor = CLIPProcessor.from_pretrained(CLIP_MODEL)
46
+ owl_model = OwlViTForObjectDetection.from_pretrained(OWL_MODEL)
47
+ owl_model.to(device)
48
+ owl_model.eval()
49
+ owl_processor = OwlViTProcessor.from_pretrained(OWL_MODEL)
50
+ embeddings = {
51
+ 0: np.load(f"embeddings-{MODEL}.npy"),
52
+ 1: np.load(f"embeddings2-{MODEL}.npy"),
53
+ }
54
+ for k in [0, 1]:
55
+ embeddings[k] = embeddings[k] / np.linalg.norm(
56
+ embeddings[k], axis=1, keepdims=True
57
+ )
58
+ return clip_model, clip_processor, owl_model, owl_processor, df, embeddings
59
+
60
+
61
+ clip_model, clip_processor, owl_model, owl_processor, df, embeddings = load()
62
+ mixin = ImageFeatureExtractionMixin()
63
+ source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
64
+
65
+
66
+ def compute_text_embeddings(list_of_strings):
67
+ inputs = clip_processor(text=list_of_strings, return_tensors="pt", padding=True).to(
68
+ device
69
+ )
70
+ with torch.no_grad():
71
+ result = clip_model.get_text_features(**inputs).detach().cpu().numpy()
72
+ return result / np.linalg.norm(result, axis=1, keepdims=True)
73
+
74
+
75
+ def image_search(query, corpus, n_results=N_RESULTS):
76
+ query_embedding = compute_text_embeddings([query])
77
+ corpus_id = 0 if corpus == "Unsplash" else 1
78
+ dot_product = (embeddings[corpus_id] @ query_embedding.T)[:, 0]
79
+ results = np.argsort(dot_product)[-1 : -n_results - 1 : -1]
80
+ return [
81
+ (
82
+ df[corpus_id].iloc[i].path,
83
+ df[corpus_id].iloc[i].tooltip + source[corpus_id],
84
+ df[corpus_id].iloc[i].link,
85
+ )
86
+ for i in results
87
+ ]
88
+
89
+
90
+ def make_square(img, fill_color=(255, 255, 255)):
91
+ x, y = img.size
92
+ size = max(x, y)
93
+ new_img = Image.new("RGB", (size, size), fill_color)
94
+ new_img.paste(img, (int((size - x) / 2), int((size - y) / 2)))
95
+ return new_img, x, y
96
+
97
+
98
+ @st.cache(allow_output_mutation=True, show_spinner=False)
99
+ def get_images(paths):
100
+ def process_image(path):
101
+ return make_square(Image.open(BytesIO(requests.get(path).content)))
102
+
103
+ processed = Pool(N_RESULTS).map(process_image, paths)
104
+ imgs, xs, ys = [], [], []
105
+ for img, x, y in processed:
106
+ imgs.append(img)
107
+ xs.append(x)
108
+ ys.append(y)
109
+ return imgs, xs, ys
110
+
111
+
112
+ @st.cache(
113
+ hash_funcs={
114
+ tokenizers.Tokenizer: lambda x: None,
115
+ tokenizers.AddedToken: lambda x: None,
116
+ torch.nn.parameter.Parameter: lambda x: None,
117
+ },
118
+ allow_output_mutation=True,
119
+ show_spinner=False,
120
+ )
121
+ def apply_owl_model(owl_queries, images):
122
+ inputs = owl_processor(text=owl_queries, images=images, return_tensors="pt").to(
123
+ device
124
+ )
125
+ with torch.no_grad():
126
+ results = owl_model(**inputs)
127
+ target_sizes = torch.Tensor([img.size[::-1] for img in images]).to(device)
128
+ return owl_processor.post_process(outputs=results, target_sizes=target_sizes)
129
+
130
+
131
+ def keep_best_boxes(boxes, scores, score_threshold=0.1, max_iou=0.8):
132
+ candidates = []
133
+ for box, score in zip(boxes, scores):
134
+ box = [round(i, 0) for i in box.tolist()]
135
+ if score >= score_threshold:
136
+ candidates.append((box, float(score)))
137
+
138
+ to_ignore = set()
139
+ for i in range(len(candidates) - 1):
140
+ if i in to_ignore:
141
+ continue
142
+ for j in range(i + 1, len(candidates)):
143
+ if j in to_ignore:
144
+ continue
145
+ xmin1, ymin1, xmax1, ymax1 = candidates[i][0]
146
+ xmin2, ymin2, xmax2, ymax2 = candidates[j][0]
147
+ if xmax1 < xmin2 or xmax2 < xmin1 or ymax1 < ymin2 or ymax2 < ymin1:
148
+ continue
149
+ else:
150
+ xmin_inter, xmax_inter = sorted([xmin1, xmax1, xmin2, xmax2])[1:3]
151
+ ymin_inter, ymax_inter = sorted([ymin1, ymax1, ymin2, ymax2])[1:3]
152
+ area_inter = (xmax_inter - xmin_inter) * (ymax_inter - ymin_inter)
153
+ area1 = (xmax1 - xmin1) * (ymax1 - ymin1)
154
+ area2 = (xmax2 - xmin2) * (ymax2 - ymin2)
155
+ iou = area_inter / (area1 + area2 - area_inter)
156
+ if iou > max_iou:
157
+ if candidates[i][1] > candidates[j][1]:
158
+ to_ignore.add(j)
159
+ else:
160
+ to_ignore.add(i)
161
+ break
162
+ else:
163
+ if area_inter / area1 > 0.9:
164
+ if candidates[i][1] < 1.1 * candidates[j][1]:
165
+ to_ignore.add(i)
166
+ if area_inter / area2 > 0.9:
167
+ if 1.1 * candidates[i][1] > candidates[j][1]:
168
+ to_ignore.add(j)
169
+ return [candidates[i][0] for i in range(len(candidates)) if i not in to_ignore]
170
+
171
+
172
+ def convert_pil_to_base64(image):
173
+ img_buffer = BytesIO()
174
+ image.save(img_buffer, format="JPEG")
175
+ byte_data = img_buffer.getvalue()
176
+ base64_str = base64.b64encode(byte_data)
177
+ return base64_str
178
+
179
+
180
+ def draw_reshape_encode(img, boxes, x, y):
181
+ image = img.copy()
182
+ draw = ImageDraw.Draw(image)
183
+ new_x, new_y = int(x * HEIGHT / y), HEIGHT
184
+ for box in boxes:
185
+ draw.rectangle(
186
+ (tuple(box[:2]), tuple(box[2:])), outline=color, width=2 * int(y / HEIGHT)
187
+ )
188
+ if x > y:
189
+ image = image.crop((0, (x - y) / 2, x, x - (x - y) / 2))
190
+ else:
191
+ image = image.crop(((y - x) / 2, 0, y - (y - x) / 2, y))
192
+ return convert_pil_to_base64(image.resize((new_x, new_y)))
193
+
194
+
195
+ def get_html(url_list, encoded_images):
196
+ html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
197
+ for i in range(len(url_list)):
198
+ title, link, encoded = url_list[i][1], url_list[i][2], encoded_images[i]
199
+ html2 = f"<img title='{escape(title)}' style='height: {HEIGHT}px; margin: 5px' src='data:image/jpeg;base64,{encoded.decode()}'>"
200
+ if len(link) > 0:
201
+ html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>"
202
+ html = html + html2
203
+ html += "</div>"
204
+ return html
205
+
206
+
207
+ description = """
208
+ # Search and Detect
209
+
210
+ This demo illustrates how you can both retrieve images containing certain objects and locate these objects with a simple natural language query.
211
+
212
+ **Enter your query and hit enter**
213
+
214
+ **Tip 1**: if your query includes "/", the part left (resp. right) of "/" will be used to retrieve images (resp. locate objects). For example, if you want to retrieve pictures with several cats but locate individual cats, you can type "cats / cat".
215
+
216
+ **Tip 2**: change the score threshold below to adjust the sensitivity of the object detection.
217
+
218
+ *Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model and Google's [Owl-ViT](https://arxiv.org/abs/2205.06230) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
219
+
220
+ """
221
+
222
+ div_style = {
223
+ "display": "flex",
224
+ "justify-content": "center",
225
+ "flex-wrap": "wrap",
226
+ }
227
+
228
+
229
+ def main():
230
+ st.markdown(
231
+ """
232
+ <style>
233
+ .block-container{
234
+ max-width: 1200px;
235
+ }
236
+ div.row-widget.stRadio > div{
237
+ flex-direction:row;
238
+ display: flex;
239
+ justify-content: center;
240
+ }
241
+ div.row-widget.stRadio > div > label{
242
+ margin-left: 5px;
243
+ margin-right: 5px;
244
+ }
245
+ .row-widget {
246
+ margin-top: -25px;
247
+ }
248
+ section>div:first-child {
249
+ padding-top: 30px;
250
+ }
251
+ div.reportview-container > section:first-child{
252
+ max-width: 320px;
253
+ }
254
+ #MainMenu {
255
+ visibility: hidden;
256
+ }
257
+ footer {
258
+ visibility: hidden;
259
+ }
260
+ </style>""",
261
+ unsafe_allow_html=True,
262
+ )
263
+ st.sidebar.markdown(description)
264
+ score_threshold = st.sidebar.slider(
265
+ "Score threshold", min_value=0.01, max_value=0.3, value=0.1, step=0.01
266
+ )
267
+
268
+ _, c, _ = st.columns((1, 3, 1))
269
+ query = c.text_input("", value="clouds at sunset")
270
+ corpus = st.radio("", ["Unsplash", "Movies"])
271
+
272
+ if len(query) > 0:
273
+ if "/" in query:
274
+ queries = query.split("/")
275
+ clip_query, owl_query = ("/").join(queries[:-1]), queries[-1]
276
+ else:
277
+ clip_query, owl_query = query, query
278
+ retrieved = image_search(clip_query, corpus)
279
+ imgs, xs, ys = get_images([x[0] for x in retrieved])
280
+ results = apply_owl_model([[owl_query]] * len(imgs), imgs)
281
+ encoded_images = []
282
+ for image_idx in range(len(imgs)):
283
+ img0, x, y = imgs[image_idx], xs[image_idx], ys[image_idx]
284
+ boxes = keep_best_boxes(
285
+ results[image_idx]["boxes"],
286
+ results[image_idx]["scores"],
287
+ score_threshold=score_threshold,
288
+ )
289
+ encoded_images.append(draw_reshape_encode(img0, boxes, x, y))
290
+ st.markdown(get_html(retrieved, encoded_images), unsafe_allow_html=True)
291
+
292
+
293
+ if __name__ == "__main__":
294
+ main()
data.csv ADDED
The diff for this file is too large to render. See raw diff
data2.csv ADDED
The diff for this file is too large to render. See raw diff
embeddings-vit-base-patch32.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3f7ebdff24079665faf58d07045056a63b5499753e3ffbda479691d53de3ab38
3
+ size 51200128
embeddings-vit-large-patch14-336.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f79f10ebe267b4ee7acd553dfe0ee31df846123630058a6d58c04bf22e0ad068
3
+ size 76800128
embeddings2-vit-base-patch32.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e7d545bed86121dac1cedcc1de61ea5295f5840c1eb751637e6628ac54faef81
3
+ size 16732288
embeddings2-vit-large-patch14-336.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e66eb377465fbfaa56cec079aa3e214533ceac43646f2ca78028ae4d8ad6d03
3
+ size 25098368
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ tokenizers
4
+ Pillow
5
+ ftfy
6
+ numpy
7
+ pandas