Vivien
commited on
Commit
·
0779f15
1
Parent(s):
ee0cebf
Add eval and torch.no_grad (because inference only)
Browse files
app.py
CHANGED
@@ -2,13 +2,14 @@ from html import escape
|
|
2 |
import re
|
3 |
import streamlit as st
|
4 |
import pandas as pd, numpy as np
|
|
|
5 |
from transformers import CLIPProcessor, CLIPModel
|
6 |
from st_clickable_images import clickable_images
|
7 |
|
8 |
MODEL_NAMES = [
|
9 |
-
# "base-patch32",
|
10 |
-
# "base-patch16",
|
11 |
-
# "large-patch14",
|
12 |
"large-patch14-336"
|
13 |
]
|
14 |
|
@@ -20,7 +21,7 @@ def load():
|
|
20 |
processors = {}
|
21 |
embeddings = {}
|
22 |
for name in MODEL_NAMES:
|
23 |
-
models[name] = CLIPModel.from_pretrained(f"openai/clip-vit-{name}")
|
24 |
processors[name] = CLIPProcessor.from_pretrained(f"openai/clip-vit-{name}")
|
25 |
embeddings[name] = {
|
26 |
0: np.load(f"embeddings-vit-{name}.npy"),
|
@@ -39,7 +40,8 @@ source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
|
|
39 |
|
40 |
def compute_text_embeddings(list_of_strings, name):
|
41 |
inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True)
|
42 |
-
|
|
|
43 |
return result / np.linalg.norm(result, axis=1, keepdims=True)
|
44 |
|
45 |
|
@@ -158,9 +160,9 @@ def main():
|
|
158 |
st.sidebar.markdown(description)
|
159 |
with st.sidebar.expander("Advanced use"):
|
160 |
st.markdown(howto)
|
161 |
-
#mode = st.sidebar.selectbox(
|
162 |
# "", ["Results for ViT-L/14@336px", "Comparison of 2 models"], index=0
|
163 |
-
#)
|
164 |
|
165 |
_, c, _ = st.columns((1, 3, 1))
|
166 |
if "query" in st.session_state:
|
@@ -176,7 +178,7 @@ def main():
|
|
176 |
"ViT-L/14@336px (slower)": "large-patch14-336",
|
177 |
}
|
178 |
|
179 |
-
if False
|
180 |
c1, c2 = st.columns((1, 1))
|
181 |
selection1 = c1.selectbox("", models_dict.keys(), index=0)
|
182 |
selection2 = c2.selectbox("", models_dict.keys(), index=2)
|
@@ -187,7 +189,7 @@ def main():
|
|
187 |
|
188 |
if len(query) > 0:
|
189 |
results1 = image_search(query, corpus, name1)
|
190 |
-
if False
|
191 |
with c1:
|
192 |
clicked1 = clickable_images(
|
193 |
[result[0] for result in results1],
|
@@ -225,7 +227,7 @@ def main():
|
|
225 |
if change_query:
|
226 |
if clicked1 >= 0:
|
227 |
st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]"
|
228 |
-
#elif clicked2 >= 0:
|
229 |
# st.session_state["query"] = f"[{corpus}:{results2[clicked2][2]}]"
|
230 |
st.experimental_rerun()
|
231 |
|
|
|
2 |
import re
|
3 |
import streamlit as st
|
4 |
import pandas as pd, numpy as np
|
5 |
+
import torch
|
6 |
from transformers import CLIPProcessor, CLIPModel
|
7 |
from st_clickable_images import clickable_images
|
8 |
|
9 |
MODEL_NAMES = [
|
10 |
+
# "base-patch32",
|
11 |
+
# "base-patch16",
|
12 |
+
# "large-patch14",
|
13 |
"large-patch14-336"
|
14 |
]
|
15 |
|
|
|
21 |
processors = {}
|
22 |
embeddings = {}
|
23 |
for name in MODEL_NAMES:
|
24 |
+
models[name] = CLIPModel.from_pretrained(f"openai/clip-vit-{name}").eval()
|
25 |
processors[name] = CLIPProcessor.from_pretrained(f"openai/clip-vit-{name}")
|
26 |
embeddings[name] = {
|
27 |
0: np.load(f"embeddings-vit-{name}.npy"),
|
|
|
40 |
|
41 |
def compute_text_embeddings(list_of_strings, name):
|
42 |
inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True)
|
43 |
+
with torch.no_grad():
|
44 |
+
result = models[name].get_text_features(**inputs).detach().numpy()
|
45 |
return result / np.linalg.norm(result, axis=1, keepdims=True)
|
46 |
|
47 |
|
|
|
160 |
st.sidebar.markdown(description)
|
161 |
with st.sidebar.expander("Advanced use"):
|
162 |
st.markdown(howto)
|
163 |
+
# mode = st.sidebar.selectbox(
|
164 |
# "", ["Results for ViT-L/14@336px", "Comparison of 2 models"], index=0
|
165 |
+
# )
|
166 |
|
167 |
_, c, _ = st.columns((1, 3, 1))
|
168 |
if "query" in st.session_state:
|
|
|
178 |
"ViT-L/14@336px (slower)": "large-patch14-336",
|
179 |
}
|
180 |
|
181 |
+
if False: # "Comparison" in mode:
|
182 |
c1, c2 = st.columns((1, 1))
|
183 |
selection1 = c1.selectbox("", models_dict.keys(), index=0)
|
184 |
selection2 = c2.selectbox("", models_dict.keys(), index=2)
|
|
|
189 |
|
190 |
if len(query) > 0:
|
191 |
results1 = image_search(query, corpus, name1)
|
192 |
+
if False: # "Comparison" in mode:
|
193 |
with c1:
|
194 |
clicked1 = clickable_images(
|
195 |
[result[0] for result in results1],
|
|
|
227 |
if change_query:
|
228 |
if clicked1 >= 0:
|
229 |
st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]"
|
230 |
+
# elif clicked2 >= 0:
|
231 |
# st.session_state["query"] = f"[{corpus}:{results2[clicked2][2]}]"
|
232 |
st.experimental_rerun()
|
233 |
|