TheLitttleThings commited on
Commit
d35396a
1 Parent(s): 7ce51b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -9
app.py CHANGED
@@ -1,19 +1,224 @@
1
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
-
7
- #st.sidebar.markdown(description)
8
- _, col1, col2, col3,_ = st.columns((1, 2,2,2, 1))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  mainContain = st.container()
10
-
11
  if col1.button("Search by text"):
12
- st.session_state['tab'] = 1
13
  if col2.button("Find Similar"):
14
- st.session_state['tab'] = 2
15
  if col3.button("Classify"):
16
- st.session_state['tab'] = 3
17
- mainContain.text("mommy")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  if __name__ == '__main__':
19
- main()
1
  import streamlit as st
2
+ import numpy as np
3
+ from html import escape
4
+ import torch
5
+ import torchvision.transforms as transforms
6
+ from transformers import BertModel, AutoTokenizer, CLIPVisionModel
7
+ from PIL import Image
8
+ import io
9
+
10
+ IMAGE_SIZE = 224
11
+ MEAN = torch.tensor([0.48145466, 0.4578275, 0.40821073])
12
+ STD = torch.tensor([0.26862954, 0.26130258, 0.27577711])
13
+ device = 'cuda'
14
+
15
+ modelPath = 'TheLitttleThings/clip-archdaily-text'
16
+ tokenizer = AutoTokenizer.from_pretrained(modelPath)
17
+ text_encoder = BertModel.from_pretrained(modelPath).eval()
18
+ vision_encoder = CLIPVisionModel.from_pretrained(
19
+ 'TheLitttleThings/clip-archdaily-vision').eval()
20
+
21
+ image_embeddings = torch.load('image_embeddings.pt')
22
+ text_embeddings = torch.load('text_embeddings.pt')
23
+ links = np.load('links_list.npy', allow_pickle=True)
24
+ categories = np.load('categories_list.npy', allow_pickle=True)
25
+
26
+ if 'tab' not in st.session_state:
27
+ st.session_state['tab'] = 0
28
+
29
+
30
+ @st.experimental_memo
31
+ def image_search(query, top_k=24):
32
+ with torch.no_grad():
33
+ text_embedding = text_encoder(
34
+ **tokenizer(query, return_tensors='pt')).pooler_output
35
+ _, indices = torch.cosine_similarity(
36
+ image_embeddings, text_embedding).sort(descending=True)
37
+
38
+ return [links[i] for i in indices[:top_k]]
39
+
40
+
41
+ def text_query_embedding(query: str = 'architecture'):
42
+ tokens = tokenizer(query, return_tensors='pt')
43
+ with torch.no_grad():
44
+ text_embedding = text_encoder(
45
+ **tokenizer(query, return_tensors='pt')).pooler_output
46
+ return text_embedding
47
+
48
+
49
+ preprocessImage = transforms.Compose([
50
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
51
+ transforms.ToTensor(),
52
+ transforms.Normalize(mean=MEAN, std=STD)
53
+ ])
54
+
55
+
56
+ def image_query_embedding(image):
57
+ image = preprocessImage(image).unsqueeze(0)
58
+ with torch.no_grad():
59
+ image_embedding = vision_encoder(image).pooler_output
60
+ return image_embedding
61
+
62
+
63
+ def most_similars(embeddings_1, embeddings_2):
64
+ values, indices = torch.cosine_similarity(
65
+ embeddings_1, embeddings_2).sort(descending=True)
66
+ return values.cpu(), indices.cpu()
67
+
68
+
69
+ def analogy(input_image_path: str, top_k=24, additional_text: str = '', input_include=True):
70
+ """ Analogies with embedding space arithmetic.
71
+ Args:
72
+ input_image_path (str): The path to original image
73
+ image_paths (list[str]): A database of images
74
+ """
75
+ base_image = Image.open(input_image_path)
76
+ image_embedding = image_query_embedding(base_image)
77
+ additional_embedding = text_query_embedding(query=additional_text)
78
+ new_image_embedding = image_embedding # + additional_embedding
79
+ _, indices = most_similars(image_embeddings, new_image_embedding)
80
+
81
+ return [links[i] for i in indices[:top_k]]
82
+
83
+
84
+ def image_comparison(base_image, top_k=24):
85
+ image_embedding = image_query_embedding(base_image)
86
+ #additional_embedding = text_query_embedding(query=additional_text)
87
+ new_image_embedding = image_embedding # + additional_embedding
88
+ _, indices = most_similars(image_embeddings, new_image_embedding)
89
+
90
+ return [links[i] for i in indices[:top_k]]
91
+
92
+
93
+ def get_html(url_list, classOther=""):
94
+ html = f"<div class='wrapper {classOther}'>"
95
+ for url in url_list:
96
+ project = url["project_url"]
97
+ image = url["source_url"]
98
+ title = url["title"]
99
+ year = url["year"]
100
+ html2 = f"<a href='{project}' target='_blank' class='link'><div class='imageparent'><img style=' src='{escape(image)}'/></div><div>{year}/{title}</div></a>"
101
+ html = html + html2
102
+ html += "</div>"
103
+ return html
104
+
105
+
106
+ def load_image(image_file):
107
+ img = Image.open(image_file)
108
+ return img
109
+
110
+
111
+ description = '''
112
+ # Architecture-Clip
113
+ - Enter your query and hit enter
114
+ - Note: Quick demo if Clip model trained on Architectural images
115
+ Built with 5k images from [ArchDaily](https://www.archdaily.com/)
116
+ Based on code from
117
+ [CLIP-fa](https://github.com/sajjjadayobi/CLIPfa)
118
+ [Clip-Italian](https://github.com/clip-italian/clip-italian)
119
+ '''
120
 
121
 
122
  def main():
123
+ st.markdown('''
124
+ <style>
125
+ .block-container{
126
+ max-width: 1200px;
127
+ }
128
+ .wrapper {
129
+ display: grid;
130
+ grid-template-columns: repeat(3, 1fr);
131
+ gap: 10px;
132
+ grid-auto-rows: minmax(100px, auto);
133
+ margin-top: 50px;
134
+ max-width: 1100px;
135
+ justify-content: space-evenly;
136
+ }
137
+ .wrapper.small{
138
+ grid-template-columns: repeat(2, 1fr);
139
+ }
140
+ .imageparent{
141
+ overflow:hidden;
142
+ width:100%;
143
+ height:100%;
144
+ aspect-ratio : 1 / 1;
145
+ margin-bottom: 2px;
146
 
147
+ }
148
+ a.link{
149
+ display:block;
150
+ }
151
+ .wrapper a img{
152
+ width:100%;
153
+ display:block;
154
+ aspect-ratio : 1 / 1;
155
+ }
156
+ section.main>div:first-child {
157
+ padding-top: 0px;
158
+ }
159
+ section:not(.main)>div:first-child {
160
+ padding-top: 30px;
161
+ }
162
+ div.reportview-container > section:first-child{
163
+ max-width: 320px;
164
+ }
165
+ #MainMenu {
166
+ visibility: hidden;
167
+ }
168
+ footer {
169
+ visibility: hidden;
170
+ }
171
+ </style>''',
172
+
173
+ unsafe_allow_html=True)
174
+
175
+ st.sidebar.markdown(description)
176
+ _, col1, col2, col3, _ = st.columns((1, 2, 2, 2, 1))
177
  mainContain = st.container()
178
+
179
  if col1.button("Search by text"):
180
+ st.session_state['tab'] = 1
181
  if col2.button("Find Similar"):
182
+ st.session_state['tab'] = 2
183
  if col3.button("Classify"):
184
+ st.session_state['tab'] = 3
185
+
186
+ # def textSearch(mainContain):
187
+ if st.session_state['tab'] == 1:
188
+ _, c, _ = mainContain.columns((1, 6, 1))
189
+ c.header("Text Search")
190
+ query = c.text_input('Search Box', value='Architecture')
191
+ if len(query) > 0:
192
+ c.text("It'll take about 30s to load all new images")
193
+ results = image_search(query)
194
+ mainContain.markdown(get_html(results, "big"),
195
+ unsafe_allow_html=True)
196
+
197
+ if st.session_state['tab'] == 2:
198
+ _, d, _ = mainContain.columns((1, 6, 1))
199
+ d.header("Find Related")
200
+ image_file = d.file_uploader("Choose a file", type=['png', 'jpg'])
201
+ if image_file is not None:
202
+ _, left, right, _ = mainContain.columns((1, 2, 4, 1))
203
+ img = load_image(image_file)
204
+ left.image(img, width=300)
205
+ left.text("It'll take about 30s to load all new images")
206
+ results = image_comparison(img)
207
+ right.markdown(get_html(results, "small"), unsafe_allow_html=True)
208
+
209
+ if st.session_state['tab'] == 3:
210
+ _, d, _ = mainContain.columns((1, 6, 1))
211
+ d.header("Classify Elements")
212
+ image_file = d.file_uploader("Choose a file", type=['png', 'jpg'])
213
+ if image_file is not None:
214
+ img = load_image(image_file)
215
+ _, left, right, _ = mainContain.columns((1, 4, 2, 1))
216
+ left.image(img, width=300)
217
+ image_embedding = image_query_embedding(img)
218
+ values, indices = most_similars(image_embedding, text_embeddings)
219
+ for i, sim in zip(indices, torch.softmax(values, dim=0)):
220
+ right.text(f'label: {categories[i]} | {round(float(sim), 3)}')
221
+
222
+
223
  if __name__ == '__main__':
224
+ main()