|
import streamlit as st |
|
import numpy as np |
|
from html import escape |
|
import torch |
|
import torchvision.transforms as transforms |
|
from transformers import BertModel, AutoTokenizer, CLIPVisionModel |
|
from PIL import Image |
|
import io |
|
|
|
IMAGE_SIZE = 224 |
|
MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073]) |
|
STD = torch.tensor([0.26862954, 0.26130258, 0.27577711]) |
|
device = 'cuda' |
|
|
|
modelPath = 'TheLitttleThings/clip-archdaily-text' |
|
tokenizer = AutoTokenizer.from_pretrained(modelPath) |
|
text_encoder = BertModel.from_pretrained(modelPath).eval() |
|
vision_encoder = CLIPVisionModel.from_pretrained( |
|
'TheLitttleThings/clip-archdaily-vision').eval() |
|
|
|
image_embeddings = torch.load('image_embeddings.pt') |
|
text_embeddings = torch.load('text_embeddings.pt') |
|
links = np.load('links_list.npy', allow_pickle=True) |
|
categories = np.load('categories_list.npy', allow_pickle=True) |
|
|
|
if 'tab' not in st.session_state: |
|
st.session_state['tab'] = 0 |
|
|
|
|
|
@st.experimental_memo |
|
def image_search(query, top_k=24): |
|
with torch.no_grad(): |
|
text_embedding = text_encoder( |
|
**tokenizer(query, return_tensors='pt')).pooler_output |
|
_, indices = torch.cosine_similarity( |
|
image_embeddings, text_embedding).sort(descending=True) |
|
|
|
return [links[i] for i in indices[:top_k]] |
|
|
|
|
|
def text_query_embedding(query: str = 'architecture'): |
|
tokens = tokenizer(query, return_tensors='pt') |
|
with torch.no_grad(): |
|
text_embedding = text_encoder( |
|
**tokenizer(query, return_tensors='pt')).pooler_output |
|
return text_embedding |
|
|
|
|
|
preprocessImage = transforms.Compose([ |
|
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=MEAN, std=STD) |
|
]) |
|
|
|
|
|
def image_query_embedding(image): |
|
image = preprocessImage(image).unsqueeze(0) |
|
with torch.no_grad(): |
|
image_embedding = vision_encoder(image).pooler_output |
|
return image_embedding |
|
|
|
|
|
def most_similars(embeddings_1, embeddings_2): |
|
values, indices = torch.cosine_similarity( |
|
embeddings_1, embeddings_2).sort(descending=True) |
|
return values.cpu(), indices.cpu() |
|
|
|
|
|
def analogy(input_image_path: str, top_k=24, additional_text: str = '', input_include=True): |
|
""" Analogies with embedding space arithmetic. |
|
Args: |
|
input_image_path (str): The path to original image |
|
image_paths (list[str]): A database of images |
|
""" |
|
base_image = Image.open(input_image_path) |
|
image_embedding = image_query_embedding(base_image) |
|
additional_embedding = text_query_embedding(query=additional_text) |
|
new_image_embedding = image_embedding |
|
_, indices = most_similars(image_embeddings, new_image_embedding) |
|
|
|
return [links[i] for i in indices[:top_k]] |
|
|
|
|
|
def image_comparison(base_image, top_k=24): |
|
image_embedding = image_query_embedding(base_image) |
|
|
|
new_image_embedding = image_embedding |
|
_, indices = most_similars(image_embeddings, new_image_embedding) |
|
|
|
return [links[i] for i in indices[:top_k]] |
|
|
|
|
|
def get_html(url_list, classOther=""): |
|
html = f"<div class='wrapper {classOther}'>" |
|
for url in url_list: |
|
project = url["project_url"] |
|
image = url["source_url"] |
|
title = url["title"] |
|
year = url["year"] |
|
html2 = f"<a href='{project}' target='_blank' class='link'><div class='imageparent'><img style=' src='{escape(image)}'/></div><div>{year}/{title}</div></a>" |
|
html = html + html2 |
|
html += "</div>" |
|
return html |
|
|
|
|
|
def load_image(image_file): |
|
img = Image.open(image_file) |
|
return img |
|
|
|
|
|
description = ''' |
|
# Architecture-Clip |
|
- Enter your query and hit enter |
|
- Note: Quick demo if Clip model trained on Architectural images |
|
Built with 5k images from [ArchDaily](https://www.archdaily.com/) |
|
Based on code from |
|
[CLIP-fa](https://github.com/sajjjadayobi/CLIPfa) |
|
[Clip-Italian](https://github.com/clip-italian/clip-italian) |
|
''' |
|
|
|
|
|
def main(): |
|
st.markdown(''' |
|
<style> |
|
.block-container{ |
|
max-width: 1200px; |
|
} |
|
.wrapper { |
|
display: grid; |
|
grid-template-columns: repeat(3, 1fr); |
|
gap: 10px; |
|
grid-auto-rows: minmax(100px, auto); |
|
margin-top: 50px; |
|
max-width: 1100px; |
|
justify-content: space-evenly; |
|
} |
|
.wrapper.small{ |
|
grid-template-columns: repeat(2, 1fr); |
|
} |
|
.imageparent{ |
|
overflow:hidden; |
|
width:100%; |
|
height:100%; |
|
aspect-ratio : 1 / 1; |
|
margin-bottom: 2px; |
|
|
|
} |
|
a.link{ |
|
display:block; |
|
} |
|
.wrapper a img{ |
|
width:100%; |
|
display:block; |
|
aspect-ratio : 1 / 1; |
|
} |
|
section.main>div:first-child { |
|
padding-top: 0px; |
|
} |
|
section:not(.main)>div:first-child { |
|
padding-top: 30px; |
|
} |
|
div.reportview-container > section:first-child{ |
|
max-width: 320px; |
|
} |
|
#MainMenu { |
|
visibility: hidden; |
|
} |
|
footer { |
|
visibility: hidden; |
|
} |
|
</style>''', |
|
|
|
unsafe_allow_html=True) |
|
|
|
st.sidebar.markdown(description) |
|
_, col1, col2, col3, _ = st.columns((1, 2, 2, 2, 1)) |
|
mainContain = st.container() |
|
|
|
if col1.button("Search by text"): |
|
st.session_state['tab'] = 1 |
|
if col2.button("Find Similar"): |
|
st.session_state['tab'] = 2 |
|
if col3.button("Classify"): |
|
st.session_state['tab'] = 3 |
|
|
|
|
|
if st.session_state['tab'] == 1: |
|
_, c, _ = mainContain.columns((1, 6, 1)) |
|
c.header("Text Search") |
|
query = c.text_input('Search Box', value='Architecture') |
|
if len(query) > 0: |
|
c.text("It'll take about 30s to load all new images") |
|
results = image_search(query) |
|
mainContain.markdown(get_html(results, "big"), |
|
unsafe_allow_html=True) |
|
|
|
if st.session_state['tab'] == 2: |
|
_, d, _ = mainContain.columns((1, 6, 1)) |
|
d.header("Find Related") |
|
image_file = d.file_uploader("Choose a file", type=['png', 'jpg']) |
|
if image_file is not None: |
|
_, left, right, _ = mainContain.columns((1, 2, 4, 1)) |
|
img = load_image(image_file) |
|
left.image(img, width=300) |
|
left.text("It'll take about 30s to load all new images") |
|
results = image_comparison(img) |
|
right.markdown(get_html(results, "small"), unsafe_allow_html=True) |
|
|
|
if st.session_state['tab'] == 3: |
|
_, d, _ = mainContain.columns((1, 6, 1)) |
|
d.header("Classify Elements") |
|
image_file = d.file_uploader("Choose a file", type=['png', 'jpg']) |
|
if image_file is not None: |
|
img = load_image(image_file) |
|
_, left, right, _ = mainContain.columns((1, 4, 2, 1)) |
|
left.image(img, width=300) |
|
image_embedding = image_query_embedding(img) |
|
values, indices = most_similars(image_embedding, text_embeddings) |
|
for i, sim in zip(indices, torch.softmax(values, dim=0)): |
|
right.text(f'label: {categories[i]} | {round(float(sim), 3)}') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|