Vivien commited on
Commit
a55de09
β€’
1 Parent(s): b59b1d0

Switch from ViT-B32 to ViT-B16

Browse files
Files changed (4) hide show
  1. app.py +58 -32
  2. data.csv +0 -0
  3. embeddings.npy +1 -1
  4. embeddings2.npy +1 -1
app.py CHANGED
@@ -4,21 +4,31 @@ from html import escape
4
  import os
5
  from transformers import CLIPProcessor, CLIPModel
6
 
7
- @st.cache(show_spinner=False,
8
- hash_funcs={CLIPModel: lambda _: None,
9
- CLIPProcessor: lambda _: None,
10
- dict: lambda _: None})
 
 
 
 
 
11
  def load():
12
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
13
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
14
- df = {0: pd.read_csv('data.csv'), 1: pd.read_csv('data2.csv')}
15
- embeddings = {0: np.load('embeddings.npy'), 1: np.load('embeddings2.npy')}
16
- for k in [0, 1]:
17
- embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True)))
18
- return model, processor, df, embeddings
 
 
 
 
19
  model, processor, df, embeddings = load()
20
 
21
- source = {0: '\nSource: Unsplash', 1: '\nSource: The Movie Database (TMDB)'}
 
22
 
23
  def get_html(url_list, height=200):
24
  html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
@@ -30,20 +40,32 @@ def get_html(url_list, height=200):
30
  html += "</div>"
31
  return html
32
 
 
33
  def compute_text_embeddings(list_of_strings):
34
  inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
35
  return model.get_text_features(**inputs)
36
 
 
37
  st.cache(show_spinner=False)
 
 
38
  def image_search(query, corpus, n_results=24):
39
  text_embeddings = compute_text_embeddings([query]).detach().numpy()
40
- k = 0 if corpus == 'Unsplash' else 1
41
- results = np.argsort((embeddings[k]@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
42
- return [(df[k].iloc[i]['path'],
43
- df[k].iloc[i]['tooltip'] + source[k],
44
- df[k].iloc[i]['link']) for i in results]
 
 
 
 
 
 
 
 
45
 
46
- description = '''
47
  # Semantic image search
48
 
49
  **Enter your query and hit enter**
@@ -51,10 +73,12 @@ description = '''
51
  *Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, πŸ€— Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
52
 
53
  *Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe*
54
- '''
 
55
 
56
  def main():
57
- st.markdown('''
 
58
  <style>
59
  .block-container{
60
  max-width: 1200px;
@@ -83,15 +107,17 @@ def main():
83
  footer {
84
  visibility: hidden;
85
  }
86
- </style>''',
87
- unsafe_allow_html=True)
88
- st.sidebar.markdown(description)
89
- _, c, _ = st.columns((1, 3, 1))
90
- query = c.text_input('', value='clouds at sunset')
91
- corpus = st.radio('', ["Unsplash","Movies"])
92
- if len(query) > 0:
93
- results = image_search(query, corpus)
94
- st.markdown(get_html(results), unsafe_allow_html=True)
95
-
96
- if __name__ == '__main__':
97
- main()
 
 
 
4
  import os
5
  from transformers import CLIPProcessor, CLIPModel
6
 
7
+
8
+ @st.cache(
9
+ show_spinner=False,
10
+ hash_funcs={
11
+ CLIPModel: lambda _: None,
12
+ CLIPProcessor: lambda _: None,
13
+ dict: lambda _: None,
14
+ },
15
+ )
16
  def load():
17
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
18
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
19
+ df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
20
+ embeddings = {0: np.load("embeddings.npy"), 1: np.load("embeddings2.npy")}
21
+ for k in [0, 1]:
22
+ embeddings[k] = np.divide(
23
+ embeddings[k], np.sqrt(np.sum(embeddings[k] ** 2, axis=1, keepdims=True))
24
+ )
25
+ return model, processor, df, embeddings
26
+
27
+
28
  model, processor, df, embeddings = load()
29
 
30
+ source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
31
+
32
 
33
  def get_html(url_list, height=200):
34
  html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
 
40
  html += "</div>"
41
  return html
42
 
43
+
44
  def compute_text_embeddings(list_of_strings):
45
  inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
46
  return model.get_text_features(**inputs)
47
 
48
+
49
  st.cache(show_spinner=False)
50
+
51
+
52
  def image_search(query, corpus, n_results=24):
53
  text_embeddings = compute_text_embeddings([query]).detach().numpy()
54
+ k = 0 if corpus == "Unsplash" else 1
55
+ results = np.argsort((embeddings[k] @ text_embeddings.T)[:, 0])[
56
+ -1 : -n_results - 1 : -1
57
+ ]
58
+ return [
59
+ (
60
+ df[k].iloc[i]["path"],
61
+ df[k].iloc[i]["tooltip"] + source[k],
62
+ df[k].iloc[i]["link"],
63
+ )
64
+ for i in results
65
+ ]
66
+
67
 
68
+ description = """
69
  # Semantic image search
70
 
71
  **Enter your query and hit enter**
 
73
  *Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, πŸ€— Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
74
 
75
  *Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe*
76
+ """
77
+
78
 
79
  def main():
80
+ st.markdown(
81
+ """
82
  <style>
83
  .block-container{
84
  max-width: 1200px;
 
107
  footer {
108
  visibility: hidden;
109
  }
110
+ </style>""",
111
+ unsafe_allow_html=True,
112
+ )
113
+ st.sidebar.markdown(description)
114
+ _, c, _ = st.columns((1, 3, 1))
115
+ query = c.text_input("", value="clouds at sunset")
116
+ corpus = st.radio("", ["Unsplash", "Movies"])
117
+ if len(query) > 0:
118
+ results = image_search(query, corpus)
119
+ st.markdown(get_html(results), unsafe_allow_html=True)
120
+
121
+
122
+ if __name__ == "__main__":
123
+ main()
data.csv CHANGED
The diff for this file is too large to render. See raw diff
 
embeddings.npy CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9f8c171e32276739be6b020592edc8a2c06e029ff6505a9d1d4efe3cafa073bd
3
  size 51200128
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:125430e11a4a415ec0c0fc5339f97544f0447e4b0a24c20f2e59f8852e706afc
3
  size 51200128
embeddings2.npy CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:9664e980f31e81c4a34e07833539fea32795d83a4262c9828ceae445fa2e412a
3
  size 16732288
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:153cf3fae2385d51fe8729d3a1c059f611ca47a3fc501049708114d1bbf79049
3
  size 16732288