ludusc commited on
Commit
4347fc3
β€’
1 Parent(s): fcc16aa

added most files for entanglement project

Browse files
Files changed (24) hide show
  1. .gitignore +1 -0
  2. Home.py +11 -23
  3. Visual-Explanation-Methods-PyTorch +0 -1
  4. backend/disentangle_concepts.py +47 -0
  5. data/annotated_files/annotations_seeds0000-1000.pkl +3 -0
  6. data/concepts.txt +3 -0
  7. data/model_files/pytorch_model.bin +3 -0
  8. data/{ImageNet_metadata.csv β†’ old/ImageNet_metadata.csv} +0 -0
  9. data/{activation β†’ old/activation}/convnext_activation.json +0 -0
  10. data/{activation β†’ old/activation}/mobilenet_activation.json +0 -0
  11. data/{activation β†’ old/activation}/resnet_activation.json +0 -0
  12. data/{dot_architectures β†’ old/dot_architectures}/convnext_architecture.dot +0 -0
  13. data/{layer_infos β†’ old/layer_infos}/convnext_layer_infos.json +0 -0
  14. data/{layer_infos β†’ old/layer_infos}/mobilenet_layer_infos.json +0 -0
  15. data/{layer_infos β†’ old/layer_infos}/resnet_layer_infos.json +0 -0
  16. data/{preprocessed_image_net β†’ old/preprocessed_image_net}/val_data_0.pkl +0 -0
  17. data/{preprocessed_image_net β†’ old/preprocessed_image_net}/val_data_1.pkl +0 -0
  18. data/{preprocessed_image_net β†’ old/preprocessed_image_net}/val_data_2.pkl +0 -0
  19. data/{preprocessed_image_net β†’ old/preprocessed_image_net}/val_data_3.pkl +0 -0
  20. data/{preprocessed_image_net β†’ old/preprocessed_image_net}/val_data_4.pkl +0 -0
  21. pages/{3_Adversarial_attack.py β†’ 1_Disentanglement.py} +127 -129
  22. pages/{1_Maximally_activating_patches.py β†’ 2_todo.py} +0 -0
  23. pages/{2_SmoothGrad.py β†’ 3_todo.py} +0 -0
  24. pages/{4_ImageNet1k.py β†’ 4_todo.py} +0 -0
.gitignore CHANGED
@@ -32,6 +32,7 @@ git-large-file
32
  deta_drive.py
33
  secret_keys.py
34
 
 
35
  # Large files
36
  # data/preprocessed_image_net/
37
  # data/activation/*.pkl
 
32
  deta_drive.py
33
  secret_keys.py
34
 
35
+ data/old
36
  # Large files
37
  # data/preprocessed_image_net/
38
  # data/activation/*.pkl
Home.py CHANGED
@@ -7,37 +7,25 @@ st.set_page_config(layout='wide')
7
  st.title('About')
8
 
9
  # INTRO
10
- intro_text = """Convolutional neural networks (ConvNets) have evolved at a rapid speed from the 2010s.
11
- Some of the representative ConvNets models are VGGNet, Inceptions, ResNe(X)t, DenseNet, MobileNet, EfficientNet and RegNet, which focus on various factors of accuracy, efficiency, and scalability.
12
- In the year 2020, Vision Transformers (ViT) was introduced as a Transformer model solving the computer vision problems.
13
- Larger model and dataset sizes allow ViT to perform significantly better than ResNet, however, ViT still encountered challenges in generic computer vision tasks such as object detection and semantic segmentation.
14
- Swin Transformer’ s success made Transformers be adopted as a generic vision backbone and showed outstanding performance in a wide range of computer vision tasks.
15
- Nevertheless, rather than the intrinsic inductive biases of convolutions, the success of this approach is still primarily attributed to Transformers’ inherent superiority.
16
 
17
- In 2022, Zhuang Liu et. al. proposed a pure convolutional model dubbed ConvNeXt, discovered from the modernization of a standard ResNet towards the design of Vision Transformers and claimed to outperform them.
 
 
18
 
19
- The project aims to interpret the ConvNeXt model by several visualization techniques.
20
- After that, a web interface would be built to demonstrate the interpretations, helping us look inside the deep ConvNeXt model and answer the questions:
21
- > β€œWhat patterns maximally activated this filter (channel) in this layer?”\n
22
- > β€œWhich features are responsible for the current prediction?”.
23
-
24
- Due to the limitation in time and resources, the project only used the tiny-sized ConvNeXt model, which was trained on ImageNet-1k at resolution 224x224 and used 50,000 images in validation set of ImageNet-1k for demo purpose.
25
-
26
- In this web app, two visualization techniques were implemented and demonstrated, they are **Maximally activating patches** and **SmoothGrad**.
27
- Besides, this web app also helps investigate the effect of **adversarial attacks** on ConvNeXt interpretations.
28
- Last but not least, there is a last webpage that stores 50,000 images in the **ImageNet-1k** validation set, facilitating the two web pages above in searching and referencing.
29
  """
30
  st.write(intro_text)
31
 
32
  # 4 PAGES
33
- st.subheader('Features')
34
  sections_text = """Overall, there are 4 features in this web app:
35
- 1) Maximally activating patches: The visualization method in this page answers the question β€œwhat patterns maximally activated this filter (channel)?”.
36
- 2) SmoothGrad: This visualization method in this page answers the question β€œwhich features are responsible for the current prediction?”.
37
- 3) Adversarial attack: How adversarial attacks affect ConvNeXt interpretation?
38
- 4) ImageNet1k: The storage of 50,000 images in validation set.
39
  """
40
  st.write(sections_text)
41
 
42
 
43
- add_footer('Developed with ❀ by ', 'Hanna Ta Quynh Nga', 'https://www.linkedin.com/in/ta-quynh-nga-hanna/')
 
7
  st.title('About')
8
 
9
  # INTRO
10
+ intro_text = """This project investigates the nature and nurture of latent spaces, with the aim of formulating a theory of this particular vectorial space. It draws together reflections on the inherent constraints of latent spaces in particular architectures and considers the learning-specific features that emerge.
11
+ The thesis concentrates mostly on the second part, exploring different avenues for understanding the space. Using a multitude of vision generative models, it discusses possibilities for the systematic exploration of space, including disentanglement properties and coverage of various guidance methods.
12
+ It also explores the possibility of comparison across latent spaces and investigates the differences and commonalities across different learning experiments. Furthermore, the thesis investigates the role of stochasticity in newer models.
13
+ As a case study, this thesis adopts art historical data, spanning classic art, photography, and modern and contemporary art.
 
 
14
 
15
+ The project aims to interpret the StyleGAN2 model by several techniques.
16
+ > β€œWhat concepts are disentangled in the latent space of StyleGAN2”\n
17
+ > β€œCan we quantify the complexity of such concepts?”.
18
 
 
 
 
 
 
 
 
 
 
 
19
  """
20
  st.write(intro_text)
21
 
22
  # 4 PAGES
23
+ st.subheader('Pages')
24
  sections_text = """Overall, there are 4 features in this web app:
25
+ 1) Disentanglement visualizer
26
+ ...
 
 
27
  """
28
  st.write(sections_text)
29
 
30
 
31
+ add_footer('Developed by ', 'Ludovica Schaerf', 'https://www.linkedin.com/in/ludovica-schaerf-59063422b/')
Visual-Explanation-Methods-PyTorch DELETED
@@ -1 +0,0 @@
1
- Subproject commit 5cb88902729af1d9d85259879b47cb238b841881
 
 
backend/disentangle_concepts.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.svm import SVC
3
+ from sklearn.model_selection import train_test_split
4
+ import torch
5
+ import PIL
6
+
7
+ def get_separation_space(type_bin, annotations):
8
+ abstracts = np.array([ann[type_bin] for ann in annotations['annotations']])
9
+ abstract_idxs = list(np.argsort(abstracts))[:200]
10
+ repr_idxs = list(np.argsort(abstracts))[-200:]
11
+ X = np.array([annotations['z_vectors'][i] for i in abstract_idxs+repr_idxs])
12
+ X = X.reshape((400, 512))
13
+ y = np.array([1]*200 + [0]*200)
14
+ x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
15
+ svc = SVC(gamma='auto', kernel='linear')
16
+ svc.fit(x_train, y_train)
17
+ print(svc.score(x_val, y_val))
18
+ imp_features = np.count(np.abs(svc.coef_) > 2)
19
+ return svc.coef_, imp_features
20
+
21
+ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=11):
22
+ device = torch.device('cpu')
23
+ G = model.to(device) # type: ignore
24
+
25
+ # Labels.
26
+ label = torch.zeros([1, G.c_dim], device=device)
27
+
28
+ z = torch.from_numpy(z.copy()).to(device)
29
+ decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
30
+
31
+ lambdas = np.linspace(-3, 3, 11)
32
+ images = []
33
+ # Generate images.
34
+ for _, lambda_ in enumerate(lambdas):
35
+ z_0 = z + lambda_ * decision_boundary
36
+ # Construct an inverse rotation/translation matrix and pass to the generator. The
37
+ # generator expects this matrix as an inverse to avoid potentially failing numerical
38
+ # operations in the network.
39
+ #if hasattr(G.synthesis, 'input'):
40
+ #m = make_transform(translate, rotate)
41
+ #m = np.linalg.inv(m)
42
+ #G.synthesis.input.transform.copy_(torch.from_numpy(m))
43
+
44
+ img = G(z_0, label, truncation_psi=0.7, noise_mode='random')
45
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
46
+ images.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
47
+ return images
data/annotated_files/annotations_seeds0000-1000.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ffcd38622c48bf91e0e193b846a181e6baee559633e6057df7765fe0ecd422cf
3
+ size 4461349
data/concepts.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2383e924d76b10523b9f7b6800a5653c2cc78777bbd9a071de79bf14ab43078
3
+ size 1407
data/model_files/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:27d6840c1f9f11a0af97f6f1ff3809f7f3641d1e4ea7bc893ad15d9e4341caed
3
+ size 120944973
data/{ImageNet_metadata.csv β†’ old/ImageNet_metadata.csv} RENAMED
File without changes
data/{activation β†’ old/activation}/convnext_activation.json RENAMED
File without changes
data/{activation β†’ old/activation}/mobilenet_activation.json RENAMED
File without changes
data/{activation β†’ old/activation}/resnet_activation.json RENAMED
File without changes
data/{dot_architectures β†’ old/dot_architectures}/convnext_architecture.dot RENAMED
File without changes
data/{layer_infos β†’ old/layer_infos}/convnext_layer_infos.json RENAMED
File without changes
data/{layer_infos β†’ old/layer_infos}/mobilenet_layer_infos.json RENAMED
File without changes
data/{layer_infos β†’ old/layer_infos}/resnet_layer_infos.json RENAMED
File without changes
data/{preprocessed_image_net β†’ old/preprocessed_image_net}/val_data_0.pkl RENAMED
File without changes
data/{preprocessed_image_net β†’ old/preprocessed_image_net}/val_data_1.pkl RENAMED
File without changes
data/{preprocessed_image_net β†’ old/preprocessed_image_net}/val_data_2.pkl RENAMED
File without changes
data/{preprocessed_image_net β†’ old/preprocessed_image_net}/val_data_3.pkl RENAMED
File without changes
data/{preprocessed_image_net β†’ old/preprocessed_image_net}/val_data_4.pkl RENAMED
File without changes
pages/{3_Adversarial_attack.py β†’ 1_Disentanglement.py} RENAMED
@@ -1,4 +1,5 @@
1
  import streamlit as st
 
2
  import pandas as pd
3
  import numpy as np
4
  import random
@@ -10,7 +11,7 @@ import torch
10
 
11
  from matplotlib.backends.backend_agg import RendererAgg
12
 
13
- from backend.adversarial_attack import *
14
 
15
  _lock = RendererAgg.lock
16
 
@@ -19,18 +20,9 @@ BACKGROUND_COLOR = '#bcd0e7'
19
  SECONDARY_COLOR = '#bce7db'
20
 
21
 
22
- st.title('Adversarial Attack')
23
- st.write('> **How adversarial attacks affect ConvNeXt interpretation?**')
24
- st.write("""Adversarial examples are inputs crafted to confuse neural networks, causing them to misclassify a given input.
25
- These examples are not distinguishable by humans but cause the network to fail to recognize the image content.
26
- One type of such attack is the fast gradient sign method (FGSM) attack, which is a white box attack that aims to ensure misclassification.
27
- A white box attack is where the attacker has full access to the model being attacked.
28
-
29
- The FGSM attack is one of the earliest and most popular adversarial attacks.
30
- It is described by Goodfellow _et al_ in their work on [Explaining and Harnessing Adversarial Examples](https://arxiv.org/abs/1412.6572).
31
- The attack is simple yet powerful, using the gradients that neural networks use to learn.
32
- Instead of adjusting the weights based on the backpropagated gradients to minimize loss, the attack adjusts the input data to maximize the loss using the gradient of the loss with respect to the input data.
33
- """)
34
 
35
  instruction_text = """Instruction to input:
36
  1. Choosing image: Users can choose a specific image by entering **Image ID** and hit the _Choose the defined image_ button or can generate an image randomly by hitting the _Generate a random image_ button.
@@ -46,12 +38,13 @@ with st.expander("See more instruction", expanded=False):
46
  st.write(instruction_text)
47
 
48
 
49
-
50
- imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
51
- image_id = None
52
-
53
- if 'image_id' not in st.session_state:
54
- st.session_state.image_id = 0
 
55
 
56
  # def on_change_random_input():
57
  # st.session_state.image_id = st.session_state.image_id
@@ -61,30 +54,30 @@ st.header('Input')
61
  input_col_1, input_col_2, input_col_3 = st.columns(3)
62
  # --------------------------- INPUT column 1 ---------------------------
63
  with input_col_1:
64
- with st.form('image_form'):
65
 
66
  # image_id = st.number_input('Image ID: ', format='%d', step=1)
67
- st.write('**Choose or generate a random image**')
68
- chosen_image_id_input = st.empty()
69
- image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
70
 
71
- choose_image_button = st.form_submit_button('Choose the defined image')
72
- random_id = st.form_submit_button('Generate a random image')
73
 
74
- if random_id:
75
- image_id = random.randint(0, 50000)
76
- st.session_state.image_id = image_id
77
- chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
78
 
79
  if choose_image_button:
80
- image_id = int(image_id)
81
- st.session_state.image_id = int(image_id)
82
  # st.write(image_id, st.session_state.image_id)
83
 
84
  # ---------------------------- SET UP OUTPUT ------------------------------
85
  epsilon_container = st.empty()
86
  st.header('Output')
87
- st.subheader('Perform attack')
88
 
89
  # perform attack container
90
  header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
@@ -98,118 +91,123 @@ smoothgrad_header_container = st.empty()
98
  smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
99
  smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
100
 
101
- original_image_dict = load_image(st.session_state.image_id)
102
- input_image = original_image_dict['image']
103
- input_label = original_image_dict['label']
104
- input_id = original_image_dict['id']
105
-
106
  # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
107
  with output_col_1:
108
- pred_prob, pred_class_id, pred_class_label = feed_forward(input_image)
109
  # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
110
- st.image(input_image)
111
- header_col_1.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.1f}% confidence')
112
 
113
-
114
-
115
- if pred_class_id != (input_id-1):
116
- with error_container.container():
117
- st.write(f'Predicted output: Class ID {pred_class_id} - {pred_class_label} {pred_prob*100:.1f}% confidence')
118
- st.error('ConvNeXt misclassified the chosen image. Please choose or generate another image.',
119
- icon = "🚫")
120
-
121
- # ----------------------------- INPUT column 2 & 3 ----------------------------
122
  with input_col_2:
123
- with st.form('epsilon_form'):
124
- st.write('**Set epsilon or find the smallest epsilon automatically**')
125
- chosen_epsilon_input = st.empty()
126
- epsilon = chosen_epsilon_input.number_input('Epsilon:', min_value=0.001, format='%.3f', step=0.001)
127
-
128
- epsilon_button = st.form_submit_button('Choose the defined epsilon')
129
- find_epsilon = st.form_submit_button('Find the smallest epsilon automatically')
 
 
130
 
 
 
 
 
 
 
 
 
 
131
 
132
  with input_col_3:
133
- with st.form('iterate_epsilon_form'):
134
- max_epsilon = st.number_input('Maximum value of epsilon (Optional setting)', value=0.500, format='%.3f')
135
- step_epsilon = st.number_input('Step (Optional setting)', value=0.001, format='%.3f')
136
- setting_button = st.form_submit_button('Set iterating mode')
137
-
138
 
139
  # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
140
- if pred_class_id == (input_id-1) and (epsilon_button or find_epsilon or setting_button):
141
- with output_col_3:
142
- if epsilon_button:
143
- perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, input_id-1, epsilon)
144
- else:
145
- epsilons = [i*step_epsilon for i in range(1, 1001) if i*step_epsilon <= max_epsilon]
146
- with epsilon_container.container():
147
- epsilon_container_text = 'Checking epsilon'
148
- st.write(epsilon_container_text)
149
- st.progress(0)
 
 
 
 
 
 
 
150
 
151
- for i, e in enumerate(epsilons):
152
 
153
- perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, input_id-1, e)
154
- with epsilon_container.container():
155
- epsilon_container_text = f'Checking epsilon={e:.3f}. Confidence={new_prob*100:.1f}%'
156
- st.write(epsilon_container_text)
157
- st.progress(i/len(epsilons))
158
 
159
- epsilon = e
160
-
161
- if new_id != input_id - 1:
162
- epsilon_container.empty()
163
- st.balloons()
164
- break
165
- if i == len(epsilons)-1:
166
- epsilon_container.error(f'FGSM failed to attack on this image at epsilon={e:.3f}. Set higher maximum value of epsilon or choose another image',
167
- icon = "🚫")
168
-
169
- perturbed_image = deprocess_image(perturbed_data.detach().numpy())[0].astype(np.uint8).transpose(1,2,0)
170
- perturbed_amount = perturbed_image - input_image
171
- header_col_3.write(f'Pertubed amount - epsilon={epsilon:.3f}')
172
- st.image(ShowImage(perturbed_amount))
173
 
174
- with output_col_2:
175
- # st.write('plus sign')
176
- st.image(LoadImage('frontend/images/plus-sign.png'))
177
 
178
- with output_col_4:
179
- # st.write('equal sign')
180
- st.image(LoadImage('frontend/images/equal-sign.png'))
181
-
182
- # ---------------------------- DISPLAY COL 5 ROW 1 ------------------------------
183
- with output_col_5:
184
- # st.write(f'ID {new_id+1} - {new_label}: {new_prob*100:.3f}% confidence')
185
- st.image(ShowImage(perturbed_image))
186
- header_col_5.write(f'Class ID {new_id+1} - {new_label}: {new_prob*100:.1f}% confidence')
187
-
188
- # -------------------------- DISPLAY SMOOTHGRAD ---------------------------
189
- smoothgrad_header_container.subheader('SmoothGrad visualization')
190
-
191
- with smoothgrad_col_1:
192
- smooth_head_1.write(f'SmoothGrad before attacked')
193
- heatmap_image, masked_image, mask = generate_images(st.session_state.image_id, epsilon=0)
194
- st.image(heatmap_image)
195
- st.image(masked_image)
196
- with smoothgrad_col_3:
197
- smooth_head_3.write('SmoothGrad after attacked')
198
- heatmap_image_attacked, masked_image_attacked, attacked_mask= generate_images(st.session_state.image_id, epsilon=epsilon)
199
- st.image(heatmap_image_attacked)
200
- st.image(masked_image_attacked)
201
 
202
- with smoothgrad_col_2:
203
- st.image(LoadImage('frontend/images/minus-sign-5.png'))
204
-
205
- with smoothgrad_col_5:
206
- smooth_head_5.write('SmoothGrad difference')
207
- difference_mask = abs(attacked_mask-mask)
208
- st.image(ShowHeatMap(difference_mask))
209
- masked_image = ShowMaskedImage(difference_mask, perturbed_image)
210
- st.image(masked_image)
211
-
212
- with smoothgrad_col_4:
213
- st.image(LoadImage('frontend/images/equal-sign.png'))
214
 
215
 
 
1
  import streamlit as st
2
+ import pickle
3
  import pandas as pd
4
  import numpy as np
5
  import random
 
11
 
12
  from matplotlib.backends.backend_agg import RendererAgg
13
 
14
+ from backend.disentangle_concepts import *
15
 
16
  _lock = RendererAgg.lock
17
 
 
20
  SECONDARY_COLOR = '#bce7db'
21
 
22
 
23
+ st.title('Disentanglement studies')
24
+ st.write('> **What concepts can be disentangled in the latent spae of a model?**')
25
+ st.write("""Explain more in depth""")
 
 
 
 
 
 
 
 
 
26
 
27
  instruction_text = """Instruction to input:
28
  1. Choosing image: Users can choose a specific image by entering **Image ID** and hit the _Choose the defined image_ button or can generate an image randomly by hitting the _Generate a random image_ button.
 
38
  st.write(instruction_text)
39
 
40
 
41
+ annotations_file = './data/annotated_files/annotations_seeds0000-1000.pkl'
42
+ with open(annotations_file, 'rb') as f:
43
+ annotations = pickle.load(f)
44
+
45
+ concepts = './data/concepts.txt'
46
+ with open(concepts) as f:
47
+ labels = [line.strip() for line in f.readlines()]
48
 
49
  # def on_change_random_input():
50
  # st.session_state.image_id = st.session_state.image_id
 
54
  input_col_1, input_col_2, input_col_3 = st.columns(3)
55
  # --------------------------- INPUT column 1 ---------------------------
56
  with input_col_1:
57
+ with st.form('text_form'):
58
 
59
  # image_id = st.number_input('Image ID: ', format='%d', step=1)
60
+ st.write('**Choose a concept to disentangle**')
61
+ chosen_text_id_input = st.empty()
62
+ concept_id = chosen_text_id_input.text_input('Concept:', format='', value=st.session_state.concept_id)
63
 
64
+ choose_image_button = st.form_submit_button('Choose the defined concept')
65
+ random_text = st.form_submit_button('Select a random concept')
66
 
67
+ if random_text:
68
+ concept_id = random.choice(labels)
69
+ st.session_state.concept_id = concept_id
70
+ chosen_text_id_input.text_input('Concept:', format='', value=st.session_state.concept_id)
71
 
72
  if choose_image_button:
73
+ concept_id = str(concept_id)
74
+ st.session_state.concept_id = concept_id
75
  # st.write(image_id, st.session_state.image_id)
76
 
77
  # ---------------------------- SET UP OUTPUT ------------------------------
78
  epsilon_container = st.empty()
79
  st.header('Output')
80
+ st.subheader('Concept vector')
81
 
82
  # perform attack container
83
  header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
 
91
  smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
92
  smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
93
 
 
 
 
 
 
94
  # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
95
  with output_col_1:
96
+ separation_vector, number_important_features = get_separation_space(concept_id, annotations)
97
  # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
98
+ st.write('Separation vector', separation_vector)
99
+ header_col_1.write(f'Concept {concept_id} - Number of relevant nodes: {number_important_features}')
100
 
101
+ # ----------------------------- INPUT column 2 & 3 ----------------------------
 
 
 
 
 
 
 
 
102
  with input_col_2:
103
+ with st.form('image_form'):
104
+
105
+ # image_id = st.number_input('Image ID: ', format='%d', step=1)
106
+ st.write('**Choose or generate a random image to test the disentanglement**')
107
+ chosen_image_id_input = st.empty()
108
+ image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
109
+
110
+ choose_image_button = st.form_submit_button('Choose the defined image')
111
+ random_id = st.form_submit_button('Generate a random image')
112
 
113
+ if random_id:
114
+ image_id = random.randint(0, 50000)
115
+ st.session_state.image_id = image_id
116
+ chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
117
+
118
+ if choose_image_button:
119
+ image_id = int(image_id)
120
+ st.session_state.image_id = int(image_id)
121
+ # st.write(image_id, st.session_state.image_id)
122
 
123
  with input_col_3:
124
+ with st.form('Variate along the disentangled concept'):
125
+ st.write('**Set range of change**')
126
+ chosen_epsilon_input = st.empty()
127
+ epsilon = chosen_epsilon_input.number_input('Epsilon:', min_value=1, format='%.1f')
128
+ epsilon_button = st.form_submit_button('Choose the defined epsilon')
129
 
130
  # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
131
+ # original_image_dict = load_image(st.session_state.image_id)
132
+ # input_image = original_image_dict['image']
133
+ # input_label = original_image_dict['label']
134
+ # input_id = original_image_dict['id']
135
+
136
+
137
+
138
+ # if pred_class_id == (input_id-1) and (epsilon_button or find_epsilon or setting_button):
139
+ # with output_col_3:
140
+ # if epsilon_button:
141
+ # perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, input_id-1, epsilon)
142
+ # else:
143
+ # epsilons = [i*step_epsilon for i in range(1, 1001) if i*step_epsilon <= max_epsilon]
144
+ # with epsilon_container.container():
145
+ # epsilon_container_text = 'Checking epsilon'
146
+ # st.write(epsilon_container_text)
147
+ # st.progress(0)
148
 
149
+ # for i, e in enumerate(epsilons):
150
 
151
+ # perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, input_id-1, e)
152
+ # with epsilon_container.container():
153
+ # epsilon_container_text = f'Checking epsilon={e:.3f}. Confidence={new_prob*100:.1f}%'
154
+ # st.write(epsilon_container_text)
155
+ # st.progress(i/len(epsilons))
156
 
157
+ # epsilon = e
158
+
159
+ # if new_id != input_id - 1:
160
+ # epsilon_container.empty()
161
+ # st.balloons()
162
+ # break
163
+ # if i == len(epsilons)-1:
164
+ # epsilon_container.error(f'FGSM failed to attack on this image at epsilon={e:.3f}. Set higher maximum value of epsilon or choose another image',
165
+ # icon = "🚫")
166
+
167
+ # perturbed_image = deprocess_image(perturbed_data.detach().numpy())[0].astype(np.uint8).transpose(1,2,0)
168
+ # perturbed_amount = perturbed_image - input_image
169
+ # header_col_3.write(f'Pertubed amount - epsilon={epsilon:.3f}')
170
+ # st.image(ShowImage(perturbed_amount))
171
 
172
+ # with output_col_2:
173
+ # # st.write('plus sign')
174
+ # st.image(LoadImage('frontend/images/plus-sign.png'))
175
 
176
+ # with output_col_4:
177
+ # # st.write('equal sign')
178
+ # st.image(LoadImage('frontend/images/equal-sign.png'))
179
+
180
+ # # ---------------------------- DISPLAY COL 5 ROW 1 ------------------------------
181
+ # with output_col_5:
182
+ # # st.write(f'ID {new_id+1} - {new_label}: {new_prob*100:.3f}% confidence')
183
+ # st.image(ShowImage(perturbed_image))
184
+ # header_col_5.write(f'Class ID {new_id+1} - {new_label}: {new_prob*100:.1f}% confidence')
185
+
186
+ # # -------------------------- DISPLAY SMOOTHGRAD ---------------------------
187
+ # smoothgrad_header_container.subheader('SmoothGrad visualization')
188
+
189
+ # with smoothgrad_col_1:
190
+ # smooth_head_1.write(f'SmoothGrad before attacked')
191
+ # heatmap_image, masked_image, mask = generate_images(st.session_state.image_id, epsilon=0)
192
+ # st.image(heatmap_image)
193
+ # st.image(masked_image)
194
+ # with smoothgrad_col_3:
195
+ # smooth_head_3.write('SmoothGrad after attacked')
196
+ # heatmap_image_attacked, masked_image_attacked, attacked_mask= generate_images(st.session_state.image_id, epsilon=epsilon)
197
+ # st.image(heatmap_image_attacked)
198
+ # st.image(masked_image_attacked)
199
 
200
+ # with smoothgrad_col_2:
201
+ # st.image(LoadImage('frontend/images/minus-sign-5.png'))
202
+
203
+ # with smoothgrad_col_5:
204
+ # smooth_head_5.write('SmoothGrad difference')
205
+ # difference_mask = abs(attacked_mask-mask)
206
+ # st.image(ShowHeatMap(difference_mask))
207
+ # masked_image = ShowMaskedImage(difference_mask, perturbed_image)
208
+ # st.image(masked_image)
209
+
210
+ # with smoothgrad_col_4:
211
+ # st.image(LoadImage('frontend/images/equal-sign.png'))
212
 
213
 
pages/{1_Maximally_activating_patches.py β†’ 2_todo.py} RENAMED
File without changes
pages/{2_SmoothGrad.py β†’ 3_todo.py} RENAMED
File without changes
pages/{4_ImageNet1k.py β†’ 4_todo.py} RENAMED
File without changes