import streamlit as st
import pandas as pd, numpy as np
from html import escape
import os
from transformers import CLIPProcessor, CLIPModel
@st.cache(
show_spinner=False,
hash_funcs={
CLIPModel: lambda _: None,
CLIPProcessor: lambda _: None,
dict: lambda _: None,
},
)
def load():
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch16")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")
df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")}
embeddings = {0: np.load("embeddings.npy"), 1: np.load("embeddings2.npy")}
for k in [0, 1]:
embeddings[k] = np.divide(
embeddings[k], np.sqrt(np.sum(embeddings[k] ** 2, axis=1, keepdims=True))
)
return model, processor, df, embeddings
model, processor, df, embeddings = load()
source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"}
def get_html(url_list, height=200):
html = "
"
for url, title, link in url_list:
html2 = f"
"
if len(link) > 0:
html2 = f"
" + html2 + ""
html = html + html2
html += "
"
return html
def compute_text_embeddings(list_of_strings):
inputs = processor(text=list_of_strings, return_tensors="pt", padding=True)
return model.get_text_features(**inputs)
st.cache(show_spinner=False)
def image_search(query, corpus, n_results=24):
text_embeddings = compute_text_embeddings([query]).detach().numpy()
k = 0 if corpus == "Unsplash" else 1
results = np.argsort((embeddings[k] @ text_embeddings.T)[:, 0])[
-1 : -n_results - 1 : -1
]
return [
(
df[k].iloc[i]["path"],
df[k].iloc[i]["tooltip"] + source[k],
df[k].iloc[i]["link"],
)
for i in results
]
description = """
# Semantic image search
**Enter your query and hit enter**
*Built with OpenAI's [CLIP](https://openai.com/blog/clip/) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), 25k images from [Unsplash](https://unsplash.com/) and 8k images from [The Movie Database (TMDB)](https://www.themoviedb.org/)*
*Inspired by [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) from Vladimir Haltakov and [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) from Travis Hoppe*
"""
def main():
st.markdown(
"""
""",
unsafe_allow_html=True,
)
st.sidebar.markdown(description)
_, c, _ = st.columns((1, 3, 1))
query = c.text_input("", value="clouds at sunset")
corpus = st.radio("", ["Unsplash", "Movies"])
if len(query) > 0:
results = image_search(query, corpus)
st.markdown(get_html(results), unsafe_allow_html=True)
if __name__ == "__main__":
main()