Spaces:
Sleeping
Sleeping
| from html import escape | |
| import re | |
| import streamlit as st | |
| import pandas as pd, numpy as np | |
| import torch | |
| from transformers import CLIPProcessor, CLIPModel | |
| from st_clickable_images import clickable_images | |
| MODEL_NAMES = [ | |
| # "base-patch32", | |
| # "base-patch16", | |
| # "large-patch14", | |
| "large-patch14-336" | |
| ] | |
| def load(): | |
| df = {0: pd.read_csv("data.csv"), 1: pd.read_csv("data2.csv")} | |
| models = {} | |
| processors = {} | |
| embeddings = {} | |
| for name in MODEL_NAMES: | |
| models[name] = CLIPModel.from_pretrained(f"openai/clip-vit-{name}").eval() | |
| processors[name] = CLIPProcessor.from_pretrained(f"openai/clip-vit-{name}") | |
| embeddings[name] = { | |
| 0: np.load(f"embeddings-vit-{name}.npy"), | |
| 1: np.load(f"embeddings2-vit-{name}.npy"), | |
| } | |
| for k in [0, 1]: | |
| embeddings[name][k] = embeddings[name][k] / np.linalg.norm( | |
| embeddings[name][k], axis=1, keepdims=True | |
| ) | |
| return models, processors, df, embeddings | |
| models, processors, df, embeddings = load() | |
| source = {0: "\nSource: Unsplash", 1: "\nSource: The Movie Database (TMDB)"} | |
| def compute_text_embeddings(list_of_strings, name): | |
| inputs = processors[name](text=list_of_strings, return_tensors="pt", padding=True) | |
| with torch.no_grad(): | |
| result = models[name].get_text_features(**inputs).detach().numpy() | |
| return result / np.linalg.norm(result, axis=1, keepdims=True) | |
| def image_search(query, corpus, name, n_results=24): | |
| positive_embeddings = None | |
| def concatenate_embeddings(e1, e2): | |
| if e1 is None: | |
| return e2 | |
| else: | |
| return np.concatenate((e1, e2), axis=0) | |
| splitted_query = query.split("EXCLUDING ") | |
| dot_product = 0 | |
| k = 0 if corpus == "Unsplash" else 1 | |
| if len(splitted_query[0]) > 0: | |
| positive_queries = splitted_query[0].split(";") | |
| for positive_query in positive_queries: | |
| match = re.match(r"\[(Movies|Unsplash):(\d{1,5})\](.*)", positive_query) | |
| if match: | |
| corpus2, idx, remainder = match.groups() | |
| idx, remainder = int(idx), remainder.strip() | |
| k2 = 0 if corpus2 == "Unsplash" else 1 | |
| positive_embeddings = concatenate_embeddings( | |
| positive_embeddings, embeddings[name][k2][idx : idx + 1, :] | |
| ) | |
| if len(remainder) > 0: | |
| positive_embeddings = concatenate_embeddings( | |
| positive_embeddings, compute_text_embeddings([remainder], name) | |
| ) | |
| else: | |
| positive_embeddings = concatenate_embeddings( | |
| positive_embeddings, compute_text_embeddings([positive_query], name) | |
| ) | |
| dot_product = embeddings[name][k] @ positive_embeddings.T | |
| dot_product = dot_product - np.median(dot_product, axis=0) | |
| dot_product = dot_product / np.max(dot_product, axis=0, keepdims=True) | |
| dot_product = np.min(dot_product, axis=1) | |
| if len(splitted_query) > 1: | |
| negative_queries = (" ".join(splitted_query[1:])).split(";") | |
| negative_embeddings = compute_text_embeddings(negative_queries, name) | |
| dot_product2 = embeddings[name][k] @ negative_embeddings.T | |
| dot_product2 = dot_product2 - np.median(dot_product2, axis=0) | |
| dot_product2 = dot_product2 / np.max(dot_product2, axis=0, keepdims=True) | |
| dot_product -= np.max(np.maximum(dot_product2, 0), axis=1) | |
| results = np.argsort(dot_product)[-1 : -n_results - 1 : -1] | |
| return [ | |
| ( | |
| df[k].iloc[i]["path"], | |
| df[k].iloc[i]["tooltip"] + source[k], | |
| i, | |
| ) | |
| for i in results | |
| ] | |
| description = """ | |
| # 意味による画像検索 | |
| **検索語を入力してから Enter キーを押してください** | |
| *OpenAI の [CLIP](https://openai.com/blog/clip/) model, 🤗 Hugging Face's [transformers library](https://huggingface.co/transformers/), [Streamlit](https://streamlit.io/), [Unsplash](https://unsplash.com/) の 25k images と [The Movie Database (TMDB)](https://www.themoviedb.org/) の 8k images を使用して構築しています。* | |
| *Vladimir Haltakov の [Unsplash Image Search](https://github.com/haltakov/natural-language-image-search) と Travis Hoppe の [Alph, The Sacred River](https://github.com/thoppe/alph-the-sacred-river) に触発されました。* | |
| """ | |
| howto = """ | |
| - 画像をクリックすると、それをクエリとして使用し、類似画像を検索できます。 | |
| - 複数の検索語を組み合わせることができます(区切り文字として「**;**」を使用します)。 | |
| - 検索語に 「**EXCLUDING**」 が含まれている場合、その右側の部分が否定クエリとして使用されます。 | |
| """ | |
| div_style = { | |
| "display": "flex", | |
| "justify-content": "center", | |
| "flex-wrap": "wrap", | |
| } | |
| def main(): | |
| st.markdown( | |
| """ | |
| <style> | |
| .block-container{ | |
| max-width: 1200px; | |
| } | |
| div.row-widget.stRadio > div{ | |
| flex-direction:row; | |
| display: flex; | |
| justify-content: center; | |
| } | |
| div.row-widget.stRadio > div > label{ | |
| margin-left: 5px; | |
| margin-right: 5px; | |
| } | |
| .row-widget { | |
| margin-top: -25px; | |
| } | |
| section>div:first-child { | |
| padding-top: 30px; | |
| } | |
| div.reportview-container > section:first-child{ | |
| max-width: 320px; | |
| } | |
| #MainMenu { | |
| visibility: hidden; | |
| } | |
| footer { | |
| visibility: hidden; | |
| } | |
| </style>""", | |
| unsafe_allow_html=True, | |
| ) | |
| st.sidebar.markdown(description) | |
| with st.sidebar.expander("高度な使用方法"): | |
| st.markdown(howto) | |
| # mode = st.sidebar.selectbox( | |
| # "", ["Results for ViT-L/14@336px", "Comparison of 2 models"], index=0 | |
| # ) | |
| _, c, _ = st.columns((1, 3, 1)) | |
| if "query" in st.session_state: | |
| query = c.text_input("", value=st.session_state["query"]) | |
| else: | |
| query = c.text_input("", value="clouds at sunset") | |
| corpus = st.radio("", ["Unsplash", "Movies"]) | |
| models_dict = { | |
| "ViT-B/32 (quicker)": "base-patch32", | |
| "ViT-B/16 (average)": "base-patch16", | |
| # "ViT-L/14 (slow)": "large-patch14", | |
| "ViT-L/14@336px (slower)": "large-patch14-336", | |
| } | |
| if False: # "Comparison" in mode: | |
| c1, c2 = st.columns((1, 1)) | |
| selection1 = c1.selectbox("", models_dict.keys(), index=0) | |
| selection2 = c2.selectbox("", models_dict.keys(), index=2) | |
| name1 = models_dict[selection1] | |
| name2 = models_dict[selection2] | |
| else: | |
| name1 = MODEL_NAMES[-1] | |
| if len(query) > 0: | |
| results1 = image_search(query, corpus, name1) | |
| if False: # "Comparison" in mode: | |
| with c1: | |
| clicked1 = clickable_images( | |
| [result[0] for result in results1], | |
| titles=[result[1] for result in results1], | |
| div_style=div_style, | |
| img_style={"margin": "2px", "height": "150px"}, | |
| key=query + corpus + name1 + "1", | |
| ) | |
| results2 = image_search(query, corpus, name2) | |
| with c2: | |
| clicked2 = clickable_images( | |
| [result[0] for result in results2], | |
| titles=[result[1] for result in results2], | |
| div_style=div_style, | |
| img_style={"margin": "2px", "height": "150px"}, | |
| key=query + corpus + name2 + "2", | |
| ) | |
| else: | |
| clicked1 = clickable_images( | |
| [result[0] for result in results1], | |
| titles=[result[1] for result in results1], | |
| div_style=div_style, | |
| img_style={"margin": "2px", "height": "200px"}, | |
| key=query + corpus + name1 + "1", | |
| ) | |
| clicked2 = -1 | |
| if clicked2 >= 0 or clicked1 >= 0: | |
| change_query = False | |
| if "last_clicked" not in st.session_state: | |
| change_query = True | |
| else: | |
| if max(clicked2, clicked1) != st.session_state["last_clicked"]: | |
| change_query = True | |
| if change_query: | |
| if clicked1 >= 0: | |
| st.session_state["query"] = f"[{corpus}:{results1[clicked1][2]}]" | |
| # elif clicked2 >= 0: | |
| # st.session_state["query"] = f"[{corpus}:{results2[clicked2][2]}]" | |
| st.experimental_rerun() | |
| if __name__ == "__main__": | |
| main() | |