ludusc commited on
Commit
a426c90
·
1 Parent(s): 625d699

added disentanglement also for vase art

Browse files
backend/disentangle_concepts.py CHANGED
@@ -28,10 +28,21 @@ def get_separation_space(type_bin, annotations, df, samples=200, method='LR', C=
28
  else:
29
  col = 'w_vectors'
30
 
31
- abstracts = np.array([float(ann) for ann in df[type_bin]])
32
- abstract_idxs = list(np.argsort(abstracts))[:samples]
33
- repr_idxs = list(np.argsort(abstracts))[-samples:]
34
- X = np.array([annotations[col][i] for i in abstract_idxs+repr_idxs])
 
 
 
 
 
 
 
 
 
 
 
35
  X = X.reshape((2*samples, 512))
36
  y = np.array([1]*samples + [0]*samples)
37
  x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
 
28
  else:
29
  col = 'w_vectors'
30
 
31
+ if type(type_bin) == str or len(type_bin) == 1:
32
+ abstracts = np.array([float(ann) for ann in df[type_bin]])
33
+ abstract_idxs = list(np.argsort(abstracts))[:samples]
34
+ repr_idxs = list(np.argsort(abstracts))[-samples:]
35
+ X = np.array([annotations[col][i] for i in abstract_idxs+repr_idxs])
36
+ elif len(type_bin) == 2:
37
+ print('Using two concepts for separation space')
38
+ first_concept = np.array([float(ann) for ann in df[type_bin[0]]])
39
+ second_concept = np.array([float(ann) for ann in df[type_bin[1]]])
40
+ first_idxs = list(np.argsort(first_concept))[:samples]
41
+ second_idxs = list(np.argsort(second_concept))[:samples]
42
+ X = np.array([annotations[col][i] for i in first_idxs+second_idxs])
43
+ else:
44
+ print('Error: type_bin must be either a string or a list of strings of len 2')
45
+ return
46
  X = X.reshape((2*samples, 512))
47
  y = np.array([1]*samples + [0]*samples)
48
  x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
data/vase_annotated_files/seeds0000-20000.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e790910bf45c0d5a84e74c9011b88012f59d0fc27b19987c890b891c57ab739c
3
+ size 125913423
data/vase_annotated_files/sim_Shape Name_seeds0000-20000.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e258361e0db7c208ae67654c08ed5b900df10980e82e84bcddd3de89428f679a
3
+ size 30853761
data/vase_model_files/network-snapshot-003800.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42be0a24e7021dc66a9353c3a904494bb8e64b62e00e535ad3b03ad18238b0d2
3
+ size 357349976
pages/{1_Disentanglement.py → 1_Omniart_Disentanglement.py} RENAMED
File without changes
pages/3_Oxford_Vases_Disentanglement.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pickle
3
+ import pandas as pd
4
+ import numpy as np
5
+ import random
6
+ import torch
7
+
8
+ from matplotlib.backends.backend_agg import RendererAgg
9
+
10
+ from backend.disentangle_concepts import *
11
+ import torch_utils
12
+ import dnnlib
13
+ import legacy
14
+
15
+ _lock = RendererAgg.lock
16
+
17
+
18
+ st.set_page_config(layout='wide')
19
+ BACKGROUND_COLOR = '#bcd0e7'
20
+ SECONDARY_COLOR = '#bce7db'
21
+
22
+
23
+ st.title('Disentanglement studies on the Oxford Vases Dataset')
24
+ st.markdown(
25
+ """
26
+ This is a demo of the Disentanglement studies on the [Oxford Vases Dataset](https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/).
27
+ """,
28
+ unsafe_allow_html=False,)
29
+
30
+ annotations_file = './data/vase_annotated_files/seeds0000-20000.pkl'
31
+ with open(annotations_file, 'rb') as f:
32
+ annotations = pickle.load(f)
33
+
34
+ ann_df = pd.read_csv('./data/vase_annotated_files/sim_Shape Name_seeds0000-20000.csv')
35
+ labels = ann_df.columns
36
+
37
+ if 'image_id' not in st.session_state:
38
+ st.session_state.image_id = 0
39
+ if 'concept_ids' not in st.session_state:
40
+ st.session_state.concept_ids =['AMPHORA']
41
+ if 'space_id' not in st.session_state:
42
+ st.session_state.space_id = 'W'
43
+
44
+ # def on_change_random_input():
45
+ # st.session_state.image_id = st.session_state.image_id
46
+
47
+ # ----------------------------- INPUT ----------------------------------
48
+ st.header('Input')
49
+ input_col_1, input_col_2, input_col_3 = st.columns(3)
50
+ # --------------------------- INPUT column 1 ---------------------------
51
+ with input_col_1:
52
+ with st.form('text_form'):
53
+
54
+ # image_id = st.number_input('Image ID: ', format='%d', step=1)
55
+ st.write('**Choose two options to disentangle**')
56
+ concept_ids = st.multiselect('Concepts:', tuple(labels), max_selections=2, default=['AMPHORA', 'CHALICE'])
57
+
58
+ st.write('**Choose a latent space to disentangle**')
59
+ space_id = st.selectbox('Space:', tuple(['Z', 'W']))
60
+
61
+ choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
62
+
63
+ if choose_text_button:
64
+ concept_ids = list(concept_ids)
65
+ st.session_state.concept_ids = concept_ids
66
+ space_id = str(space_id)
67
+ st.session_state.space_id = space_id
68
+ # st.write(image_id, st.session_state.image_id)
69
+
70
+ # ---------------------------- SET UP OUTPUT ------------------------------
71
+ epsilon_container = st.empty()
72
+ st.header('Output')
73
+ st.subheader('Concept vector')
74
+
75
+ # perform attack container
76
+ # header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
77
+ # output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
78
+ header_col_1, header_col_2 = st.columns([5,1])
79
+ output_col_1, output_col_2 = st.columns([5,1])
80
+
81
+ st.subheader('Derivations along the concept vector')
82
+
83
+ # prediction error container
84
+ error_container = st.empty()
85
+ smoothgrad_header_container = st.empty()
86
+
87
+ # smoothgrad container
88
+ smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
89
+ smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
90
+
91
+ # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
92
+ with output_col_1:
93
+ separation_vector, number_important_features, imp_nodes, performance = get_separation_space(concept_ids, annotations, ann_df, latent_space=st.session_state.space_id)
94
+ # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
95
+ st.write('Concept vector', separation_vector)
96
+ header_col_1.write(f'Concept {st.session_state.concept_ids} - Space {st.session_state.space_id} - Number of relevant nodes: {number_important_features} - Val classification performance: {performance}')# - Nodes {",".join(list(imp_nodes))}')
97
+
98
+ # ----------------------------- INPUT column 2 & 3 ----------------------------
99
+ with input_col_2:
100
+ with st.form('image_form'):
101
+
102
+ # image_id = st.number_input('Image ID: ', format='%d', step=1)
103
+ st.write('**Choose or generate a random image to test the disentanglement**')
104
+ chosen_image_id_input = st.empty()
105
+ image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
106
+
107
+ choose_image_button = st.form_submit_button('Choose the defined image')
108
+ random_id = st.form_submit_button('Generate a random image')
109
+
110
+ if random_id:
111
+ image_id = random.randint(0, 50000)
112
+ st.session_state.image_id = image_id
113
+ chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
114
+
115
+ if choose_image_button:
116
+ image_id = int(image_id)
117
+ st.session_state.image_id = int(image_id)
118
+ # st.write(image_id, st.session_state.image_id)
119
+
120
+ with input_col_3:
121
+ with st.form('Variate along the disentangled concept'):
122
+ st.write('**Set range of change**')
123
+ chosen_epsilon_input = st.empty()
124
+ epsilon = chosen_epsilon_input.number_input('Lambda:', min_value=1, step=1)
125
+ epsilon_button = st.form_submit_button('Choose the defined lambda')
126
+ st.write('**Select hierarchical levels to manipulate**')
127
+ layers = st.multiselect('Layers:', tuple(range(14)))
128
+ if len(layers) == 0:
129
+ layers = None
130
+ print(layers)
131
+ layers_button = st.form_submit_button('Choose the defined layers')
132
+
133
+
134
+ # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
135
+
136
+ #model = torch.load('./data/model_files/pytorch_model.bin', map_location=torch.device('cpu'))
137
+ with dnnlib.util.open_url('./data/vase_model_files/network-snapshot-003800.pkl') as f:
138
+ model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
139
+
140
+ if st.session_state.space_id == 'Z':
141
+ original_image_vec = annotations['z_vectors'][st.session_state.image_id]
142
+ else:
143
+ original_image_vec = annotations['w_vectors'][st.session_state.image_id]
144
+
145
+ img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
146
+ # input_image = original_image_dict['image']
147
+ # input_label = original_image_dict['label']
148
+ # input_id = original_image_dict['id']
149
+
150
+ with smoothgrad_col_3:
151
+ st.image(img)
152
+ smooth_head_3.write(f'Base image')
153
+
154
+
155
+ images, lambdas = regenerate_images(model, original_image_vec, separation_vector, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon), latent_space=st.session_state.space_id, layers=layers)
156
+
157
+ with smoothgrad_col_1:
158
+ st.image(images[0])
159
+ smooth_head_1.write(f'Change of {np.round(lambdas[0], 2)}')
160
+
161
+ with smoothgrad_col_2:
162
+ st.image(images[1])
163
+ smooth_head_2.write(f'Change of {np.round(lambdas[1], 2)}')
164
+
165
+ with smoothgrad_col_4:
166
+ st.image(images[3])
167
+ smooth_head_4.write(f'Change of {np.round(lambdas[3], 2)}')
168
+
169
+ with smoothgrad_col_5:
170
+ st.image(images[4])
171
+ smooth_head_5.write(f'Change of {np.round(lambdas[4], 2)}')
pages/3_todo.py DELETED
@@ -1,124 +0,0 @@
1
- import streamlit as st
2
- import pandas as pd
3
- import numpy as np
4
- import random
5
- from backend.utils import make_grid, load_dataset, load_model, load_images
6
-
7
- from backend.smooth_grad import generate_smoothgrad_mask, ShowImage, fig2img
8
- from transformers import AutoFeatureExtractor, AutoModelForImageClassification
9
- import torch
10
-
11
- from matplotlib.backends.backend_agg import RendererAgg
12
- _lock = RendererAgg.lock
13
-
14
- st.set_page_config(layout='wide')
15
- BACKGROUND_COLOR = '#bcd0e7'
16
-
17
-
18
- st.title('Feature attribution visualization with SmoothGrad')
19
- st.write("""> **Which features are responsible for the current prediction of ConvNeXt?**
20
-
21
- In machine learning, it is helpful to identify the significant features of the input (e.g., pixels for images) that affect the model's prediction.
22
- If the model makes an incorrect prediction, we might want to determine which features contributed to the mistake.
23
- To do this, we can generate a feature importance mask, which is a grayscale image with the same size as the original image.
24
- The brightness of each pixel in the mask represents the importance of that feature to the model's prediction.
25
-
26
- There are various methods to calculate an image sensitivity mask for a specific prediction.
27
- One simple way is to use the gradient of a class prediction neuron concerning the input pixels, indicating how the prediction is affected by small pixel changes.
28
- However, this method usually produces a noisy mask.
29
- To reduce the noise, the SmoothGrad technique as described in [SmoothGrad: Removing noise by adding noise](https://arxiv.org/abs/1706.03825) by Daniel _et al_ is used,
30
- which adds Gaussian noise to multiple copies of the image and averages the resulting gradients.
31
- """)
32
-
33
- instruction_text = """Users need to input the model(s), type of image set and image set setting to use this functionality.
34
- 1. Choose model: Users can choose one or more models for comparison.
35
- There are 3 models supported: [ConvNeXt](https://huggingface.co/facebook/convnext-tiny-224),
36
- [ResNet](https://huggingface.co/microsoft/resnet-50) and [MobileNet](https://pytorch.org/hub/pytorch_vision_mobilenet_v2/).
37
- These 3 models have similar number of parameters.
38
- \n2. Choose type of Image set: There are 2 types of Image set. They are _User-defined set_ and _Random set_.
39
- \n3. Image set setting: If users choose _User-defined set_ in Image set,
40
- users need to enter a list of image IDs separated by commas (,). For example, `0,1,4,7` is a valid input.
41
- Check the page [ImageNet1k](/ImageNet1k) to see all the Image IDs.
42
- If users choose _Random set_ in Image set, users just need to choose the number of random images to display here.
43
- """
44
- with st.expander("See more instruction", expanded=False):
45
- st.write(instruction_text)
46
-
47
-
48
- imagenet_df = pd.read_csv('./data/ImageNet_metadata.csv')
49
-
50
- # --------------------------- LOAD function -----------------------------
51
-
52
-
53
- images = []
54
- image_ids = []
55
- # INPUT ------------------------------
56
- st.header('Input')
57
- with st.form('smooth_grad_form'):
58
- st.markdown('**Model and Input Setting**')
59
- selected_models = st.multiselect('Model', options=['ConvNeXt', 'ResNet', 'MobileNet'])
60
- selected_image_set = st.selectbox('Image set', ['User-defined set', 'Random set'])
61
-
62
- summit_button = st.form_submit_button('Set')
63
- if summit_button:
64
- setting_container = st.container()
65
- # for id in image_ids:
66
- # images = load_images(image_ids)
67
-
68
- with st.form('2nd_form'):
69
- st.markdown('**Image set setting**')
70
- if selected_image_set == 'Random set':
71
- no_images = st.slider('Number of images', 1, 50, value=10)
72
- image_ids = random.sample(list(range(50_000)), k=no_images)
73
- else:
74
- text = st.text_area('Specific Image IDs', value='0')
75
- image_ids = list(map(lambda x: int(x.strip()), text.split(',')))
76
-
77
- run_button = st.form_submit_button('Display output')
78
- if run_button:
79
- for id in image_ids:
80
- images = load_images(image_ids)
81
-
82
- st.header('Output')
83
-
84
- models = {}
85
- feature_extractors = {}
86
-
87
- for i, model_name in enumerate(selected_models):
88
- models[model_name], feature_extractors[model_name] = load_model(model_name)
89
-
90
-
91
- # DISPLAY ----------------------------------
92
- if run_button:
93
- header_cols = st.columns([1, 1] + [2]*len(selected_models))
94
- header_cols[0].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>Image ID</b></div>', unsafe_allow_html=True)
95
- header_cols[1].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>Original Image</b></div>', unsafe_allow_html=True)
96
- for i, model_name in enumerate(selected_models):
97
- header_cols[i + 2].markdown(f'<div style="text-align: center;margin-bottom: 10px;background-color:{BACKGROUND_COLOR};"><b>{model_name}</b></div>', unsafe_allow_html=True)
98
-
99
- grids = make_grid(cols=2+len(selected_models)*2, rows=len(image_ids)+1)
100
-
101
-
102
- @st.cache(allow_output_mutation=True)
103
- # @st.cache_data
104
- def generate_images(image_id, model_name):
105
- j = image_ids.index(image_id)
106
- image = images[j]['image']
107
- return generate_smoothgrad_mask(
108
- image, model_name,
109
- models[model_name], feature_extractors[model_name], num_samples=10)
110
-
111
- with _lock:
112
- for j, (image_id, image_dict) in enumerate(zip(image_ids, images)):
113
- grids[j][0].write(f'{image_id}. {image_dict["label"]}')
114
- image = image_dict['image']
115
- ori_image = ShowImage(np.asarray(image))
116
- grids[j][1].image(ori_image)
117
-
118
- for i, model_name in enumerate(selected_models):
119
- # ori_image, heatmap_image, masked_image = generate_smoothgrad_mask(image,
120
- # model_name, models[model_name], feature_extractors[model_name], num_samples=10)
121
- heatmap_image, masked_image = generate_images(image_id, model_name)
122
- # grids[j][1].image(ori_image)
123
- grids[j][i*2+2].image(heatmap_image)
124
- grids[j][i*2+3].image(masked_image)