File size: 3,670 Bytes
c81898a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff968d5
 
 
 
c81898a
 
 
 
 
 
 
 
ff968d5
 
 
 
 
 
 
 
c81898a
7600dc3
 
 
586f7e5
55fea56
586f7e5
c81898a
 
 
 
 
 
 
 
 
ff968d5
 
 
c81898a
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import streamlit as st
import pandas as pd, numpy as np
import os
from transformers import CLIPProcessor, CLIPTextModel, CLIPModel

@st.cache(show_spinner=False,
          hash_funcs={CLIPModel: lambda _: None,
                      CLIPTextModel: lambda _: None,
                      CLIPProcessor: lambda _: None,
                      dict: lambda _: None})
def load():
  model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
  text_model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
  processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
  df = {0: pd.read_csv('data.csv'), 1: pd.read_csv('data2.csv')}
  embeddings = {0: np.load('embeddings.npy'), 1: np.load('embeddings2.npy')}
  for k in [0, 1]:
    embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True)))
  return model, text_model, processor, df, embeddings
model, text_model, processor, df, embeddings = load()

source = {0: '\nSource: Unsplash', 1: '\nSource: The Movie Database (TMDB)'}

def get_html(url_list, height=200):
    html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
    for url, title, link in url_list:
        html2 = f"<img title='{title}' style='height: {height}px; margin: 5px' src='{url}'>"
        if len(link) > 0:
            html2 = f"<a href='{link}' target='_blank'>" + html2 + "</a>"
        html = html + html2
    html += "</div>"
    return html

def compute_text_embeddings(list_of_strings):
    inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
    return model.text_projection(text_model(**inputs).pooler_output)

st.cache(show_spinner=False)
def image_search(query, corpus, n_results=24):
    text_embeddings = compute_text_embeddings([query]).detach().numpy()
    k = 0 if corpus == 'Unsplash' else 1
    results = np.argsort((embeddings[k]@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
    return [(df[k].iloc[i]['path'],
             df[k].iloc[i]['tooltip'] + source[k],
             df[k].iloc[i]['link']) for i in results]

description = '''
# Semantic image search

**Enter your query and hit enter**

*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/) and images from [Unsplash](https://unsplash.com/) and [The Movie Database (TMDB)](https://www.themoviedb.org/)*
'''

def main():
  st.markdown('''
              <style>
              .block-container{
                max-width: 1200px;
              }
              div.row-widget.stRadio > div{
                flex-direction:row;
                display: flex;
                justify-content: center;
              }
              div.row-widget.stRadio > div > label{
                margin-left: 5px;
                margin-right: 5px;
              }
              section.main>div:first-child {
                padding-top: 50px;
              }
              div.reportview-container > section:first-child{
                max-width: 320px;
              }
              #MainMenu {
                visibility: hidden;
              }
              footer {
                visibility: hidden;
              }
              </style>''',
              unsafe_allow_html=True)
  st.sidebar.markdown(description)
  _, c, _ = st.beta_columns((1, 3, 1))
  query = c.text_input('')
  corpus = st.radio('', ["Unsplash","Movies"])
  if len(query) > 0:
    results = image_search(query, corpus)
    st.markdown(get_html(results), unsafe_allow_html=True)

if __name__ == '__main__':
  main()