ClipTest / app.py
TheLitttleThings's picture
Update app.py
71ee83a
raw history blame
No virus
6.75 kB
import streamlit as st
import numpy as np
from html import escape
import torch
import torchvision.transforms as transforms
from transformers import BertModel, AutoTokenizer,CLIPVisionModel
from PIL import Image
import io
IMAGE_SIZE = 224
MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073])
STD = torch.tensor([0.26862954, 0.26130258, 0.27577711])
device='cuda'
modelPath = 'TheLitttleThings/clip-archdaily-text'
tokenizer = AutoTokenizer.from_pretrained(modelPath)
text_encoder = BertModel.from_pretrained(modelPath).eval()
vision_encoder = CLIPVisionModel.from_pretrained('TheLitttleThings/clip-archdaily-vision').eval()
image_embeddings = torch.load('image_embeddings.pt')
links = np.load('data_list.npy', allow_pickle=True)
categories = np.load('categories_list.npy', allow_pickle=True)
if 'tab' not in st.session_state:
st.session_state['tab'] = 0
@st.experimental_memo
def image_search(query, top_k=24):
with torch.no_grad():
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
_, indices = torch.cosine_similarity(image_embeddings, text_embedding).sort(descending=True)
return [links[i] for i in indices[:top_k]]
def text_query_embedding( query: str = 'architecture'):
tokens = tokenizer(query, return_tensors='pt')
with torch.no_grad():
text_embedding = text_encoder(**tokenizer(query, return_tensors='pt')).pooler_output
return text_embedding
preprocessImage = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=MEAN, std=STD)
])
def image_query_embedding( image):
image = preprocessImage(image).unsqueeze(0)
with torch.no_grad():
image_embedding = vision_encoder(image).pooler_output
return image_embedding
def most_similars(embeddings_1, embeddings_2):
values, indices = torch.cosine_similarity(embeddings_1, embeddings_2).sort(descending=True)
return values.cpu(), indices.cpu()
def analogy(input_image_path: str, top_k=24, additional_text: str = '', input_include=True):
""" Analogies with embedding space arithmetic.
Args:
input_image_path (str): The path to original image
image_paths (list[str]): A database of images
"""
base_image = Image.open(input_image_path)
image_embedding = image_query_embedding(base_image)
additional_embedding = text_query_embedding(query=additional_text)
new_image_embedding = image_embedding #+ additional_embedding
_, indices = most_similars(image_embeddings, new_image_embedding)
return [links[i] for i in indices[:top_k]]
def image_comparison(base_image, top_k=24):
image_embedding = image_query_embedding(base_image)
#additional_embedding = text_query_embedding(query=additional_text)
new_image_embedding = image_embedding #+ additional_embedding
_, indices = most_similars(image_embeddings, new_image_embedding)
return [links[i] for i in indices[:top_k]]
def get_html(url_list):
html = "<div class='wrapper' style=''>"
for url in url_list:
project = url["project_url"]
image = url["source_url"]
title = url["title"]
year = url["year"]
html2 = f"<a href='{project}' target='_blank' style='display:block; max-width: 320px;'><div><img style='height: 180px; margin: 2px' src='{escape(image)}'/></div><div>{year}/{title}</div></a>"
html = html + html2
html += "</div>"
return html
def load_image(image_file):
img = Image.open(image_file)
return img
description = '''
# Architecture-Clip
- Enter your query and hit enter
- Note: Quick demo if Clip model trained on Architectural images
Built with 5k images from [ArchDaily](https://www.archdaily.com/)
Based on code from
[CLIP-fa](https://github.com/sajjjadayobi/CLIPfa)
[Clip-Italian](https://github.com/clip-italian/clip-italian)
'''
def main():
st.markdown('''
<style>
.block-container{
max-width: 1200px;
}
.wrapper {
display: grid;
grid-template-columns: repeat(3, 1fr);
gap: 10px;
grid-auto-rows: minmax(100px, auto);margin-top: 50px; max-width: 1100px; justify-content: space-evenly;
}
section.main>div:first-child {
padding-top: 0px;
}
section:not(.main)>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
col1, col2, col3 = st.columns(3)
mainContain = st.container()
if col1.button("Search by text"):
st.session_state['tab'] = 1
if col2.button("Find Similar"):
st.session_state['tab'] = 2
if col3.button("Classify"):
st.session_state['tab'] = 3
#def textSearch(mainContain):
if st.session_state['tab'] == 1:
mainContain.header("Text Search")
_, c, _ = mainContain.columns((1, 3, 1))
query = c.text_input('Search Box', value='Architecture')
c.text("It'll take about 30s to load all new images")
if len(query) > 0:
results = image_search(query)
mainContain.markdown(get_html(results), unsafe_allow_html=True)
#def compare(mainContain):
if st.session_state['tab'] == 2:
mainContain.header("Image Relations")
_, d, _ = mainContain.columns((1, 3, 1))
image_file = d.file_uploader("Choose a file", type=['png', 'jpg'])
if image_file is not None:
# To read file as bytes:
#bytes_data = uploaded_file.getvalue()
#st.write(bytes_data)
img = load_image(image_file)
d.image(img,width=300)
d.text("It'll take about 30s to load all new images")
results = image_comparison(img)
mainContain.markdown(get_html(results), unsafe_allow_html=True)
#def classify(mainContain):
if st.session_state['tab'] == 3:
mainContain.header("Classify Elements")
_, d, _ = mainContain.columns((1, 3, 1))
d.text("Coming soon")
#col1.button("Search by text", on_click=textSearch, args=(mainContain,))
#col2.button("Find Similar", on_click=compare, args=(mainContain,))
#col3.button("Classify", on_click=classify, args=(mainContain,))
if __name__ == '__main__':
main()