taquynhnga commited on
Commit
0c1e42b
β€’
1 Parent(s): 8a287fa

added adversarial attacks & change to streamlit 1.19.0

Browse files
.vscode/settings.json CHANGED
@@ -1,3 +1,5 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ce4af7987964cdde596ea51e2af55a4f18b804fc77e090cedd3d6ae426c7e401
3
- size 92
 
 
 
1
+ {
2
+ "python.analysis.extraPaths": [
3
+ "./Visual-Explanation-Methods-PyTorch"
4
+ ]
5
+ }
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: CNNs Interpretation Visualization
3
  emoji: πŸ’‘
4
- colorFrom: yellow
5
  colorTo: green
6
  sdk: streamlit
7
  sdk_version: 1.10.0
 
1
  ---
2
  title: CNNs Interpretation Visualization
3
  emoji: πŸ’‘
4
+ colorFrom: blue
5
  colorTo: green
6
  sdk: streamlit
7
  sdk_version: 1.10.0
backend/adversarial_attack.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ from PIL import Image
3
+ import numpy as np
4
+ from matplotlib import pylab as P
5
+ import cv2
6
+
7
+ import torch
8
+ from torch.utils.data import TensorDataset
9
+ from torchvision import transforms
10
+ import torch.nn.functional as F
11
+
12
+ from transformers.image_utils import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
13
+
14
+ from torchvex.base import ExplanationMethod
15
+ from torchvex.utils.normalization import clamp_quantile
16
+
17
+ from backend.utils import load_image, load_model
18
+ from backend.smooth_grad import generate_smoothgrad_mask
19
+
20
+ import streamlit as st
21
+
22
+ IMAGENET_DEFAULT_MEAN = np.asarray(IMAGENET_DEFAULT_MEAN).reshape([1,3,1,1])
23
+ IMAGENET_DEFAULT_STD = np.asarray(IMAGENET_DEFAULT_STD).reshape([1,3,1,1])
24
+
25
+ def deprocess_image(image_inputs):
26
+ return (image_inputs * IMAGENET_DEFAULT_STD + IMAGENET_DEFAULT_MEAN) * 255
27
+
28
+
29
+ def feed_forward(input_image):
30
+ model, feature_extractor = load_model('ConvNeXt')
31
+ inputs = feature_extractor(input_image, do_resize=False, return_tensors="pt")['pixel_values']
32
+ logits = model(inputs).logits
33
+ prediction_prob = F.softmax(logits, dim=-1).max() # prediction probability
34
+ # prediction class id, start from 1 to 1000 so it needs to +1 in the end
35
+ prediction_class = logits.argmax(-1).item()
36
+ prediction_label = model.config.id2label[prediction_class] # prediction class label
37
+ return prediction_prob, prediction_class, prediction_label
38
+
39
+ # FGSM attack code
40
+ def fgsm_attack(image, epsilon, data_grad):
41
+ # Collect the element-wise sign of the data gradient and normalize it
42
+ sign_data_grad = torch.gt(data_grad, 0).type(torch.FloatTensor) * 2.0 - 1.0
43
+ perturbed_image = image + epsilon*sign_data_grad
44
+ return perturbed_image
45
+
46
+ # perform attack on the model
47
+ def perform_attack(input_image, target, epsilon):
48
+ model, feature_extractor = load_model("ConvNeXt")
49
+ # preprocess input image
50
+ inputs = feature_extractor(input_image, do_resize=False, return_tensors="pt")['pixel_values']
51
+ inputs.requires_grad = True
52
+
53
+ # predict
54
+ logits = model(inputs).logits
55
+ prediction_prob = F.softmax(logits, dim=-1).max()
56
+ prediction_class = logits.argmax(-1).item()
57
+ prediction_label = model.config.id2label[prediction_class]
58
+
59
+ # Calculate the loss
60
+ loss = F.nll_loss(logits, torch.tensor([target]))
61
+
62
+ # Zero all existing gradients
63
+ model.zero_grad()
64
+
65
+ # Calculate gradients of model in backward pass
66
+ loss.backward()
67
+
68
+ # Collect datagrad
69
+ data_grad = inputs.grad.data
70
+
71
+ # Call FGSM Attack
72
+ perturbed_data = fgsm_attack(inputs, epsilon, data_grad)
73
+
74
+ # Re-classify the perturbed image
75
+ new_prediction = model(perturbed_data).logits
76
+ new_pred_prob = F.softmax(new_prediction, dim=-1).max()
77
+ new_pred_class = new_prediction.argmax(-1).item()
78
+ new_pred_label = model.config.id2label[new_pred_class]
79
+
80
+ return perturbed_data, new_pred_prob.item(), new_pred_class, new_pred_label
81
+
82
+
83
+ def find_smallest_epsilon(input_image, target):
84
+ epsilons = [i*0.001 for i in range(1000)]
85
+
86
+ for epsilon in epsilons:
87
+ perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, target, epsilon)
88
+ if new_id != target:
89
+ return perturbed_data, new_prob, new_id, new_label, epsilon
90
+ return None
91
+
92
+ @st.cache_data
93
+ def generate_images(image_id, epsilon=0):
94
+ model, feature_extractor = load_model("ConvNeXt")
95
+ original_image_dict = load_image(image_id)
96
+ image = original_image_dict['image']
97
+ return generate_smoothgrad_mask(
98
+ image, 'ConvNeXt',
99
+ model, feature_extractor, num_samples=10, return_mask=True)
backend/load_file.py CHANGED
@@ -19,7 +19,11 @@ def load_json(filename):
19
  loaded_dict = json.loads(read_file.read())
20
  loaded_dict = OrderedDict(loaded_dict)
21
  for k, v in loaded_dict.items():
22
- loaded_dict[k] = np.asarray(v)
 
 
 
 
23
  return loaded_dict
24
 
25
  class NumpyEncoder(json.JSONEncoder):
@@ -32,6 +36,6 @@ class NumpyEncoder(json.JSONEncoder):
32
  # save_pickle_to_json('data/layer_infos/resnet_layer_infos.pkl')
33
  # save_pickle_to_json('data/layer_infos/mobilenet_layer_infos.pkl')
34
 
35
- file = load_json('data/layer_infos/convnext_layer_infos.json')
36
- print(type(file))
37
- print(type(file['embeddings.patch_embeddings']))
 
19
  loaded_dict = json.loads(read_file.read())
20
  loaded_dict = OrderedDict(loaded_dict)
21
  for k, v in loaded_dict.items():
22
+ if type(v) == list:
23
+ loaded_dict[k] = np.asarray(v)
24
+ else:
25
+ for k_, v_ in v.items():
26
+ v[k_] = np.asarray(v_)
27
  return loaded_dict
28
 
29
  class NumpyEncoder(json.JSONEncoder):
 
36
  # save_pickle_to_json('data/layer_infos/resnet_layer_infos.pkl')
37
  # save_pickle_to_json('data/layer_infos/mobilenet_layer_infos.pkl')
38
 
39
+ # file = load_json('data/layer_infos/convnext_layer_infos.json')
40
+ # print(type(file))
41
+ # print(type(file['embeddings.patch_embeddings']))
backend/maximally_activating_patches.py CHANGED
@@ -4,12 +4,14 @@ import streamlit as st
4
  from backend.load_file import load_json
5
 
6
 
7
- @st.cache(allow_output_mutation=True)
 
8
  def load_activation(filename):
9
  activation = load_json(filename)
10
  return activation
11
 
12
- @st.cache(allow_output_mutation=True)
 
13
  def load_dataset(data_index):
14
  with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
15
  dataset = pickle.load(file)
 
4
  from backend.load_file import load_json
5
 
6
 
7
+ # @st.cache(allow_output_mutation=True)
8
+ st.cache_data
9
  def load_activation(filename):
10
  activation = load_json(filename)
11
  return activation
12
 
13
+ # @st.cache(allow_output_mutation=True)
14
+ @st.cache_data
15
  def load_dataset(data_index):
16
  with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
17
  dataset = pickle.load(file)
backend/smooth_grad.py CHANGED
@@ -1,5 +1,3 @@
1
- import sys
2
- import os
3
  import PIL
4
  from PIL import Image
5
  import numpy as np
@@ -10,8 +8,8 @@ import torch
10
  from torch.utils.data import TensorDataset
11
  from torchvision import transforms
12
 
13
- dirpath_to_modules = './Visual-Explanation-Methods-PyTorch'
14
- sys.path.append(dirpath_to_modules)
15
 
16
  from torchvex.base import ExplanationMethod
17
  from torchvex.utils.normalization import clamp_quantile
@@ -212,7 +210,7 @@ def fig2img(fig):
212
  img = Image.open(buf)
213
  return img
214
 
215
- def generate_smoothgrad_mask(image, model_name, model=None, feature_extractor=None, num_samples=25):
216
  inputs, prediction_class = feed_forward(model_name, image, model, feature_extractor)
217
 
218
  smoothgrad_gen = SmoothGradient(
@@ -230,4 +228,8 @@ def generate_smoothgrad_mask(image, model_name, model=None, feature_extractor=No
230
  # ori_image = ShowImage(image)
231
  heat_map_image = ShowHeatMap(smoothgrad_mask)
232
  masked_image = ShowMaskedImage(smoothgrad_mask, image)
233
- return heat_map_image, masked_image
 
 
 
 
 
 
 
1
  import PIL
2
  from PIL import Image
3
  import numpy as np
 
8
  from torch.utils.data import TensorDataset
9
  from torchvision import transforms
10
 
11
+ # dirpath_to_modules = './Visual-Explanation-Methods-PyTorch'
12
+ # sys.path.append(dirpath_to_modules)
13
 
14
  from torchvex.base import ExplanationMethod
15
  from torchvex.utils.normalization import clamp_quantile
 
210
  img = Image.open(buf)
211
  return img
212
 
213
+ def generate_smoothgrad_mask(image, model_name, model=None, feature_extractor=None, num_samples=25, return_mask=False):
214
  inputs, prediction_class = feed_forward(model_name, image, model, feature_extractor)
215
 
216
  smoothgrad_gen = SmoothGradient(
 
228
  # ori_image = ShowImage(image)
229
  heat_map_image = ShowHeatMap(smoothgrad_mask)
230
  masked_image = ShowMaskedImage(smoothgrad_mask, image)
231
+
232
+ if return_mask:
233
+ return heat_map_image, masked_image, smoothgrad_mask
234
+ else:
235
+ return heat_map_image, masked_image
backend/utils.py CHANGED
@@ -14,12 +14,17 @@ from plotly import express as px
14
  from plotly.subplots import make_subplots
15
  from tqdm import trange
16
 
17
- @st.cache(allow_output_mutation=True)
 
 
 
 
18
  def load_dataset(data_index):
19
  with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
20
  dataset = pickle.load(file)
21
  return dataset
22
 
 
23
  def load_dataset_dict():
24
  dataset_dict = {}
25
  progress_empty = st.empty()
@@ -33,6 +38,43 @@ def load_dataset_dict():
33
  text_empty.empty()
34
  return dataset_dict
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  def make_grid(cols=None,rows=None):
37
  grid = [0]*rows
38
  for i in range(rows):
 
14
  from plotly.subplots import make_subplots
15
  from tqdm import trange
16
 
17
+ import torch
18
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
19
+
20
+ # @st.cache(allow_output_mutation=True)
21
+ @st.cache_resource
22
  def load_dataset(data_index):
23
  with open(f'./data/preprocessed_image_net/val_data_{data_index}.pkl', 'rb') as file:
24
  dataset = pickle.load(file)
25
  return dataset
26
 
27
+ @st.cache_resource
28
  def load_dataset_dict():
29
  dataset_dict = {}
30
  progress_empty = st.empty()
 
38
  text_empty.empty()
39
  return dataset_dict
40
 
41
+
42
+ @st.cache_data
43
+ def load_image(image_id):
44
+ dataset = load_dataset(image_id//10000)
45
+ image = dataset[image_id%10000]
46
+ return image
47
+
48
+ @st.cache_data
49
+ def load_images(image_ids):
50
+ images = []
51
+ for image_id in image_ids:
52
+ image = load_image(image_id)
53
+ images.append(image)
54
+ return images
55
+
56
+
57
+ # @st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
58
+ @st.cache_resource
59
+ def load_model(model_name):
60
+ with st.spinner(f"Loading {model_name} model! This process might take 1-2 minutes..."):
61
+ if model_name == 'ResNet':
62
+ model_file_path = 'microsoft/resnet-50'
63
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
64
+ model = AutoModelForImageClassification.from_pretrained(model_file_path)
65
+ model.eval()
66
+ elif model_name == 'ConvNeXt':
67
+ model_file_path = 'facebook/convnext-tiny-224'
68
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
69
+ model = AutoModelForImageClassification.from_pretrained(model_file_path)
70
+ model.eval()
71
+ else:
72
+ model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
73
+ model.eval()
74
+ feature_extractor = None
75
+ return model, feature_extractor
76
+
77
+
78
  def make_grid(cols=None,rows=None):
79
  grid = [0]*rows
80
  for i in range(rows):
frontend/images/equal-sign.png ADDED
frontend/images/minus-sign-2.png ADDED
frontend/images/minus-sign-3.png ADDED
frontend/images/minus-sign-4.png ADDED
frontend/images/minus-sign-5.png ADDED
frontend/images/minus-sign.png ADDED
frontend/images/plus-sign-2.png ADDED
frontend/images/plus-sign.png ADDED
load_file.py DELETED
@@ -1,37 +0,0 @@
1
- import json
2
- import pickle
3
- import numpy as np
4
- from collections import OrderedDict
5
-
6
- def load_pickle(filename):
7
- with open(filename, 'rb') as file:
8
- data = pickle.load(file)
9
- return data
10
-
11
- def save_pickle_to_json(filename):
12
- ordered_dict = load_pickle(filename)
13
- json_obj = json.dumps(ordered_dict, cls=NumpyEncoder)
14
- with open(filename.replace('.pkl', '.json'), 'w') as f:
15
- f.write(json_obj)
16
-
17
- def load_json(filename):
18
- with open(filename, 'r') as read_file:
19
- loaded_dict = json.loads(read_file.read())
20
- loaded_dict = OrderedDict(loaded_dict)
21
- for k, v in loaded_dict.items():
22
- loaded_dict[k] = np.asarray(v)
23
- return loaded_dict
24
-
25
- class NumpyEncoder(json.JSONEncoder):
26
- def default(self, obj):
27
- if isinstance(obj, np.ndarray):
28
- return obj.tolist()
29
- return json.JSONEncoder.default(self, obj)
30
-
31
- # save_pickle_to_json('data/layer_infos/convnext_layer_infos.pkl')
32
- # save_pickle_to_json('data/layer_infos/resnet_layer_infos.pkl')
33
- # save_pickle_to_json('data/layer_infos/mobilenet_layer_infos.pkl')
34
-
35
- file = load_json('data/layer_infos/convnext_layer_infos.json')
36
- print(type(file))
37
- print(type(file['embeddings.patch_embeddings']))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/1_Maximally_activating_patches.py CHANGED
@@ -28,7 +28,7 @@ def load_dot_to_graph(filename):
28
  return graph, dot
29
 
30
  st.title('Maximally activating image patches')
31
- st.write('Visualize image patches that maximize the activation of layers in three models: ConvNeXt, ResNet, MobileNet')
32
 
33
  # st.header('ConvNeXt')
34
  convnext_dot_file = './data/dot_architectures/convnext_architecture.dot'
@@ -130,7 +130,7 @@ if nodes != None:
130
  subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True)
131
  else:
132
  top_margin = 0
133
- fig = make_subplots(rows=1, cols=num_cols)
134
  for col in range(1, num_cols+1):
135
  k, c = col-1, row-1
136
  img_index = int(top_k_coor_max_[k, c, 3])
 
28
  return graph, dot
29
 
30
  st.title('Maximally activating image patches')
31
+ st.write('Visualize image patches that maximize the activation of layers in ConvNeXt model')
32
 
33
  # st.header('ConvNeXt')
34
  convnext_dot_file = './data/dot_architectures/convnext_architecture.dot'
 
130
  subplot_titles=tuple([f"#{i+1}" for i in range(top_k)]), shared_yaxes=True)
131
  else:
132
  top_margin = 0
133
+ fig = make_subplots(rows=1, cols=num_cols, shared_yaxes=True)
134
  for col in range(1, num_cols+1):
135
  k, c = col-1, row-1
136
  img_index = int(top_k_coor_max_[k, c, 3])
pages/2_SmoothGrad.py CHANGED
@@ -2,7 +2,7 @@ import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
  import random
5
- from backend.utils import make_grid, load_dataset
6
 
7
  from backend.smooth_grad import generate_smoothgrad_mask, ShowImage, fig2img
8
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification
@@ -22,32 +22,34 @@ imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
22
 
23
  # --------------------------- LOAD function -----------------------------
24
 
25
- @st.cache(allow_output_mutation=True)
26
- def load_images(image_ids):
27
- images = []
28
- for image_id in image_ids:
29
- dataset = load_dataset(image_id//10000)
30
- images.append(dataset[image_id%10000])
31
- return images
32
-
33
- @st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
34
- def load_model(model_name):
35
- with st.spinner(f"Loading {model_name} model! This process might take 1-2 minutes..."):
36
- if model_name == 'ResNet':
37
- model_file_path = 'microsoft/resnet-50'
38
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
39
- model = AutoModelForImageClassification.from_pretrained(model_file_path)
40
- model.eval()
41
- elif model_name == 'ConvNeXt':
42
- model_file_path = 'facebook/convnext-tiny-224'
43
- feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
44
- model = AutoModelForImageClassification.from_pretrained(model_file_path)
45
- model.eval()
46
- else:
47
- model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
48
- model.eval()
49
- feature_extractor = None
50
- return model, feature_extractor
 
 
51
 
52
  images = []
53
  image_ids = []
@@ -56,28 +58,7 @@ st.header('Input')
56
  with st.form('smooth_grad_form'):
57
  st.markdown('**Model and Input Setting**')
58
  selected_models = st.multiselect('Model', options=['ConvNeXt', 'ResNet', 'MobileNet'])
59
- # selected_image_set = st.selectbox('Image set', ['Random set', 'User-defined set'])
60
  selected_image_set = st.selectbox('Image set', ['User-defined set', 'Random set'])
61
-
62
- # if selected_image_set == 'Class set':
63
- # class_labels = imagenet_df.ClassLabel.unique().tolist()
64
- # class_labels.sort()
65
- # selected_classes = st.multiselect('Class filter', options=['All'] + class_labels)
66
- # if not ('All' in selected_classes or len(selected_classes) == 0):
67
- # imagenet_df = imagenet_df[imagenet_df['ClassLabel'].isin(selected_classes)]
68
- # no_images = st.slider('Number of images', 1, len(imagenet_df), value=10)
69
- # image_ids = random.sample(imagenet_df.index.tolist(), k=no_images)
70
-
71
-
72
- # user_defined_button = st.form_submit_button('User-defined set')
73
- # random_set_button = st.form_submit_button('Random set')
74
-
75
- # if user_defined_button:
76
- # text = st.text_area('Specific Image IDs', value='0')
77
- # image_ids = list(map(lambda x: int(x.strip()), text.split(',')))
78
- # if random_set_button:
79
- # no_images = st.slider('Number of images', 1, 50, value=10)
80
- # image_ids = random.sample(list(range(50_000)), k=no_images)
81
 
82
  summit_button = st.form_submit_button('Set')
83
  if summit_button:
@@ -123,8 +104,11 @@ grids = make_grid(cols=2+len(selected_models)*2, rows=len(image_ids)+1)
123
  # models[model_name], feature_extractors[model_name] = load_model(model_name)
124
 
125
 
126
- @st.cache(allow_output_mutation=True)
127
- def generate_images(image, model_name):
 
 
 
128
  return generate_smoothgrad_mask(
129
  image, model_name,
130
  models[model_name], feature_extractors[model_name], num_samples=10)
@@ -139,7 +123,7 @@ with _lock:
139
  for i, model_name in enumerate(selected_models):
140
  # ori_image, heatmap_image, masked_image = generate_smoothgrad_mask(image,
141
  # model_name, models[model_name], feature_extractors[model_name], num_samples=10)
142
- heatmap_image, masked_image = generate_images(image, model_name)
143
  # grids[j][1].image(ori_image)
144
  grids[j][i*2+2].image(heatmap_image)
145
  grids[j][i*2+3].image(masked_image)
 
2
  import pandas as pd
3
  import numpy as np
4
  import random
5
+ from backend.utils import make_grid, load_dataset, load_model, load_images
6
 
7
  from backend.smooth_grad import generate_smoothgrad_mask, ShowImage, fig2img
8
  from transformers import AutoFeatureExtractor, AutoModelForImageClassification
 
22
 
23
  # --------------------------- LOAD function -----------------------------
24
 
25
+ # @st.cache(allow_output_mutation=True)
26
+ # @st.cache_data
27
+ # def load_images(image_ids):
28
+ # images = []
29
+ # for image_id in image_ids:
30
+ # dataset = load_dataset(image_id//10000)
31
+ # images.append(dataset[image_id%10000])
32
+ # return images
33
+
34
+ # @st.cache(allow_output_mutation=True, suppress_st_warning=True, show_spinner=False)
35
+ # @st.cache_resource
36
+ # def load_model(model_name):
37
+ # with st.spinner(f"Loading {model_name} model! This process might take 1-2 minutes..."):
38
+ # if model_name == 'ResNet':
39
+ # model_file_path = 'microsoft/resnet-50'
40
+ # feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
41
+ # model = AutoModelForImageClassification.from_pretrained(model_file_path)
42
+ # model.eval()
43
+ # elif model_name == 'ConvNeXt':
44
+ # model_file_path = 'facebook/convnext-tiny-224'
45
+ # feature_extractor = AutoFeatureExtractor.from_pretrained(model_file_path, crop_pct=1.0)
46
+ # model = AutoModelForImageClassification.from_pretrained(model_file_path)
47
+ # model.eval()
48
+ # else:
49
+ # model = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
50
+ # model.eval()
51
+ # feature_extractor = None
52
+ # return model, feature_extractor
53
 
54
  images = []
55
  image_ids = []
 
58
  with st.form('smooth_grad_form'):
59
  st.markdown('**Model and Input Setting**')
60
  selected_models = st.multiselect('Model', options=['ConvNeXt', 'ResNet', 'MobileNet'])
 
61
  selected_image_set = st.selectbox('Image set', ['User-defined set', 'Random set'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
  summit_button = st.form_submit_button('Set')
64
  if summit_button:
 
104
  # models[model_name], feature_extractors[model_name] = load_model(model_name)
105
 
106
 
107
+ # @st.cache(allow_output_mutation=True)
108
+ @st.cache_data
109
+ def generate_images(image_id, model_name):
110
+ j = image_ids.index(image_id)
111
+ image = images[j]['image']
112
  return generate_smoothgrad_mask(
113
  image, model_name,
114
  models[model_name], feature_extractors[model_name], num_samples=10)
 
123
  for i, model_name in enumerate(selected_models):
124
  # ori_image, heatmap_image, masked_image = generate_smoothgrad_mask(image,
125
  # model_name, models[model_name], feature_extractors[model_name], num_samples=10)
126
+ heatmap_image, masked_image = generate_images(image_id, model_name)
127
  # grids[j][1].image(ori_image)
128
  grids[j][i*2+2].image(heatmap_image)
129
  grids[j][i*2+3].image(masked_image)
pages/3_Adversarial_attack.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ import random
5
+ from backend.utils import make_grid, load_dataset, load_model, load_image
6
+
7
+ from backend.smooth_grad import generate_smoothgrad_mask, ShowImage, fig2img, LoadImage, ShowHeatMap, ShowMaskedImage
8
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
9
+ import torch
10
+
11
+ from matplotlib.backends.backend_agg import RendererAgg
12
+
13
+ from backend.adversarial_attack import *
14
+
15
+ _lock = RendererAgg.lock
16
+
17
+ st.set_page_config(layout='wide')
18
+ BACKGROUND_COLOR = '#bcd0e7'
19
+ SECONDARY_COLOR = '#bce7db'
20
+
21
+
22
+ st.title('Adversarial Attack')
23
+ st.write('How adversarial attacks affect ConvNeXt interpretation?')
24
+
25
+ imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
26
+ image_id = None
27
+
28
+ if 'image_id' not in st.session_state:
29
+ st.session_state.image_id = 0
30
+
31
+ # def on_change_random_input():
32
+ # st.session_state.image_id = st.session_state.image_id
33
+
34
+ # ----------------------------- INPUT ----------------------------------
35
+ st.header('Input')
36
+ input_col_1, input_col_2, input_col_3 = st.columns(3)
37
+ # --------------------------- INPUT column 1 ---------------------------
38
+ with input_col_1:
39
+ with st.form('image_form'):
40
+
41
+ # image_id = st.number_input('Image ID: ', format='%d', step=1)
42
+ st.write('**Choose or generate a random image**')
43
+ chosen_image_id_input = st.empty()
44
+ image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
45
+
46
+ choose_image_button = st.form_submit_button('Choose the defined image')
47
+ random_id = st.form_submit_button('Generate a random image')
48
+
49
+ if random_id:
50
+ image_id = random.randint(0, 50000)
51
+ st.session_state.image_id = image_id
52
+ chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
53
+
54
+ if choose_image_button:
55
+ image_id = int(image_id)
56
+ st.session_state.image_id = int(image_id)
57
+ # st.write(image_id, st.session_state.image_id)
58
+
59
+ # ---------------------------- SET UP OUTPUT ------------------------------
60
+ epsilon_container = st.empty()
61
+ st.header('Output')
62
+ st.subheader('Perform attack')
63
+
64
+ # perform attack container
65
+ header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
66
+ output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
67
+
68
+ # prediction error container
69
+ error_container = st.empty()
70
+ smoothgrad_header_container = st.empty()
71
+
72
+ # smoothgrad container
73
+ smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
74
+ smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
75
+
76
+ original_image_dict = load_image(st.session_state.image_id)
77
+ input_image = original_image_dict['image']
78
+ input_label = original_image_dict['label']
79
+ input_id = original_image_dict['id']
80
+
81
+ # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
82
+ with output_col_1:
83
+ pred_prob, pred_class_id, pred_class_label = feed_forward(input_image)
84
+ # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
85
+ st.image(input_image)
86
+ header_col_1.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.1f}% confidence')
87
+
88
+
89
+
90
+ if pred_class_id != (input_id-1):
91
+ with error_container.container():
92
+ st.write(f'Predicted output: Class ID {pred_class_id} - {pred_class_label} {pred_prob*100:.1f}% confidence')
93
+ st.error('ConvNeXt misclassified the chosen image. Please choose or generate another image.',
94
+ icon = "🚫")
95
+
96
+ # ----------------------------- INPUT column 2 & 3 ----------------------------
97
+ with input_col_2:
98
+ with st.form('epsilon_form'):
99
+ st.write('**Set epsilon or find the smallest epsilon automatically**')
100
+ chosen_epsilon_input = st.empty()
101
+ epsilon = chosen_epsilon_input.number_input('Epsilon:', min_value=0.001, format='%.3f', step=0.001)
102
+
103
+ epsilon_button = st.form_submit_button('Choose the defined epsilon')
104
+ find_epsilon = st.form_submit_button('Find the smallest epsilon automatically')
105
+
106
+
107
+ with input_col_3:
108
+ with st.form('iterate_epsilon_form'):
109
+ max_epsilon = st.number_input('Maximum value of epsilon (Optional setting)', value=0.500, format='%.3f')
110
+ step_epsilon = st.number_input('Step (Optional setting)', value=0.001, format='%.3f')
111
+ setting_button = st.form_submit_button('Set iterating mode')
112
+
113
+
114
+ # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
115
+ if pred_class_id == (input_id-1) and (epsilon_button or find_epsilon or setting_button):
116
+ with output_col_3:
117
+ if epsilon_button:
118
+ perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, input_id-1, epsilon)
119
+ else:
120
+ epsilons = [i*step_epsilon for i in range(1, 1001) if i*step_epsilon <= max_epsilon]
121
+ epsilon_container.progress(0, text='Checking epsilon')
122
+
123
+ for i, e in enumerate(epsilons):
124
+ print(e)
125
+
126
+ perturbed_data, new_prob, new_id, new_label = perform_attack(input_image, input_id-1, e)
127
+ epsilon_container.progress(i/len(epsilons), text=f'Checking epsilon={e:.3f}. Confidence={new_prob*100:.1f}%')
128
+ epsilon = e
129
+
130
+ if new_id != input_id - 1:
131
+ epsilon_container.empty()
132
+ st.balloons()
133
+ break
134
+ if i == len(epsilons)-1:
135
+ epsilon_container.error(f'FSGM failed to attack on this image at epsilon={e:.3f}. Set higher maximum value of epsilon or choose another image',
136
+ icon = "🚫")
137
+
138
+ perturbed_image = deprocess_image(perturbed_data.detach().numpy())[0].astype(np.uint8).transpose(1,2,0)
139
+ perturbed_amount = perturbed_image - input_image
140
+ header_col_3.write(f'Pertubed amount - epsilon={epsilon:.3f}')
141
+ st.image(ShowImage(perturbed_amount))
142
+
143
+ with output_col_2:
144
+ # st.write('plus sign')
145
+ st.image(LoadImage('frontend/images/plus-sign.png'))
146
+
147
+ with output_col_4:
148
+ # st.write('equal sign')
149
+ st.image(LoadImage('frontend/images/equal-sign.png'))
150
+
151
+ # ---------------------------- DISPLAY COL 5 ROW 1 ------------------------------
152
+ with output_col_5:
153
+ # st.write(f'ID {new_id+1} - {new_label}: {new_prob*100:.3f}% confidence')
154
+ st.image(ShowImage(perturbed_image))
155
+ header_col_5.write(f'Class ID {new_id+1} - {new_label}: {new_prob*100:.1f}% confidence')
156
+
157
+ # -------------------------- DISPLAY SMOOTHGRAD ---------------------------
158
+ smoothgrad_header_container.subheader('SmoothGrad visualization')
159
+
160
+ with smoothgrad_col_1:
161
+ smooth_head_1.write(f'SmoothGrad before attacked')
162
+ heatmap_image, masked_image, mask = generate_images(st.session_state.image_id, epsilon=0)
163
+ st.image(heatmap_image)
164
+ st.image(masked_image)
165
+ with smoothgrad_col_3:
166
+ smooth_head_3.write('SmoothGrad after attacked')
167
+ heatmap_image_attacked, masked_image_attacked, attacked_mask= generate_images(st.session_state.image_id, epsilon=epsilon)
168
+ st.image(heatmap_image_attacked)
169
+ st.image(masked_image_attacked)
170
+
171
+ with smoothgrad_col_2:
172
+ st.image(LoadImage('frontend/images/minus-sign-5.png'))
173
+
174
+ with smoothgrad_col_5:
175
+ smooth_head_5.write('SmoothGrad difference')
176
+ difference_mask = abs(attacked_mask-mask)
177
+ st.image(ShowHeatMap(difference_mask))
178
+ masked_image = ShowMaskedImage(difference_mask, perturbed_image)
179
+ st.image(masked_image)
180
+
181
+ with smoothgrad_col_4:
182
+ st.image(LoadImage('frontend/images/equal-sign.png'))
183
+
184
+
pages/{3_ImageNet1k.py β†’ 4_ImageNet1k.py} RENAMED
File without changes
requirements.txt CHANGED
@@ -1,18 +1,17 @@
1
  captum==0.5.0
2
- deta==1.1.0
3
  graphviz==0.20.1
4
  Markdown==3.4.1
5
  matplotlib==3.6.2
6
  numpy==1.22.3
7
  opencv_python_headless==4.6.0.66
8
  pandas==1.5.2
9
- Pillow==9.3.0
10
  plotly==5.11.0
11
- scipy==1.9.3
12
  setuptools==65.5.0
13
- # streamlit==1.15.2
14
- streamlit==1.10.0
15
  torch==1.10.1
16
  torchvision==0.11.2
17
  tqdm==4.64.1
18
  transformers==4.25.1
 
 
1
  captum==0.5.0
 
2
  graphviz==0.20.1
3
  Markdown==3.4.1
4
  matplotlib==3.6.2
5
  numpy==1.22.3
6
  opencv_python_headless==4.6.0.66
7
  pandas==1.5.2
8
+ Pillow==9.4.0
9
  plotly==5.11.0
10
+ scipy==1.10.1
11
  setuptools==65.5.0
12
+ streamlit==1.19.0
 
13
  torch==1.10.1
14
  torchvision==0.11.2
15
  tqdm==4.64.1
16
  transformers==4.25.1
17
+ git+https://github.com/vlue-c/Visual-Explanation-Methods-PyTorch.git