Vivien commited on
Commit
0779f15
·
1 Parent(s): ee0cebf

Add eval and torch.no_grad (because inference only)

Browse files
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -2,13 +2,14 @@ from html import escape
2
  import re
3
  import streamlit as st
4
  import pandas as pd, numpy as np
 
5
  from transformers import CLIPProcessor, CLIPModel
6
  from st_clickable_images import clickable_images
7
 
8
  MODEL_NAMES = [
9
- # "base-patch32",
10
- # "base-patch16",
11
- # "large-patch14",
12
  "large-patch14-336"
13
  ]
14
 
@@ -20,7 +21,7 @@ def load():
20
  processors = {}
21
  embeddings = {}
22
  for name in MODEL_NAMES:
23
- models[name] = CLIPModel.from_pretrained(f"openai/clip-vit-{name}")
24
  processors[name] = CLIPProcessor.from_pretrained(f"openai/clip-vit-{name}")
25
  embeddings[name] = {
26
  0: np.load(f"embeddings-vit-{name}.npy"),
@@ -39,7 +40,8 @@ source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
39
 
40
  def compute_text_embeddings(list_of_strings, name):
41
  inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True)
42
- result = models[name].get_text_features(**inputs).detach().numpy()
 
43
  return result / np.linalg.norm(result, axis=1, keepdims=True)
44
 
45
 
@@ -158,9 +160,9 @@ def main():
158
  st.sidebar.markdown(description)
159
  with st.sidebar.expander("Advanced use"):
160
  st.markdown(howto)
161
- #mode = st.sidebar.selectbox(
162
  # "", ["Results for ViT-L/14@336px", "Comparison of 2 models"], index=0
163
- #)
164
 
165
  _, c, _ = st.columns((1, 3, 1))
166
  if "query" in st.session_state:
@@ -176,7 +178,7 @@ def main():
176
  "ViT-L/14@336px (slower)": "large-patch14-336",
177
  }
178
 
179
- if False:#"Comparison" in mode:
180
  c1, c2 = st.columns((1, 1))
181
  selection1 = c1.selectbox("", models_dict.keys(), index=0)
182
  selection2 = c2.selectbox("", models_dict.keys(), index=2)
@@ -187,7 +189,7 @@ def main():
187
 
188
  if len(query) > 0:
189
  results1 = image_search(query, corpus, name1)
190
- if False:#"Comparison" in mode:
191
  with c1:
192
  clicked1 = clickable_images(
193
  [result[0] for result in results1],
@@ -225,7 +227,7 @@ def main():
225
  if change_query:
226
  if clicked1 >= 0:
227
  st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]"
228
- #elif clicked2 >= 0:
229
  # st.session_state["query"] = f"[{corpus}:{results2[clicked2][2]}]"
230
  st.experimental_rerun()
231
 
 
2
  import re
3
  import streamlit as st
4
  import pandas as pd, numpy as np
5
+ import torch
6
  from transformers import CLIPProcessor, CLIPModel
7
  from st_clickable_images import clickable_images
8
 
9
  MODEL_NAMES = [
10
+ # "base-patch32",
11
+ # "base-patch16",
12
+ # "large-patch14",
13
  "large-patch14-336"
14
  ]
15
 
 
21
  processors = {}
22
  embeddings = {}
23
  for name in MODEL_NAMES:
24
+ models[name] = CLIPModel.from_pretrained(f"openai/clip-vit-{name}").eval()
25
  processors[name] = CLIPProcessor.from_pretrained(f"openai/clip-vit-{name}")
26
  embeddings[name] = {
27
  0: np.load(f"embeddings-vit-{name}.npy"),
 
40
 
41
  def compute_text_embeddings(list_of_strings, name):
42
  inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True)
43
+ with torch.no_grad():
44
+ result = models[name].get_text_features(**inputs).detach().numpy()
45
  return result / np.linalg.norm(result, axis=1, keepdims=True)
46
 
47
 
 
160
  st.sidebar.markdown(description)
161
  with st.sidebar.expander("Advanced use"):
162
  st.markdown(howto)
163
+ # mode = st.sidebar.selectbox(
164
  # "", ["Results for ViT-L/14@336px", "Comparison of 2 models"], index=0
165
+ # )
166
 
167
  _, c, _ = st.columns((1, 3, 1))
168
  if "query" in st.session_state:
 
178
  "ViT-L/14@336px (slower)": "large-patch14-336",
179
  }
180
 
181
+ if False: # "Comparison" in mode:
182
  c1, c2 = st.columns((1, 1))
183
  selection1 = c1.selectbox("", models_dict.keys(), index=0)
184
  selection2 = c2.selectbox("", models_dict.keys(), index=2)
 
189
 
190
  if len(query) > 0:
191
  results1 = image_search(query, corpus, name1)
192
+ if False: # "Comparison" in mode:
193
  with c1:
194
  clicked1 = clickable_images(
195
  [result[0] for result in results1],
 
227
  if change_query:
228
  if clicked1 >= 0:
229
  st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]"
230
+ # elif clicked2 >= 0:
231
  # st.session_state["query"] = f"[{corpus}:{results2[clicked2][2]}]"
232
  st.experimental_rerun()
233