clip / app.py
Vivien
Escape strings to avoid quote problems
ba03fb2
raw history blame
No virus
3.95 kB
import streamlit as st
import pandas as pd, numpy as np
from html import escape
import os
from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
@st.cache(show_spinner=False,
hash_funcs={CLIPModel: lambda _: None,
CLIPTextModel: lambda _: None,
CLIPProcessor: lambda _: None,
dict: lambda _: None})
def load():
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
df = {0: pd.read_csv('data.csv'), 1: pd.read_csv('data2.csv')}
embeddings = {0: np.load('embeddings.npy'), 1: np.load('embeddings2.npy')}
for k in [0, 1]:
embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True)))
return model, processor, df, embeddings
model, processor, df, embeddings = load()
source = {0: '\nSource: Unsplash', 1: '\nSource: The Movie Database (TMDB)'}
def get_html(url_list, height=200):
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
for url, title, link in url_list:
html2 = f"<img title='{escape(title)}' style='height: {height}px; margin: 5px' src='{escape(url)}'>"
if len(link) > 0:
html2 = f"<a href='{escape(link)}' target='_blank'>" + html2 + "</a>"
html = html + html2
html += "</div>"
return html
def compute_text_embeddings(list_of_strings):
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
return model.get_text_features(**inputs)
st.cache(show_spinner=False)
def image_search(query, corpus, n_results=24):
text_embeddings = compute_text_embeddings([query]).detach().numpy()
k = 0 if corpus == 'Unsplash' else 1
results = np.argsort((embeddings[k]@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
return [(df[k].iloc[i]['path'],
df[k].iloc[i]['tooltip'] + source[k],
df[k].iloc[i]['link']) for i in results]
description = '''
# Semantic image search
**Enter your query and hit enter**
*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, πŸ€— Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)
Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe*
'''
def main():
st.markdown('''
<style>
.block-container{
max-width: 1200px;
}
div.row-widget.stRadio > div{
flex-direction:row;
display: flex;
justify-content: center;
}
div.row-widget.stRadio > div > label{
margin-left: 5px;
margin-right: 5px;
}
section.main>div:first-child {
padding-top: 0px;
}
section:not(.main)>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
_, c, _ = st.beta_columns((1, 3, 1))
query = c.text_input('', value='clouds at sunset')
corpus = st.radio('', ["Unsplash","Movies"])
if len(query) > 0:
results = image_search(query, corpus)
st.markdown(get_html(results), unsafe_allow_html=True)
if __name__ == '__main__':
main()