clip / app.py
Vivien
Move content up
555584f
raw history blame
No virus
3.69 kB
import streamlit as st
import pandas as pd, numpy as np
import os
from transformers import CLIPProcessor, CLIPTextModel, CLIPModel
@st.cache(show_spinner=False,
hash_funcs={CLIPModel: lambda _: None,
CLIPTextModel: lambda _: None,
CLIPProcessor: lambda _: None,
dict: lambda _: None})
def load():
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
df = {0: pd.read_csv('data.csv'), 1: pd.read_csv('data2.csv')}
embeddings = {0: np.load('embeddings.npy'), 1: np.load('embeddings2.npy')}
for k in [0, 1]:
embeddings[k] = np.divide(embeddings[k], np.sqrt(np.sum(embeddings[k]**2, axis=1, keepdims=True)))
return model, processor, df, embeddings
model, processor, df, embeddings = load()
source = {0: '\nSource: Unsplash', 1: '\nSource: The Movie Database (TMDB)'}
def get_html(url_list, height=200):
html = "<div style='margin-top: 20px; max-width: 1200px; display: flex; flex-wrap: wrap; justify-content: space-evenly'>"
for url, title, link in url_list:
html2 = f"<img title='{title}' style='height: {height}px; margin: 5px' src='{url}'>"
if len(link) > 0:
html2 = f"<a href='{link}' target='_blank'>" + html2 + "</a>"
html = html + html2
html += "</div>"
return html
def compute_text_embeddings(list_of_strings):
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
return model.get_text_features(**inputs)
st.cache(show_spinner=False)
def image_search(query, corpus, n_results=24):
text_embeddings = compute_text_embeddings([query]).detach().numpy()
k = 0 if corpus == 'Unsplash' else 1
results = np.argsort((embeddings[k]@text_embeddings.T)[:, 0])[-1:-n_results-1:-1]
return [(df[k].iloc[i]['path'],
df[k].iloc[i]['tooltip'] + source[k],
df[k].iloc[i]['link']) for i in results]
description = '''
# Semantic image search
**Enter your query and hit enter**
*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, πŸ€— Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
'''
def main():
st.markdown('''
<style>
.block-container{
max-width: 1200px;
}
div.row-widget.stRadio > div{
flex-direction:row;
display: flex;
justify-content: center;
}
div.row-widget.stRadio > div > label{
margin-left: 5px;
margin-right: 5px;
}
section.main>div:first-child {
padding-top: 0px;
}
section:not(.main)>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
_, c, _ = st.beta_columns((1, 3, 1))
query = c.text_input('', value='clouds at sunset')
corpus = st.radio('', ["Unsplash","Movies"])
if len(query) > 0:
results = image_search(query, corpus)
st.markdown(get_html(results), unsafe_allow_html=True)
if __name__ == '__main__':
main()