ludusc commited on
Commit
ae2da92
1 Parent(s): 4098d00

working version

Browse files
data/stored_vectors/scores_colors_hsv.csv CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:93f5789a80465ca7b21713819bc444d72239fa1b7ae56adf69e3323e0f3bedd1
3
- size 974247
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2d237e5efaac5afd98d777681cbfbf77bf7c41b8e4f221557fc588ab17e5e42b
3
+ size 974255
pages/1_Textiles_Disentanglement.py CHANGED
@@ -34,6 +34,13 @@ with open(annotations_file, 'rb') as f:
34
  concept_vectors = pd.read_csv('./data/stored_vectors/scores_colors_hsv.csv')
35
  concept_vectors['vector'] = [np.array([float(xx) for xx in x]) for x in concept_vectors['vector'].str.split(', ')]
36
  concept_vectors['score'] = concept_vectors['score'].astype(float)
 
 
 
 
 
 
 
37
  concept_vectors = concept_vectors.sort_values('score', ascending=False).reset_index()
38
 
39
  with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:
@@ -53,14 +60,25 @@ if 'saturation_lambda' not in st.session_state:
53
  st.session_state.saturation_lambda = 0
54
  if 'value_lambda' not in st.session_state:
55
  st.session_state.value_lambda = 0
56
-
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  # def on_change_random_input():
59
  # st.session_state.image_id = st.session_state.image_id
60
 
61
  # ----------------------------- INPUT ----------------------------------
62
  st.header('Input')
63
- input_col_1, input_col_2, input_col_3 = st.columns(3)
64
  # --------------------------- INPUT column 1 ---------------------------
65
  with input_col_1:
66
  with st.form('image_form'):
@@ -80,44 +98,69 @@ with input_col_1:
80
 
81
  if choose_image_button:
82
  image_id = int(image_id)
83
- st.session_state.image_id = int(image_id)
84
 
85
  with input_col_2:
86
  with st.form('text_form_1'):
87
 
88
  st.write('**Choose color to vary**')
89
- type_col = st.selectbox('Color:', tuple(COLORS_LIST))
90
  colors_button = st.form_submit_button('Choose the defined color')
91
 
92
  st.write('**Set range of change**')
93
  chosen_color_lambda_input = st.empty()
94
- color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=0, step=1, value=7)
95
- color_lambda_button = st.form_submit_button('Choose the defined lambda')
96
 
97
  if colors_button or color_lambda_button:
 
98
  st.session_state.concept_ids = type_col
99
  st.session_state.color_lambda = color_lambda
100
 
101
-
102
-
103
 
104
  with input_col_3:
105
  with st.form('text_form'):
106
 
107
  st.write('**Saturation variation**')
108
  chosen_saturation_lambda_input = st.empty()
109
- saturation_lambda = chosen_saturation_lambda_input.number_input('Lambda:', min_value=0, step=1, key=0)
110
- saturation_lambda_button = st.form_submit_button('Choose the defined lambda for Saturation')
111
 
112
  st.write('**Value variation**')
113
  chosen_value_lambda_input = st.empty()
114
- value_lambda = chosen_value_lambda_input.number_input('Lambda:', min_value=0, step=1, key=1)
115
- value_lambda_button = st.form_submit_button('Choose the defined lambda for Value')
116
 
117
  if saturation_lambda_button or value_lambda_button:
118
  st.session_state.saturation_lambda = int(saturation_lambda)
119
  st.session_state.value_lambda = int(value_lambda)
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  # with input_col_4:
122
  # with st.form('Network specifics:'):
123
  # st.write('**Choose a latent space to use**')
@@ -153,9 +196,19 @@ with header_col_1:
153
  st.write(f'Original image')
154
 
155
  with header_col_2:
156
- color_separation_vector, performance_color = concept_vectors[concept_vectors['color'] == st.session_state.concept_ids].reset_index().loc[0, ['vector', 'score']]
157
- saturation_separation_vector, performance_saturation = concept_vectors[concept_vectors['color'] == 'Saturation'].reset_index().loc[0, ['vector', 'score']]
158
- value_separation_vector, performance_value = concept_vectors[concept_vectors['color'] == 'Value'].reset_index().loc[0, ['vector', 'score']]
 
 
 
 
 
 
 
 
 
 
159
  st.write(f'Change in {st.session_state.concept_ids} of {np.round(st.session_state.color_lambda, 2)}, in saturation of {np.round(st.session_state.saturation_lambda, 2)}, in value of {np.round(st.session_state.value_lambda, 2)}. - Performance color vector: {performance_color}, saturation vector: {performance_saturation/100}, value vector: {performance_value/100}')
160
 
161
  # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
 
34
  concept_vectors = pd.read_csv('./data/stored_vectors/scores_colors_hsv.csv')
35
  concept_vectors['vector'] = [np.array([float(xx) for xx in x]) for x in concept_vectors['vector'].str.split(', ')]
36
  concept_vectors['score'] = concept_vectors['score'].astype(float)
37
+
38
+ concept_vectors['sign'] = [True if 'sign:True' in val else False for val in concept_vectors['kwargs']]
39
+ concept_vectors['extremes'] = [True if 'extremes method:True' in val else False for val in concept_vectors['kwargs']]
40
+ concept_vectors['regularization'] = [float(val.split(',')[1].strip('regularization: ')) if 'regularization:' in val else False for val in concept_vectors['kwargs']]
41
+ concept_vectors['cl_method'] = [val.split(',')[0].strip('classification method:') if 'classification method:' in val else False for val in concept_vectors['kwargs']]
42
+ concept_vectors['num_factors'] = [int(val.split(',')[1].strip('number of factors:')) if 'number of factors:' in val else False for val in concept_vectors['kwargs']]
43
+
44
  concept_vectors = concept_vectors.sort_values('score', ascending=False).reset_index()
45
 
46
  with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:
 
60
  st.session_state.saturation_lambda = 0
61
  if 'value_lambda' not in st.session_state:
62
  st.session_state.value_lambda = 0
63
+ if 'sign' not in st.session_state:
64
+ st.session_state.sign = False
65
+ if 'extremes' not in st.session_state:
66
+ st.session_state.extremes = False
67
+ if 'regularization' not in st.session_state:
68
+ st.session_state.regularization = False
69
+ if 'cl_method' not in st.session_state:
70
+ st.session_state.cl_method = False
71
+ if 'num_factors' not in st.session_state:
72
+ st.session_state.num_factors = False
73
+ if 'best' not in st.session_state:
74
+ st.session_state.best = True
75
 
76
  # def on_change_random_input():
77
  # st.session_state.image_id = st.session_state.image_id
78
 
79
  # ----------------------------- INPUT ----------------------------------
80
  st.header('Input')
81
+ input_col_1, input_col_2, input_col_3, input_col_4 = st.columns(4)
82
  # --------------------------- INPUT column 1 ---------------------------
83
  with input_col_1:
84
  with st.form('image_form'):
 
98
 
99
  if choose_image_button:
100
  image_id = int(image_id)
101
+ st.session_state.image_id = image_id
102
 
103
  with input_col_2:
104
  with st.form('text_form_1'):
105
 
106
  st.write('**Choose color to vary**')
107
+ type_col = st.selectbox('Color:', tuple(COLORS_LIST), index=7)
108
  colors_button = st.form_submit_button('Choose the defined color')
109
 
110
  st.write('**Set range of change**')
111
  chosen_color_lambda_input = st.empty()
112
+ color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=-100, step=1, value=7)
113
+ color_lambda_button = st.form_submit_button('Choose the defined lambda for color')
114
 
115
  if colors_button or color_lambda_button:
116
+ st.session_state.image_id = image_id
117
  st.session_state.concept_ids = type_col
118
  st.session_state.color_lambda = color_lambda
119
 
 
 
120
 
121
  with input_col_3:
122
  with st.form('text_form'):
123
 
124
  st.write('**Saturation variation**')
125
  chosen_saturation_lambda_input = st.empty()
126
+ saturation_lambda = chosen_saturation_lambda_input.number_input('Lambda:', min_value=-100, step=1, key=0, value=0)
127
+ saturation_lambda_button = st.form_submit_button('Choose the defined lambda for saturation')
128
 
129
  st.write('**Value variation**')
130
  chosen_value_lambda_input = st.empty()
131
+ value_lambda = chosen_value_lambda_input.number_input('Lambda:', min_value=-100, step=1, key=1, value=0)
132
+ value_lambda_button = st.form_submit_button('Choose the defined lambda for salue')
133
 
134
  if saturation_lambda_button or value_lambda_button:
135
  st.session_state.saturation_lambda = int(saturation_lambda)
136
  st.session_state.value_lambda = int(value_lambda)
137
 
138
+ with input_col_4:
139
+ with st.form('text_form_2'):
140
+ st.write('Use best options')
141
+ best = st.selectbox('Option:', tuple([True, False]), index=0)
142
+ st.write('Options for StyleSpace (not available for Saturation and Value)')
143
+ sign = st.selectbox('Sign option:', tuple([True, False]), index=1)
144
+ num_factors = st.selectbox('Number of factors option:', tuple([1, 5, 10, 20, False]), index=4)
145
+ st.write('Options for InterFaceGAN (not available for Saturation and Value)')
146
+ cl_method = st.selectbox('Classification method option:', tuple(['LR', 'SVM', False]), index=2)
147
+ regularization = st.selectbox('Regularization option:', tuple([0.1, 1.0, False]), index=2)
148
+ st.write('Options for InterFaceGAN (only for Saturation and Value)')
149
+ extremes = st.selectbox('Extremes option:', tuple([True, False]), index=1)
150
+
151
+ choose_options_button = st.form_submit_button('Choose the defined options')
152
+ # st.write('**Choose a latent space to disentangle**')
153
+ # # chosen_text_id_input = st.empty()
154
+ # # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
155
+ # space_id = st.selectbox('Space:', tuple(['Z', 'W']))
156
+ if choose_options_button:
157
+ st.session_state.sign = sign
158
+ st.session_state.num_factors = num_factors
159
+ st.session_state.cl_method = cl_method
160
+ st.session_state.regularization = regularization
161
+ st.session_state.extremes = extremes
162
+ st.session_state.best = best
163
+
164
  # with input_col_4:
165
  # with st.form('Network specifics:'):
166
  # st.write('**Choose a latent space to use**')
 
196
  st.write(f'Original image')
197
 
198
  with header_col_2:
199
+ if st.session_state.best:
200
+ color_separation_vector, performance_color = concept_vectors[concept_vectors['color'] == st.session_state.concept_ids].reset_index().loc[0, ['vector', 'score']]
201
+ saturation_separation_vector, performance_saturation = concept_vectors[concept_vectors['color'] == 'Saturation'].reset_index().loc[0, ['vector', 'score']]
202
+ value_separation_vector, performance_value = concept_vectors[concept_vectors['color'] == 'Value'].reset_index().loc[0, ['vector', 'score']]
203
+ else:
204
+ tmp = concept_vectors[concept_vectors['color'] == st.session_state.concept_ids]
205
+ tmp = tmp[tmp['sign'] == st.session_state.sign][tmp['num_factors'] == st.session_state.num_factors][tmp['cl_method'] == st.session_state.cl_method][tmp['regularization'] == st.session_state.regularization]
206
+ color_separation_vector, performance_color = tmp.reset_index().loc[0, ['vector', 'score']]
207
+ tmp_value = concept_vectors[concept_vectors['color'] == 'Value'][concept_vectors['extremes'] == st.session_state.extremes]
208
+ value_separation_vector, performance_value = tmp_value.reset_index().loc[0, ['vector', 'score']]
209
+ tmp_sat = concept_vectors[concept_vectors['color'] == 'Saturation'][concept_vectors['extremes'] == st.session_state.extremes]
210
+ saturation_separation_vector, performance_saturation = tmp_sat.reset_index().loc[0, ['vector', 'score']]
211
+
212
  st.write(f'Change in {st.session_state.concept_ids} of {np.round(st.session_state.color_lambda, 2)}, in saturation of {np.round(st.session_state.saturation_lambda, 2)}, in value of {np.round(st.session_state.value_lambda, 2)}. - Performance color vector: {performance_color}, saturation vector: {performance_saturation/100}, value vector: {performance_value/100}')
213
 
214
  # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
pages/2_Colours_comparison.py CHANGED
@@ -22,91 +22,124 @@ HIGHTLIGHT_COLOR = '#e7bcc5'
22
  st.set_page_config(layout='wide')
23
 
24
 
25
- st.title('Comparison among concept vectors')
26
- st.write('> **How do the concept vectors relate to each other?**')
27
- st.write('> **What is their join impact on the image?**')
28
- st.write("""Description to write""")
29
 
30
-
31
- annotations_file = './data/annotated_files/seeds0000-50000.pkl'
32
  with open(annotations_file, 'rb') as f:
33
  annotations = pickle.load(f)
34
 
35
- ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-50000.csv')
36
- concepts = './data/concepts.txt'
 
 
 
 
 
 
 
 
 
 
37
 
38
- with open(concepts) as f:
39
- labels = [line.strip() for line in f.readlines()]
40
 
41
  if 'image_id' not in st.session_state:
42
  st.session_state.image_id = 0
43
  if 'concept_ids' not in st.session_state:
44
- st.session_state.concept_ids = ['Abstract', 'Representational']
 
 
 
 
 
 
 
 
 
 
 
 
45
  if 'space_id' not in st.session_state:
46
- st.session_state.space_id = 'Z'
47
- # def on_change_random_input():
48
- # st.session_state.image_id = st.session_state.image_id
49
 
50
  # ----------------------------- INPUT ----------------------------------
51
  st.header('Input')
52
- input_col_1, input_col_2, input_col_3 = st.columns(3)
53
  # --------------------------- INPUT column 1 ---------------------------
54
  with input_col_1:
55
  with st.form('text_form'):
56
 
57
  # image_id = st.number_input('Image ID: ', format='%d', step=1)
58
- st.write('**Choose a series of concepts to compare**')
59
  # chosen_text_id_input = st.empty()
60
  # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
61
- concept_ids = st.multiselect('Concept:', tuple(labels))
62
-
63
- st.write('**Choose a latent space to disentangle**')
64
- # chosen_text_id_input = st.empty()
65
- # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
66
- space_id = st.selectbox('Space:', tuple(['Z', 'W']))
67
-
68
- choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
69
 
70
  if choose_text_button:
71
  st.session_state.concept_ids = list(concept_ids)
72
- space_id = str(space_id)
73
- st.session_state.space_id = space_id
74
- # st.write(image_id, st.session_state.image_id)
75
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # ---------------------------- SET UP OUTPUT ------------------------------
77
  epsilon_container = st.empty()
78
- st.header('Output')
79
- st.subheader('Concept vector')
80
 
81
- # perform attack container
82
- # header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
83
- # output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
84
- header_col_1, header_col_2 = st.columns([1,1])
85
- output_col_1, output_col_2 = st.columns([1,1])
86
 
87
- st.subheader('Derivations along the concept vector')
88
-
89
- # prediction error container
90
- error_container = st.empty()
91
- smoothgrad_header_container = st.empty()
 
 
 
92
 
93
- # smoothgrad container
94
- smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
95
- smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
 
 
 
96
 
97
- # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
98
  with output_col_1:
99
- vectors, nodes_in_common, performances = get_concepts_vectors(concept_ids, annotations, ann_df, latent_space=space_id)
100
- header_col_1.write(f'Concepts {", ".join(concept_ids)} - Latent space {space_id} - Relevant nodes in common: {nodes_in_common} - Performance of the concept vectors: {performances}')# - Nodes {",".join(list(imp_nodes))}')
101
-
102
  edges = []
103
  for i in range(len(concept_ids)):
104
  for j in range(len(concept_ids)):
105
- if i != j:
106
  print(f'Similarity between {concept_ids[i]} and {concept_ids[j]}')
107
- similarity = cosine_similarity(vectors[i,:].reshape(1, -1), vectors[j,:].reshape(1, -1))
108
  print(np.round(similarity[0][0], 3))
109
- edges.append((concept_ids[i], concept_ids[j], np.round(similarity[0][0], 3)))
110
 
111
 
112
  net = Network(height="750px", width="100%",)
@@ -142,123 +175,3 @@ with output_col_1:
142
 
143
  # Load HTML file in HTML component for display on Streamlit page
144
  components.html(HtmlFile.read(), height=435)
145
-
146
- with output_col_2:
147
- with open('data/CLIP_vecs.pkl', 'rb') as f:
148
- vectors_CLIP = pickle.load(f)
149
-
150
- # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
151
- #st.write('Concept vector', separation_vector)
152
- header_col_2.write(f'Concepts {", ".join(concept_ids)} - Latent space CLIP')# - Nodes {",".join(list(imp_nodes))}')
153
-
154
- edges_clip = []
155
- for c1 in concept_ids:
156
- for c2 in concept_ids:
157
- if c1 != c2:
158
- print(f'Similarity between {c1} and {c2}')
159
- similarity = cosine_similarity(vectors_CLIP[c1].reshape(1, -1), vectors_CLIP[c2].reshape(1, -1))
160
- print(np.round(similarity[0][0], 3))
161
- edges_clip.append((c1, c2, np.round(float(np.round(similarity[0][0], 3)), 3)))
162
-
163
-
164
- net_clip = Network(height="750px", width="100%",)
165
- for e in edges_clip:
166
- src = e[0]
167
- dst = e[1]
168
- w = e[2]
169
-
170
- net_clip.add_node(src, src, title=src)
171
- net_clip.add_node(dst, dst, title=dst)
172
- net_clip.add_edge(src, dst, value=w, title=src + ' to ' + dst + ' similarity ' +str(w))
173
-
174
- # Generate network with specific layout settings
175
- net_clip.repulsion(
176
- node_distance=420,
177
- central_gravity=0.33,
178
- spring_length=110,
179
- spring_strength=0.10,
180
- damping=0.95
181
- )
182
-
183
- # Save and read graph as HTML file (on Streamlit Sharing)
184
- try:
185
- path = '/tmp'
186
- net_clip.save_graph(f'{path}/pyvis_graph_clip.html')
187
- HtmlFile = open(f'{path}/pyvis_graph_clip.html', 'r', encoding='utf-8')
188
-
189
- # Save and read graph as HTML file (locally)
190
- except:
191
- path = '/html_files'
192
- net_clip.save_graph(f'{path}/pyvis_graph_clip.html')
193
- HtmlFile = open(f'{path}/pyvis_graph_clip.html', 'r', encoding='utf-8')
194
-
195
- # Load HTML file in HTML component for display on Streamlit page
196
- components.html(HtmlFile.read(), height=435)
197
-
198
- # ----------------------------- INPUT column 2 & 3 ----------------------------
199
- with input_col_2:
200
- with st.form('image_form'):
201
-
202
- # image_id = st.number_input('Image ID: ', format='%d', step=1)
203
- st.write('**Choose or generate a random image to test the disentanglement**')
204
- chosen_image_id_input = st.empty()
205
- image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
206
-
207
- choose_image_button = st.form_submit_button('Choose the defined image')
208
- random_id = st.form_submit_button('Generate a random image')
209
-
210
- if random_id:
211
- image_id = random.randint(0, 50000)
212
- st.session_state.image_id = image_id
213
- chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
214
-
215
- if choose_image_button:
216
- image_id = int(image_id)
217
- st.session_state.image_id = int(image_id)
218
- # st.write(image_id, st.session_state.image_id)
219
-
220
- with input_col_3:
221
- with st.form('Variate along the disentangled concepts'):
222
- st.write('**Set range of change**')
223
- chosen_epsilon_input = st.empty()
224
- epsilon = chosen_epsilon_input.number_input('Epsilon:', min_value=1, step=1)
225
- epsilon_button = st.form_submit_button('Choose the defined epsilon')
226
-
227
- # # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
228
-
229
-
230
- with dnnlib.util.open_url('./data/model_files/network-snapshot-010600.pkl') as f:
231
- model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
232
-
233
- if st.session_state.space_id == 'Z':
234
- original_image_vec = annotations['z_vectors'][st.session_state.image_id]
235
- else:
236
- original_image_vec = annotations['w_vectors'][st.session_state.image_id]
237
-
238
- img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
239
- # input_image = original_image_dict['image']
240
- # input_label = original_image_dict['label']
241
- # input_id = original_image_dict['id']
242
-
243
- with smoothgrad_col_3:
244
- st.image(img)
245
- smooth_head_3.write(f'Base image')
246
-
247
-
248
- images, lambdas = generate_joint_effect(model, original_image_vec, vectors, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon), latent_space=st.session_state.space_id)
249
-
250
- with smoothgrad_col_1:
251
- st.image(images[0])
252
- smooth_head_1.write(f'Change of {np.round(lambdas[0], 2)}')
253
-
254
- with smoothgrad_col_2:
255
- st.image(images[1])
256
- smooth_head_2.write(f'Change of {np.round(lambdas[1], 2)}')
257
-
258
- with smoothgrad_col_4:
259
- st.image(images[3])
260
- smooth_head_4.write(f'Change of {np.round(lambdas[3], 2)}')
261
-
262
- with smoothgrad_col_5:
263
- st.image(images[4])
264
- smooth_head_5.write(f'Change of {np.round(lambdas[4], 2)}')
 
22
  st.set_page_config(layout='wide')
23
 
24
 
25
+ st.title('Comparison among color directions')
26
+ st.write('> **How do the color directions relate to each other?**')
27
+ st.write('> **What is their joint impact on the image?**')
 
28
 
29
+
30
+ annotations_file = './data/textile_annotated_files/seeds0000-100000_S.pkl'
31
  with open(annotations_file, 'rb') as f:
32
  annotations = pickle.load(f)
33
 
34
+ concept_vectors = pd.read_csv('./data/stored_vectors/scores_colors_hsv.csv')
35
+ concept_vectors['vector'] = [np.array([float(xx) for xx in x]) for x in concept_vectors['vector'].str.split(', ')]
36
+ concept_vectors['score'] = concept_vectors['score'].astype(float)
37
+ concept_vectors['sign'] = [True if 'sign:True' in val else False for val in concept_vectors['kwargs']]
38
+ concept_vectors['extremes'] = [True if 'extremes method:True' in val else False for val in concept_vectors['kwargs']]
39
+ concept_vectors['regularization'] = [float(val.split(',')[1].strip('regularization: ')) if 'regularization:' in val else False for val in concept_vectors['kwargs']]
40
+ concept_vectors['cl_method'] = [val.split(',')[0].strip('classification method:') if 'classification method:' in val else False for val in concept_vectors['kwargs']]
41
+ concept_vectors['num_factors'] = [int(val.split(',')[1].strip('number of factors:')) if 'number of factors:' in val else False for val in concept_vectors['kwargs']]
42
+ concept_vectors = concept_vectors.sort_values('score', ascending=False).reset_index()
43
+
44
+ with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:
45
+ model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
46
 
47
+ COLORS_LIST = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue', 'Blue', 'Purple', 'Pink', 'Saturation', 'Value']
 
48
 
49
  if 'image_id' not in st.session_state:
50
  st.session_state.image_id = 0
51
  if 'concept_ids' not in st.session_state:
52
+ st.session_state.concept_ids = [COLORS_LIST[-1], COLORS_LIST[-2], ]
53
+ if 'sign' not in st.session_state:
54
+ st.session_state.sign = False
55
+ if 'extremes' not in st.session_state:
56
+ st.session_state.extremes = False
57
+ if 'regularization' not in st.session_state:
58
+ st.session_state.regularization = False
59
+ if 'cl_method' not in st.session_state:
60
+ st.session_state.cl_method = False
61
+ if 'num_factors' not in st.session_state:
62
+ st.session_state.num_factors = False
63
+
64
+
65
  if 'space_id' not in st.session_state:
66
+ st.session_state.space_id = 'W'
 
 
67
 
68
  # ----------------------------- INPUT ----------------------------------
69
  st.header('Input')
70
+ input_col_1, input_col_2 = st.columns([1,1])
71
  # --------------------------- INPUT column 1 ---------------------------
72
  with input_col_1:
73
  with st.form('text_form'):
74
 
75
  # image_id = st.number_input('Image ID: ', format='%d', step=1)
76
+ st.write('**Choose a series of colors to compare**')
77
  # chosen_text_id_input = st.empty()
78
  # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
79
+ concept_ids = st.multiselect('Color (including Saturation and Value):', tuple(COLORS_LIST), default=[COLORS_LIST[-1], COLORS_LIST[-2], ])
80
+ choose_text_button = st.form_submit_button('Choose the defined colors')
 
 
 
 
 
 
81
 
82
  if choose_text_button:
83
  st.session_state.concept_ids = list(concept_ids)
84
+
85
+
86
+ with input_col_2:
87
+ with st.form('text_form_1'):
88
+ st.write('Options for StyleSpace (not available for Saturation and Value)')
89
+ sign = st.selectbox('Sign option:', tuple([True, False]), index=1)
90
+ num_factors = st.selectbox('Number of factors option:', tuple([1, 5, 10, 20, False]), index=4)
91
+ st.write('Options for InterFaceGAN (not available for Saturation and Value)')
92
+ cl_method = st.selectbox('Classification method option:', tuple(['LR', 'SVM', False]), index=2)
93
+ regularization = st.selectbox('Regularization option:', tuple([0.1, 1.0, False]), index=2)
94
+ st.write('Options for InterFaceGAN (only for Saturation and Value)')
95
+ extremes = st.selectbox('Extremes option:', tuple([True, False]), index=1)
96
+
97
+ choose_options_button = st.form_submit_button('Choose the defined options')
98
+ # st.write('**Choose a latent space to disentangle**')
99
+ # # chosen_text_id_input = st.empty()
100
+ # # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
101
+ # space_id = st.selectbox('Space:', tuple(['Z', 'W']))
102
+ if choose_options_button:
103
+ st.session_state.sign = sign
104
+ st.session_state.num_factors = num_factors
105
+ st.session_state.cl_method = cl_method
106
+ st.session_state.regularization = regularization
107
+ st.session_state.extremes = extremes
108
+
109
  # ---------------------------- SET UP OUTPUT ------------------------------
110
  epsilon_container = st.empty()
111
+ st.header('Comparison')
112
+ st.subheader('Color vectors')
113
 
114
+ header_col_1, header_col_2 = st.columns([3,1])
115
+ output_col_1, output_col_2 = st.columns([3,1])
 
 
 
116
 
117
+ # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
118
+ tmp = concept_vectors[concept_vectors['color'].isin(st.session_state.concept_ids)]
119
+ tmp = tmp[tmp['sign'] == st.session_state.sign][tmp['extremes'] == st.session_state.extremes][tmp['num_factors'] == st.session_state.num_factors][tmp['cl_method'] == st.session_state.cl_method][tmp['regularization'] == st.session_state.regularization]
120
+ info = tmp.loc[:, ['vector', 'score', 'color', 'kwargs']].values
121
+ concept_ids = [i[2] for i in info] #+ ' ' + i[3]
122
+
123
+ with header_col_1:
124
+ st.write('Similarity graph')
125
 
126
+ with header_col_2:
127
+ st.write('Information')
128
+
129
+ with output_col_2:
130
+ for i,concept_id in enumerate(concept_ids):
131
+ st.write(f'Color {info[i][2]} - Settings: {info[i][3]} Performance of the color vector: {info[i][1]}')# - Nodes {",".join(list(imp_nodes))}')
132
 
 
133
  with output_col_1:
134
+
 
 
135
  edges = []
136
  for i in range(len(concept_ids)):
137
  for j in range(len(concept_ids)):
138
+ if i != j and info[i][2] != info[j][2]:
139
  print(f'Similarity between {concept_ids[i]} and {concept_ids[j]}')
140
+ similarity = cosine_similarity(info[i][0].reshape(1, -1), info[j][0].reshape(1, -1))
141
  print(np.round(similarity[0][0], 3))
142
+ edges.append((concept_ids[i], concept_ids[j], np.round(similarity[0][0] + 0.001, 3)))
143
 
144
 
145
  net = Network(height="750px", width="100%",)
 
175
 
176
  # Load HTML file in HTML component for display on Streamlit page
177
  components.html(HtmlFile.read(), height=435)