import streamlit as st import numpy as np from html import escape import torch import torchvision.transforms as transforms from transformers import BertModel, AutoTokenizer, CLIPVisionModel from PIL import Image import io IMAGE_SIZE = 224 MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073]) STD = torch.tensor([0.26862954, 0.26130258, 0.27577711]) device = 'cuda' modelPath = 'TheLitttleThings/clip-archdaily-text' tokenizer = AutoTokenizer.from_pretrained(modelPath) text_encoder = BertModel.from_pretrained(modelPath).eval() vision_encoder = CLIPVisionModel.from_pretrained( 'TheLitttleThings/clip-archdaily-vision').eval() image_embeddings = torch.load('image_embeddings.pt') text_embeddings = torch.load('text_embeddings.pt') links = np.load('links_list.npy', allow_pickle=True) categories = np.load('categories_list.npy', allow_pickle=True) if 'tab' not in st.session_state: st.session_state['tab'] = 0 @st.experimental_memo def image_search(query, top_k=24): with torch.no_grad(): text_embedding = text_encoder( **tokenizer(query, return_tensors='pt')).pooler_output _, indices = torch.cosine_similarity( image_embeddings, text_embedding).sort(descending=True) return [links[i] for i in indices[:top_k]] def text_query_embedding(query: str = 'architecture'): tokens = tokenizer(query, return_tensors='pt') with torch.no_grad(): text_embedding = text_encoder( **tokenizer(query, return_tensors='pt')).pooler_output return text_embedding preprocessImage = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=MEAN, std=STD) ]) def image_query_embedding(image): image = preprocessImage(image).unsqueeze(0) with torch.no_grad(): image_embedding = vision_encoder(image).pooler_output return image_embedding def most_similars(embeddings_1, embeddings_2): values, indices = torch.cosine_similarity( embeddings_1, embeddings_2).sort(descending=True) return values.cpu(), indices.cpu() def analogy(input_image_path: str, top_k=24, additional_text: str = '', input_include=True): """ Analogies with embedding space arithmetic. Args: input_image_path (str): The path to original image image_paths (list[str]): A database of images """ base_image = Image.open(input_image_path) image_embedding = image_query_embedding(base_image) additional_embedding = text_query_embedding(query=additional_text) new_image_embedding = image_embedding # + additional_embedding _, indices = most_similars(image_embeddings, new_image_embedding) return [links[i] for i in indices[:top_k]] def image_comparison(base_image, top_k=24): image_embedding = image_query_embedding(base_image) #additional_embedding = text_query_embedding(query=additional_text) new_image_embedding = image_embedding # + additional_embedding _, indices = most_similars(image_embeddings, new_image_embedding) return [links[i] for i in indices[:top_k]] def get_html(url_list, classOther=""): html = f"
" for url in url_list: project = url["project_url"] image = url["source_url"] title = url["title"] year = url["year"] html2 = f"
{year}/{title}
" html = html + html2 html += "
" return html def load_image(image_file): img = Image.open(image_file) return img description = ''' # Architecture-Clip - Enter your query and hit enter - Note: Quick demo if Clip model trained on Architectural images Built with 5k images from [ArchDaily](https://www.archdaily.com/) Based on code from [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) [Clip-Italian](https://github.com/clip-italian/clip-italian) ''' def main(): st.markdown(''' ''', unsafe_allow_html=True) st.sidebar.markdown(description) _, col1, col2, col3, _ = st.columns((1, 2, 2, 2, 1)) mainContain = st.container() if col1.button("Search by text"): st.session_state['tab'] = 1 if col2.button("Find Similar"): st.session_state['tab'] = 2 if col3.button("Classify"): st.session_state['tab'] = 3 # def textSearch(mainContain): if st.session_state['tab'] == 1: _, c, _ = mainContain.columns((1, 6, 1)) c.header("Text Search") query = c.text_input('Search Box', value='Architecture') if len(query) > 0: c.text("It'll take about 30s to load all new images") results = image_search(query) mainContain.markdown(get_html(results, "big"), unsafe_allow_html=True) if st.session_state['tab'] == 2: _, d, _ = mainContain.columns((1, 6, 1)) d.header("Find Related") image_file = d.file_uploader("Choose a file", type=['png', 'jpg']) if image_file is not None: _, left, right, _ = mainContain.columns((1, 2, 4, 1)) img = load_image(image_file) left.image(img, width=300) left.text("It'll take about 30s to load all new images") results = image_comparison(img) right.markdown(get_html(results, "small"), unsafe_allow_html=True) if st.session_state['tab'] == 3: _, d, _ = mainContain.columns((1, 6, 1)) d.header("Classify Elements") image_file = d.file_uploader("Choose a file", type=['png', 'jpg']) if image_file is not None: img = load_image(image_file) _, left, right, _ = mainContain.columns((1, 4, 2, 1)) left.image(img, width=300) image_embedding = image_query_embedding(img) values, indices = most_similars(image_embedding, text_embeddings) for i, sim in zip(indices, torch.softmax(values, dim=0)): right.text(f'label: {categories[i]} | {round(float(sim), 3)}') if __name__ == '__main__': main()