Spaces:
Runtime error
Commit
•
d35396a
1
Parent(s):
7ce51b8
Update app.py
Browse files
app.py
CHANGED
@@ -1,19 +1,224 @@
|
|
1 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
mainContain = st.container()
|
10 |
-
|
11 |
if col1.button("Search by text"):
|
12 |
-
|
13 |
if col2.button("Find Similar"):
|
14 |
-
|
15 |
if col3.button("Classify"):
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
if __name__ == '__main__':
|
19 |
-
main()
|
1 |
import streamlit as st
|
2 |
+
import numpy as np
|
3 |
+
from html import escape
|
4 |
+
import torch
|
5 |
+
import torchvision.transforms as transforms
|
6 |
+
from transformers import BertModel, AutoTokenizer, CLIPVisionModel
|
7 |
+
from PIL import Image
|
8 |
+
import io
|
9 |
+
|
10 |
+
IMAGE_SIZE = 224
|
11 |
+
MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073])
|
12 |
+
STD = torch.tensor([0.26862954, 0.26130258, 0.27577711])
|
13 |
+
device = 'cuda'
|
14 |
+
|
15 |
+
modelPath = 'TheLitttleThings/clip-archdaily-text'
|
16 |
+
tokenizer = AutoTokenizer.from_pretrained(modelPath)
|
17 |
+
text_encoder = BertModel.from_pretrained(modelPath).eval()
|
18 |
+
vision_encoder = CLIPVisionModel.from_pretrained(
|
19 |
+
'TheLitttleThings/clip-archdaily-vision').eval()
|
20 |
+
|
21 |
+
image_embeddings = torch.load('image_embeddings.pt')
|
22 |
+
text_embeddings = torch.load('text_embeddings.pt')
|
23 |
+
links = np.load('links_list.npy', allow_pickle=True)
|
24 |
+
categories = np.load('categories_list.npy', allow_pickle=True)
|
25 |
+
|
26 |
+
if 'tab' not in st.session_state:
|
27 |
+
st.session_state['tab'] = 0
|
28 |
+
|
29 |
+
|
30 |
+
@st.experimental_memo
|
31 |
+
def image_search(query, top_k=24):
|
32 |
+
with torch.no_grad():
|
33 |
+
text_embedding = text_encoder(
|
34 |
+
**tokenizer(query, return_tensors='pt')).pooler_output
|
35 |
+
_, indices = torch.cosine_similarity(
|
36 |
+
image_embeddings, text_embedding).sort(descending=True)
|
37 |
+
|
38 |
+
return [links[i] for i in indices[:top_k]]
|
39 |
+
|
40 |
+
|
41 |
+
def text_query_embedding(query: str = 'architecture'):
|
42 |
+
tokens = tokenizer(query, return_tensors='pt')
|
43 |
+
with torch.no_grad():
|
44 |
+
text_embedding = text_encoder(
|
45 |
+
**tokenizer(query, return_tensors='pt')).pooler_output
|
46 |
+
return text_embedding
|
47 |
+
|
48 |
+
|
49 |
+
preprocessImage = transforms.Compose([
|
50 |
+
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
|
51 |
+
transforms.ToTensor(),
|
52 |
+
transforms.Normalize(mean=MEAN, std=STD)
|
53 |
+
])
|
54 |
+
|
55 |
+
|
56 |
+
def image_query_embedding(image):
|
57 |
+
image = preprocessImage(image).unsqueeze(0)
|
58 |
+
with torch.no_grad():
|
59 |
+
image_embedding = vision_encoder(image).pooler_output
|
60 |
+
return image_embedding
|
61 |
+
|
62 |
+
|
63 |
+
def most_similars(embeddings_1, embeddings_2):
|
64 |
+
values, indices = torch.cosine_similarity(
|
65 |
+
embeddings_1, embeddings_2).sort(descending=True)
|
66 |
+
return values.cpu(), indices.cpu()
|
67 |
+
|
68 |
+
|
69 |
+
def analogy(input_image_path: str, top_k=24, additional_text: str = '', input_include=True):
|
70 |
+
""" Analogies with embedding space arithmetic.
|
71 |
+
Args:
|
72 |
+
input_image_path (str): The path to original image
|
73 |
+
image_paths (list[str]): A database of images
|
74 |
+
"""
|
75 |
+
base_image = Image.open(input_image_path)
|
76 |
+
image_embedding = image_query_embedding(base_image)
|
77 |
+
additional_embedding = text_query_embedding(query=additional_text)
|
78 |
+
new_image_embedding = image_embedding # + additional_embedding
|
79 |
+
_, indices = most_similars(image_embeddings, new_image_embedding)
|
80 |
+
|
81 |
+
return [links[i] for i in indices[:top_k]]
|
82 |
+
|
83 |
+
|
84 |
+
def image_comparison(base_image, top_k=24):
|
85 |
+
image_embedding = image_query_embedding(base_image)
|
86 |
+
#additional_embedding = text_query_embedding(query=additional_text)
|
87 |
+
new_image_embedding = image_embedding # + additional_embedding
|
88 |
+
_, indices = most_similars(image_embeddings, new_image_embedding)
|
89 |
+
|
90 |
+
return [links[i] for i in indices[:top_k]]
|
91 |
+
|
92 |
+
|
93 |
+
def get_html(url_list, classOther=""):
|
94 |
+
html = f"<div class='wrapper {classOther}'>"
|
95 |
+
for url in url_list:
|
96 |
+
project = url["project_url"]
|
97 |
+
image = url["source_url"]
|
98 |
+
title = url["title"]
|
99 |
+
year = url["year"]
|
100 |
+
html2 = f"<a href='{project}' target='_blank' class='link'><div class='imageparent'><img style=' src='{escape(image)}'/></div><div>{year}/{title}</div></a>"
|
101 |
+
html = html + html2
|
102 |
+
html += "</div>"
|
103 |
+
return html
|
104 |
+
|
105 |
+
|
106 |
+
def load_image(image_file):
|
107 |
+
img = Image.open(image_file)
|
108 |
+
return img
|
109 |
+
|
110 |
+
|
111 |
+
description = '''
|
112 |
+
# Architecture-Clip
|
113 |
+
- Enter your query and hit enter
|
114 |
+
- Note: Quick demo if Clip model trained on Architectural images
|
115 |
+
Built with 5k images from [ArchDaily](https://www.archdaily.com/)
|
116 |
+
Based on code from
|
117 |
+
[CLIP-fa](https://github.com/sajjjadayobi/CLIPfa)
|
118 |
+
[Clip-Italian](https://github.com/clip-italian/clip-italian)
|
119 |
+
'''
|
120 |
|
121 |
|
122 |
def main():
|
123 |
+
st.markdown('''
|
124 |
+
<style>
|
125 |
+
.block-container{
|
126 |
+
max-width: 1200px;
|
127 |
+
}
|
128 |
+
.wrapper {
|
129 |
+
display: grid;
|
130 |
+
grid-template-columns: repeat(3, 1fr);
|
131 |
+
gap: 10px;
|
132 |
+
grid-auto-rows: minmax(100px, auto);
|
133 |
+
margin-top: 50px;
|
134 |
+
max-width: 1100px;
|
135 |
+
justify-content: space-evenly;
|
136 |
+
}
|
137 |
+
.wrapper.small{
|
138 |
+
grid-template-columns: repeat(2, 1fr);
|
139 |
+
}
|
140 |
+
.imageparent{
|
141 |
+
overflow:hidden;
|
142 |
+
width:100%;
|
143 |
+
height:100%;
|
144 |
+
aspect-ratio : 1 / 1;
|
145 |
+
margin-bottom: 2px;
|
146 |
|
147 |
+
}
|
148 |
+
a.link{
|
149 |
+
display:block;
|
150 |
+
}
|
151 |
+
.wrapper a img{
|
152 |
+
width:100%;
|
153 |
+
display:block;
|
154 |
+
aspect-ratio : 1 / 1;
|
155 |
+
}
|
156 |
+
section.main>div:first-child {
|
157 |
+
padding-top: 0px;
|
158 |
+
}
|
159 |
+
section:not(.main)>div:first-child {
|
160 |
+
padding-top: 30px;
|
161 |
+
}
|
162 |
+
div.reportview-container > section:first-child{
|
163 |
+
max-width: 320px;
|
164 |
+
}
|
165 |
+
#MainMenu {
|
166 |
+
visibility: hidden;
|
167 |
+
}
|
168 |
+
footer {
|
169 |
+
visibility: hidden;
|
170 |
+
}
|
171 |
+
</style>''',
|
172 |
+
|
173 |
+
unsafe_allow_html=True)
|
174 |
+
|
175 |
+
st.sidebar.markdown(description)
|
176 |
+
_, col1, col2, col3, _ = st.columns((1, 2, 2, 2, 1))
|
177 |
mainContain = st.container()
|
178 |
+
|
179 |
if col1.button("Search by text"):
|
180 |
+
st.session_state['tab'] = 1
|
181 |
if col2.button("Find Similar"):
|
182 |
+
st.session_state['tab'] = 2
|
183 |
if col3.button("Classify"):
|
184 |
+
st.session_state['tab'] = 3
|
185 |
+
|
186 |
+
# def textSearch(mainContain):
|
187 |
+
if st.session_state['tab'] == 1:
|
188 |
+
_, c, _ = mainContain.columns((1, 6, 1))
|
189 |
+
c.header("Text Search")
|
190 |
+
query = c.text_input('Search Box', value='Architecture')
|
191 |
+
if len(query) > 0:
|
192 |
+
c.text("It'll take about 30s to load all new images")
|
193 |
+
results = image_search(query)
|
194 |
+
mainContain.markdown(get_html(results, "big"),
|
195 |
+
unsafe_allow_html=True)
|
196 |
+
|
197 |
+
if st.session_state['tab'] == 2:
|
198 |
+
_, d, _ = mainContain.columns((1, 6, 1))
|
199 |
+
d.header("Find Related")
|
200 |
+
image_file = d.file_uploader("Choose a file", type=['png', 'jpg'])
|
201 |
+
if image_file is not None:
|
202 |
+
_, left, right, _ = mainContain.columns((1, 2, 4, 1))
|
203 |
+
img = load_image(image_file)
|
204 |
+
left.image(img, width=300)
|
205 |
+
left.text("It'll take about 30s to load all new images")
|
206 |
+
results = image_comparison(img)
|
207 |
+
right.markdown(get_html(results, "small"), unsafe_allow_html=True)
|
208 |
+
|
209 |
+
if st.session_state['tab'] == 3:
|
210 |
+
_, d, _ = mainContain.columns((1, 6, 1))
|
211 |
+
d.header("Classify Elements")
|
212 |
+
image_file = d.file_uploader("Choose a file", type=['png', 'jpg'])
|
213 |
+
if image_file is not None:
|
214 |
+
img = load_image(image_file)
|
215 |
+
_, left, right, _ = mainContain.columns((1, 4, 2, 1))
|
216 |
+
left.image(img, width=300)
|
217 |
+
image_embedding = image_query_embedding(img)
|
218 |
+
values, indices = most_similars(image_embedding, text_embeddings)
|
219 |
+
for i, sim in zip(indices, torch.softmax(values, dim=0)):
|
220 |
+
right.text(f'label: {categories[i]} | {round(float(sim), 3)}')
|
221 |
+
|
222 |
+
|
223 |
if __name__ == '__main__':
|
224 |
+
main()
|