ClipTest / app.py
TheLitttleThings's picture
Update app.py
9dbff11
import streamlit as st
import numpy as np
from html import escape
import torch
import torchvision.transforms as transforms
from transformers import BertModel, AutoTokenizer, CLIPVisionModel
from PIL import Image
import io
IMAGE_SIZE = 224
MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073])
STD = torch.tensor([0.26862954, 0.26130258, 0.27577711])
device = 'cuda'
modelPath = 'TheLitttleThings/clip-archdaily-text'
tokenizer = AutoTokenizer.from_pretrained(modelPath)
text_encoder = BertModel.from_pretrained(modelPath).eval()
vision_encoder = CLIPVisionModel.from_pretrained(
'TheLitttleThings/clip-archdaily-vision').eval()
image_embeddings = torch.load('image_embeddings.pt')
text_embeddings = torch.load('text_embeddings.pt')
links = np.load('links_list.npy', allow_pickle=True)
categories = np.load('categories_list.npy', allow_pickle=True)
if 'tab' not in st.session_state:
st.session_state['tab'] = 0
@st.experimental_memo
def image_search(query, top_k=24):
with torch.no_grad():
text_embedding = text_encoder(
**tokenizer(query, return_tensors='pt')).pooler_output
_, indices = torch.cosine_similarity(
image_embeddings, text_embedding).sort(descending=True)
return [links[i] for i in indices[:top_k]]
def text_query_embedding(query: str = 'architecture'):
tokens = tokenizer(query, return_tensors='pt')
with torch.no_grad():
text_embedding = text_encoder(
**tokenizer(query, return_tensors='pt')).pooler_output
return text_embedding
preprocessImage = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize(mean=MEAN, std=STD)
])
def image_query_embedding(image):
image = preprocessImage(image).unsqueeze(0)
with torch.no_grad():
image_embedding = vision_encoder(image).pooler_output
return image_embedding
def most_similars(embeddings_1, embeddings_2):
values, indices = torch.cosine_similarity(
embeddings_1, embeddings_2).sort(descending=True)
return values.cpu(), indices.cpu()
def analogy(input_image_path: str, top_k=64, additional_text: str = '', input_include=True):
""" Analogies with embedding space arithmetic.
Args:
input_image_path (str): The path to original image
image_paths (list[str]): A database of images
"""
base_image = Image.open(input_image_path)
image_embedding = image_query_embedding(base_image)
additional_embedding = text_query_embedding(query=additional_text)
new_image_embedding = image_embedding # + additional_embedding
_, indices = most_similars(image_embeddings, new_image_embedding)
return [links[i] for i in indices[:top_k]]
def image_comparison(base_image, top_k=64):
image_embedding = image_query_embedding(base_image)
#additional_embedding = text_query_embedding(query=additional_text)
new_image_embedding = image_embedding # + additional_embedding
_, indices = most_similars(image_embeddings, new_image_embedding)
return [links[i] for i in indices[:top_k]]
def get_html(url_list, classOther=""):
html = f"<div class='wrapper {classOther}'>"
for url in url_list:
project = url["project_url"]
image = url["source_url"]
title = url["title"]
year = url["year"]
html2 = f"<a href='{project}' target='_blank' class='link'><div class='imageparent' style='background-image:url({escape(image)})'></div><div>{year}/{title}</div></a>"
html = html + html2
html += "</div>"
return html
def load_image(image_file):
img = Image.open(image_file)
return img
description = '''
# Architecture-Clip
- Enter your query and hit enter
- Note: Quick demo if Clip model trained on Architectural images
Trained and indexed on 30k images from [ArchDaily](https://www.archdaily.com/)
Based on code from
- [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa)
- [Clip-Italian](https://github.com/clip-italian/clip-italian)
'''
def main():
st.markdown('''
<style>
.block-container{
max-width: 1200px;
}
.wrapper {
display: grid;
grid-template-columns: repeat(3, 1fr);
gap: 10px;
grid-auto-rows: minmax(100px, auto);
max-width: 1100px;
justify-content: space-evenly;
}
.wrapper.small{
grid-template-columns: repeat(2, 1fr);
}
.imageparent{
overflow:hidden;
width:100%;
aspect-ratio : 1 / 1;
margin-bottom: 2px;
background-repeat:no-repeat;
background-size:cover;
background-position:center center;
}
a.link{
display:block;
}
section.main>div:first-child {
padding-top: 0px;
}
div.reportview-container > section:first-child{
max-width: 320px;
}
#MainMenu {
visibility: hidden;
}
footer {
visibility: hidden;
}
</style>''',
unsafe_allow_html=True)
st.sidebar.markdown(description)
_, col1, col2, _ = st.columns((2, 2, 2, 2))
mainContain = st.container()
if col1.button("Search by text"):
st.session_state['tab'] = 1
if col2.button("Evaluate Image"):
st.session_state['tab'] = 2
# def textSearch(mainContain):
if st.session_state['tab'] == 1:
_, c, _ = mainContain.columns((1, 6, 1))
c.header("Text Search")
query = c.text_input('Search Box', value='Architecture')
if len(query) > 0:
c.text("It'll take about 30s to load all new images")
results = image_search(query)
c.markdown(get_html(results, "big"),
unsafe_allow_html=True)
if st.session_state['tab'] == 2:
_, d, _ = mainContain.columns((1, 6, 1))
d.header("Evaluate Image")
image_file = d.file_uploader("Choose a file", type=['png', 'jpg'])
if image_file is not None:
_, left, right, _ = mainContain.columns((1, 2, 4, 1))
img = load_image(image_file)
left.image(img)
left.markdown("It'll take about 30s to load all new images")
results = image_comparison(img)
right.markdown(get_html(results, "small"), unsafe_allow_html=True)
image_embedding = image_query_embedding(img)
values, indices = most_similars(image_embedding, text_embeddings)
for i, sim in zip(indices, torch.softmax(values, dim=0)):
left.markdown(f'{round(float(sim * 100), 1)} | {categories[i]}')
if __name__ == '__main__':
main()