ludusc commited on
Commit
e6dc87e
1 Parent(s): edcca83

tried to add network viz

Browse files
backend/disentangle_concepts.py CHANGED
@@ -80,13 +80,14 @@ def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=
80
  vectors[i,:] = vec
81
  important_nodes.append(set(imp_nodes))
82
 
83
- reducer = UMAP(n_neighbors=3, # default 15, The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation.
84
- n_components=3, # default 2, The dimension of the space to embed into.
85
- min_dist=0.1, # default 0.1, The effective minimum distance between embedded points.
86
- spread=2.0, # default 1.0, The effective scale of embedded points. In combination with ``min_dist`` this determines how clustered/clumped the embedded points are.
87
- random_state=0, # default: None, If int, random_state is the seed used by the random number generator;
88
- )
89
 
90
- projection = reducer.fit_transform(vectors)
91
  nodes_in_common = set.intersection(*important_nodes)
92
- return vectors, projection, nodes_in_common
 
 
80
  vectors[i,:] = vec
81
  important_nodes.append(set(imp_nodes))
82
 
83
+ # reducer = UMAP(n_neighbors=3, # default 15, The size of local neighborhood (in terms of number of neighboring sample points) used for manifold approximation.
84
+ # n_components=3, # default 2, The dimension of the space to embed into.
85
+ # min_dist=0.1, # default 0.1, The effective minimum distance between embedded points.
86
+ # spread=2.0, # default 1.0, The effective scale of embedded points. In combination with ``min_dist`` this determines how clustered/clumped the embedded points are.
87
+ # random_state=0, # default: None, If int, random_state is the seed used by the random number generator;
88
+ # )
89
 
90
+ # projection = reducer.fit_transform(vectors)
91
  nodes_in_common = set.intersection(*important_nodes)
92
+ return vectors, nodes_in_common
93
+
pages/1_Disentanglement.py CHANGED
@@ -47,7 +47,7 @@ with open(concepts) as f:
47
  if 'image_id' not in st.session_state:
48
  st.session_state.image_id = 0
49
  if 'concept_id' not in st.session_state:
50
- st.session_state.concept_id = 'abstract'
51
 
52
  # def on_change_random_input():
53
  # st.session_state.image_id = st.session_state.image_id
 
47
  if 'image_id' not in st.session_state:
48
  st.session_state.image_id = 0
49
  if 'concept_id' not in st.session_state:
50
+ st.session_state.concept_id = 'Abstract'
51
 
52
  # def on_change_random_input():
53
  # st.session_state.image_id = st.session_state.image_id
pages/2_Concepts_comparison.py CHANGED
@@ -1,14 +1,19 @@
1
  import streamlit as st
 
 
 
 
2
  import numpy as np
 
 
 
 
3
 
4
- from plotly.subplots import make_subplots
5
- import plotly.graph_objects as go
6
 
7
- import graphviz
8
 
9
- #from backend.maximally_activating_patches import load_layer_infos, load_activation, get_receptive_field_coordinates
10
- from frontend import on_click_graph
11
- #from backend.utils import load_dataset_dict
12
 
13
  HIGHTLIGHT_COLOR = '#e7bcc5'
14
  st.set_page_config(layout='wide')
@@ -19,145 +24,191 @@ st.write('> **How do the concept vectors relate to each other?**')
19
  st.write('> **What is their join impact on the image?**')
20
  st.write("""Description to write""")
21
 
22
- # -------------------------- LOAD DATASET ---------------------------------
23
- dataset_dict = load_dataset_dict()
24
-
25
- # -------------------------- LOAD GRAPH -----------------------------------
26
-
27
- def load_dot_to_graph(filename):
28
- dot = graphviz.Source.from_file(filename)
29
- source_lines = str(dot).splitlines()
30
- source_lines.pop(0)
31
- source_lines.pop(-1)
32
- graph = graphviz.Digraph()
33
- graph.body += source_lines
34
- return graph, dot
35
-
36
-
37
- # st.header('ConvNeXt')
38
- convnext_dot_file = './data/dot_architectures/convnext_architecture.dot'
39
- convnext_graph = load_dot_to_graph(convnext_dot_file)[0]
40
-
41
- convnext_graph.graph_attr['size'] = '4,40'
42
-
43
- # -------------------------- DISPLAY GRAPH -----------------------------------
44
-
45
- def chosen_node_text(clicked_node_title):
46
- clicked_node_title = clicked_node_title.replace('stage ', 'stage_').replace('block ', 'block_')
47
- stage_id = clicked_node_title.split()[0].split('_')[1] if 'stage' in clicked_node_title else None
48
- block_id = clicked_node_title.split()[1].split('_')[1] if 'block' in clicked_node_title else None
49
- layer_id = clicked_node_title.split()[-1]
50
 
51
- if 'embeddings' in layer_id:
52
- display_text = 'Patchify layer'
53
- activation_key = 'embeddings.patch_embeddings'
54
- elif 'downsampling' in layer_id:
55
- display_text = f'Stage {stage_id} > Downsampling layer'
56
- activation_key = f'encoder.stages[{stage_id}].downsampling_layer[1]'
57
- else:
58
- display_text = f'Stage {stage_id} > Block {block_id} > {layer_id} layer'
59
- activation_key = f'encoder.stages[{int(stage_id)-1}].layers[{int(block_id)-1}].{layer_id}'
60
- return display_text, activation_key
61
-
62
-
63
- props = {
64
- 'hightlight_color': HIGHTLIGHT_COLOR,
65
- 'initial_state': {
66
- 'group_1_header': 'Choose an option from group 1',
67
- 'group_2_header': 'Choose an option from group 2'
68
- }
69
- }
70
-
71
-
72
- col1, col2 = st.columns((2,5))
73
- col1.markdown("#### Architecture")
74
- col1.write('')
75
- col1.write('Click on a layer below to generate top-k maximally activating image patches')
76
- col1.graphviz_chart(convnext_graph)
77
-
78
- with col2:
79
- st.markdown("#### Output")
80
- nodes = on_click_graph(key='toggle_buttons', **props)
81
-
82
- # -------------------------- DISPLAY OUTPUT -----------------------------------
83
-
84
- if nodes != None:
85
- clicked_node_title = nodes["choice"]["node_title"]
86
- clicked_node_id = nodes["choice"]["node_id"]
87
- display_text, activation_key = chosen_node_text(clicked_node_title)
88
- col2.write(f'**Chosen layer:** {display_text}')
89
- # col2.write(f'**Activation key:** {activation_key}')
90
-
91
- hightlight_syle = f'''
92
- <style>
93
- div[data-stale]:has(iframe) {{
94
- height: 0;
95
- }}
96
- #{clicked_node_id}>polygon {{
97
- fill: {HIGHTLIGHT_COLOR};
98
- stroke: {HIGHTLIGHT_COLOR};
99
- }}
100
- </style>
101
- '''
102
- col2.markdown(hightlight_syle, unsafe_allow_html=True)
103
-
104
- with col2:
105
- layer_infos = None
106
- with st.form('top_k_form'):
107
- activation_path = './data/activation/convnext_activation.json'
108
- activation = load_activation(activation_path)
109
- num_channels = activation[activation_key].shape[1]
110
-
111
- top_k = st.slider('Choose K for top-K maximally activating patches', 1,20, value=10)
112
- channel_start, channel_end = st.slider(
113
- 'Choose channel range of this layer (recommend to choose small range less than 30)',
114
- 1, num_channels, value=(1, 30))
115
- summit_button = st.form_submit_button('Generate image patches')
116
- if summit_button:
117
-
118
- activation = activation[activation_key][:top_k,:,:]
119
- layer_infos = load_layer_infos('./data/layer_infos/convnext_layer_infos.json')
120
- # st.write(channel_start, channel_end)
121
- # st.write(activation.shape, activation.shape[1])
122
-
123
- if layer_infos != None:
124
- num_cols, num_rows = top_k, channel_end - channel_start + 1
125
- # num_rows = activation.shape[1]
126
- top_k_coor_max_ = activation
127
- st.markdown(f"#### Top-{top_k} maximally activating image patches of {num_rows} channels ({channel_start}-{channel_end})")
128
-
129
- for row in range(channel_start, channel_end+1):
130
- if row == channel_start:
131
- top_margin = 50
132
- fig = make_subplots(
133
- rows=1, cols=num_cols,
134
- subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True)
135
- else:
136
- top_margin = 0
137
- fig = make_subplots(rows=1, cols=num_cols, shared_yaxes=True)
138
- for col in range(1, num_cols+1):
139
- k, c = col-1, row-1
140
- img_index = int(top_k_coor_max_[k, c, 3])
141
- activation_value = top_k_coor_max_[k, c, 0]
142
- img = dataset_dict[img_index//10_000][img_index%10_000]['image']
143
- class_label = dataset_dict[img_index//10_000][img_index%10_000]['label']
144
- class_id = dataset_dict[img_index//10_000][img_index%10_000]['id']
145
-
146
- idx_x, idx_y = top_k_coor_max_[k, c, 1], top_k_coor_max_[k, c, 2]
147
- x1, x2, y1, y2 = get_receptive_field_coordinates(layer_infos, activation_key, idx_x, idx_y)
148
- img = np.array(img)[y1:y2, x1:x2, :]
149
-
150
- hovertemplate = f"""Top-{col}<br>Activation value: {activation_value:.5f}<br>Class Label: {class_label}<br>Class id: {class_id}<br>Image id: {img_index}"""
151
- fig.add_trace(go.Image(z=img, hovertemplate=hovertemplate), row=1, col=col)
152
- fig.update_xaxes(showticklabels=False, showgrid=False)
153
- fig.update_yaxes(showticklabels=False, showgrid=False)
154
- fig.update_layout(margin={'b':0, 't':top_margin, 'r':0, 'l':0})
155
- fig.update_layout(showlegend=False, yaxis_title=row)
156
- fig.update_layout(height=100, plot_bgcolor='rgba(0,0,0,0)', paper_bgcolor='rgba(0,0,0,0)')
157
- fig.update_layout(hoverlabel=dict(bgcolor="#e9f2f7"))
158
- st.plotly_chart(fig, use_container_width=True)
159
-
160
-
161
- else:
162
- col2.markdown(f'Chosen layer: <code>None</code>', unsafe_allow_html=True)
163
- col2.markdown("""<style>div[data-stale]:has(iframe) {height: 0};""", unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import streamlit.components.v1 as components
3
+
4
+ import pickle
5
+ import pandas as pd
6
  import numpy as np
7
+ from pyvis.network import Network
8
+ import networkx as nx
9
+
10
+ from sklearn.metrics.pairwise import cosine_similarity
11
 
12
+ from matplotlib.backends.backend_agg import RendererAgg
 
13
 
14
+ from backend.disentangle_concepts import *
15
 
16
+ _lock = RendererAgg.lock
 
 
17
 
18
  HIGHTLIGHT_COLOR = '#e7bcc5'
19
  st.set_page_config(layout='wide')
 
24
  st.write('> **What is their join impact on the image?**')
25
  st.write("""Description to write""")
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ annotations_file = './data/annotated_files/seeds0000-100000.pkl'
29
+ with open(annotations_file, 'rb') as f:
30
+ annotations = pickle.load(f)
31
+
32
+ ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-100000.csv')
33
+ concepts = './data/concepts.txt'
34
+
35
+ with open(concepts) as f:
36
+ labels = [line.strip() for line in f.readlines()]
37
+
38
+ if 'image_id' not in st.session_state:
39
+ st.session_state.image_id = 0
40
+ if 'concept_ids' not in st.session_state:
41
+ st.session_state.concept_ids = ['Abstract', 'Representational']
42
+
43
+ # def on_change_random_input():
44
+ # st.session_state.image_id = st.session_state.image_id
45
+
46
+ # ----------------------------- INPUT ----------------------------------
47
+ st.header('Input')
48
+ input_col_1, input_col_2, input_col_3 = st.columns(3)
49
+ # --------------------------- INPUT column 1 ---------------------------
50
+ with input_col_1:
51
+ with st.form('text_form'):
52
+
53
+ # image_id = st.number_input('Image ID: ', format='%d', step=1)
54
+ st.write('**Choose a series of concepts to compare**')
55
+ # chosen_text_id_input = st.empty()
56
+ # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
57
+ concept_ids = st.multiselect('Concept:', tuple(labels))
58
+
59
+ choose_text_button = st.form_submit_button('Choose the defined concepts')
60
+ # random_text = st.form_submit_button('Select a random concept')
61
+
62
+ # if random_text:
63
+ # concept_id = random.choice(labels)
64
+ # st.session_state.concept_id = concept_id
65
+ # chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
66
+
67
+ if choose_text_button:
68
+ st.session_state.concept_ids = list(concept_ids)
69
+ # st.write(image_id, st.session_state.image_id)
70
+
71
+ # ---------------------------- SET UP OUTPUT ------------------------------
72
+ epsilon_container = st.empty()
73
+ st.header('Output')
74
+ st.subheader('Concept vector')
75
+
76
+ # perform attack container
77
+ # header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
78
+ # output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
79
+ header_col_1, header_col_2 = st.columns([5,1])
80
+ output_col_1, output_col_2 = st.columns([5,1])
81
+
82
+ st.subheader('Derivations along the concept vector')
83
+
84
+ # prediction error container
85
+ error_container = st.empty()
86
+ smoothgrad_header_container = st.empty()
87
+
88
+ # smoothgrad container
89
+ smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
90
+ smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
91
+
92
+ # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
93
+ with output_col_1:
94
+ vectors, nodes_in_common = get_concepts_vectors(concept_ids, annotations, ann_df)
95
+ # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
96
+ #st.write('Concept vector', separation_vector)
97
+ header_col_1.write(f'Concepts {", ".join(concept_ids)} - Relevant nodes in common: {nodes_in_common}')# - Nodes {",".join(list(imp_nodes))}')
98
+
99
+ edges = []
100
+ for i in range(len(concepts)):
101
+ for j in range(len(concepts)):
102
+ if i != j:
103
+ print(f'Similarity between {concepts[i]} and {concepts[j]}')
104
+ similarity = cosine_similarity(vectors[i,:].reshape(1, -1), vectors[j,:].reshape(1, -1))
105
+ print(np.round(similarity[0][0], 3))
106
+ edges.append((concepts[i], concepts[j], np.round(similarity[0][0], 3)))
107
+
108
+ # Create an empty graph
109
+ G = nx.Graph()
110
+
111
+ # Add edges with weights to the graph
112
+ for edge in edges:
113
+ node1, node2, weight = edge
114
+ G.add_edge(node1, node2, weight=weight)
115
+
116
+
117
+ # Initiate PyVis network object
118
+ net = Network(
119
+ height='400px',
120
+ width='100%',
121
+ bgcolor='#222222',
122
+ font_color='white'
123
+ )
124
+
125
+ # Take Networkx graph and translate it to a PyVis graph format
126
+ net.from_nx(G)
127
+
128
+ # Generate network with specific layout settings
129
+ net.repulsion(
130
+ node_distance=420,
131
+ central_gravity=0.33,
132
+ spring_length=110,
133
+ spring_strength=0.10,
134
+ damping=0.95
135
+ )
136
+
137
+ # Save and read graph as HTML file (on Streamlit Sharing)
138
+ try:
139
+ path = '/tmp'
140
+ net.save_graph(f'{path}/pyvis_graph.html')
141
+ HtmlFile = open(f'{path}/pyvis_graph.html', 'r', encoding='utf-8')
142
+
143
+ # Save and read graph as HTML file (locally)
144
+ except:
145
+ path = '/html_files'
146
+ net.save_graph(f'{path}/pyvis_graph.html')
147
+ HtmlFile = open(f'{path}/pyvis_graph.html', 'r', encoding='utf-8')
148
+
149
+ # Load HTML file in HTML component for display on Streamlit page
150
+ components.html(HtmlFile.read(), height=435)
151
+
152
+ # ----------------------------- INPUT column 2 & 3 ----------------------------
153
+ # with input_col_2:
154
+ # with st.form('image_form'):
155
+
156
+ # # image_id = st.number_input('Image ID: ', format='%d', step=1)
157
+ # st.write('**Choose or generate a random image to test the disentanglement**')
158
+ # chosen_image_id_input = st.empty()
159
+ # image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
160
+
161
+ # choose_image_button = st.form_submit_button('Choose the defined image')
162
+ # random_id = st.form_submit_button('Generate a random image')
163
+
164
+ # if random_id:
165
+ # image_id = random.randint(0, 100000)
166
+ # st.session_state.image_id = image_id
167
+ # chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
168
+
169
+ # if choose_image_button:
170
+ # image_id = int(image_id)
171
+ # st.session_state.image_id = int(image_id)
172
+ # # st.write(image_id, st.session_state.image_id)
173
+
174
+ # with input_col_3:
175
+ # with st.form('Variate along the disentangled concept'):
176
+ # st.write('**Set range of change**')
177
+ # chosen_epsilon_input = st.empty()
178
+ # epsilon = chosen_epsilon_input.number_input('Epsilon:', min_value=1, step=1)
179
+ # epsilon_button = st.form_submit_button('Choose the defined epsilon')
180
+
181
+ # # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
182
+
183
+ # #model = torch.load('./data/model_files/pytorch_model.bin', map_location=torch.device('cpu'))
184
+ # with dnnlib.util.open_url('./data/model_files/network-snapshot-010600.pkl') as f:
185
+ # model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
186
+
187
+ # original_image_vec = annotations['z_vectors'][st.session_state.image_id]
188
+ # img = generate_original_image(original_image_vec, model)
189
+ # # input_image = original_image_dict['image']
190
+ # # input_label = original_image_dict['label']
191
+ # # input_id = original_image_dict['id']
192
+
193
+ # with smoothgrad_col_3:
194
+ # st.image(img)
195
+ # smooth_head_3.write(f'Base image')
196
+
197
+
198
+ # images, lambdas = regenerate_images(model, original_image_vec, separation_vector, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon))
199
+
200
+ # with smoothgrad_col_1:
201
+ # st.image(images[0])
202
+ # smooth_head_1.write(f'Change of {np.round(lambdas[0], 2)}')
203
+
204
+ # with smoothgrad_col_2:
205
+ # st.image(images[1])
206
+ # smooth_head_2.write(f'Change of {np.round(lambdas[1], 2)}')
207
+
208
+ # with smoothgrad_col_4:
209
+ # st.image(images[3])
210
+ # smooth_head_4.write(f'Change of {np.round(lambdas[3], 2)}')
211
+
212
+ # with smoothgrad_col_5:
213
+ # st.image(images[4])
214
+ # smooth_head_5.write(f'Change of {np.round(lambdas[4], 2)}')
requirements.txt CHANGED
@@ -16,3 +16,6 @@ altair==4.0
16
  #torch-utils
17
  opencv-python
18
  umap-learn
 
 
 
 
16
  #torch-utils
17
  opencv-python
18
  umap-learn
19
+ graphviz
20
+ networkx
21
+ pyvis
nx.html → tmp/nx.html RENAMED
File without changes