ludusc commited on
Commit
3f788ef
1 Parent(s): 2a71f28

added disentanglement of W vector

Browse files
backend/disentangle_concepts.py CHANGED
@@ -6,7 +6,7 @@ import torch
6
  from umap import UMAP
7
  import PIL
8
 
9
- def get_separation_space(type_bin, annotations, df, samples=100, method='LR', C=0.1):
10
  """
11
  The get_separation_space function takes in a type_bin, annotations, and df.
12
  It then samples 100 of the most representative abstracts for that type_bin and 100 of the least representative abstracts for that type_bin.
@@ -22,10 +22,16 @@ def get_separation_space(type_bin, annotations, df, samples=100, method='LR', C=
22
  :return: The weights of the linear classifier
23
  :doc-author: Trelent
24
  """
 
 
 
 
 
 
25
  abstracts = np.array([float(ann) for ann in df[type_bin]])
26
  abstract_idxs = list(np.argsort(abstracts))[:samples]
27
  repr_idxs = list(np.argsort(abstracts))[-samples:]
28
- X = np.array([annotations['z_vectors'][i] for i in abstract_idxs+repr_idxs])
29
  X = X.reshape((2*samples, 512))
30
  y = np.array([1]*samples + [0]*samples)
31
  x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
@@ -45,7 +51,7 @@ def get_separation_space(type_bin, annotations, df, samples=100, method='LR', C=
45
  return clf.coef_ / np.linalg.norm(clf.coef_), imp_features, imp_nodes, np.round(clf.score(x_val, y_val),2)
46
 
47
 
48
- def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=5):
49
  """
50
  The regenerate_images function takes a model, z, and decision_boundary as input. It then
51
  constructs an inverse rotation/translation matrix and passes it to the generator. The generator
@@ -69,6 +75,7 @@ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3
69
  # Labels.
70
  label = torch.zeros([1, G.c_dim], device=device)
71
 
 
72
  z = torch.from_numpy(z.copy()).to(device)
73
  decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
74
 
@@ -84,14 +91,19 @@ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3
84
  #m = make_transform(translate, rotate)
85
  #m = np.linalg.inv(m)
86
  #G.synthesis.input.transform.copy_(torch.from_numpy(m))
87
-
88
- img = G(z_0, label, truncation_psi=0.7, noise_mode='random')
 
 
 
 
 
89
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
90
  images.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
91
 
92
  return images, lambdas
93
 
94
- def generate_original_image(z, model):
95
  """
96
  The generate_original_image function takes in a latent vector and the model,
97
  and returns an image generated from that latent vector.
@@ -106,13 +118,19 @@ def generate_original_image(z, model):
106
  G = model.to(device) # type: ignore
107
  # Labels.
108
  label = torch.zeros([1, G.c_dim], device=device)
109
- z = torch.from_numpy(z.copy()).to(device)
110
- img = G(z, label, truncation_psi=0.7, noise_mode='random')
 
 
 
 
 
 
111
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
112
  return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
113
 
114
 
115
- def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=0.1):
116
  """
117
  The get_concepts_vectors function takes in a list of concepts, a dictionary of annotations, and the dataframe containing all the images.
118
  It returns two things:
@@ -132,7 +150,7 @@ def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=
132
  performances = []
133
  vectors = np.zeros((len(concepts), 512))
134
  for i, conc in enumerate(concepts):
135
- vec, _, imp_nodes, performance = get_separation_space(conc, annotations, df, samples=samples, method=method, C=C)
136
  vectors[i,:] = vec
137
  performances.append(performance)
138
  important_nodes.append(set(imp_nodes))
@@ -148,3 +166,60 @@ def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=
148
  nodes_in_common = set.intersection(*important_nodes)
149
  return vectors, nodes_in_common, performances
150
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  from umap import UMAP
7
  import PIL
8
 
9
+ def get_separation_space(type_bin, annotations, df, samples=100, method='LR', C=0.1, latent_space='Z'):
10
  """
11
  The get_separation_space function takes in a type_bin, annotations, and df.
12
  It then samples 100 of the most representative abstracts for that type_bin and 100 of the least representative abstracts for that type_bin.
 
22
  :return: The weights of the linear classifier
23
  :doc-author: Trelent
24
  """
25
+
26
+ if latent_space == 'Z':
27
+ col = 'z_vectors'
28
+ else:
29
+ col = 'w_vectors'
30
+
31
  abstracts = np.array([float(ann) for ann in df[type_bin]])
32
  abstract_idxs = list(np.argsort(abstracts))[:samples]
33
  repr_idxs = list(np.argsort(abstracts))[-samples:]
34
+ X = np.array([annotations[col][i] for i in abstract_idxs+repr_idxs])
35
  X = X.reshape((2*samples, 512))
36
  y = np.array([1]*samples + [0]*samples)
37
  x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
 
51
  return clf.coef_ / np.linalg.norm(clf.coef_), imp_features, imp_nodes, np.round(clf.score(x_val, y_val),2)
52
 
53
 
54
+ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=5, latent_space='Z'):
55
  """
56
  The regenerate_images function takes a model, z, and decision_boundary as input. It then
57
  constructs an inverse rotation/translation matrix and passes it to the generator. The generator
 
75
  # Labels.
76
  label = torch.zeros([1, G.c_dim], device=device)
77
 
78
+
79
  z = torch.from_numpy(z.copy()).to(device)
80
  decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
81
 
 
91
  #m = make_transform(translate, rotate)
92
  #m = np.linalg.inv(m)
93
  #G.synthesis.input.transform.copy_(torch.from_numpy(m))
94
+ if latent_space == 'Z':
95
+ img = G(z_0, label, truncation_psi=0.7, noise_mode='const')
96
+
97
+ else:
98
+ W = z_0.expand((14, -1)).unsqueeze(0)
99
+ img = G.synthesis(W, noise_mode='const')
100
+
101
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
102
  images.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
103
 
104
  return images, lambdas
105
 
106
+ def generate_original_image(z, model, latent_space='Z'):
107
  """
108
  The generate_original_image function takes in a latent vector and the model,
109
  and returns an image generated from that latent vector.
 
118
  G = model.to(device) # type: ignore
119
  # Labels.
120
  label = torch.zeros([1, G.c_dim], device=device)
121
+ if latent_space == 'Z':
122
+ z = torch.from_numpy(z.copy()).to(device)
123
+ img = G(z, label, truncation_psi=0.7, noise_mode='const')
124
+ else:
125
+ W = torch.from_numpy(np.repeat(z, 14, axis=0).reshape(1, 14, z.shape[1]).copy()).to(device)
126
+ print(W.shape)
127
+ img = G.synthesis(W, noise_mode='const')
128
+
129
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
130
  return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
131
 
132
 
133
+ def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=0.1, latent_space='Z'):
134
  """
135
  The get_concepts_vectors function takes in a list of concepts, a dictionary of annotations, and the dataframe containing all the images.
136
  It returns two things:
 
150
  performances = []
151
  vectors = np.zeros((len(concepts), 512))
152
  for i, conc in enumerate(concepts):
153
+ vec, _, imp_nodes, performance = get_separation_space(conc, annotations, df, samples=samples, method=method, C=C, latent_space=latent_space)
154
  vectors[i,:] = vec
155
  performances.append(performance)
156
  important_nodes.append(set(imp_nodes))
 
166
  nodes_in_common = set.intersection(*important_nodes)
167
  return vectors, nodes_in_common, performances
168
 
169
+
170
+ def get_verification_score(concept, decision_boundary, model, annotations, samples=100, latent_space='Z'):
171
+ import open_clip
172
+ import os
173
+ import random
174
+ from tqdm import tqdm
175
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
176
+
177
+
178
+ model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')
179
+ tokenizer = open_clip.get_tokenizer('ViT-L-14')
180
+
181
+ # Prepare the text queries
182
+ #@markdown _in the form pre_prompt {label}_:
183
+ pre_prompt = "Artwork, " #@param {type:"string"}
184
+ text_descriptions = [f"{pre_prompt}{label}" for label in [concept]]
185
+ text_tokens = tokenizer(text_descriptions)
186
+
187
+
188
+ listlen = len(annotations['fname'])
189
+ items = random.sample(range(listlen), samples)
190
+ changes = []
191
+ for iterator in tqdm(items):
192
+ chunk_imgs = []
193
+ chunk_ids = []
194
+
195
+ if latent_space == 'Z':
196
+ z = annotations['z_vectors'][iterator]
197
+ else:
198
+ z = annotations['w_vectors'][iterator]
199
+ images, lambdas = regenerate_images(model, z, decision_boundary, min_epsilon=0, max_epsilon=1, count=2, latent_space=latent_space)
200
+ for im,l in zip(images, lambdas):
201
+
202
+ chunk_imgs.append(preprocess(im.convert("RGB")))
203
+ chunk_ids.append(l)
204
+
205
+ image_input = torch.tensor(np.stack(chunk_imgs))
206
+
207
+ with torch.no_grad(), torch.cuda.amp.autocast():
208
+ text_features = model_clip.encode_text(text_tokens).float()
209
+ image_features = model_clip.encode_image(image_input).float()
210
+
211
+ # Rescale features
212
+ image_features /= image_features.norm(dim=-1, keepdim=True)
213
+ text_features /= text_features.norm(dim=-1, keepdim=True)
214
+
215
+ # Analyze featues
216
+ text_probs = (100.0 * image_features.cpu().numpy() @ text_features.cpu().numpy().T)#.softmax(dim=-1)
217
+
218
+ change = max(text_probs[1][0].item() - text_probs[0][0].item(), 0)
219
+ changes.append(change)
220
+
221
+ return np.round(np.mean(np.array(changes)), 4)
222
+
223
+
224
+
225
+
data/annotated_files/{annotations_seeds0000-1000.pkl → seeds0000-50000.pkl} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ffcd38622c48bf91e0e193b846a181e6baee559633e6057df7765fe0ecd422cf
3
- size 4461349
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd1bd97b8ff508b1d4a7ef43323530368ace65b35d12d84a914913f541187298
3
+ size 314939226
data/annotated_files/sim_seeds0000-10000.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4e82d206b3aa231c00176a24c8de33a6299e92e65b23013a40538146b8d24ff8
3
- size 5645518
 
 
 
 
data/annotated_files/{annotations_parallel_seeds0000-10000.pkl → sim_seeds0000-50000.csv} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:cae08d2bfaa9be2002692efcacfdd10dfd480749e99d99c170a6de13f4811bad
3
- size 55986521
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c3faa3d75c2da1dbb2c5d90aeddee256e1f3324b24b902a54115d9b6aad0ae9d
3
+ size 21965577
data/model_files/pytorch_model.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:27d6840c1f9f11a0af97f6f1ff3809f7f3641d1e4ea7bc893ad15d9e4341caed
3
- size 120944973
 
 
 
 
pages/1_Disentanglement.py CHANGED
@@ -34,11 +34,11 @@ with st.expander("See more instruction", expanded=False):
34
  st.write(instruction_text)
35
 
36
 
37
- annotations_file = './data/annotated_files/seeds0000-100000.pkl'
38
  with open(annotations_file, 'rb') as f:
39
  annotations = pickle.load(f)
40
 
41
- ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-100000.csv')
42
  concepts = './data/concepts.txt'
43
 
44
  with open(concepts) as f:
@@ -48,6 +48,8 @@ if 'image_id' not in st.session_state:
48
  st.session_state.image_id = 0
49
  if 'concept_id' not in st.session_state:
50
  st.session_state.concept_id = 'Abstract'
 
 
51
 
52
  # def on_change_random_input():
53
  # st.session_state.image_id = st.session_state.image_id
@@ -65,7 +67,12 @@ with input_col_1:
65
  # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
66
  concept_id = st.selectbox('Concept:', tuple(labels))
67
 
68
- choose_text_button = st.form_submit_button('Choose the defined concept')
 
 
 
 
 
69
  # random_text = st.form_submit_button('Select a random concept')
70
 
71
  # if random_text:
@@ -76,6 +83,8 @@ with input_col_1:
76
  if choose_text_button:
77
  concept_id = str(concept_id)
78
  st.session_state.concept_id = concept_id
 
 
79
  # st.write(image_id, st.session_state.image_id)
80
 
81
  # ---------------------------- SET UP OUTPUT ------------------------------
@@ -101,10 +110,10 @@ smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgr
101
 
102
  # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
103
  with output_col_1:
104
- separation_vector, number_important_features, imp_nodes, performance = get_separation_space(concept_id, annotations, ann_df)
105
  # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
106
  st.write('Concept vector', separation_vector)
107
- header_col_1.write(f'Concept {concept_id} - Number of relevant nodes: {number_important_features} - Val classification performance: {performance}')# - Nodes {",".join(list(imp_nodes))}')
108
 
109
  # ----------------------------- INPUT column 2 & 3 ----------------------------
110
  with input_col_2:
@@ -141,8 +150,12 @@ with input_col_3:
141
  with dnnlib.util.open_url('./data/model_files/network-snapshot-010600.pkl') as f:
142
  model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
143
 
144
- original_image_vec = annotations['z_vectors'][st.session_state.image_id]
145
- img = generate_original_image(original_image_vec, model)
 
 
 
 
146
  # input_image = original_image_dict['image']
147
  # input_label = original_image_dict['label']
148
  # input_id = original_image_dict['id']
@@ -152,7 +165,7 @@ with smoothgrad_col_3:
152
  smooth_head_3.write(f'Base image')
153
 
154
 
155
- images, lambdas = regenerate_images(model, original_image_vec, separation_vector, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon))
156
 
157
  with smoothgrad_col_1:
158
  st.image(images[0])
 
34
  st.write(instruction_text)
35
 
36
 
37
+ annotations_file = './data/annotated_files/seeds0000-50000.pkl'
38
  with open(annotations_file, 'rb') as f:
39
  annotations = pickle.load(f)
40
 
41
+ ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-50000.csv')
42
  concepts = './data/concepts.txt'
43
 
44
  with open(concepts) as f:
 
48
  st.session_state.image_id = 0
49
  if 'concept_id' not in st.session_state:
50
  st.session_state.concept_id = 'Abstract'
51
+ if 'space_id' not in st.session_state:
52
+ st.session_state.space_id = 'Z'
53
 
54
  # def on_change_random_input():
55
  # st.session_state.image_id = st.session_state.image_id
 
67
  # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
68
  concept_id = st.selectbox('Concept:', tuple(labels))
69
 
70
+ st.write('**Choose a latent space to disentangle**')
71
+ # chosen_text_id_input = st.empty()
72
+ # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
73
+ space_id = st.selectbox('Space:', tuple(['Z', 'W']))
74
+
75
+ choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
76
  # random_text = st.form_submit_button('Select a random concept')
77
 
78
  # if random_text:
 
83
  if choose_text_button:
84
  concept_id = str(concept_id)
85
  st.session_state.concept_id = concept_id
86
+ space_id = str(space_id)
87
+ st.session_state.space_id = space_id
88
  # st.write(image_id, st.session_state.image_id)
89
 
90
  # ---------------------------- SET UP OUTPUT ------------------------------
 
110
 
111
  # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
112
  with output_col_1:
113
+ separation_vector, number_important_features, imp_nodes, performance = get_separation_space(concept_id, annotations, ann_df, latent_space=st.session_state.space_id)
114
  # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
115
  st.write('Concept vector', separation_vector)
116
+ header_col_1.write(f'Concept {concept_id} - Space {st.session_state.space_id} - Number of relevant nodes: {number_important_features} - Val classification performance: {performance}')# - Nodes {",".join(list(imp_nodes))}')
117
 
118
  # ----------------------------- INPUT column 2 & 3 ----------------------------
119
  with input_col_2:
 
150
  with dnnlib.util.open_url('./data/model_files/network-snapshot-010600.pkl') as f:
151
  model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
152
 
153
+ if st.session_state.space_id == 'Z':
154
+ original_image_vec = annotations['z_vectors'][st.session_state.image_id]
155
+ else:
156
+ original_image_vec = annotations['w_vectors'][st.session_state.image_id]
157
+
158
+ img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
159
  # input_image = original_image_dict['image']
160
  # input_label = original_image_dict['label']
161
  # input_id = original_image_dict['id']
 
165
  smooth_head_3.write(f'Base image')
166
 
167
 
168
+ images, lambdas = regenerate_images(model, original_image_vec, separation_vector, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon), latent_space=st.session_state.space_id)
169
 
170
  with smoothgrad_col_1:
171
  st.image(images[0])
pages/2_Concepts_comparison.py CHANGED
@@ -39,7 +39,8 @@ if 'image_id' not in st.session_state:
39
  st.session_state.image_id = 0
40
  if 'concept_ids' not in st.session_state:
41
  st.session_state.concept_ids = ['Abstract', 'Representational']
42
-
 
43
  # def on_change_random_input():
44
  # st.session_state.image_id = st.session_state.image_id
45
 
@@ -63,9 +64,17 @@ with input_col_1:
63
  # concept_id = random.choice(labels)
64
  # st.session_state.concept_id = concept_id
65
  # chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
66
-
 
 
 
 
 
 
67
  if choose_text_button:
68
  st.session_state.concept_ids = list(concept_ids)
 
 
69
  # st.write(image_id, st.session_state.image_id)
70
 
71
  # ---------------------------- SET UP OUTPUT ------------------------------
@@ -91,10 +100,10 @@ smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgr
91
 
92
  # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
93
  with output_col_1:
94
- vectors, nodes_in_common, performances = get_concepts_vectors(concept_ids, annotations, ann_df)
95
  # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
96
  #st.write('Concept vector', separation_vector)
97
- header_col_1.write(f'Concepts {", ".join(concept_ids)} - Relevant nodes in common: {nodes_in_common} - Performance of the concept vectors: {performances}')# - Nodes {",".join(list(imp_nodes))}')
98
 
99
  edges = []
100
  for i in range(len(concept_ids)):
 
39
  st.session_state.image_id = 0
40
  if 'concept_ids' not in st.session_state:
41
  st.session_state.concept_ids = ['Abstract', 'Representational']
42
+ if 'space_id' not in st.session_state:
43
+ st.session_state.space_id = 'Z'
44
  # def on_change_random_input():
45
  # st.session_state.image_id = st.session_state.image_id
46
 
 
64
  # concept_id = random.choice(labels)
65
  # st.session_state.concept_id = concept_id
66
  # chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
67
+ st.write('**Choose a latent space to disentangle**')
68
+ # chosen_text_id_input = st.empty()
69
+ # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
70
+ space_id = st.selectbox('Space:', tuple(['Z', 'W']))
71
+
72
+ choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
73
+
74
  if choose_text_button:
75
  st.session_state.concept_ids = list(concept_ids)
76
+ space_id = str(space_id)
77
+ st.session_state.space_id = space_id
78
  # st.write(image_id, st.session_state.image_id)
79
 
80
  # ---------------------------- SET UP OUTPUT ------------------------------
 
100
 
101
  # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
102
  with output_col_1:
103
+ vectors, nodes_in_common, performances = get_concepts_vectors(concept_ids, annotations, ann_df, latent_space=space_id)
104
  # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
105
  #st.write('Concept vector', separation_vector)
106
+ header_col_1.write(f'Concepts {", ".join(concept_ids)} - Latent space {space_id} - Relevant nodes in common: {nodes_in_common} - Performance of the concept vectors: {performances}')# - Nodes {",".join(list(imp_nodes))}')
107
 
108
  edges = []
109
  for i in range(len(concept_ids)):
view_predictions.ipynb CHANGED
The diff for this file is too large to render. See raw diff