File size: 3,948 Bytes
c81898a ba03fb2 c81898a 000d238 c81898a ba03fb2 c81898a ba03fb2 c81898a 000d238 c81898a ff968d5 9ea8c8c ba03fb2 9ea8c8c c81898a ff968d5 c81898a 7600dc3 555584f 7600dc3 586f7e5 55fea56 586f7e5 c81898a d2df557 3f8dd94 ff968d5 c81898a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import streamlit as st
import pandas as pd, numpy as np
from html import escape
import os
from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
@st.cache(show_spinner=False,
hash_funcs={CLIPModel: lambda _: None,
CLIPTextModel: lambda _: None,
CLIPProcessor: lambda _: None,
dict: lambda _: None})
def load():
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
df = {0: pd.read_csv('data.csv'), 1: pd.read_csv('data2.csv')}
embeddings = {0: np.load('embeddings.npy'), 1: np.load('embeddings2.npy')}
for k in [0, 1]:
embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True)))
return model, processor, df, embeddings
model, processor, df, embeddings = load()
source = {0: '\nSource: Unsplash', 1: '\nSource: The Movie Database (TMDB)'}
def get_html(url_list, height=200):
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
for url, title, link in url_list:
html2 = f"<img title='{escape(title)}' style='height: {height}px; margin: 5px' src='{escape(url)}'>"
if len(link) > 0:
html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>"
html = html + html2
html += "</div>"
return html
def compute_text_embeddings(list_of_strings):
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
return model.get_text_features(**inputs)
st.cache(show_spinner=False)
def image_search(query, corpus, n_results=24):
text_embeddings = compute_text_embeddings([query]).detach().numpy()
k = 0 if corpus == 'Unsplash' else 1
results = np.argsort((embeddings[k]@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
return [(df[k].iloc[i]['path'],
df[k].iloc[i]['tooltip'] + source[k],
df[k].iloc[i]['link']) for i in results]
description = '''
# Semantic image search
**Enter your query and hit enter**
*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
*Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe*
'''
def main():
st.markdown('''
<style>
.block-container{
max-width: 1200px;
}
div.row-widget.stRadio > div{
flex-direction:row;
display: flex;
justify-content: center;
}
div.row-widget.stRadio > div > label{
margin-left: 5px;
margin-right: 5px;
}
section.main>div:first-child {
padding-top: 0px;
}
section:not(.main)>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
_, c, _ = st.columns((1, 3, 1))
query = c.text_input('', value='clouds at sunset')
corpus = st.radio('', ["Unsplash","Movies"])
if len(query) > 0:
results = image_search(query, corpus)
st.markdown(get_html(results), unsafe_allow_html=True)
if __name__ == '__main__':
main()
|