Spaces:
Runtime error
Runtime error
small fixes, CLIP vecs graph
Browse files- data/CLIP_vecs.pkl +3 -0
- pages/1_Disentanglement.py +1 -1
- pages/2_Concepts_comparison.py +64 -11
- view_predictions.ipynb +67 -0
data/CLIP_vecs.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e2971a01a74a391c752fff9ba91c2939ffc6b29165842a87b911e67d9658df53
|
3 |
+
size 412234
|
pages/1_Disentanglement.py
CHANGED
@@ -128,7 +128,7 @@ with input_col_2:
|
|
128 |
random_id = st.form_submit_button('Generate a random image')
|
129 |
|
130 |
if random_id:
|
131 |
-
image_id = random.randint(0,
|
132 |
st.session_state.image_id = image_id
|
133 |
chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
134 |
|
|
|
128 |
random_id = st.form_submit_button('Generate a random image')
|
129 |
|
130 |
if random_id:
|
131 |
+
image_id = random.randint(0, 50000)
|
132 |
st.session_state.image_id = image_id
|
133 |
chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
134 |
|
pages/2_Concepts_comparison.py
CHANGED
@@ -25,11 +25,11 @@ st.write('> **What is their join impact on the image?**')
|
|
25 |
st.write("""Description to write""")
|
26 |
|
27 |
|
28 |
-
annotations_file = './data/annotated_files/seeds0000-
|
29 |
with open(annotations_file, 'rb') as f:
|
30 |
annotations = pickle.load(f)
|
31 |
|
32 |
-
ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-
|
33 |
concepts = './data/concepts.txt'
|
34 |
|
35 |
with open(concepts) as f:
|
@@ -57,13 +57,6 @@ with input_col_1:
|
|
57 |
# concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
58 |
concept_ids = st.multiselect('Concept:', tuple(labels))
|
59 |
|
60 |
-
choose_text_button = st.form_submit_button('Choose the defined concepts')
|
61 |
-
# random_text = st.form_submit_button('Select a random concept')
|
62 |
-
|
63 |
-
# if random_text:
|
64 |
-
# concept_id = random.choice(labels)
|
65 |
-
# st.session_state.concept_id = concept_id
|
66 |
-
# chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
67 |
st.write('**Choose a latent space to disentangle**')
|
68 |
# chosen_text_id_input = st.empty()
|
69 |
# concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
@@ -85,8 +78,8 @@ st.subheader('Concept vector')
|
|
85 |
# perform attack container
|
86 |
# header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
|
87 |
# output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
|
88 |
-
header_col_1, header_col_2 = st.columns([
|
89 |
-
output_col_1, output_col_2 = st.columns([
|
90 |
|
91 |
st.subheader('Derivations along the concept vector')
|
92 |
|
@@ -157,6 +150,66 @@ with output_col_1:
|
|
157 |
# Load HTML file in HTML component for display on Streamlit page
|
158 |
components.html(HtmlFile.read(), height=435)
|
159 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
160 |
# ----------------------------- INPUT column 2 & 3 ----------------------------
|
161 |
# with input_col_2:
|
162 |
# with st.form('image_form'):
|
|
|
25 |
st.write("""Description to write""")
|
26 |
|
27 |
|
28 |
+
annotations_file = './data/annotated_files/seeds0000-50000.pkl'
|
29 |
with open(annotations_file, 'rb') as f:
|
30 |
annotations = pickle.load(f)
|
31 |
|
32 |
+
ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-50000.csv')
|
33 |
concepts = './data/concepts.txt'
|
34 |
|
35 |
with open(concepts) as f:
|
|
|
57 |
# concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
58 |
concept_ids = st.multiselect('Concept:', tuple(labels))
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
st.write('**Choose a latent space to disentangle**')
|
61 |
# chosen_text_id_input = st.empty()
|
62 |
# concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
|
|
|
78 |
# perform attack container
|
79 |
# header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
|
80 |
# output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
|
81 |
+
header_col_1, header_col_2 = st.columns([1,1])
|
82 |
+
output_col_1, output_col_2 = st.columns([1,1])
|
83 |
|
84 |
st.subheader('Derivations along the concept vector')
|
85 |
|
|
|
150 |
# Load HTML file in HTML component for display on Streamlit page
|
151 |
components.html(HtmlFile.read(), height=435)
|
152 |
|
153 |
+
with output_col_2:
|
154 |
+
with open('data/CLIP_vecs.pkl', 'rb') as f:
|
155 |
+
vectors = pickle.load(f)
|
156 |
+
|
157 |
+
# st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
|
158 |
+
#st.write('Concept vector', separation_vector)
|
159 |
+
header_col_2.write(f'Concepts {", ".join(concept_ids)} - Latent space CLIP')# - Nodes {",".join(list(imp_nodes))}')
|
160 |
+
|
161 |
+
edges = []
|
162 |
+
for i in range(len(concept_ids)):
|
163 |
+
for j in range(len(concept_ids)):
|
164 |
+
if i != j:
|
165 |
+
print(f'Similarity between {concept_ids[i]} and {concept_ids[j]}')
|
166 |
+
similarity = cosine_similarity(vectors[i,:].reshape(1, -1), vectors[j,:].reshape(1, -1))
|
167 |
+
print(np.round(similarity[0][0], 3))
|
168 |
+
edges.append((concept_ids[i], concept_ids[j], np.round(similarity[0][0], 3)))
|
169 |
+
|
170 |
+
# # Create an empty graph
|
171 |
+
# G = nx.Graph()
|
172 |
+
|
173 |
+
# # Add edges with weights to the graph
|
174 |
+
# for edge in edges:
|
175 |
+
# node1, node2, weight = edge
|
176 |
+
# G.add_edge(node1, node2, weight=weight)
|
177 |
+
|
178 |
+
net = Network(height="750px", width="100%",)
|
179 |
+
for e in edges:
|
180 |
+
src = e[0]
|
181 |
+
dst = e[1]
|
182 |
+
w = e[2]
|
183 |
+
|
184 |
+
net.add_node(src, src, title=src)
|
185 |
+
net.add_node(dst, dst, title=dst)
|
186 |
+
net.add_edge(src, dst, value=w, title=src + ' to ' + dst + ' similarity ' +str(w))
|
187 |
+
print(net)
|
188 |
+
|
189 |
+
# Generate network with specific layout settings
|
190 |
+
net.repulsion(
|
191 |
+
node_distance=420,
|
192 |
+
central_gravity=0.33,
|
193 |
+
spring_length=110,
|
194 |
+
spring_strength=0.10,
|
195 |
+
damping=0.95
|
196 |
+
)
|
197 |
+
|
198 |
+
# Save and read graph as HTML file (on Streamlit Sharing)
|
199 |
+
try:
|
200 |
+
path = '/tmp'
|
201 |
+
net.save_graph(f'{path}/pyvis_graph_clip.html')
|
202 |
+
HtmlFile = open(f'{path}/pyvis_graph_clip.html', 'r', encoding='utf-8')
|
203 |
+
|
204 |
+
# Save and read graph as HTML file (locally)
|
205 |
+
except:
|
206 |
+
path = '/html_files'
|
207 |
+
net.save_graph(f'{path}/pyvis_graph_clip.html')
|
208 |
+
HtmlFile = open(f'{path}/pyvis_graph_clip.html', 'r', encoding='utf-8')
|
209 |
+
|
210 |
+
# Load HTML file in HTML component for display on Streamlit page
|
211 |
+
components.html(HtmlFile.read(), height=435)
|
212 |
+
|
213 |
# ----------------------------- INPUT column 2 & 3 ----------------------------
|
214 |
# with input_col_2:
|
215 |
# with st.form('image_form'):
|
view_predictions.ipynb
CHANGED
@@ -208,6 +208,73 @@
|
|
208 |
"images[-1]"
|
209 |
]
|
210 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
{
|
212 |
"cell_type": "code",
|
213 |
"execution_count": 8,
|
|
|
208 |
"images[-1]"
|
209 |
]
|
210 |
},
|
211 |
+
{
|
212 |
+
"cell_type": "code",
|
213 |
+
"execution_count": 20,
|
214 |
+
"id": "f5390d8f",
|
215 |
+
"metadata": {},
|
216 |
+
"outputs": [
|
217 |
+
{
|
218 |
+
"name": "stderr",
|
219 |
+
"output_type": "stream",
|
220 |
+
"text": [
|
221 |
+
"/Users/ludovicaschaerf/anaconda3/envs/art-reco_x86/lib/python3.8/site-packages/torch/amp/autocast_mode.py:204: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
|
222 |
+
" warnings.warn('User provided device_type of \\'cuda\\', but CUDA is not available. Disabling')\n"
|
223 |
+
]
|
224 |
+
},
|
225 |
+
{
|
226 |
+
"name": "stdout",
|
227 |
+
"output_type": "stream",
|
228 |
+
"text": [
|
229 |
+
"(132, 768)\n"
|
230 |
+
]
|
231 |
+
}
|
232 |
+
],
|
233 |
+
"source": [
|
234 |
+
"import open_clip\n",
|
235 |
+
"import os\n",
|
236 |
+
"import random\n",
|
237 |
+
"from tqdm import tqdm\n",
|
238 |
+
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"\"\n",
|
239 |
+
"\n",
|
240 |
+
"model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')\n",
|
241 |
+
"tokenizer = open_clip.get_tokenizer('ViT-L-14')\n",
|
242 |
+
"\n",
|
243 |
+
"pre_prompt = \"Artwork, \" #@param {type:\"string\"}\n",
|
244 |
+
"text_descriptions = [f\"{pre_prompt}{label}\" for label in labels]\n",
|
245 |
+
"text_tokens = tokenizer(text_descriptions)\n",
|
246 |
+
"\n",
|
247 |
+
"with torch.no_grad(), torch.cuda.amp.autocast():\n",
|
248 |
+
" text_features = model_clip.encode_text(text_tokens).float()\n",
|
249 |
+
" text_features /= text_features.norm(dim=-1, keepdim=True)\n",
|
250 |
+
" \n",
|
251 |
+
"text_features = text_features.cpu().numpy()\n",
|
252 |
+
"print(text_features.shape)\n",
|
253 |
+
"\n"
|
254 |
+
]
|
255 |
+
},
|
256 |
+
{
|
257 |
+
"cell_type": "code",
|
258 |
+
"execution_count": 22,
|
259 |
+
"id": "f7858bbf",
|
260 |
+
"metadata": {},
|
261 |
+
"outputs": [],
|
262 |
+
"source": [
|
263 |
+
"dic_clip_vecs = {l:v for l,v in zip(labels, text_features)}"
|
264 |
+
]
|
265 |
+
},
|
266 |
+
{
|
267 |
+
"cell_type": "code",
|
268 |
+
"execution_count": 26,
|
269 |
+
"id": "89b4a6fc",
|
270 |
+
"metadata": {},
|
271 |
+
"outputs": [],
|
272 |
+
"source": [
|
273 |
+
"dic_clip_vecs['Abstract'].shape\n",
|
274 |
+
"with open('data/CLIP_vecs.pkl', 'wb') as f:\n",
|
275 |
+
" pickle.dump(dic_clip_vecs, f)"
|
276 |
+
]
|
277 |
+
},
|
278 |
{
|
279 |
"cell_type": "code",
|
280 |
"execution_count": 8,
|