import streamlit as st import numpy as np from html import escape import torch import torchvision.transforms as transforms from transformers import BertModel, AutoTokenizer,CLIPVisionModel from PIL import Image import io IMAGE_SIZE = 224 MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073]) STD = torch.tensor([0.26862954, 0.26130258, 0.27577711]) device='cuda' modelPath = 'TheLitttleThings/clip-archdaily-text' tokenizer = AutoTokenizer.from_pretrained(modelPath) text_encoder = BertModel.from_pretrained(modelPath).eval() vision_encoder = CLIPVisionModel.from_pretrained('TheLitttleThings/clip-archdaily-vision').eval() image_embeddings = torch.load('image_embeddings.pt') links = np.load('data_list.npy', allow_pickle=True) categories = np.load('categories_list.npy', allow_pickle=True) if 'tab' not in st.session_state: st.session_state['tab'] = 0 @st.experimental_memo def image_search(query, top_k=24): with torch.no_grad(): text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output _, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True) return [links[i] for i in indices[:top_k]] def text_query_embedding( query: str = 'architecture'): tokens = tokenizer(query, return_tensors='pt') with torch.no_grad(): text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output return text_embedding preprocessImage = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize(mean=MEAN, std=STD) ]) def image_query_embedding( image): image = preprocessImage(image).unsqueeze(0) with torch.no_grad(): image_embedding = vision_encoder(image).pooler_output return image_embedding def most_similars(embeddings_1, embeddings_2): values, indices = torch.cosine_similarity(embeddings_1, embeddings_2).sort(descending=True) return values.cpu(), indices.cpu() def analogy(input_image_path: str, top_k=24, additional_text: str = '', input_include=True): """ Analogies with embedding space arithmetic. Args: input_image_path (str): The path to original image image_paths (list[str]): A database of images """ base_image = Image.open(input_image_path) image_embedding = image_query_embedding(base_image) additional_embedding = text_query_embedding(query=additional_text) new_image_embedding = image_embedding #+ additional_embedding _, indices = most_similars(image_embeddings, new_image_embedding) return [links[i] for i in indices[:top_k]] def image_comparison(base_image, top_k=24): image_embedding = image_query_embedding(base_image) #additional_embedding = text_query_embedding(query=additional_text) new_image_embedding = image_embedding #+ additional_embedding _, indices = most_similars(image_embeddings, new_image_embedding) return [links[i] for i in indices[:top_k]] def get_html(url_list): html = "
" for url in url_list: project = url["project_url"] image = url["source_url"] title = url["title"] year = url["year"] html2 = f"
{year}/{title}
" html = html + html2 html += "
" return html def load_image(image_file): img = Image.open(image_file) return img description = ''' # Architecture-Clip - Enter your query and hit enter - Note: Quick demo if Clip model trained on Architectural images Built with 5k images from [ArchDaily](https://www.archdaily.com/) Based on code from [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) [Clip-Italian](https://github.com/clip-italian/clip-italian) ''' def main(): st.markdown(''' ''', unsafe_allow_html=True) st.sidebar.markdown(description) col1, col2, col3 = st.columns(3) mainContain = st.container() if col1.button("Search by text"): st.session_state['tab'] = 1 if col2.button("Find Similar"): st.session_state['tab'] = 2 if col3.button("Classify"): st.session_state['tab'] = 3 #def textSearch(mainContain): if st.session_state['tab'] == 1: mainContain.header("Text Search") _, c, _ = mainContain.columns((1, 3, 1)) query = c.text_input('Search Box', value='Architecture') c.text("It'll take about 30s to load all new images") if len(query) > 0: results = image_search(query) mainContain.markdown(get_html(results), unsafe_allow_html=True) #def compare(mainContain): if st.session_state['tab'] == 2: mainContain.header("Image Relations") _, d, _ = mainContain.columns((1, 3, 1)) image_file = d.file_uploader("Choose a file", type=['png', 'jpg']) if image_file is not None: # To read file as bytes: #bytes_data = uploaded_file.getvalue() #st.write(bytes_data) img = load_image(image_file) d.image(img,width=300) d.text("It'll take about 30s to load all new images") results = image_comparison(img) mainContain.markdown(get_html(results), unsafe_allow_html=True) #def classify(mainContain): if st.session_state['tab'] == 3: mainContain.header("Classify Elements") _, d, _ = mainContain.columns((1, 3, 1)) d.text("Coming soon") #col1.button("Search by text", on_click=textSearch, args=(mainContain,)) #col2.button("Find Similar", on_click=compare, args=(mainContain,)) #col3.button("Classify", on_click=classify, args=(mainContain,)) if __name__ == '__main__': main()