Spaces:
Build error
Build error
import streamlit as st | |
import numpy as np | |
from html import escape | |
import torch | |
import torchvision.transforms as transforms | |
from transformers import BertModel, AutoTokenizer, CLIPVisionModel | |
from PIL import Image | |
import io | |
IMAGE_SIZE = 224 | |
MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073]) | |
STD = torch.tensor([0.26862954, 0.26130258, 0.27577711]) | |
device = 'cuda' | |
modelPath = 'TheLitttleThings/clip-archdaily-text' | |
tokenizer = AutoTokenizer.from_pretrained(modelPath) | |
text_encoder = BertModel.from_pretrained(modelPath).eval() | |
vision_encoder = CLIPVisionModel.from_pretrained( | |
'TheLitttleThings/clip-archdaily-vision').eval() | |
image_embeddings = torch.load('image_embeddings.pt') | |
text_embeddings = torch.load('text_embeddings.pt') | |
links = np.load('links_list.npy', allow_pickle=True) | |
categories = np.load('categories_list.npy', allow_pickle=True) | |
if 'tab' not in st.session_state: | |
st.session_state['tab'] = 0 | |
def image_search(query, top_k=24): | |
with torch.no_grad(): | |
text_embedding = text_encoder( | |
**tokenizer(query, return_tensors='pt')).pooler_output | |
_, indices = torch.cosine_similarity( | |
image_embeddings, text_embedding).sort(descending=True) | |
return [links[i] for i in indices[:top_k]] | |
def text_query_embedding(query: str = 'architecture'): | |
tokens = tokenizer(query, return_tensors='pt') | |
with torch.no_grad(): | |
text_embedding = text_encoder( | |
**tokenizer(query, return_tensors='pt')).pooler_output | |
return text_embedding | |
preprocessImage = transforms.Compose([ | |
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=MEAN, std=STD) | |
]) | |
def image_query_embedding(image): | |
image = preprocessImage(image).unsqueeze(0) | |
with torch.no_grad(): | |
image_embedding = vision_encoder(image).pooler_output | |
return image_embedding | |
def most_similars(embeddings_1, embeddings_2): | |
values, indices = torch.cosine_similarity( | |
embeddings_1, embeddings_2).sort(descending=True) | |
return values.cpu(), indices.cpu() | |
def analogy(input_image_path: str, top_k=64, additional_text: str = '', input_include=True): | |
""" Analogies with embedding space arithmetic. | |
Args: | |
input_image_path (str): The path to original image | |
image_paths (list[str]): A database of images | |
""" | |
base_image = Image.open(input_image_path) | |
image_embedding = image_query_embedding(base_image) | |
additional_embedding = text_query_embedding(query=additional_text) | |
new_image_embedding = image_embedding # + additional_embedding | |
_, indices = most_similars(image_embeddings, new_image_embedding) | |
return [links[i] for i in indices[:top_k]] | |
def image_comparison(base_image, top_k=64): | |
image_embedding = image_query_embedding(base_image) | |
#additional_embedding = text_query_embedding(query=additional_text) | |
new_image_embedding = image_embedding # + additional_embedding | |
_, indices = most_similars(image_embeddings, new_image_embedding) | |
return [links[i] for i in indices[:top_k]] | |
def get_html(url_list, classOther=""): | |
html = f"<div class='wrapper {classOther}'>" | |
for url in url_list: | |
project = url["project_url"] | |
image = url["source_url"] | |
title = url["title"] | |
year = url["year"] | |
html2 = f"<a href='{project}' target='_blank' class='link'><div class='imageparent' style='background-image:url({escape(image)})'></div><div>{year}/{title}</div></a>" | |
html = html + html2 | |
html += "</div>" | |
return html | |
def load_image(image_file): | |
img = Image.open(image_file) | |
return img | |
description = ''' | |
# Architecture-Clip | |
- Enter your query and hit enter | |
- Note: Quick demo if Clip model trained on Architectural images | |
Trained and indexed on 30k images from [ArchDaily](https://www.archdaily.com/) | |
Based on code from | |
- [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) | |
- [Clip-Italian](https://github.com/clip-italian/clip-italian) | |
''' | |
def main(): | |
st.markdown(''' | |
<style> | |
.block-container{ | |
max-width: 1200px; | |
} | |
.wrapper { | |
display: grid; | |
grid-template-columns: repeat(3, 1fr); | |
gap: 10px; | |
grid-auto-rows: minmax(100px, auto); | |
max-width: 1100px; | |
justify-content: space-evenly; | |
} | |
.wrapper.small{ | |
grid-template-columns: repeat(2, 1fr); | |
} | |
.imageparent{ | |
overflow:hidden; | |
width:100%; | |
aspect-ratio : 1 / 1; | |
margin-bottom: 2px; | |
background-repeat:no-repeat; | |
background-size:cover; | |
background-position:center center; | |
} | |
a.link{ | |
display:block; | |
} | |
section.main>div:first-child { | |
padding-top: 0px; | |
} | |
div.reportview-container > section:first-child{ | |
max-width: 320px; | |
} | |
#MainMenu { | |
visibility: hidden; | |
} | |
footer { | |
visibility: hidden; | |
} | |
</style>''', | |
unsafe_allow_html=True) | |
st.sidebar.markdown(description) | |
_, col1, col2, _ = st.columns((2, 2, 2, 2)) | |
mainContain = st.container() | |
if col1.button("Search by text"): | |
st.session_state['tab'] = 1 | |
if col2.button("Evaluate Image"): | |
st.session_state['tab'] = 2 | |
# def textSearch(mainContain): | |
if st.session_state['tab'] == 1: | |
_, c, _ = mainContain.columns((1, 6, 1)) | |
c.header("Text Search") | |
query = c.text_input('Search Box', value='Architecture') | |
if len(query) > 0: | |
c.text("It'll take about 30s to load all new images") | |
results = image_search(query) | |
c.markdown(get_html(results, "big"), | |
unsafe_allow_html=True) | |
if st.session_state['tab'] == 2: | |
_, d, _ = mainContain.columns((1, 6, 1)) | |
d.header("Evaluate Image") | |
image_file = d.file_uploader("Choose a file", type=['png', 'jpg']) | |
if image_file is not None: | |
_, left, right, _ = mainContain.columns((1, 2, 4, 1)) | |
img = load_image(image_file) | |
left.image(img) | |
left.markdown("It'll take about 30s to load all new images") | |
results = image_comparison(img) | |
right.markdown(get_html(results, "small"), unsafe_allow_html=True) | |
image_embedding = image_query_embedding(img) | |
values, indices = most_similars(image_embedding, text_embeddings) | |
for i, sim in zip(indices, torch.softmax(values, dim=0)): | |
left.markdown(f'{round(float(sim * 100), 1)} | {categories[i]}') | |
if __name__ == '__main__': | |
main() | |