Vivien commited on
Commit
c81898a
β€’
1 Parent(s): 59898ea

Initial commit

Browse files
Files changed (7) hide show
  1. .gitattributes +1 -0
  2. README.md +4 -4
  3. app.py +78 -0
  4. data.csv +0 -0
  5. data2.csv +0 -0
  6. embeddings.npy +3 -0
  7. embeddings2.npy +3 -0
.gitattributes CHANGED
@@ -25,3 +25,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ *.npy filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,11 +1,11 @@
1
  ---
2
- title: Clip
3
  emoji: πŸ‘
4
- colorFrom: yellow
5
- colorTo: yellow
6
  sdk: streamlit
7
  app_file: app.py
8
- pinned: false
9
  ---
10
 
11
  # Configuration
1
  ---
2
+ title: Clip Demo
3
  emoji: πŸ‘
4
+ colorFrom: indigo
5
+ colorTo: blue
6
  sdk: streamlit
7
  app_file: app.py
8
+ pinned: true
9
  ---
10
 
11
  # Configuration
app.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd, numpy as np
3
+ import os
4
+ from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
5
+
6
+ @st.cache(show_spinner=False,
7
+ hash_funcs={CLIPModel: lambda _: None,
8
+ CLIPTextModel: lambda _: None,
9
+ CLIPProcessor: lambda _: None,
10
+ dict: lambda _: None})
11
+ def load():
12
+ model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
13
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
14
+ processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
15
+ df = {0: pd.read_csv('data.csv'), 1: pd.read_csv('data2.csv')}
16
+ embeddings = {0: np.load('embeddings.npy'), 1: np.load('embeddings2.npy')}
17
+ for k in [0, 1]:
18
+ embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True)))
19
+ return model, text_model, processor, df, embeddings
20
+ model, text_model, processor, df, embeddings = load()
21
+
22
+ source = {0: '\nSource: Unsplash', 1: '\nSource: The Movie Database (TMDB)'}
23
+
24
+ def get_html(url_list, height=200):
25
+ html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
26
+ for url, title, link in url_list:
27
+ html2 = f"<img title='{title}' style='height: {height}px; margin: 5px' src='{url}'>"
28
+ if len(link) > 0:
29
+ html2 = f"<a href='{link}' target='_blank'>" + html2 + "</a>"
30
+ html = html + html2
31
+ html += "</div>"
32
+ return html
33
+
34
+ def compute_text_embeddings(list_of_strings):
35
+ inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
36
+ return model.text_projection(text_model(**inputs).pooler_output)
37
+
38
+ st.cache(show_spinner=False)
39
+ def image_search(query, corpus, n_results=24):
40
+ text_embeddings = compute_text_embeddings([query]).detach().numpy()
41
+ k = 0 if corpus == 'Unsplash' else 1
42
+ results = np.argsort((embeddings[k]@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
43
+ return [(df[k].iloc[i]['path'],
44
+ df[k].iloc[i]['tooltip'] + source[k],
45
+ df[k].iloc[i]['link']) for i in results]
46
+
47
+ description = '''
48
+ # Semantic image search
49
+ Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, πŸ€— Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/) and images from [Unsplash](https://unsplash.com/) and [The Movie Database (TMDB)](https://www.themoviedb.org/)
50
+ '''
51
+
52
+ def main():
53
+ st.markdown('''
54
+ <style>
55
+ .block-container{
56
+ max-width: 1200px;
57
+ }
58
+ stTextInput{
59
+ max-width: 600px;
60
+ }
61
+ #MainMenu {
62
+ visibility: hidden;
63
+ }
64
+ footer {
65
+ visibility: hidden;
66
+ }
67
+ </style>''',
68
+ unsafe_allow_html=True)
69
+ st.sidebar.markdown(description)
70
+ _, col1, col2, _ = st.columns([2, 10, 2, 2])
71
+ query = col1.text_input('')
72
+ corpus = col2.radio('', ["Unsplash","Movies"])
73
+ if len(query) > 0:
74
+ results = image_search(query, corpus)
75
+ st.markdown(get_html(results), unsafe_allow_html=True)
76
+
77
+ if __name__ == '__main__':
78
+ main()
data.csv ADDED
The diff for this file is too large to render. See raw diff
data2.csv ADDED
The diff for this file is too large to render. See raw diff
embeddings.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f8c171e32276739be6b020592edc8a2c06e029ff6505a9d1d4efe3cafa073bd
3
+ size 51200128
embeddings2.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:389f1012d8980c48d3e193dbed13435bbf249adc842c9e67c2ab1e3c5292cb76
3
+ size 15739008