ClipTest / app.py
TheLitttleThings's picture
Update app.py
d35396a
import streamlit as st
import numpy as np
from html import escape
import torch
import torchvision.transforms as transforms
from transformers import BertModel, AutoTokenizer, CLIPVisionModel
from PIL import Image
import io
IMAGE_SIZE = 224
MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073])
STD = torch.tensor([0.26862954, 0.26130258, 0.27577711])
device = 'cuda'
modelPath = 'TheLitttleThings/clip-archdaily-text'
tokenizer = AutoTokenizer.from_pretrained(modelPath)
text_encoder = BertModel.from_pretrained(modelPath).eval()
vision_encoder = CLIPVisionModel.from_pretrained(
'TheLitttleThings/clip-archdaily-vision').eval()
image_embeddings = torch.load('image_embeddings.pt')
text_embeddings = torch.load('text_embeddings.pt')
links = np.load('links_list.npy', allow_pickle=True)
categories = np.load('categories_list.npy', allow_pickle=True)
if 'tab' not in st.session_state:
st.session_state['tab'] = 0
@st.experimental_memo
def image_search(query, top_k=24):
with torch.no_grad():
text_embedding = text_encoder(
**tokenizer(query, return_tensors='pt')).pooler_output
_, indices = torch.cosine_similarity(
image_embeddings, text_embedding).sort(descending=True)
return [links[i] for i in indices[:top_k]]
def text_query_embedding(query: str = 'architecture'):
tokens = tokenizer(query, return_tensors='pt')
with torch.no_grad():
text_embedding = text_encoder(
**tokenizer(query, return_tensors='pt')).pooler_output
return text_embedding
preprocessImage = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=MEAN, std=STD)
])
def image_query_embedding(image):
image = preprocessImage(image).unsqueeze(0)
with torch.no_grad():
image_embedding = vision_encoder(image).pooler_output
return image_embedding
def most_similars(embeddings_1, embeddings_2):
values, indices = torch.cosine_similarity(
embeddings_1, embeddings_2).sort(descending=True)
return values.cpu(), indices.cpu()
def analogy(input_image_path: str, top_k=24, additional_text: str = '', input_include=True):
""" Analogies with embedding space arithmetic.
Args:
input_image_path (str): The path to original image
image_paths (list[str]): A database of images
"""
base_image = Image.open(input_image_path)
image_embedding = image_query_embedding(base_image)
additional_embedding = text_query_embedding(query=additional_text)
new_image_embedding = image_embedding # + additional_embedding
_, indices = most_similars(image_embeddings, new_image_embedding)
return [links[i] for i in indices[:top_k]]
def image_comparison(base_image, top_k=24):
image_embedding = image_query_embedding(base_image)
#additional_embedding = text_query_embedding(query=additional_text)
new_image_embedding = image_embedding # + additional_embedding
_, indices = most_similars(image_embeddings, new_image_embedding)
return [links[i] for i in indices[:top_k]]
def get_html(url_list, classOther=""):
html = f"<div class='wrapper {classOther}'>"
for url in url_list:
project = url["project_url"]
image = url["source_url"]
title = url["title"]
year = url["year"]
html2 = f"<a href='{project}' target='_blank' class='link'><div class='imageparent'><img style=' src='{escape(image)}'/></div><div>{year}/{title}</div></a>"
html = html + html2
html += "</div>"
return html
def load_image(image_file):
img = Image.open(image_file)
return img
description = '''
# Architecture-Clip
- Enter your query and hit enter
- Note: Quick demo if Clip model trained on Architectural images
Built with 5k images from [ArchDaily](https://www.archdaily.com/)
Based on code from
[CLIP-fa](https://github.com/sajjjadayobi/CLIPfa)
[Clip-Italian](https://github.com/clip-italian/clip-italian)
'''
def main():
st.markdown('''
<style>
.block-container{
max-width: 1200px;
}
.wrapper {
display: grid;
grid-template-columns: repeat(3, 1fr);
gap: 10px;
grid-auto-rows: minmax(100px, auto);
margin-top: 50px;
max-width: 1100px;
justify-content: space-evenly;
}
.wrapper.small{
grid-template-columns: repeat(2, 1fr);
}
.imageparent{
overflow:hidden;
width:100%;
height:100%;
aspect-ratio : 1 / 1;
margin-bottom: 2px;
}
a.link{
display:block;
}
.wrapper a img{
width:100%;
display:block;
aspect-ratio : 1 / 1;
}
section.main>div:first-child {
padding-top: 0px;
}
section:not(.main)>div:first-child {
padding-top: 30px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
_, col1, col2, col3, _ = st.columns((1, 2, 2, 2, 1))
mainContain = st.container()
if col1.button("Search by text"):
st.session_state['tab'] = 1
if col2.button("Find Similar"):
st.session_state['tab'] = 2
if col3.button("Classify"):
st.session_state['tab'] = 3
# def textSearch(mainContain):
if st.session_state['tab'] == 1:
_, c, _ = mainContain.columns((1, 6, 1))
c.header("Text Search")
query = c.text_input('Search Box', value='Architecture')
if len(query) > 0:
c.text("It'll take about 30s to load all new images")
results = image_search(query)
mainContain.markdown(get_html(results, "big"),
unsafe_allow_html=True)
if st.session_state['tab'] == 2:
_, d, _ = mainContain.columns((1, 6, 1))
d.header("Find Related")
image_file = d.file_uploader("Choose a file", type=['png', 'jpg'])
if image_file is not None:
_, left, right, _ = mainContain.columns((1, 2, 4, 1))
img = load_image(image_file)
left.image(img, width=300)
left.text("It'll take about 30s to load all new images")
results = image_comparison(img)
right.markdown(get_html(results, "small"), unsafe_allow_html=True)
if st.session_state['tab'] == 3:
_, d, _ = mainContain.columns((1, 6, 1))
d.header("Classify Elements")
image_file = d.file_uploader("Choose a file", type=['png', 'jpg'])
if image_file is not None:
img = load_image(image_file)
_, left, right, _ = mainContain.columns((1, 4, 2, 1))
left.image(img, width=300)
image_embedding = image_query_embedding(img)
values, indices = most_similars(image_embedding, text_embeddings)
for i, sim in zip(indices, torch.softmax(values, dim=0)):
right.text(f'label: {categories[i]} | {round(float(sim), 3)}')
if __name__ == '__main__':
main()