ludusc commited on
Commit
058ceb7
1 Parent(s): d49c6b3

added fourth page

Browse files
Files changed (1) hide show
  1. pages/4_Vase_Qualities_Comparison.py +270 -0
pages/4_Vase_Qualities_Comparison.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import streamlit.components.v1 as components
3
+
4
+ import dnnlib
5
+ import legacy
6
+
7
+ import pickle
8
+ import pandas as pd
9
+ import numpy as np
10
+ from pyvis.network import Network
11
+
12
+ import random
13
+ from sklearn.metrics.pairwise import cosine_similarity
14
+
15
+ from matplotlib.backends.backend_agg import RendererAgg
16
+
17
+ from backend.disentangle_concepts import *
18
+
19
+ _lock = RendererAgg.lock
20
+
21
+ HIGHTLIGHT_COLOR = '#e7bcc5'
22
+ st.set_page_config(layout='wide')
23
+
24
+
25
+ st.title('Comparison among concept vectors')
26
+ st.write('> **How do the concept vectors relate to each other?**')
27
+ st.write('> **What is their join impact on the image?**')
28
+ st.write("""Description to write""")
29
+
30
+
31
+ annotations_file = './data/vase_annotated_files/seeds0000-20000.pkl'
32
+ with open(annotations_file, 'rb') as f:
33
+ annotations = pickle.load(f)
34
+
35
+ if 'image_id' not in st.session_state:
36
+ st.session_state.image_id = 0
37
+ if 'concept_ids' not in st.session_state:
38
+ st.session_state.concept_ids = ['Provenance ADRIA']
39
+ if 'space_id' not in st.session_state:
40
+ st.session_state.space_id = 'Z'
41
+ if 'type_col' not in st.session_state:
42
+ st.session_state.type_col = 'Provenance'
43
+
44
+ # def on_change_random_input():
45
+ # st.session_state.image_id = st.session_state.image_id
46
+
47
+ # ----------------------------- INPUT ----------------------------------
48
+ st.header('Input')
49
+ input_col_1, input_col_2, input_col_3 = st.columns(3)
50
+ # --------------------------- INPUT column 1 ---------------------------
51
+ with input_col_1:
52
+ with st.form('text_form'):
53
+
54
+ # image_id = st.number_input('Image ID: ', format='%d', step=1)
55
+ st.write('**Choose two options to disentangle**')
56
+ type_col = st.selectbox('Concept category:', tuple(['Provenance', 'Shape Name', 'Fabric', 'Technique']))
57
+
58
+ ann_df = pd.read_csv(f'./data/vase_annotated_files/sim_{type_col}_seeds0000-20000.csv')
59
+ labels = list(ann_df.columns)
60
+ labels.remove('ID')
61
+ labels.remove('Unnamed: 0')
62
+
63
+ concept_ids = st.multiselect('Concepts:', tuple(labels), default=[labels[2], labels[3]])
64
+
65
+ st.write('**Choose a latent space to disentangle**')
66
+ # chosen_text_id_input = st.empty()
67
+ # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
68
+ space_id = st.selectbox('Space:', tuple(['Z', 'W']))
69
+
70
+ choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
71
+
72
+ if choose_text_button:
73
+ st.session_state.concept_ids = list(concept_ids)
74
+ space_id = str(space_id)
75
+ st.session_state.space_id = space_id
76
+ st.session_state.type_col = type_col
77
+ # st.write(image_id, st.session_state.image_id)
78
+
79
+ # ---------------------------- SET UP OUTPUT ------------------------------
80
+ epsilon_container = st.empty()
81
+ st.header('Output')
82
+ st.subheader('Concept vector')
83
+
84
+ # perform attack container
85
+ # header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
86
+ # output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
87
+ header_col_1, header_col_2 = st.columns([1,1])
88
+ output_col_1, output_col_2 = st.columns([1,1])
89
+
90
+ st.subheader('Derivations along the concept vector')
91
+
92
+ # prediction error container
93
+ error_container = st.empty()
94
+ smoothgrad_header_container = st.empty()
95
+
96
+ # smoothgrad container
97
+ smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
98
+ smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
99
+
100
+ # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
101
+ with output_col_1:
102
+ vectors, nodes_in_common, performances = get_concepts_vectors(concept_ids, annotations, ann_df, latent_space=space_id)
103
+ header_col_1.write(f'Concepts {", ".join(concept_ids)} - Latent space {space_id} - Relevant nodes in common: {nodes_in_common} - Performance of the concept vectors: {performances}')# - Nodes {",".join(list(imp_nodes))}')
104
+
105
+ edges = []
106
+ for i in range(len(concept_ids)):
107
+ for j in range(len(concept_ids)):
108
+ if i != j:
109
+ print(f'Similarity between {concept_ids[i]} and {concept_ids[j]}')
110
+ similarity = cosine_similarity(vectors[i,:].reshape(1, -1), vectors[j,:].reshape(1, -1))
111
+ print(np.round(similarity[0][0], 3))
112
+ edges.append((concept_ids[i], concept_ids[j], np.round(similarity[0][0], 3)))
113
+
114
+
115
+ net = Network(height="750px", width="100%",)
116
+ for e in edges:
117
+ src = e[0]
118
+ dst = e[1]
119
+ w = e[2]
120
+
121
+ net.add_node(src, src, title=src)
122
+ net.add_node(dst, dst, title=dst)
123
+ net.add_edge(src, dst, value=w, title=src + ' to ' + dst + ' similarity ' +str(w))
124
+
125
+ # Generate network with specific layout settings
126
+ net.repulsion(
127
+ node_distance=420,
128
+ central_gravity=0.33,
129
+ spring_length=110,
130
+ spring_strength=0.10,
131
+ damping=0.95
132
+ )
133
+
134
+ # Save and read graph as HTML file (on Streamlit Sharing)
135
+ try:
136
+ path = '/tmp'
137
+ net.save_graph(f'{path}/pyvis_graph.html')
138
+ HtmlFile = open(f'{path}/pyvis_graph.html', 'r', encoding='utf-8')
139
+
140
+ # Save and read graph as HTML file (locally)
141
+ except:
142
+ path = '/html_files'
143
+ net.save_graph(f'{path}/pyvis_graph.html')
144
+ HtmlFile = open(f'{path}/pyvis_graph.html', 'r', encoding='utf-8')
145
+
146
+ # Load HTML file in HTML component for display on Streamlit page
147
+ components.html(HtmlFile.read(), height=435)
148
+
149
+ with output_col_2:
150
+ with open('data/CLIP_vecs_vases.pkl', 'rb') as f:
151
+ vectors_CLIP = pickle.load(f)
152
+
153
+ # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
154
+ #st.write('Concept vector', separation_vector)
155
+ header_col_2.write(f'Concepts {", ".join(concept_ids)} - Latent space CLIP')# - Nodes {",".join(list(imp_nodes))}')
156
+
157
+ edges_clip = []
158
+ for c1 in concept_ids:
159
+ for c2 in concept_ids:
160
+ if c1 != c2:
161
+ c1 = st.session_state.type_col + ' ' + c1
162
+ c2 = st.session_state.type_col + ' ' + c2
163
+
164
+ print(f'Similarity between {c1} and {c2}')
165
+ similarity = cosine_similarity(vectors_CLIP[c1].reshape(1, -1), vectors_CLIP[c2].reshape(1, -1))
166
+ print(np.round(similarity[0][0], 3))
167
+ edges_clip.append((c1, c2, np.round(float(np.round(similarity[0][0], 3)), 3)))
168
+
169
+
170
+ net_clip = Network(height="750px", width="100%",)
171
+ for e in edges_clip:
172
+ src = e[0]
173
+ dst = e[1]
174
+ w = e[2]
175
+
176
+ net_clip.add_node(src, src, title=src)
177
+ net_clip.add_node(dst, dst, title=dst)
178
+ net_clip.add_edge(src, dst, value=w, title=src + ' to ' + dst + ' similarity ' +str(w))
179
+
180
+ # Generate network with specific layout settings
181
+ net_clip.repulsion(
182
+ node_distance=420,
183
+ central_gravity=0.33,
184
+ spring_length=110,
185
+ spring_strength=0.10,
186
+ damping=0.95
187
+ )
188
+
189
+ # Save and read graph as HTML file (on Streamlit Sharing)
190
+ try:
191
+ path = '/tmp'
192
+ net_clip.save_graph(f'{path}/pyvis_graph_clip.html')
193
+ HtmlFile = open(f'{path}/pyvis_graph_clip.html', 'r', encoding='utf-8')
194
+
195
+ # Save and read graph as HTML file (locally)
196
+ except:
197
+ path = '/html_files'
198
+ net_clip.save_graph(f'{path}/pyvis_graph_clip.html')
199
+ HtmlFile = open(f'{path}/pyvis_graph_clip.html', 'r', encoding='utf-8')
200
+
201
+ # Load HTML file in HTML component for display on Streamlit page
202
+ components.html(HtmlFile.read(), height=435)
203
+
204
+ # ----------------------------- INPUT column 2 & 3 ----------------------------
205
+ with input_col_2:
206
+ with st.form('image_form'):
207
+
208
+ # image_id = st.number_input('Image ID: ', format='%d', step=1)
209
+ st.write('**Choose or generate a random image to test the disentanglement**')
210
+ chosen_image_id_input = st.empty()
211
+ image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
212
+
213
+ choose_image_button = st.form_submit_button('Choose the defined image')
214
+ random_id = st.form_submit_button('Generate a random image')
215
+
216
+ if random_id:
217
+ image_id = random.randint(0, 50000)
218
+ st.session_state.image_id = image_id
219
+ chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
220
+
221
+ if choose_image_button:
222
+ image_id = int(image_id)
223
+ st.session_state.image_id = int(image_id)
224
+ # st.write(image_id, st.session_state.image_id)
225
+
226
+ with input_col_3:
227
+ with st.form('Variate along the disentangled concepts'):
228
+ st.write('**Set range of change**')
229
+ chosen_epsilon_input = st.empty()
230
+ epsilon = chosen_epsilon_input.number_input('Epsilon:', min_value=1, step=1)
231
+ epsilon_button = st.form_submit_button('Choose the defined epsilon')
232
+
233
+ # # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
234
+
235
+
236
+ with dnnlib.util.open_url('./data/vase_model_files/network-snapshot-010000.pkl') as f:
237
+ model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
238
+
239
+ if st.session_state.space_id == 'Z':
240
+ original_image_vec = annotations['z_vectors'][st.session_state.image_id]
241
+ else:
242
+ original_image_vec = annotations['w_vectors'][st.session_state.image_id]
243
+
244
+ img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
245
+ # input_image = original_image_dict['image']
246
+ # input_label = original_image_dict['label']
247
+ # input_id = original_image_dict['id']
248
+
249
+ with smoothgrad_col_3:
250
+ st.image(img)
251
+ smooth_head_3.write(f'Base image')
252
+
253
+
254
+ images, lambdas = generate_joint_effect(model, original_image_vec, vectors, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon), latent_space=st.session_state.space_id)
255
+
256
+ with smoothgrad_col_1:
257
+ st.image(images[0])
258
+ smooth_head_1.write(f'Change of {np.round(lambdas[0], 2)}')
259
+
260
+ with smoothgrad_col_2:
261
+ st.image(images[1])
262
+ smooth_head_2.write(f'Change of {np.round(lambdas[1], 2)}')
263
+
264
+ with smoothgrad_col_4:
265
+ st.image(images[3])
266
+ smooth_head_4.write(f'Change of {np.round(lambdas[3], 2)}')
267
+
268
+ with smoothgrad_col_5:
269
+ st.image(images[4])
270
+ smooth_head_5.write(f'Change of {np.round(lambdas[4], 2)}')