ludusc commited on
Commit
8a2c29c
1 Parent(s): 3f788ef

small fixes, CLIP vecs graph

Browse files
data/CLIP_vecs.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2971a01a74a391c752fff9ba91c2939ffc6b29165842a87b911e67d9658df53
3
+ size 412234
pages/1_Disentanglement.py CHANGED
@@ -128,7 +128,7 @@ with input_col_2:
128
  random_id = st.form_submit_button('Generate a random image')
129
 
130
  if random_id:
131
- image_id = random.randint(0, 100000)
132
  st.session_state.image_id = image_id
133
  chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
134
 
 
128
  random_id = st.form_submit_button('Generate a random image')
129
 
130
  if random_id:
131
+ image_id = random.randint(0, 50000)
132
  st.session_state.image_id = image_id
133
  chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
134
 
pages/2_Concepts_comparison.py CHANGED
@@ -25,11 +25,11 @@ st.write('> **What is their join impact on the image?**')
25
  st.write("""Description to write""")
26
 
27
 
28
- annotations_file = './data/annotated_files/seeds0000-100000.pkl'
29
  with open(annotations_file, 'rb') as f:
30
  annotations = pickle.load(f)
31
 
32
- ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-100000.csv')
33
  concepts = './data/concepts.txt'
34
 
35
  with open(concepts) as f:
@@ -57,13 +57,6 @@ with input_col_1:
57
  # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
58
  concept_ids = st.multiselect('Concept:', tuple(labels))
59
 
60
- choose_text_button = st.form_submit_button('Choose the defined concepts')
61
- # random_text = st.form_submit_button('Select a random concept')
62
-
63
- # if random_text:
64
- # concept_id = random.choice(labels)
65
- # st.session_state.concept_id = concept_id
66
- # chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
67
  st.write('**Choose a latent space to disentangle**')
68
  # chosen_text_id_input = st.empty()
69
  # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
@@ -85,8 +78,8 @@ st.subheader('Concept vector')
85
  # perform attack container
86
  # header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
87
  # output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
88
- header_col_1, header_col_2 = st.columns([5,1])
89
- output_col_1, output_col_2 = st.columns([5,1])
90
 
91
  st.subheader('Derivations along the concept vector')
92
 
@@ -157,6 +150,66 @@ with output_col_1:
157
  # Load HTML file in HTML component for display on Streamlit page
158
  components.html(HtmlFile.read(), height=435)
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  # ----------------------------- INPUT column 2 & 3 ----------------------------
161
  # with input_col_2:
162
  # with st.form('image_form'):
 
25
  st.write("""Description to write""")
26
 
27
 
28
+ annotations_file = './data/annotated_files/seeds0000-50000.pkl'
29
  with open(annotations_file, 'rb') as f:
30
  annotations = pickle.load(f)
31
 
32
+ ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-50000.csv')
33
  concepts = './data/concepts.txt'
34
 
35
  with open(concepts) as f:
 
57
  # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
58
  concept_ids = st.multiselect('Concept:', tuple(labels))
59
 
 
 
 
 
 
 
 
60
  st.write('**Choose a latent space to disentangle**')
61
  # chosen_text_id_input = st.empty()
62
  # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
 
78
  # perform attack container
79
  # header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
80
  # output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
81
+ header_col_1, header_col_2 = st.columns([1,1])
82
+ output_col_1, output_col_2 = st.columns([1,1])
83
 
84
  st.subheader('Derivations along the concept vector')
85
 
 
150
  # Load HTML file in HTML component for display on Streamlit page
151
  components.html(HtmlFile.read(), height=435)
152
 
153
+ with output_col_2:
154
+ with open('data/CLIP_vecs.pkl', 'rb') as f:
155
+ vectors = pickle.load(f)
156
+
157
+ # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
158
+ #st.write('Concept vector', separation_vector)
159
+ header_col_2.write(f'Concepts {", ".join(concept_ids)} - Latent space CLIP')# - Nodes {",".join(list(imp_nodes))}')
160
+
161
+ edges = []
162
+ for i in range(len(concept_ids)):
163
+ for j in range(len(concept_ids)):
164
+ if i != j:
165
+ print(f'Similarity between {concept_ids[i]} and {concept_ids[j]}')
166
+ similarity = cosine_similarity(vectors[i,:].reshape(1, -1), vectors[j,:].reshape(1, -1))
167
+ print(np.round(similarity[0][0], 3))
168
+ edges.append((concept_ids[i], concept_ids[j], np.round(similarity[0][0], 3)))
169
+
170
+ # # Create an empty graph
171
+ # G = nx.Graph()
172
+
173
+ # # Add edges with weights to the graph
174
+ # for edge in edges:
175
+ # node1, node2, weight = edge
176
+ # G.add_edge(node1, node2, weight=weight)
177
+
178
+ net = Network(height="750px", width="100%",)
179
+ for e in edges:
180
+ src = e[0]
181
+ dst = e[1]
182
+ w = e[2]
183
+
184
+ net.add_node(src, src, title=src)
185
+ net.add_node(dst, dst, title=dst)
186
+ net.add_edge(src, dst, value=w, title=src + ' to ' + dst + ' similarity ' +str(w))
187
+ print(net)
188
+
189
+ # Generate network with specific layout settings
190
+ net.repulsion(
191
+ node_distance=420,
192
+ central_gravity=0.33,
193
+ spring_length=110,
194
+ spring_strength=0.10,
195
+ damping=0.95
196
+ )
197
+
198
+ # Save and read graph as HTML file (on Streamlit Sharing)
199
+ try:
200
+ path = '/tmp'
201
+ net.save_graph(f'{path}/pyvis_graph_clip.html')
202
+ HtmlFile = open(f'{path}/pyvis_graph_clip.html', 'r', encoding='utf-8')
203
+
204
+ # Save and read graph as HTML file (locally)
205
+ except:
206
+ path = '/html_files'
207
+ net.save_graph(f'{path}/pyvis_graph_clip.html')
208
+ HtmlFile = open(f'{path}/pyvis_graph_clip.html', 'r', encoding='utf-8')
209
+
210
+ # Load HTML file in HTML component for display on Streamlit page
211
+ components.html(HtmlFile.read(), height=435)
212
+
213
  # ----------------------------- INPUT column 2 & 3 ----------------------------
214
  # with input_col_2:
215
  # with st.form('image_form'):
view_predictions.ipynb CHANGED
@@ -208,6 +208,73 @@
208
  "images[-1]"
209
  ]
210
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
211
  {
212
  "cell_type": "code",
213
  "execution_count": 8,
 
208
  "images[-1]"
209
  ]
210
  },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": 20,
214
+ "id": "f5390d8f",
215
+ "metadata": {},
216
+ "outputs": [
217
+ {
218
+ "name": "stderr",
219
+ "output_type": "stream",
220
+ "text": [
221
+ "/Users/ludovicaschaerf/anaconda3/envs/art-reco_x86/lib/python3.8/site-packages/torch/amp/autocast_mode.py:204: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
222
+ " warnings.warn('User provided device_type of \\'cuda\\', but CUDA is not available. Disabling')\n"
223
+ ]
224
+ },
225
+ {
226
+ "name": "stdout",
227
+ "output_type": "stream",
228
+ "text": [
229
+ "(132, 768)\n"
230
+ ]
231
+ }
232
+ ],
233
+ "source": [
234
+ "import open_clip\n",
235
+ "import os\n",
236
+ "import random\n",
237
+ "from tqdm import tqdm\n",
238
+ "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
239
+ "\n",
240
+ "model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')\n",
241
+ "tokenizer = open_clip.get_tokenizer('ViT-L-14')\n",
242
+ "\n",
243
+ "pre_prompt = \"Artwork, \" #@param {type:\"string\"}\n",
244
+ "text_descriptions = [f\"{pre_prompt}{label}\" for label in labels]\n",
245
+ "text_tokens = tokenizer(text_descriptions)\n",
246
+ "\n",
247
+ "with torch.no_grad(), torch.cuda.amp.autocast():\n",
248
+ " text_features = model_clip.encode_text(text_tokens).float()\n",
249
+ " text_features /= text_features.norm(dim=-1, keepdim=True)\n",
250
+ " \n",
251
+ "text_features = text_features.cpu().numpy()\n",
252
+ "print(text_features.shape)\n",
253
+ "\n"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": 22,
259
+ "id": "f7858bbf",
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": [
263
+ "dic_clip_vecs = {l:v for l,v in zip(labels, text_features)}"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "code",
268
+ "execution_count": 26,
269
+ "id": "89b4a6fc",
270
+ "metadata": {},
271
+ "outputs": [],
272
+ "source": [
273
+ "dic_clip_vecs['Abstract'].shape\n",
274
+ "with open('data/CLIP_vecs.pkl', 'wb') as f:\n",
275
+ " pickle.dump(dic_clip_vecs, f)"
276
+ ]
277
+ },
278
  {
279
  "cell_type": "code",
280
  "execution_count": 8,