ludusc commited on
Commit
52bd88d
1 Parent(s): c1562ad
ganspace_unsupervised_disentanglement.ipynb ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "3722712c",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "%matplotlib inline \n",
11
+ "\n",
12
+ "import pandas as pd\n",
13
+ "import pickle\n",
14
+ "import random\n",
15
+ "\n",
16
+ "from PIL import Image, ImageColor\n",
17
+ "import matplotlib.pyplot as plt\n",
18
+ "\n",
19
+ "import numpy as np\n",
20
+ "import torch\n",
21
+ "\n",
22
+ "from backend.disentangle_concepts import *\n",
23
+ "import dnnlib \n",
24
+ "import legacy\n",
25
+ "from backend.color_annotations import *\n",
26
+ "\n",
27
+ "import random\n",
28
+ "\n",
29
+ "from sklearn.linear_model import LinearRegression, LogisticRegression\n",
30
+ "\n",
31
+ "\n",
32
+ "%load_ext autoreload\n",
33
+ "%autoreload 2"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "id": "fe7acfaf-dc61-4211-9c78-8e4433bc9deb",
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "annotations_file = './data/textile_annotated_files/seeds0000-100000.pkl'\n",
44
+ "with open(annotations_file, 'rb') as f:\n",
45
+ " annotations = pickle.load(f)\n",
46
+ "\n",
47
+ "ann_df = pd.read_csv('./data/textile_annotated_files/top_three_colours.csv').fillna('#000000')\n",
48
+ "\n",
49
+ "with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:\n",
50
+ " model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore\n"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "id": "cd114cb1",
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "ann_df = tohsv(ann_df)\n",
61
+ "ann_df.head()"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "id": "feb64168",
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "X = np.array(annotations['w_vectors']).reshape((len(annotations['w_vectors']), 512))\n",
72
+ "print(X.shape)\n",
73
+ "y_h = np.array(ann_df['H1'].values)\n",
74
+ "y_s = np.array(ann_df['S1'].values)\n",
75
+ "y_v = np.array(ann_df['S1'].values)"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "markdown",
80
+ "id": "4e814959",
81
+ "metadata": {},
82
+ "source": [
83
+ "### Unsupervised approaches"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "id": "c9493f54",
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "from sklearn.decomposition import PCA"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "id": "93596853",
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "pca = PCA(n_components=20)\n",
104
+ "\n",
105
+ "dims_pca = pca.fit_transform(x_trainhc.T)\n",
106
+ "dims_pca /= np.linalg.norm(dims_pca, axis=0)\n",
107
+ "print(dims_pca.shape, np.linalg.norm(dims_pca, axis=0).shape)"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "id": "dd0b1501",
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "method = 'PCA dimension'\n",
118
+ "for sep, num in zip(dims_pca.T, range(10)):\n",
119
+ " images, lambdas = regenerate_images(model, original_image_vec, sep, min_epsilon=-(int(4)), max_epsilon=int(4), count=5, latent_space='W')\n",
120
+ " fig, axs = plt.subplots(1, len(images), figsize=(50,10))\n",
121
+ " fig.suptitle(method +': ' + str(num), fontsize=20)\n",
122
+ " for i,im in enumerate(images):\n",
123
+ " axs[i].imshow(im)\n",
124
+ " axs[i].set_title(np.round(lambdas[i], 2))\n",
125
+ " plt.show()"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "markdown",
130
+ "id": "4e0c7808",
131
+ "metadata": {},
132
+ "source": [
133
+ "## dimensionality reduction e vediamo dove finiscono i vari colori"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "markdown",
138
+ "id": "833ed31f",
139
+ "metadata": {},
140
+ "source": [
141
+ "## clustering per vedere quali sono i centroid di questo spazio e se ci sono regioni determinate dai colori"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "id": "7c19e820",
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": []
151
+ }
152
+ ],
153
+ "metadata": {
154
+ "kernelspec": {
155
+ "display_name": "Python 3",
156
+ "language": "python",
157
+ "name": "python3"
158
+ },
159
+ "language_info": {
160
+ "codemirror_mode": {
161
+ "name": "ipython",
162
+ "version": 3
163
+ },
164
+ "file_extension": ".py",
165
+ "mimetype": "text/x-python",
166
+ "name": "python",
167
+ "nbconvert_exporter": "python",
168
+ "pygments_lexer": "ipython3",
169
+ "version": "3.8.16"
170
+ }
171
+ },
172
+ "nbformat": 4,
173
+ "nbformat_minor": 5
174
+ }
interfacegan_colour_disentanglement.ipynb ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "3722712c",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "%matplotlib inline \n",
11
+ "\n",
12
+ "import pandas as pd\n",
13
+ "import pickle\n",
14
+ "import random\n",
15
+ "\n",
16
+ "from PIL import Image, ImageColor\n",
17
+ "import matplotlib.pyplot as plt\n",
18
+ "\n",
19
+ "import numpy as np\n",
20
+ "import torch\n",
21
+ "\n",
22
+ "from backend.disentangle_concepts import *\n",
23
+ "import dnnlib \n",
24
+ "import legacy\n",
25
+ "from backend.color_annotations import *\n",
26
+ "\n",
27
+ "import random\n",
28
+ "\n",
29
+ "from sklearn.linear_model import LinearRegression, LogisticRegression\n",
30
+ "\n",
31
+ "\n",
32
+ "%load_ext autoreload\n",
33
+ "%autoreload 2"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "id": "5630402a",
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "num_colors = 7"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "id": "00e57598",
50
+ "metadata": {},
51
+ "outputs": [],
52
+ "source": [
53
+ "values = [x*256/num_colors if x<num_colors else 256 for x in range(num_colors + 1)]\n",
54
+ "centers = [int((values[i-1]+values[i])/2) for i in range(len(values)) if i > 0]"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "id": "1550ecd7",
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "print(values)\n",
65
+ "print(centers)"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": null,
71
+ "id": "ab9be91e",
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "def create_color_image(hue, saturation, value, size=(20, 10)):\n",
76
+ " color_rgb = ImageColor.getrgb(\"hsv({}, {}%, {}%)\".format(hue, int(saturation * 100), int(value * 100)))\n",
77
+ " image = Image.new(\"RGB\", size, color_rgb)\n",
78
+ " return image"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "id": "bf1c8ab5",
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "def display_image(image, title=''):\n",
89
+ " plt.figure()\n",
90
+ " plt.suptitle(title)\n",
91
+ " plt.imshow(image)\n",
92
+ " plt.axis('off')\n",
93
+ " plt.show()"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "id": "519b16d4",
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "def to_256(val):\n",
104
+ " x = val*360/256\n",
105
+ " return x"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "id": "8f696758",
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "names = ['Red', 'Orange', 'Yellow', 'Yellow Green', 'Chartreuse Green',\n",
116
+ " 'Kelly Green', 'Green Blue Seafoam', 'Cyan Blue',\n",
117
+ " 'Warm Blue', 'Indigo', 'Purple Magenta', 'Magenta Pink']"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "id": "50825823",
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "saturation = 1 # Saturation value (0 to 1)\n",
128
+ "value = 1 # Value (brightness) value (0 to 1)\n",
129
+ "for hue, name in zip(centers, names[:num_colors]):\n",
130
+ " image = create_color_image(to_256(hue), saturation, value)\n",
131
+ " display_image(image, name) # Display the generated color image"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "id": "fe7acfaf-dc61-4211-9c78-8e4433bc9deb",
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": [
141
+ "annotations_file = './data/textile_annotated_files/seeds0000-100000.pkl'\n",
142
+ "with open(annotations_file, 'rb') as f:\n",
143
+ " annotations = pickle.load(f)\n",
144
+ "\n",
145
+ "ann_df = pd.read_csv('./data/textile_annotated_files/top_three_colours.csv').fillna('#000000')\n",
146
+ "\n",
147
+ "with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:\n",
148
+ " model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore\n"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": null,
154
+ "id": "cd114cb1",
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": [
158
+ "ann_df = tohsv(ann_df)\n",
159
+ "ann_df.head()"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "id": "feb64168",
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "X = np.array(annotations['w_vectors']).reshape((len(annotations['w_vectors']), 512))\n",
170
+ "print(X.shape)\n",
171
+ "y_h = np.array(ann_df['H1'].values)\n",
172
+ "y_s = np.array(ann_df['S1'].values)\n",
173
+ "y_v = np.array(ann_df['S1'].values)"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "id": "0ca08749",
180
+ "metadata": {},
181
+ "outputs": [],
182
+ "source": [
183
+ "np.unique(y_h)"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "markdown",
188
+ "id": "e8f33f14",
189
+ "metadata": {},
190
+ "source": [
191
+ "## Regression model"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": null,
197
+ "id": "8da0a43d",
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "x_trainh, x_valh, y_trainh, y_valh = train_test_split(X, y_h, test_size=0.2)\n",
202
+ "x_trains, x_vals, y_trains, y_vals = train_test_split(X, y_s, test_size=0.2)\n",
203
+ "x_trainv, x_valv, y_trainv, y_valv = train_test_split(X, y_v, test_size=0.2)\n"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": null,
209
+ "id": "8eddba20",
210
+ "metadata": {},
211
+ "outputs": [],
212
+ "source": [
213
+ "regh = LinearRegression().fit(x_trainh, y_trainh)\n",
214
+ "print('Val performance logistic regression', np.round(regh.score(x_valh, y_valh),2))\n",
215
+ "\n",
216
+ "separation_vectorh = regh.coef_ / np.linalg.norm(regh.coef_)\n",
217
+ "print(separation_vectorh.shape)\n",
218
+ "\n",
219
+ "regs = LinearRegression().fit(x_trains, y_trains)\n",
220
+ "print('Val performance logistic regression', np.round(regs.score(x_vals, y_vals),2))\n",
221
+ "\n",
222
+ "separation_vectors = regs.coef_ / np.linalg.norm(regs.coef_)\n",
223
+ "print(separation_vectors.shape)\n",
224
+ "\n",
225
+ "regv = LinearRegression().fit(x_trainv, y_trainv)\n",
226
+ "print('Val performance logistic regression', np.round(reg.score(x_valv, y_valv),2))\n",
227
+ "\n",
228
+ "separation_vectorv = regv.coef_ / np.linalg.norm(regv.coef_)\n",
229
+ "print(separation_vectorv.shape)\n"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "id": "c6a63345",
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": [
239
+ "seed = random.randint(0,100000)\n",
240
+ "original_image_vec = annotations['w_vectors'][seed]\n",
241
+ "img = generate_original_image(original_image_vec, model, latent_space='W')\n",
242
+ "img"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": null,
248
+ "id": "09f13e6a",
249
+ "metadata": {},
250
+ "outputs": [],
251
+ "source": [
252
+ "images, lambdas = regenerate_images(model, original_image_vec, separation_vectors, min_epsilon=-(int(5)), max_epsilon=int(5), count=7, latent_space='W')"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": null,
258
+ "id": "c66bcdde",
259
+ "metadata": {},
260
+ "outputs": [],
261
+ "source": [
262
+ "fig, axs = plt.subplots(1, len(images), figsize=(30,200))\n",
263
+ "for i,im in enumerate(images):\n",
264
+ " axs[i].imshow(im)\n",
265
+ " axs[i].set_title(np.round(lambdas[i], 2))"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "markdown",
270
+ "id": "4c44f0dd",
271
+ "metadata": {},
272
+ "source": [
273
+ "fourier per regolarità pattern\n",
274
+ "linear correlation con il colore\n",
275
+ "distribution dei colori original e non \n",
276
+ "neural network per vedere quanto riesce a classificare"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "markdown",
281
+ "id": "c2790c25",
282
+ "metadata": {},
283
+ "source": [
284
+ "## Multiclass model"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": null,
290
+ "id": "afa0c100",
291
+ "metadata": {},
292
+ "outputs": [],
293
+ "source": [
294
+ "colors_list = ['Warm Pink Red', 'Red Orange', 'Orange Yellow', 'Gold Yellow', 'Chartreuse Green',\n",
295
+ " 'Kelly Green', 'Green Blue Seafoam', 'Blue Green Cyan',\n",
296
+ " 'Warm Blue', 'Indigo Blue Purple', 'Purple Magenta', 'Magenta Pink']"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "markdown",
301
+ "id": "39a5668a",
302
+ "metadata": {},
303
+ "source": [
304
+ "double check colori"
305
+ ]
306
+ },
307
+ {
308
+ "cell_type": "code",
309
+ "execution_count": null,
310
+ "id": "5f2b48c0",
311
+ "metadata": {},
312
+ "outputs": [],
313
+ "source": [
314
+ "from sklearn import svm\n",
315
+ "\n",
316
+ "print([int(x*256/12) if x<12 else 255 for x in range(13)])\n",
317
+ "y_h_cat = pd.cut(y_h,bins=[x*256/12 if x<12 else 256 for x in range(13)],labels=colors_list).fillna('Warm Pink Red')\n",
318
+ "\n",
319
+ "print(y_h_cat.value_counts(dropna=False))\n",
320
+ "x_trainhc, x_valhc, y_trainhc, y_valhc = train_test_split(X, y_h_cat, test_size=0.2)"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "markdown",
325
+ "id": "67651454",
326
+ "metadata": {},
327
+ "source": [
328
+ "### SVR and LR"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "code",
333
+ "execution_count": null,
334
+ "id": "7804f593",
335
+ "metadata": {},
336
+ "outputs": [],
337
+ "source": [
338
+ "clf = svm.LinearSVC().fit(x_trainhc, y_trainhc)\n",
339
+ "print('Val performance SVR regression', np.round(clf.score(x_valhc, y_valhc),2))"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": null,
345
+ "id": "e6e31b75",
346
+ "metadata": {},
347
+ "outputs": [],
348
+ "source": [
349
+ "clf_log = LogisticRegression(multi_class='ovr').fit(x_trainhc, y_trainhc)\n",
350
+ "print('Val performance logistic regression', np.round(clf_log.score(x_valhc, y_valhc),2))"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": null,
356
+ "id": "82e30f0c",
357
+ "metadata": {},
358
+ "outputs": [],
359
+ "source": [
360
+ "seed = random.randint(0,100000)\n",
361
+ "original_image_vec = annotations['w_vectors'][seed]\n",
362
+ "img = generate_original_image(original_image_vec, model, latent_space='W')\n",
363
+ "img"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "code",
368
+ "execution_count": null,
369
+ "id": "c8ce6086",
370
+ "metadata": {},
371
+ "outputs": [],
372
+ "source": [
373
+ "from sklearn.metrics import accuracy_score, confusion_matrix \n",
374
+ "\n",
375
+ "y_predhc = clf.predict(x_valhc)\n",
376
+ "print(y_predhc, y_valhc)\n",
377
+ "accuracy_score(y_valhc, y_predhc,)\n",
378
+ "\n",
379
+ "\n",
380
+ "#Get the confusion matrix\n",
381
+ "cm = confusion_matrix(y_valhc, y_predhc)\n",
382
+ "#array([[1, 0, 0],\n",
383
+ "# [1, 0, 0],\n",
384
+ "# [0, 1, 2]])\n",
385
+ "\n",
386
+ "#Now the normalize the diagonal entries\n",
387
+ "cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
388
+ "#array([[1. , 0. , 0. ],\n",
389
+ "# [1. , 0. , 0. ],\n",
390
+ "# [0. , 0.33333333, 0.66666667]])\n",
391
+ "\n",
392
+ "#The diagonal entries are the accuracies of each class\n",
393
+ "cm.diagonal()\n",
394
+ "#array([1. , 0. , 0.66666667])"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "id": "112f4b87",
401
+ "metadata": {},
402
+ "outputs": [],
403
+ "source": [
404
+ "print(clf.coef_, clf.coef_.shape)"
405
+ ]
406
+ },
407
+ {
408
+ "cell_type": "code",
409
+ "execution_count": null,
410
+ "id": "6241bce1",
411
+ "metadata": {},
412
+ "outputs": [],
413
+ "source": [
414
+ "warm_blue = clf.coef_[-3, :] / np.linalg.norm(clf.coef_[-3, :])\n",
415
+ "\n",
416
+ "images, lambdas = regenerate_images(model, original_image_vec, warm_blue, min_epsilon=-(int(5)), max_epsilon=int(5), count=7, latent_space='W')\n",
417
+ "fig, axs = plt.subplots(1, len(images), figsize=(30,200))\n",
418
+ "for i,im in enumerate(images):\n",
419
+ " axs[i].imshow(im)\n",
420
+ " axs[i].set_title(np.round(lambdas[i], 2))"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": null,
426
+ "id": "2fefcf0c",
427
+ "metadata": {},
428
+ "outputs": [],
429
+ "source": [
430
+ "warm_blue = clf.coef_[-4, :] / np.linalg.norm(clf.coef_[-4, :])\n",
431
+ "\n",
432
+ "images, lambdas = regenerate_images(model, original_image_vec, warm_blue, min_epsilon=-(int(50)), max_epsilon=int(50), count=2, latent_space='W')\n",
433
+ "fig, axs = plt.subplots(1, len(images), figsize=(30,200))\n",
434
+ "for i,im in enumerate(images):\n",
435
+ " axs[i].imshow(im)\n",
436
+ " axs[i].set_title(np.round(lambdas[i], 2))"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": null,
442
+ "id": "e0f31e0b",
443
+ "metadata": {},
444
+ "outputs": [],
445
+ "source": [
446
+ "from sklearn import svm\n",
447
+ "\n",
448
+ "y_h_cat = pd.cut(y_h,bins=[x*256/6 if x<6 else 256 for x in range(7)],labels=['Red', 'Yellow', 'Green', 'Blue',\n",
449
+ " 'Purple', 'Pink']).fillna('Red')\n",
450
+ "\n",
451
+ "print(y_h_cat.value_counts(dropna=False))\n",
452
+ "x_trainhc, x_valhc, y_trainhc, y_valhc = train_test_split(X, y_h_cat, test_size=0.2)\n",
453
+ "\n",
454
+ "clf6 = svm.LinearSVC().fit(x_trainhc, y_trainhc)\n",
455
+ "print('Val performance logistic regression', np.round(clf6.score(x_valhc, y_valhc),2))\n"
456
+ ]
457
+ },
458
+ {
459
+ "cell_type": "code",
460
+ "execution_count": null,
461
+ "id": "f5f28b41",
462
+ "metadata": {},
463
+ "outputs": [],
464
+ "source": [
465
+ "warm_blue = clf6.coef_[1, :] / np.linalg.norm(clf6.coef_[1, :])\n",
466
+ "\n",
467
+ "images, lambdas = regenerate_images(model, original_image_vec, warm_blue, min_epsilon=-(int(10)), max_epsilon=int(10), count=7, latent_space='W')\n",
468
+ "fig, axs = plt.subplots(1, len(images), figsize=(30,200))\n",
469
+ "for i,im in enumerate(images):\n",
470
+ " axs[i].imshow(im)\n",
471
+ " axs[i].set_title(np.round(lambdas[i], 2))"
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "markdown",
476
+ "id": "4e0c7808",
477
+ "metadata": {},
478
+ "source": [
479
+ "## dimensionality reduction e vediamo dove finiscono i vari colori"
480
+ ]
481
+ },
482
+ {
483
+ "cell_type": "markdown",
484
+ "id": "833ed31f",
485
+ "metadata": {},
486
+ "source": [
487
+ "## clustering per vedere quali sono i centroid di questo spazio e se ci sono regioni determinate dai colori"
488
+ ]
489
+ },
490
+ {
491
+ "cell_type": "code",
492
+ "execution_count": null,
493
+ "id": "7c19e820",
494
+ "metadata": {},
495
+ "outputs": [],
496
+ "source": []
497
+ }
498
+ ],
499
+ "metadata": {
500
+ "kernelspec": {
501
+ "display_name": "Python 3",
502
+ "language": "python",
503
+ "name": "python3"
504
+ },
505
+ "language_info": {
506
+ "codemirror_mode": {
507
+ "name": "ipython",
508
+ "version": 3
509
+ },
510
+ "file_extension": ".py",
511
+ "mimetype": "text/x-python",
512
+ "name": "python",
513
+ "nbconvert_exporter": "python",
514
+ "pygments_lexer": "ipython3",
515
+ "version": "3.8.16"
516
+ }
517
+ },
518
+ "nbformat": 4,
519
+ "nbformat_minor": 5
520
+ }
structure_annotations.ipynb ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os \n",
10
+ "from glob import glob \n",
11
+ "import pandas as pd\n",
12
+ "import numpy as np\n",
13
+ "\n",
14
+ "from PIL import Image, ImageColor\n",
15
+ "import matplotlib.pyplot as plt\n",
16
+ "\n",
17
+ "import torch\n",
18
+ "\n",
19
+ "from backend.disentangle_concepts import *\n",
20
+ "import dnnlib \n",
21
+ "import legacy\n",
22
+ "from backend.color_annotations import *\n",
23
+ "\n",
24
+ "\n",
25
+ "%load_ext autoreload\n",
26
+ "%autoreload 2"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "images_textiles = glob('/Users/ludovicaschaerf/Desktop/TextAIles/TextileGAN/Original Textiles/*')"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "images = []\n",
45
+ "\n",
46
+ "for step in [10, 20]:\n",
47
+ " imh = np.zeros((200, 200))\n",
48
+ " imv = np.zeros((200, 200))\n",
49
+ " imb = np.ones((200, 200))\n",
50
+ " imb2 = np.ones((200, 200))\n",
51
+ " \n",
52
+ " for x,y in zip(range(0,200, step*2),range(step, 200, step*2)):\n",
53
+ " imh[x:y, :] = 255\n",
54
+ " imv[:, x:y] = 255\n",
55
+ " imb[x:y, :] = 0\n",
56
+ " imb[:, x:y] = 0\n",
57
+ " imb2[x:y, :] = 0\n",
58
+ " imb2[:, x*2:y*2] = 0\n",
59
+ " \n",
60
+ " images.append(imh) \n",
61
+ " images.append(imb) \n",
62
+ " images.append(imb2) \n",
63
+ " images.append(imv) \n",
64
+ "\n",
65
+ "for im in images:\n",
66
+ " plt.imshow(im, cmap='gray')\n",
67
+ " plt.title('Original Image')\n",
68
+ " plt.show()"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "metadata": {},
75
+ "outputs": [],
76
+ "source": [
77
+ "def get_freqs(image):\n",
78
+ " \n",
79
+ " # Load the image\n",
80
+ " # image = cv2.imread(image, cv2.IMREAD_GRAYSCALE)\n",
81
+ " # Center crop the image to remove black borders\n",
82
+ " height, width = image.shape\n",
83
+ " short_side = min(height, width)\n",
84
+ " crop_size = min(200, short_side)\n",
85
+ " center_x = width // 2\n",
86
+ " center_y = height // 2\n",
87
+ " crop_half_size = crop_size // 2\n",
88
+ " image = image[center_y - crop_half_size:center_y + crop_half_size,\n",
89
+ " center_x - crop_half_size:center_x + crop_half_size]\n",
90
+ "\n",
91
+ " # kernel = np.ones((5,5), np.float32) / 25\n",
92
+ " # image = cv2.filter2D( image, -1, kernel)\n",
93
+ " # image = cv2.GaussianBlur(image,(5,5),0)\n",
94
+ " # image = image - cv2.GaussianBlur(image, (21, 21), 1) + 127\n",
95
+ "\n",
96
+ " # print(np.unique(image))\n",
97
+ " # Perform 1D Fourier Transforms\n",
98
+ " horizontal_freq = np.fft.fftshift(np.fft.fft(image, axis=1))\n",
99
+ " vertical_freq = np.fft.fftshift(np.fft.fft(image, axis=0))\n",
100
+ " twod = np.fft.fftshift(np.fft.fft2(image))\n",
101
+ " \n",
102
+ " # Calculate corresponding frequencies\n",
103
+ " num_cols = image.shape[1]\n",
104
+ " num_rows = image.shape[0]\n",
105
+ " # horizontal_freqs = np.fft.fftshift(np.fft.fftfreq(num_cols))\n",
106
+ " # vertical_freqs = np.fft.fftshift(np.fft.fftfreq(num_rows))\n",
107
+ " \n",
108
+ " horizontal_freqs = np.fft.fftfreq(num_cols)\n",
109
+ " vertical_freqs = np.fft.fftfreq(num_rows)\n",
110
+ " \n",
111
+ " # # Sum power along the second axis\n",
112
+ " # twod = twod.real*twod.real + twod.imag*twod.imag\n",
113
+ " # twod = twod.sum(axis=1)/twod.shape[1]\n",
114
+ "\n",
115
+ " # # Round up the size along this axis to an even number\n",
116
+ " # n = int(math.ceil(image.shape[0] / 2.) * 2 )\n",
117
+ "\n",
118
+ " # # Generate a list of frequencies\n",
119
+ " # f = np.fft.fftfreq(n)\n",
120
+ "\n",
121
+ " # # Graph it\n",
122
+ " # plt.plot(f[1:],a[1:], label = 'sum of amplitudes over y vs f_x')\n",
123
+ "\n",
124
+ " \n",
125
+ "\n",
126
+ " # Calculate magnitude spectra\n",
127
+ " horizontal_magnitudes = np.abs(horizontal_freq)[0]\n",
128
+ " vertical_magnitudes = np.abs(vertical_freq)[0]\n",
129
+ " \n",
130
+ " # # Find peaks in the magnitudes\n",
131
+ " # horizontal_peaks, _ = find_peaks(horizontal_magnitudes, height=100) # Adjust height threshold as needed\n",
132
+ " # vertical_peaks, _ = find_peaks(vertical_magnitudes, height=10) # Adjust height threshold as needed\n",
133
+ "\n",
134
+ " print('Median horizontal frequency', np.median(horizontal_magnitudes), ', median vertical frequency', np.median(vertical_magnitudes))\n",
135
+ " # Plot frequency analysis with peaks\n",
136
+ " plt.figure(figsize=(25, 5))\n",
137
+ "\n",
138
+ " # Plot frequency analysis\n",
139
+ " plt.subplot(1, 6, 1)\n",
140
+ " plt.imshow(image, cmap='gray')\n",
141
+ " plt.title('Original Image')\n",
142
+ "\n",
143
+ " plt.subplot(1, 6, 2)\n",
144
+ " plt.imshow(np.log(1 + np.abs(horizontal_freq)), cmap='gray')\n",
145
+ " plt.title('Horizontal Frequency Analysis')\n",
146
+ "\n",
147
+ " plt.subplot(1, 6, 3)\n",
148
+ " plt.imshow(np.log(1 + np.abs(vertical_freq)), cmap='gray')\n",
149
+ " plt.title('Vertical Frequency Analysis')\n",
150
+ "\n",
151
+ " plt.subplot(1, 6, 4)\n",
152
+ " plt.imshow(np.log(1 + np.abs(twod)), cmap='gray')\n",
153
+ " plt.title('2D Frequency Analysis')\n",
154
+ "\n",
155
+ " plt.subplot(1, 6, 5)\n",
156
+ " plt.scatter(horizontal_freqs[1:], horizontal_magnitudes[1:], s=5)\n",
157
+ " # plt.plot(horizontal_freqs[horizontal_peaks], horizontal_magnitudes[horizontal_peaks], 'ro', markersize=5)\n",
158
+ " plt.xlabel('Horizontal Frequency')\n",
159
+ " plt.ylabel('Magnitude')\n",
160
+ " plt.title('Horizontal Frequency Analysis with Peaks')\n",
161
+ "\n",
162
+ " plt.subplot(1, 6, 6)\n",
163
+ " plt.scatter(vertical_freqs[1:], vertical_magnitudes[1:], s=5)\n",
164
+ " # plt.plot(vertical_freqs[vertical_peaks], vertical_magnitudes[vertical_peaks], 'ro', markersize=5)\n",
165
+ " plt.xlabel('Vertical Frequency')\n",
166
+ " plt.ylabel('Magnitude')\n",
167
+ " plt.title('Vertical Frequency Analysis with Peaks')\n",
168
+ "\n",
169
+ "\n",
170
+ " plt.show()\n"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "markdown",
175
+ "metadata": {},
176
+ "source": [
177
+ "When the absolute value of the frequency is high in the context of Fourier analysis, it indicates the presence of a rapidly changing or repeating pattern in the image. The frequency component's absolute value reflects how many cycles of the pattern occur within a fixed interval.\n",
178
+ "\n",
179
+ "In Fourier analysis:\n",
180
+ "\n",
181
+ "Higher Frequency Components: Higher absolute frequency values correspond to faster changes or repetitions in the image data. This means that the pattern or feature represented by that frequency component oscillates more rapidly across the image.\n",
182
+ "\n",
183
+ "Spatial Frequency: In image analysis, frequency is often associated with how quickly the intensity or color changes as you move across the image. High absolute frequency values indicate rapid changes or transitions in the image content.\n",
184
+ "\n",
185
+ "Fine Details and Textures: Rapidly oscillating frequency components are associated with fine details and textures in the image. For example, in textiles, high-frequency components might capture the intricate weave patterns or small-scale textures.\n",
186
+ "\n",
187
+ "Small-Scale Features: Patterns that repeat at small scales, such as fine lines or tiny structures, tend to result in high-frequency components with high absolute values.\n",
188
+ "\n",
189
+ "Edges and Transitions: Edges and sharp transitions between different colors or intensities are also associated with high-frequency components. These transitions involve rapid changes in intensity, leading to higher absolute frequency values."
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "import cv2\n",
199
+ "from scipy.signal import find_peaks\n",
200
+ "\n",
201
+ "import numpy as np\n",
202
+ "import matplotlib.pyplot as plt\n",
203
+ "\n",
204
+ "\n",
205
+ "for image in images:\n",
206
+ " get_freqs(image)"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": null,
212
+ "metadata": {},
213
+ "outputs": [],
214
+ "source": [
215
+ "import cv2\n",
216
+ "import numpy as np\n",
217
+ "import matplotlib.pyplot as plt\n",
218
+ "from scipy.signal import find_peaks\n",
219
+ "from collections import Counter\n",
220
+ "import seaborn as sns\n",
221
+ "# Load the image\n",
222
+ "image = cv2.imread(images_textiles[0], cv2.IMREAD_GRAYSCALE)\n",
223
+ "\n",
224
+ "# Perform 1D Fourier Transforms\n",
225
+ "horizontal_freq = np.fft.fftshift(np.fft.fft(image, axis=1))\n",
226
+ "vertical_freq = np.fft.fftshift(np.fft.fft(image, axis=0))\n",
227
+ "\n",
228
+ "# Calculate magnitude spectra\n",
229
+ "horizontal_magnitudes = np.abs(horizontal_freq)[0]\n",
230
+ "vertical_magnitudes = np.abs(vertical_freq)[0]\n",
231
+ "\n",
232
+ "# Create histograms of magnitude recurrence\n",
233
+ "horizontal_magnitude_counter = Counter(np.round(horizontal_magnitudes).astype(int))\n",
234
+ "vertical_magnitude_counter = Counter(np.round(vertical_magnitudes).astype(int))\n",
235
+ "\n",
236
+ "# Plot magnitude recurrence histograms\n",
237
+ "plt.figure(figsize=(12, 5))\n",
238
+ "\n",
239
+ "plt.subplot(1, 2, 1)\n",
240
+ "sns.histplot(list(horizontal_magnitude_counter.elements()), kde=True, color='blue')\n",
241
+ "plt.xlim(0, 1000) # Limit x-axis range\n",
242
+ "plt.xlabel('Magnitude')\n",
243
+ "plt.ylabel('Recurrence')\n",
244
+ "plt.title('Horizontal Magnitude Recurrence')\n",
245
+ "\n",
246
+ "plt.subplot(1, 2, 2)\n",
247
+ "sns.histplot(list(vertical_magnitude_counter.elements()), kde=True, color='green')\n",
248
+ "plt.xlim(0, 1000) # Limit x-axis range\n",
249
+ "plt.xlabel('Magnitude')\n",
250
+ "plt.ylabel('Recurrence')\n",
251
+ "plt.title('Vertical Magnitude Recurrence')\n",
252
+ "\n",
253
+ "plt.tight_layout()\n",
254
+ "plt.show()\n"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": [
263
+ "import cv2\n",
264
+ "import numpy as np\n",
265
+ "import matplotlib.pyplot as plt\n",
266
+ "from skimage.feature import hog\n",
267
+ "from skimage import exposure\n",
268
+ "\n",
269
+ "# Load the image\n",
270
+ "image = cv2.imread(images_textiles[4], cv2.IMREAD_GRAYSCALE)\n",
271
+ "\n",
272
+ "# Center crop the image to remove black borders\n",
273
+ "height, width = image.shape\n",
274
+ "crop_size = 200\n",
275
+ "center_x = width // 2\n",
276
+ "center_y = height // 2\n",
277
+ "crop_half_size = crop_size // 2\n",
278
+ "cropped_image = image[center_y - crop_half_size:center_y + crop_half_size,\n",
279
+ " center_x - crop_half_size:center_x + crop_half_size]\n",
280
+ "\n",
281
+ "# Compute HOG features\n",
282
+ "orientations = 9\n",
283
+ "pixels_per_cell = (5, 5)\n",
284
+ "cells_per_block = (1,1)\n",
285
+ "hog_features, hog_image = hog(cropped_image, orientations=orientations,\n",
286
+ " pixels_per_cell=pixels_per_cell,\n",
287
+ " cells_per_block=cells_per_block,\n",
288
+ " block_norm='L2-Hys',\n",
289
+ " visualize=True)\n",
290
+ "\n",
291
+ "# Plot the HOG image\n",
292
+ "plt.figure(figsize=(8, 4))\n",
293
+ "plt.subplot(1, 2, 1)\n",
294
+ "plt.imshow(cropped_image, cmap='gray')\n",
295
+ "plt.title('Original Image')\n",
296
+ "\n",
297
+ "plt.subplot(1, 2, 2)\n",
298
+ "plt.imshow(hog_image, cmap='gray')\n",
299
+ "plt.title('HOG Image')\n",
300
+ "plt.axis('off')\n",
301
+ "\n",
302
+ "plt.tight_layout()\n",
303
+ "plt.show()\n",
304
+ "\n",
305
+ "# Plot the orientation distribution\n",
306
+ "hog_histogram, _ = np.histogram(hog_features, bins=orientations)\n",
307
+ "angles_deg = np.arange(0, 180, 180/orientations)\n",
308
+ "plt.figure(figsize=(6, 4))\n",
309
+ "plt.bar(angles_deg, hog_histogram, width=180/orientations)\n",
310
+ "plt.xlabel('Orientation (degrees)')\n",
311
+ "plt.ylabel('Frequency')\n",
312
+ "plt.title('Orientation Distribution')\n",
313
+ "plt.xticks(angles_deg)\n",
314
+ "plt.show()"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": null,
320
+ "metadata": {},
321
+ "outputs": [],
322
+ "source": [
323
+ "import cv2\n",
324
+ "import numpy as np\n",
325
+ "import matplotlib.pyplot as plt\n",
326
+ "import pywt\n",
327
+ "\n",
328
+ "# Load the image\n",
329
+ "image = cv2.imread(images_textiles[0], cv2.IMREAD_GRAYSCALE)\n",
330
+ "\n",
331
+ "# Perform Continuous Wavelet Transform (CWT) on each row of the image\n",
332
+ "wavelet = 'morl' # You can choose a different wavelet basis\n",
333
+ "scales = np.arange(1, 20) # Scales to analyze\n",
334
+ "\n",
335
+ "cwt_coeffs = []\n",
336
+ "for row in image:\n",
337
+ " coeffs, _ = pywt.cwt(row, scales, wavelet)\n",
338
+ " cwt_coeffs.append(coeffs)\n",
339
+ "\n",
340
+ "cwt_coeffs = np.array(cwt_coeffs)\n",
341
+ "\n",
342
+ "# Plot scaleograms\n",
343
+ "plt.figure(figsize=(10, 6))\n",
344
+ "plt.imshow(np.abs(cwt_coeffs[:, 0, :]), extent=[0, image.shape[1], scales[-1], scales[0]], cmap='viridis', aspect='auto')\n",
345
+ "plt.colorbar(label='Magnitude')\n",
346
+ "plt.title('Wavelet Scaleogram')\n",
347
+ "plt.xlabel('Pixel')\n",
348
+ "plt.ylabel('Scale')\n",
349
+ "plt.show()\n"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": null,
355
+ "metadata": {},
356
+ "outputs": [],
357
+ "source": []
358
+ }
359
+ ],
360
+ "metadata": {
361
+ "kernelspec": {
362
+ "display_name": "art-reco_x86",
363
+ "language": "python",
364
+ "name": "python3"
365
+ },
366
+ "language_info": {
367
+ "codemirror_mode": {
368
+ "name": "ipython",
369
+ "version": 3
370
+ },
371
+ "file_extension": ".py",
372
+ "mimetype": "text/x-python",
373
+ "name": "python",
374
+ "nbconvert_exporter": "python",
375
+ "pygments_lexer": "ipython3",
376
+ "version": "3.8.16"
377
+ },
378
+ "orig_nbformat": 4
379
+ },
380
+ "nbformat": 4,
381
+ "nbformat_minor": 2
382
+ }
stylespace_colour_disentanglement.ipynb ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "fe83bcc2",
6
+ "metadata": {},
7
+ "source": [
8
+ "![image](/Users/ludovicaschaerf/Desktop/latent-space-theories/data/stylegan3.webp)"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "3722712c",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "%matplotlib inline \n",
19
+ "\n",
20
+ "import pandas as pd\n",
21
+ "import pickle\n",
22
+ "import random\n",
23
+ "\n",
24
+ "from PIL import Image, ImageColor\n",
25
+ "import matplotlib.pyplot as plt\n",
26
+ "\n",
27
+ "import numpy as np\n",
28
+ "import torch\n",
29
+ "\n",
30
+ "from backend.disentangle_concepts import *\n",
31
+ "from backend.color_annotations import *\n",
32
+ "from backend.networks_stylegan3 import *\n",
33
+ "import dnnlib \n",
34
+ "import legacy\n",
35
+ "\n",
36
+ "import random\n",
37
+ "\n",
38
+ "from sklearn.linear_model import LinearRegression, LogisticRegression\n",
39
+ "\n",
40
+ "\n",
41
+ "%load_ext autoreload\n",
42
+ "%autoreload 2"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "id": "fe7acfaf-dc61-4211-9c78-8e4433bc9deb",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "annotations_file = './data/textile_annotated_files/seeds0000-100000.pkl'\n",
53
+ "with open(annotations_file, 'rb') as f:\n",
54
+ " annotations = pickle.load(f)\n",
55
+ "\n",
56
+ "ann_df = pd.read_csv('./data/textile_annotated_files/top_three_colours.csv').fillna('#000000')\n",
57
+ "\n",
58
+ "with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:\n",
59
+ " model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore\n"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "id": "0e4b656e",
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "z = torch.from_numpy(annotations['w_vectors'][0].copy()).to('cpu')\n",
70
+ "W = z.expand((16, -1)).unsqueeze(0)\n",
71
+ "img = model.synthesis(W, noise_mode='const')\n",
72
+ "img.shape"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "id": "1259f950",
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "in_ = model.synthesis.input(W[0, 0].unsqueeze(0))\n",
83
+ "l1 = model.synthesis.L0_36_512(in_, W[0, 1].unsqueeze(0))\n",
84
+ "l1.shape"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "id": "918feb0e",
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "a = 'L0_36_512'\n",
95
+ "getattr(model.synthesis, a)"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "id": "1bf7bfa4",
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "def rest_from_style(x, styles, layer):\n",
106
+ " dtype = torch.float16 if (getattr(model.synthesis, layer).use_fp16 and device=='cuda') else torch.float32\n",
107
+ " if getattr(model.synthesis, layer).is_torgb:\n",
108
+ " print(layer, getattr(model.synthesis, layer).is_torgb)\n",
109
+ " weight_gain = 1 / np.sqrt(getattr(model.synthesis, layer).in_channels * (getattr(model.synthesis, layer).conv_kernel ** 2))\n",
110
+ " styles = styles * weight_gain\n",
111
+ " input_gain = getattr(model.synthesis, layer).magnitude_ema.rsqrt().to(dtype)\n",
112
+ " # Execute modulated conv2d.\n",
113
+ " x = modulated_conv2d(x=x.to(dtype), w=getattr(model.synthesis, layer).weight.to(dtype), s=styles.to(dtype),\n",
114
+ " padding=getattr(model.synthesis, layer).conv_kernel-1, demodulate=(not getattr(model.synthesis, layer).is_torgb), input_gain=input_gain.to(dtype))\n",
115
+ " # Execute bias, filtered leaky ReLU, and clamping.\n",
116
+ " gain = 1 if getattr(model.synthesis, layer).is_torgb else np.sqrt(2)\n",
117
+ " slope = 1 if getattr(model.synthesis, layer).is_torgb else 0.2\n",
118
+ " x = filtered_lrelu.filtered_lrelu(x=x, fu=getattr(model.synthesis, layer).up_filter, fd=getattr(model.synthesis, layer).down_filter, \n",
119
+ " b=getattr(model.synthesis, layer).bias.to(x.dtype),\n",
120
+ " up=getattr(model.synthesis, layer).up_factor, down=getattr(model.synthesis, layer).down_factor, \n",
121
+ " padding=getattr(model.synthesis, layer).padding,\n",
122
+ " gain=gain, slope=slope, clamp=getattr(model.synthesis, layer).conv_clamp)\n",
123
+ " return x"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "id": "c674780d",
130
+ "metadata": {},
131
+ "outputs": [],
132
+ "source": [
133
+ "x1 = rest_from_style(in_, model.synthesis.L0_36_512.affine(W[0, 1].unsqueeze(0)), 'L0_36_512')\n",
134
+ "x1.shape"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "id": "0305ce16",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "def getS(w):\n",
145
+ " w_torch = torch.from_numpy(w).to('cpu')\n",
146
+ " W = w_torch.expand((16, -1)).unsqueeze(0)\n",
147
+ " s = []\n",
148
+ " s.append(model.synthesis.input.affine(W[0, 0].unsqueeze(0)).numpy())\n",
149
+ " s.append(model.synthesis.L0_36_512.affine(W[0, 1].unsqueeze(0)).numpy())\n",
150
+ " s.append(model.synthesis.L1_36_512.affine(W[0, 2].unsqueeze(0)).numpy())\n",
151
+ " s.append(model.synthesis.L2_36_512.affine(W[0, 3].unsqueeze(0)).numpy())\n",
152
+ " s.append(model.synthesis.L3_52_512.affine(W[0, 4].unsqueeze(0)).numpy())\n",
153
+ " s.append(model.synthesis.L4_52_512.affine(W[0, 5].unsqueeze(0)).numpy())\n",
154
+ " s.append(model.synthesis.L5_84_512.affine(W[0, 6].unsqueeze(0)).numpy())\n",
155
+ " s.append(model.synthesis.L6_84_512.affine(W[0, 7].unsqueeze(0)).numpy())\n",
156
+ " s.append(model.synthesis.L7_148_512.affine(W[0, 8].unsqueeze(0)).numpy())\n",
157
+ " s.append(model.synthesis.L8_148_512.affine(W[0, 9].unsqueeze(0)).numpy())\n",
158
+ " s.append(model.synthesis.L9_148_362.affine(W[0, 10].unsqueeze(0)).numpy())\n",
159
+ " s.append(model.synthesis.L10_276_256.affine(W[0, 11].unsqueeze(0)).numpy())\n",
160
+ " s.append(model.synthesis.L11_276_181.affine(W[0, 12].unsqueeze(0)).numpy())\n",
161
+ " s.append(model.synthesis.L12_276_128.affine(W[0, 13].unsqueeze(0)).numpy())\n",
162
+ " s.append(model.synthesis.L13_256_128.affine(W[0, 14].unsqueeze(0)).numpy())\n",
163
+ " s.append(model.synthesis.L14_256_3.affine(W[0, 15].unsqueeze(0)).numpy())\n",
164
+ " return s"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "id": "981f5215",
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "s = getS(annotations['w_vectors'][0])"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "id": "389ad35a",
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "shapes = [512] + [x.shape[1] for x in s]\n",
185
+ "layers = ['w', 'input', 'L0_36_512', 'L1_36_512', 'L2_36_512', 'L3_52_512', 'L4_52_512', 'L5_84_512', 'L6_84_512',\n",
186
+ " 'L7_148_512', 'L8_148_512', 'L9_148_362', 'L10_276_256', 'L11_276_181', 'L12_276_128', 'L13_256_128',\n",
187
+ " 'L14_256_3']\n",
188
+ "sum(shapes), shapes"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": null,
194
+ "id": "3c143e86",
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "def generate_flexible_images(w, change_vectors, lambdas=1, device='cpu'):\n",
199
+ " w_torch = torch.from_numpy(w).to('cpu')\n",
200
+ " # w_torch = w_torch + lambdas * change_vectors[0]\n",
201
+ " W = w_torch.expand((16, -1)).unsqueeze(0)\n",
202
+ " \n",
203
+ " x = model.synthesis.input(W[0,0].unsqueeze(0))\n",
204
+ " for i, layer in enumerate(layers):\n",
205
+ " if i < 2:\n",
206
+ " continue\n",
207
+ " style = getattr(model.synthesis, layer).affine(W[0, i-1].unsqueeze(0))\n",
208
+ " change = torch.from_numpy(change_vectors[i].copy()).unsqueeze(0).to(device)\n",
209
+ " style = torch.add(style, change, alpha=lambdas)\n",
210
+ " x = rest_from_style(x, style, layer)\n",
211
+ " \n",
212
+ " if model.synthesis.output_scale != 1:\n",
213
+ " x = x * model.synthesis.output_scale\n",
214
+ "\n",
215
+ " img = (x.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)\n",
216
+ " img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')\n",
217
+ " \n",
218
+ " return img"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "id": "f03915ff",
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "def get_original_pos(top_positions):\n",
229
+ " current_idx = 0\n",
230
+ " vectors = []\n",
231
+ " for i, (leng, layer) in enumerate(zip(shapes, layers)):\n",
232
+ " arr = np.zeros(leng)\n",
233
+ " for top_position in top_positions:\n",
234
+ " if top_position >= current_idx and top_position < current_idx + leng:\n",
235
+ " arr[top_position - current_idx] = 1\n",
236
+ " arr = arr / (np.linalg.norm(arr) + 0.000001)\n",
237
+ " vectors.append(arr)\n",
238
+ " current_idx += leng\n",
239
+ " return vectors \n"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": null,
245
+ "id": "e76d836d",
246
+ "metadata": {},
247
+ "outputs": [],
248
+ "source": [
249
+ "ss = []\n",
250
+ "for i in tqdm(range(len(annotations['w_vectors']))):\n",
251
+ " ss.append(getS(annotations['w_vectors'][i]))\n",
252
+ " \n",
253
+ "annotations['s_vectors'] = ss"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": null,
259
+ "id": "6ea1ca59",
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": [
263
+ "annotations_file = './data/textile_annotated_files/seeds0000-100000_S.pkl'\n",
264
+ "with open(annotations_file, 'wb') as f:\n",
265
+ " pickle.dump(annotations, f)\n"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": null,
271
+ "id": "12f78bdb",
272
+ "metadata": {},
273
+ "outputs": [],
274
+ "source": [
275
+ "len(ss)"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": null,
281
+ "id": "cd114cb1",
282
+ "metadata": {},
283
+ "outputs": [],
284
+ "source": [
285
+ "ann_df = tohsv(ann_df)\n",
286
+ "ann_df.head()"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "code",
291
+ "execution_count": null,
292
+ "id": "0d470f83",
293
+ "metadata": {},
294
+ "outputs": [],
295
+ "source": [
296
+ "def getX(annotations, space='s'):\n",
297
+ " if space == 'x':\n",
298
+ " X = np.array(annotations['w_vectors']).reshape((len(annotations['w_vectors']), 512))\n",
299
+ " elif space == 's':\n",
300
+ " concat_v = []\n",
301
+ " for i in range(len(annotations['w_vectors'])):\n",
302
+ " concat_v.append(np.concatenate([annotations['w_vectors'][i]] + annotations['s_vectors'][i], axis=1))\n",
303
+ " \n",
304
+ " X = np.array(concat_v)\n",
305
+ " X = X[:, 0, :]\n",
306
+ " print(X.shape)\n",
307
+ " \n",
308
+ " return X\n",
309
+ " \n",
310
+ " "
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": null,
316
+ "id": "feb64168",
317
+ "metadata": {},
318
+ "outputs": [],
319
+ "source": [
320
+ "X = getX(annotations)\n",
321
+ "print(X.shape)\n",
322
+ "y_h = np.array(ann_df['H1'].values)\n",
323
+ "y_s = np.array(ann_df['S1'].values)\n",
324
+ "y_v = np.array(ann_df['S1'].values)"
325
+ ]
326
+ },
327
+ {
328
+ "cell_type": "code",
329
+ "execution_count": null,
330
+ "id": "afa0c100",
331
+ "metadata": {},
332
+ "outputs": [],
333
+ "source": [
334
+ "colors_list = ['Warm Pink Red', 'Red Orange', 'Orange Yellow', 'Gold Yellow', 'Chartreuse Green',\n",
335
+ " 'Kelly Green', 'Green Blue Seafoam', 'Blue Green Cyan',\n",
336
+ " 'Warm Blue', 'Indigo Blue Purple', 'Purple Magenta', 'Magenta Pink']"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "markdown",
341
+ "id": "39a5668a",
342
+ "metadata": {},
343
+ "source": [
344
+ "double check colori"
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "execution_count": null,
350
+ "id": "5f2b48c0",
351
+ "metadata": {},
352
+ "outputs": [],
353
+ "source": [
354
+ "print([int(x*256/12) if x<12 else 255 for x in range(13)])\n",
355
+ "y_h_cat = pd.cut(y_h,bins=[x*256/12 if x<12 else 256 for x in range(13)],labels=colors_list).fillna('Warm Pink Red')\n",
356
+ "\n",
357
+ "print(y_h_cat.value_counts(dropna=False))\n",
358
+ "x_trainhc, x_valhc, y_trainhc, y_valhc = train_test_split(X, y_h_cat, test_size=0.2)"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "markdown",
363
+ "id": "6c2a1765",
364
+ "metadata": {},
365
+ "source": [
366
+ "### Variance based"
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "code",
371
+ "execution_count": null,
372
+ "id": "2be4202e",
373
+ "metadata": {},
374
+ "outputs": [],
375
+ "source": [
376
+ "positives = x_trainhc[np.where(y_trainhc == 'Warm Blue')]\n",
377
+ "print(positives.shape, x_trainhc.shape)\n",
378
+ "variations = detect_attribute_specific_channels(positives, x_trainhc, sign=True)\n",
379
+ "print(variations.shape, np.argmax(variations))"
380
+ ]
381
+ },
382
+ {
383
+ "cell_type": "code",
384
+ "execution_count": null,
385
+ "id": "7d0c129d",
386
+ "metadata": {},
387
+ "outputs": [],
388
+ "source": [
389
+ "argsorted_vars = np.argsort(variations)[-5:]\n",
390
+ "sorted_vars = np.sort(variations)[-5:]\n",
391
+ "argsorted_vars, sorted_vars"
392
+ ]
393
+ },
394
+ {
395
+ "cell_type": "code",
396
+ "execution_count": null,
397
+ "id": "e2c2ed49",
398
+ "metadata": {},
399
+ "outputs": [],
400
+ "source": [
401
+ "original_pos = get_original_pos(argsorted_vars)"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "code",
406
+ "execution_count": null,
407
+ "id": "82e30f0c",
408
+ "metadata": {},
409
+ "outputs": [],
410
+ "source": [
411
+ "seed = random.randint(0,100000)\n",
412
+ "seed = 52722\n",
413
+ "original_image_vec = annotations['w_vectors'][seed]\n",
414
+ "img = generate_original_image(original_image_vec, model, latent_space='W')\n",
415
+ "img"
416
+ ]
417
+ },
418
+ {
419
+ "cell_type": "code",
420
+ "execution_count": null,
421
+ "id": "cd71f2c8",
422
+ "metadata": {},
423
+ "outputs": [],
424
+ "source": [
425
+ "device = 'cpu'\n",
426
+ "img1 = generate_flexible_images(original_image_vec, original_pos, lambdas=-1)\n",
427
+ "img1"
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "code",
432
+ "execution_count": null,
433
+ "id": "abc5ac3f",
434
+ "metadata": {},
435
+ "outputs": [],
436
+ "source": [
437
+ "img1 = generate_flexible_images(original_image_vec, original_pos, lambdas=-2)\n",
438
+ "img1"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "execution_count": null,
444
+ "id": "0602dcab",
445
+ "metadata": {},
446
+ "outputs": [],
447
+ "source": [
448
+ "len(original_pos)"
449
+ ]
450
+ },
451
+ {
452
+ "cell_type": "code",
453
+ "execution_count": null,
454
+ "id": "d7eb412f",
455
+ "metadata": {},
456
+ "outputs": [],
457
+ "source": [
458
+ "img1 = generate_flexible_images(original_image_vec, original_pos, lambdas=1)\n",
459
+ "img1"
460
+ ]
461
+ },
462
+ {
463
+ "cell_type": "code",
464
+ "execution_count": null,
465
+ "id": "03161270",
466
+ "metadata": {},
467
+ "outputs": [],
468
+ "source": [
469
+ "seps, vals = all_variance_based_disentanglements(colors_list, x_trainhc, y_trainhc, k=10, sign=True, space='s')\n",
470
+ "vals[2].shape"
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": null,
476
+ "id": "ae1016d6",
477
+ "metadata": {},
478
+ "outputs": [],
479
+ "source": [
480
+ "warm_pink_val = get_verification_score(0, seps[0], model, annotations, samples=10, latent_space='W')\n",
481
+ "warm_pink_val"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "code",
486
+ "execution_count": null,
487
+ "id": "b412cb25",
488
+ "metadata": {},
489
+ "outputs": [],
490
+ "source": [
491
+ "warm_blue_val = get_verification_score(8, seps[8], model, annotations, samples=10, latent_space='W')\n",
492
+ "warm_blue_val"
493
+ ]
494
+ },
495
+ {
496
+ "cell_type": "code",
497
+ "execution_count": null,
498
+ "id": "6812cb6b",
499
+ "metadata": {},
500
+ "outputs": [],
501
+ "source": [
502
+ "seps, _ = all_variance_based_disentanglements(colors_list, x_trainhc, y_trainhc, k=10, sign=True)\n",
503
+ "\n",
504
+ "for sep, color in zip(seps, colors_list):\n",
505
+ " images, lambdas = regenerate_images(model, original_image_vec, sep, min_epsilon=-(int(4)), max_epsilon=int(4), count=5, latent_space='W')\n",
506
+ " fig, axs = plt.subplots(1, len(images), figsize=(50,10))\n",
507
+ " fig.suptitle(color, fontsize=20)\n",
508
+ " for i,im in enumerate(images):\n",
509
+ " axs[i].imshow(im)\n",
510
+ " axs[i].set_title(np.round(lambdas[i], 2))\n",
511
+ " plt.show()"
512
+ ]
513
+ },
514
+ {
515
+ "cell_type": "code",
516
+ "execution_count": null,
517
+ "id": "24b8f275",
518
+ "metadata": {},
519
+ "outputs": [],
520
+ "source": [
521
+ "seps = all_variance_based_disentanglements(colors_list, x_trainhc, y_trainhc, k=10, sign=True)\n",
522
+ "\n",
523
+ "for sep, color in zip(seps, colors_list):\n",
524
+ " images, lambdas = regenerate_images(model, original_image_vec, sep, min_epsilon=-(int(4)), max_epsilon=int(4), count=5, latent_space='W')\n",
525
+ " fig, axs = plt.subplots(1, len(images), figsize=(50,10))\n",
526
+ " fig.suptitle(color, fontsize=20)\n",
527
+ " for i,im in enumerate(images):\n",
528
+ " axs[i].imshow(im)\n",
529
+ " axs[i].set_title(np.round(lambdas[i], 2))\n",
530
+ " plt.show()"
531
+ ]
532
+ },
533
+ {
534
+ "cell_type": "code",
535
+ "execution_count": null,
536
+ "id": "b6c61fbb",
537
+ "metadata": {},
538
+ "outputs": [],
539
+ "source": [
540
+ "separation_vector_onehot = np.zeros(512)\n",
541
+ "separation_vector_onehot[argsorted_vars] = 1\n",
542
+ "\n",
543
+ "images, lambdas = regenerate_images(model, original_image_vec, separation_vector_onehot, min_epsilon=-(int(10)), max_epsilon=int(10), count=7, latent_space='W')\n",
544
+ "fig, axs = plt.subplots(1, len(images), figsize=(30,200))\n",
545
+ "for i,im in enumerate(images):\n",
546
+ " axs[i].imshow(im)\n",
547
+ " axs[i].set_title(np.round(lambdas[i], 2))"
548
+ ]
549
+ },
550
+ {
551
+ "cell_type": "code",
552
+ "execution_count": null,
553
+ "id": "7c19e820",
554
+ "metadata": {},
555
+ "outputs": [],
556
+ "source": []
557
+ }
558
+ ],
559
+ "metadata": {
560
+ "kernelspec": {
561
+ "display_name": "Python 3",
562
+ "language": "python",
563
+ "name": "python3"
564
+ },
565
+ "language_info": {
566
+ "codemirror_mode": {
567
+ "name": "ipython",
568
+ "version": 3
569
+ },
570
+ "file_extension": ".py",
571
+ "mimetype": "text/x-python",
572
+ "name": "python",
573
+ "nbconvert_exporter": "python",
574
+ "pygments_lexer": "ipython3",
575
+ "version": "3.8.16"
576
+ }
577
+ },
578
+ "nbformat": 4,
579
+ "nbformat_minor": 5
580
+ }
view_predictions.ipynb CHANGED
@@ -2,31 +2,10 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
  "id": "fe7acfaf-dc61-4211-9c78-8e4433bc9deb",
7
  "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stdout",
11
- "output_type": "stream",
12
- "text": [
13
- "Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.\n",
14
- "Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.\n"
15
- ]
16
- },
17
- {
18
- "ename": "ModuleNotFoundError",
19
- "evalue": "No module named 'sklearn'",
20
- "output_type": "error",
21
- "traceback": [
22
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
23
- "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
24
- "Cell \u001b[0;32mIn[1], line 8\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mtorch\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[39m#import ninja\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mbackend\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mdisentangle_concepts\u001b[39;00m \u001b[39mimport\u001b[39;00m \u001b[39m*\u001b[39m\n\u001b[1;32m 9\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mdnnlib\u001b[39;00m \n\u001b[1;32m 10\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mlegacy\u001b[39;00m\n",
25
- "File \u001b[0;32m~/Desktop/latent-space-theories/backend/disentangle_concepts.py:2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mnumpy\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mnp\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39msklearn\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39msvm\u001b[39;00m \u001b[39mimport\u001b[39;00m SVC\n\u001b[1;32m 3\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39msklearn\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mlinear_model\u001b[39;00m \u001b[39mimport\u001b[39;00m LogisticRegression\n\u001b[1;32m 4\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39msklearn\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmodel_selection\u001b[39;00m \u001b[39mimport\u001b[39;00m train_test_split\n",
26
- "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'sklearn'"
27
- ]
28
- }
29
- ],
30
  "source": [
31
  "import pandas as pd\n",
32
  "import pickle\n",
@@ -171,19 +150,10 @@
171
  },
172
  {
173
  "cell_type": "code",
174
- "execution_count": 1,
175
  "id": "0eae840f",
176
  "metadata": {},
177
- "outputs": [
178
- {
179
- "name": "stderr",
180
- "output_type": "stream",
181
- "text": [
182
- "/Users/ludovicaschaerf/anaconda3/envs/torch_arm/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
183
- " from .autonotebook import tqdm as notebook_tqdm\n"
184
- ]
185
- }
186
- ],
187
  "source": [
188
  "import open_clip\n",
189
  "import os\n",
@@ -194,7 +164,7 @@
194
  },
195
  {
196
  "cell_type": "code",
197
- "execution_count": 2,
198
  "id": "4d776015",
199
  "metadata": {},
200
  "outputs": [],
@@ -205,33 +175,10 @@
205
  },
206
  {
207
  "cell_type": "code",
208
- "execution_count": 3,
209
  "id": "e3f917a7",
210
  "metadata": {},
211
- "outputs": [
212
- {
213
- "name": "stdout",
214
- "output_type": "stream",
215
- "text": [
216
- "['Provenance ADRIA', 'Provenance AEGINA', 'Provenance AL MINA', 'Provenance ALICANTE', 'Provenance AMATHUS', 'Provenance AMPURIAS', 'Provenance APOLLONIA PONTICA', 'Provenance APULIA', 'Provenance ARGOLIS', 'Provenance ARGOS', 'Provenance ATHENS', 'Provenance ATHENS (?)', 'Provenance ATTICA', 'Provenance BARI', 'Provenance BENGHAZI', 'Provenance BEREZAN', 'Provenance BLACK SEA', 'Provenance BOEOTIA', 'Provenance BOLOGNA', 'Provenance CAPUA', 'Provenance CARIA', 'Provenance CARTHAGE', 'Provenance CHIUSI', 'Provenance CIVITAVECCHIA', 'Provenance CORINTH', 'Provenance CORSICA', 'Provenance CRETE', 'Provenance CRIMEA', 'Provenance CUMAE', 'Provenance CYCLADES', 'Provenance CYPRUS', 'Provenance CYRENAICA', 'Provenance CYRENE', 'Provenance EGYPT', 'Provenance ELIS', 'Provenance ENSERUNE', 'Provenance ETRURIA', 'Provenance EUBOEA', 'Provenance FALERII', 'Provenance GRANADA', 'Provenance GREECE', 'Provenance GREECE (?)', 'Provenance HISTRIA', 'Provenance ITALY', 'Provenance JAEN', 'Provenance KITION', 'Provenance LESBOS', 'Provenance LOCRI', 'Provenance LOCRIS', 'Provenance LYCIA', 'Provenance MACEDONIA', 'Provenance MARION', 'Provenance MARSEILLES', 'Provenance MARZABOTTO', 'Provenance METAPONTUM', 'Provenance MILETUS', 'Provenance NAPLES', 'Provenance NAUCRATIS', 'Provenance NOLA', 'Provenance NUMANA', 'Provenance OLD SMYRNA', 'Provenance ORVIETO', 'Provenance PAESTUM', 'Provenance PHOCIS', 'Provenance PITANE', 'Provenance POPULONIA', 'Provenance RHODES', 'Provenance ROME', 'Provenance RUVO', 'Provenance SALERNO', 'Provenance SAMOS', 'Provenance SARDINIA', 'Provenance SARDIS', 'Provenance SICILY', 'Provenance SMYRNA', 'Provenance SOUTH', 'Provenance SPARTA', 'Provenance SPINA', 'Provenance SUESSULA', 'Provenance TARANTO', 'Provenance TELL DEFENNEH', 'Provenance THASOS', 'Provenance THERA', 'Provenance THESSALY', 'Provenance THRACE', 'Provenance TOCRA', 'Provenance TODI', 'Provenance ULLASTRET', 'Provenance VALENCIA', 'Shape Name ALABASTRON', 'Shape Name ALABASTRON FRAGMENT', 'Shape Name AMPHORA', 'Shape Name AMPHORA (?) FRAGMENT', 'Shape Name AMPHORISKOS', 'Shape Name ARYBALLOS', 'Shape Name ASKOS', 'Shape Name ASKOS FRAGMENT', 'Shape Name BOTTLE', 'Shape Name BOWL', 'Shape Name BOWL FRAGMENT', 'Shape Name CHALICE', 'Shape Name CHALICE FRAGMENT', 'Shape Name CHOUS', 'Shape Name CHOUS FRAGMENT', 'Shape Name CUP', 'Shape Name CUP (?) FRAGMENT', 'Shape Name CUP A', 'Shape Name CUP A FRAGMENT', 'Shape Name CUP A FRAGMENTS', 'Shape Name CUP B', 'Shape Name CUP B FRAGMENT', 'Shape Name CUP B FRAGMENTS', 'Shape Name CUP C', 'Shape Name CUP C FRAGMENT', 'Shape Name CUP C FRAGMENTS', 'Shape Name CUP DROOP', 'Shape Name CUP DROOP FRAGMENT', 'Shape Name CUP FRAGMENT', 'Shape Name CUP FRAGMENTS', 'Shape Name CUP KASSEL', 'Shape Name CUP KASSEL FRAGMENT', 'Shape Name CUP LITTLE MASTER BAND', 'Shape Name CUP LITTLE MASTER BAND (?) FRAGMENT', 'Shape Name CUP LITTLE MASTER BAND FRAGMENT', 'Shape Name CUP LITTLE MASTER BAND FRAGMENTS', 'Shape Name CUP LITTLE MASTER FRAGMENT', 'Shape Name CUP LITTLE MASTER LIP', 'Shape Name CUP LITTLE MASTER LIP FRAGMENT', 'Shape Name CUP LITTLE MASTER LIP FRAGMENTS', 'Shape Name CUP SIANA', 'Shape Name CUP SIANA FRAGMENT', 'Shape Name CUP SIANA FRAGMENTS', 'Shape Name CUP SKYPHOS', 'Shape Name CUP SKYPHOS FRAGMENT', 'Shape Name CUP STEMLESS', 'Shape Name CUP STEMLESS FRAGMENT', 'Shape Name CUP STEMLESS FRAGMENTS', 'Shape Name DINOS', 'Shape Name DINOS FRAGMENT', 'Shape Name DISH', 'Shape Name EPICHYSIS', 'Shape Name EPINETRON FRAGMENT', 'Shape Name FEEDER', 'Shape Name FIGURE VASE', 'Shape Name FIGURE VASE ARYBALLOS', 'Shape Name FIGURE VASE ASKOS', 'Shape Name FIGURE VASE FRAGMENT', 'Shape Name FIGURE VASE KANTHAROS', 'Shape Name FIGURE VASE LEKYTHOS', 'Shape Name FIGURE VASE OINOCHOE', 'Shape Name FISH-PLATE', 'Shape Name FLASK', 'Shape Name FRAGMENT', 'Shape Name FRAGMENTS', 'Shape Name GUTTUS', 'Shape Name HYDRIA', 'Shape Name HYDRIA (?) FRAGMENT', 'Shape Name HYDRIA FRAGMENT', 'Shape Name HYDRIA FRAGMENTS', 'Shape Name INCENSE BURNER', 'Shape Name JAR', 'Shape Name JUG', 'Shape Name KALATHOS', 'Shape Name KANTHAROS', 'Shape Name KANTHAROS FRAGMENT', 'Shape Name KERNOS', 'Shape Name KOTYLE', 'Shape Name KOTYLE FRAGMENT', 'Shape Name KRATER', 'Shape Name KRATER (?) FRAGMENT', 'Shape Name KRATER FRAGMENT', 'Shape Name KRATER FRAGMENTS', 'Shape Name KYATHOS', 'Shape Name KYATHOS FRAGMENT', 'Shape Name LEBES', 'Shape Name LEBES FRAGMENT', 'Shape Name LEKANIS', 'Shape Name LEKANIS FRAGMENT', 'Shape Name LEKANIS FRAGMENTS', 'Shape Name LEKANIS LID', 'Shape Name LEKANIS LID FRAGMENT', 'Shape Name LEKYTHOS', 'Shape Name LEKYTHOS FRAGMENT', 'Shape Name LEKYTHOS FRAGMENTS', 'Shape Name LID', 'Shape Name LID FRAGMENT', 'Shape Name LOUTROPHOROS', 'Shape Name LOUTROPHOROS FRAGMENT', 'Shape Name LOUTROPHOROS FRAGMENTS', 'Shape Name LYDION', 'Shape Name MASTOID', 'Shape Name MINIATURE PANATHENAIC AMPHORA', 'Shape Name MUG', 'Shape Name MUG FRAGMENT', 'Shape Name NESTORIS', 'Shape Name OINOCHOE', 'Shape Name OINOCHOE (?) FRAGMENT', 'Shape Name OINOCHOE FRAGMENT', 'Shape Name OINOCHOE FRAGMENTS', 'Shape Name OLLA', 'Shape Name OLPE', 'Shape Name OLPE FRAGMENT', 'Shape Name OLPE FRAGMENTS', 'Shape Name PELIKE', 'Shape Name PELIKE FRAGMENT', 'Shape Name PELIKE FRAGMENTS', 'Shape Name PHIALE', 'Shape Name PHIALE FRAGMENT', 'Shape Name PITCHER', 'Shape Name PLAQUE FRAGMENT', 'Shape Name PLAQUE FRAGMENTS', 'Shape Name PLATE', 'Shape Name PLATE FRAGMENT', 'Shape Name PLATE FRAGMENTS', 'Shape Name PLEMOCHOE', 'Shape Name PSEUDO-PANATHENAIC AMPHORA', 'Shape Name PSEUDO-PANATHENAIC AMPHORA FRAGMENT', 'Shape Name PSYKTER', 'Shape Name PYXIS', 'Shape Name PYXIS FRAGMENT', 'Shape Name PYXIS FRAGMENTS', 'Shape Name PYXIS LID', 'Shape Name PYXIS LID FRAGMENT', 'Shape Name RHYTON', 'Shape Name SKYPHOS', 'Shape Name SKYPHOS (?) FRAGMENT', 'Shape Name SKYPHOS FRAGMENT', 'Shape Name SKYPHOS FRAGMENTS', 'Shape Name STAMNOS', 'Shape Name STAMNOS FRAGMENT', 'Shape Name STAMNOS FRAGMENTS', 'Shape Name STAND', 'Shape Name STAND FRAGMENT', 'Shape Name STEMLESS CUP FRAGMENT', 'Shape Name STIRRUP JAR', 'Shape Name TANKARD', 'Shape Name UNKNOWN', 'Shape Name URN', 'Shape Name VARIOUS', 'Shape Name VASE', 'Fabric ARGIVE GEOMETRIC', 'Fabric ARRETINE', 'Fabric ATHENIAN', 'Fabric ATHENIAN (?)', 'Fabric ATHENIAN GEOMETRIC', 'Fabric ATHENIAN PROTOGEOMETRIC', 'Fabric BOEOTIAN', 'Fabric BOEOTIAN GEOMETRIC', 'Fabric CALENIAN', 'Fabric CAMPANA', 'Fabric CANOSAN', 'Fabric CHALCIDIAN', 'Fabric CORINTHIAN', 'Fabric CRETAN', 'Fabric CYCLADIC', 'Fabric CYPRIOT', 'Fabric CYPRIOT, BRONZE AGE', 'Fabric CYPRIOT, IRON AGE', 'Fabric CYPRIOT, MYCENAEAN STYLE', 'Fabric DAUNIAN', 'Fabric EAST GREEK', 'Fabric EAST GREEK GEOMETRIC', 'Fabric EAST GREEK, CLAZOMENIAN', 'Fabric EAST GREEK, FIKELLURA', 'Fabric EAST GREEK, NAUCRATITE', 'Fabric EGYPTIAN', 'Fabric ETRUSCAN', 'Fabric ETRUSCO-CORINTHIAN', 'Fabric FALISCAN', 'Fabric GALLIC', 'Fabric GALLO-ROMAN', 'Fabric GREEK', 'Fabric HELLADIC', 'Fabric HELLENISTIC', 'Fabric IBERIAN', 'Fabric IONIAN', 'Fabric ITALIOTE', 'Fabric ITALO-CORINTHIAN', 'Fabric ITALO-GEOMETRIC', 'Fabric LACONIAN', 'Fabric LACONIAN GEOMETRIC', 'Fabric MESSAPIAN', 'Fabric MINOAN', 'Fabric MYCENEAN', 'Fabric PONTIC', 'Fabric PROTO-ELAMITE', 'Fabric PROTOATTIC', 'Fabric PROTOCORINTHIAN', 'Fabric PROTOCORINTHIAN, TRANSITIONAL', 'Fabric ROMAN', 'Fabric SOUTH ITALIAN', 'Fabric SOUTH ITALIAN, APULIAN', 'Fabric SOUTH ITALIAN, CAMPANIAN', 'Fabric SOUTH ITALIAN, GNATHIAN', 'Fabric SOUTH ITALIAN, LUCANIAN', 'Fabric SOUTH ITALIAN, PAESTAN', 'Fabric SOUTH ITALIAN, SICILIAN', 'Fabric TERRA SIGILLATA', 'Fabric UNCERTAIN', 'Fabric VEIAN', 'Fabric VILLA NOVA', 'Technique ADDED COLOUR', 'Technique BLACK GLAZE', 'Technique BLACK PATTERN', 'Technique BLACK-FIGURE', 'Technique BROWN GLAZE', 'Technique BUCCHERO', 'Technique IMPASTO', 'Technique OUTLINE', 'Technique PATTERN', 'Technique PLAIN', 'Technique PSEUDO RED-FIGURE', 'Technique RED POLISHED WARE', 'Technique RED-FIGURE', 'Technique RELIEF', 'Technique RESERVING', 'Technique SILHOUETTE', 'Technique WHITE PAINTED WARE']\n"
217
- ]
218
- },
219
- {
220
- "name": "stderr",
221
- "output_type": "stream",
222
- "text": [
223
- "/Users/ludovicaschaerf/anaconda3/envs/torch_arm/lib/python3.11/site-packages/torch/amp/autocast_mode.py:221: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
224
- " warnings.warn('User provided device_type of \\'cuda\\', but CUDA is not available. Disabling')\n"
225
- ]
226
- },
227
- {
228
- "name": "stdout",
229
- "output_type": "stream",
230
- "text": [
231
- "(318, 768)\n"
232
- ]
233
- }
234
- ],
235
  "source": [
236
  "model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')\n",
237
  "tokenizer = open_clip.get_tokenizer('ViT-L-14')\n",
@@ -257,7 +204,7 @@
257
  },
258
  {
259
  "cell_type": "code",
260
- "execution_count": 4,
261
  "id": "f7858bbf",
262
  "metadata": {},
263
  "outputs": [],
@@ -267,7 +214,7 @@
267
  },
268
  {
269
  "cell_type": "code",
270
- "execution_count": 5,
271
  "id": "de6bd428",
272
  "metadata": {},
273
  "outputs": [],
@@ -277,7 +224,7 @@
277
  },
278
  {
279
  "cell_type": "code",
280
- "execution_count": 6,
281
  "id": "d19c8e4c",
282
  "metadata": {},
283
  "outputs": [],
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": null,
6
  "id": "fe7acfaf-dc61-4211-9c78-8e4433bc9deb",
7
  "metadata": {},
8
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  "source": [
10
  "import pandas as pd\n",
11
  "import pickle\n",
 
150
  },
151
  {
152
  "cell_type": "code",
153
+ "execution_count": null,
154
  "id": "0eae840f",
155
  "metadata": {},
156
+ "outputs": [],
 
 
 
 
 
 
 
 
 
157
  "source": [
158
  "import open_clip\n",
159
  "import os\n",
 
164
  },
165
  {
166
  "cell_type": "code",
167
+ "execution_count": null,
168
  "id": "4d776015",
169
  "metadata": {},
170
  "outputs": [],
 
175
  },
176
  {
177
  "cell_type": "code",
178
+ "execution_count": null,
179
  "id": "e3f917a7",
180
  "metadata": {},
181
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  "source": [
183
  "model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')\n",
184
  "tokenizer = open_clip.get_tokenizer('ViT-L-14')\n",
 
204
  },
205
  {
206
  "cell_type": "code",
207
+ "execution_count": null,
208
  "id": "f7858bbf",
209
  "metadata": {},
210
  "outputs": [],
 
214
  },
215
  {
216
  "cell_type": "code",
217
+ "execution_count": null,
218
  "id": "de6bd428",
219
  "metadata": {},
220
  "outputs": [],
 
224
  },
225
  {
226
  "cell_type": "code",
227
+ "execution_count": null,
228
  "id": "d19c8e4c",
229
  "metadata": {},
230
  "outputs": [],
view_segmentations.ipynb ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from PIL import Image\n",
10
+ "from lang_sam import LangSAM"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "\n",
20
+ "def save_mask_as_image(mask, output_path):\n",
21
+ " # Create a blank image with the same dimensions as the mask\n",
22
+ " width, height = mask.shape[1], mask.shape[0]\n",
23
+ " image = Image.new(\"L\", (width, height))\n",
24
+ "\n",
25
+ " # Set the pixel values based on the mask\n",
26
+ " for y in range(height):\n",
27
+ " for x in range(width):\n",
28
+ " pixel_value = 255 if mask[y, x] == 1 else 0\n",
29
+ " image.putpixel((x, y), pixel_value)\n",
30
+ "\n",
31
+ " # Save the image as a PNG\n",
32
+ " image.save(output_path)\n",
33
+ "\n",
34
+ "\n",
35
+ "\n",
36
+ "model = LangSAM()\n",
37
+ "TEST_DIR = '/Users/ludovicaschaerf/Desktop/Data/VA_textiles_masks/'\n",
38
+ "OUT_DIR = '/Users/ludovicaschaerf/Desktop/Data/VA_textiles/'\n",
39
+ "image_pil = Image.open(TEST_DIR + \"O25495.jpg\").convert(\"RGB\")\n",
40
+ "text_prompt = \"textile\"\n",
41
+ "masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)\n"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "\n",
51
+ "save_mask_as_image(masks[0], OUT_DIR + \"O25495.jpg\")"
52
+ ]
53
+ }
54
+ ],
55
+ "metadata": {
56
+ "kernelspec": {
57
+ "display_name": "art-reco_x86",
58
+ "language": "python",
59
+ "name": "python3"
60
+ },
61
+ "language_info": {
62
+ "codemirror_mode": {
63
+ "name": "ipython",
64
+ "version": 3
65
+ },
66
+ "file_extension": ".py",
67
+ "mimetype": "text/x-python",
68
+ "name": "python",
69
+ "nbconvert_exporter": "python",
70
+ "pygments_lexer": "ipython3",
71
+ "version": "3.8.16"
72
+ },
73
+ "orig_nbformat": 4
74
+ },
75
+ "nbformat": 4,
76
+ "nbformat_minor": 2
77
+ }