ludusc commited on
Commit
065996d
2 Parent(s): 2cafca2 5b6e083

Merge branch 'main' of https://huggingface.co/spaces/ludusc/latent-space-theories into main

Browse files
Files changed (38) hide show
  1. .gitignore +2 -1
  2. data/CLIP_vecs_vases.pkl +0 -3
  3. data/annotated_files/seeds0000-100000.pkl +0 -3
  4. data/annotated_files/seeds0000-50000.pkl +0 -3
  5. data/annotated_files/sim_seeds0000-100000.csv +0 -3
  6. data/annotated_files/sim_seeds0000-50000.csv +0 -3
  7. data/model_files/network-snapshot-010600.pkl +0 -3
  8. data/old/ImageNet_metadata.csv +0 -3
  9. data/old/activation/convnext_activation.json +0 -3
  10. data/old/activation/mobilenet_activation.json +0 -3
  11. data/old/activation/resnet_activation.json +0 -3
  12. data/old/dot_architectures/convnext_architecture.dot +0 -3
  13. data/old/layer_infos/convnext_layer_infos.json +0 -3
  14. data/old/layer_infos/mobilenet_layer_infos.json +0 -3
  15. data/old/layer_infos/resnet_layer_infos.json +0 -3
  16. data/old/preprocessed_image_net/val_data_0.pkl +0 -3
  17. data/old/preprocessed_image_net/val_data_1.pkl +0 -3
  18. data/old/preprocessed_image_net/val_data_2.pkl +0 -3
  19. data/old/preprocessed_image_net/val_data_3.pkl +0 -3
  20. data/old/preprocessed_image_net/val_data_4.pkl +0 -3
  21. data/{CLIP_vecs.pkl → stored_vectors/scores_colors_hsv.csv} +2 -2
  22. data/vase_annotated_files/seeds0000-20000.pkl +0 -3
  23. data/vase_annotated_files/sim_Fabric_seeds0000-20000.csv +0 -3
  24. data/vase_annotated_files/sim_Provenance_seeds0000-20000.csv +0 -3
  25. data/vase_annotated_files/sim_Shape Name_seeds0000-20000.csv +0 -3
  26. data/vase_annotated_files/sim_Technique_seeds0000-20000.csv +0 -3
  27. data/vase_model_files/network-snapshot-003800.pkl +0 -3
  28. ganspace_unsupervised_disentanglement.ipynb +174 -0
  29. interfacegan_colour_disentanglement.ipynb +520 -0
  30. pages/1_Omniart_Disentanglement.py +0 -202
  31. pages/{5_Textiles_Disentanglement.py → 1_Textiles_Disentanglement.py} +4 -4
  32. pages/{2_Concepts_comparison.py → 2_Colours_comparison.py} +0 -0
  33. pages/3_Oxford_Vases_Disentanglement.py +0 -178
  34. pages/4_Vase_Qualities_Comparison copy.py +0 -268
  35. structure_annotations.ipynb +382 -0
  36. stylespace_colour_disentanglement.ipynb +580 -0
  37. view_predictions.ipynb +10 -63
  38. view_segmentations.ipynb +77 -0
.gitignore CHANGED
@@ -188,4 +188,5 @@ cython_debug/
188
 
189
  data/images/
190
  tmp/
191
- figures/
 
 
188
 
189
  data/images/
190
  tmp/
191
+ figures/
192
+ archive/
data/CLIP_vecs_vases.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c6299cd65f0635d6077788b135520ee8e88063930c63db458e643d77cba7b6ee
3
- size 995715
 
 
 
 
data/annotated_files/seeds0000-100000.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7b3a4fd155fa86df0953ad1cb660d50729189606de307fcee09fd893ba047228
3
- size 420351795
 
 
 
 
data/annotated_files/seeds0000-50000.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:cd1bd97b8ff508b1d4a7ef43323530368ace65b35d12d84a914913f541187298
3
- size 314939226
 
 
 
 
data/annotated_files/sim_seeds0000-100000.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6e0b08d729f87f827d3d88327b33ff22d8413cb7aa7057c0c4ccd384d72a2c21
3
- size 21090135
 
 
 
 
data/annotated_files/sim_seeds0000-50000.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:da82633e4296ae78ce9e3e208fae374ae8983137566101060aadd11ffd3b0ff7
3
- size 50535430
 
 
 
 
data/model_files/network-snapshot-010600.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9a46e8aecd50191b82632b5de7bf3b9e219a59564c54994dd203f016b7a8270e
3
- size 357344749
 
 
 
 
data/old/ImageNet_metadata.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e53b0fc17cd5c8811ca08b7ff908cd2bbd625147686ef8bc020cb85a5a4546e5
3
- size 3027633
 
 
 
 
data/old/activation/convnext_activation.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0354b28bcca4e3673888124740e3d82882cbf38af8cd3007f48a7a5db983f487
3
- size 33350177
 
 
 
 
data/old/activation/mobilenet_activation.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5abc76e9318fadee18f35bb54e90201bf28699cf75140b5d2482d42243fad302
3
- size 13564581
 
 
 
 
data/old/activation/resnet_activation.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:668bea355a5504d74f79d20d02954040ad572f50455361d7d17125c7c8b1561c
3
- size 23362905
 
 
 
 
data/old/dot_architectures/convnext_architecture.dot DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:41a258a40a93615638ae504770c14e44836c934badbe48f18148f5a750514ac9
3
- size 9108
 
 
 
 
data/old/layer_infos/convnext_layer_infos.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3e82ea48865493107b97f37da58e370f0eead5677bf10f25f237f10970aedb6f
3
- size 1678
 
 
 
 
data/old/layer_infos/mobilenet_layer_infos.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a11df5f0b23040d11ce817658a989c8faf19faa06a8cbad727b635bac824e917
3
- size 3578
 
 
 
 
data/old/layer_infos/resnet_layer_infos.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:21e1787382f1e1c206b81d2c4fe207fb6d41f4cf186d5afc32fc056dd21e10d6
3
- size 5155
 
 
 
 
data/old/preprocessed_image_net/val_data_0.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2698bdc240555e2a46a40936df87275bc5852142d30e921ae0dad9289b0f576f
3
- size 906108480
 
 
 
 
data/old/preprocessed_image_net/val_data_1.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:21780d77e212695dbee84d6d2ad17a5a520bc1634f68e1c8fd120f069ad76da1
3
- size 907109023
 
 
 
 
data/old/preprocessed_image_net/val_data_2.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2cfc83b78420baa1b2c3a8da92e7fba1f33443d506f483ecff13cdba2035ab3c
3
- size 907435149
 
 
 
 
data/old/preprocessed_image_net/val_data_3.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:2f5e2c7cb4d6bae17fbd062a0b46f2cee457ad466b725f7bdf0f8426069cafee
3
- size 906089333
 
 
 
 
data/old/preprocessed_image_net/val_data_4.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4ed53c87ec8b9945db31f910eb44b7e3092324643de25ea53a99fc29137df854
3
- size 905439763
 
 
 
 
data/{CLIP_vecs.pkl → stored_vectors/scores_colors_hsv.csv} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e2971a01a74a391c752fff9ba91c2939ffc6b29165842a87b911e67d9658df53
3
- size 412234
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93f5789a80465ca7b21713819bc444d72239fa1b7ae56adf69e3323e0f3bedd1
3
+ size 974247
data/vase_annotated_files/seeds0000-20000.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e790910bf45c0d5a84e74c9011b88012f59d0fc27b19987c890b891c57ab739c
3
- size 125913423
 
 
 
 
data/vase_annotated_files/sim_Fabric_seeds0000-20000.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:20fa48320e20998aad5665610083843705608a5f06ff081e4395ee4b5ac9cba3
3
- size 9731011
 
 
 
 
data/vase_annotated_files/sim_Provenance_seeds0000-20000.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a421ffd119eee312249c9fbd05ac65460849e71f538d05fad223cb55423f315f
3
- size 18066428
 
 
 
 
data/vase_annotated_files/sim_Shape Name_seeds0000-20000.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e258361e0db7c208ae67654c08ed5b900df10980e82e84bcddd3de89428f679a
3
- size 30853761
 
 
 
 
data/vase_annotated_files/sim_Technique_seeds0000-20000.csv DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e3d3425e15d76d47a8829783cadbd7072698083df199617a8423d5ccb9d88714
3
- size 2484876
 
 
 
 
data/vase_model_files/network-snapshot-003800.pkl DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:42be0a24e7021dc66a9353c3a904494bb8e64b62e00e535ad3b03ad18238b0d2
3
- size 357349976
 
 
 
 
ganspace_unsupervised_disentanglement.ipynb ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "3722712c",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "%matplotlib inline \n",
11
+ "\n",
12
+ "import pandas as pd\n",
13
+ "import pickle\n",
14
+ "import random\n",
15
+ "\n",
16
+ "from PIL import Image, ImageColor\n",
17
+ "import matplotlib.pyplot as plt\n",
18
+ "\n",
19
+ "import numpy as np\n",
20
+ "import torch\n",
21
+ "\n",
22
+ "from backend.disentangle_concepts import *\n",
23
+ "import dnnlib \n",
24
+ "import legacy\n",
25
+ "from backend.color_annotations import *\n",
26
+ "\n",
27
+ "import random\n",
28
+ "\n",
29
+ "from sklearn.linear_model import LinearRegression, LogisticRegression\n",
30
+ "\n",
31
+ "\n",
32
+ "%load_ext autoreload\n",
33
+ "%autoreload 2"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "id": "fe7acfaf-dc61-4211-9c78-8e4433bc9deb",
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "annotations_file = './data/textile_annotated_files/seeds0000-100000.pkl'\n",
44
+ "with open(annotations_file, 'rb') as f:\n",
45
+ " annotations = pickle.load(f)\n",
46
+ "\n",
47
+ "ann_df = pd.read_csv('./data/textile_annotated_files/top_three_colours.csv').fillna('#000000')\n",
48
+ "\n",
49
+ "with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:\n",
50
+ " model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore\n"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": null,
56
+ "id": "cd114cb1",
57
+ "metadata": {},
58
+ "outputs": [],
59
+ "source": [
60
+ "ann_df = tohsv(ann_df)\n",
61
+ "ann_df.head()"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "id": "feb64168",
68
+ "metadata": {},
69
+ "outputs": [],
70
+ "source": [
71
+ "X = np.array(annotations['w_vectors']).reshape((len(annotations['w_vectors']), 512))\n",
72
+ "print(X.shape)\n",
73
+ "y_h = np.array(ann_df['H1'].values)\n",
74
+ "y_s = np.array(ann_df['S1'].values)\n",
75
+ "y_v = np.array(ann_df['S1'].values)"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "markdown",
80
+ "id": "4e814959",
81
+ "metadata": {},
82
+ "source": [
83
+ "### Unsupervised approaches"
84
+ ]
85
+ },
86
+ {
87
+ "cell_type": "code",
88
+ "execution_count": null,
89
+ "id": "c9493f54",
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "from sklearn.decomposition import PCA"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "id": "93596853",
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "pca = PCA(n_components=20)\n",
104
+ "\n",
105
+ "dims_pca = pca.fit_transform(x_trainhc.T)\n",
106
+ "dims_pca /= np.linalg.norm(dims_pca, axis=0)\n",
107
+ "print(dims_pca.shape, np.linalg.norm(dims_pca, axis=0).shape)"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "id": "dd0b1501",
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "method = 'PCA dimension'\n",
118
+ "for sep, num in zip(dims_pca.T, range(10)):\n",
119
+ " images, lambdas = regenerate_images(model, original_image_vec, sep, min_epsilon=-(int(4)), max_epsilon=int(4), count=5, latent_space='W')\n",
120
+ " fig, axs = plt.subplots(1, len(images), figsize=(50,10))\n",
121
+ " fig.suptitle(method +': ' + str(num), fontsize=20)\n",
122
+ " for i,im in enumerate(images):\n",
123
+ " axs[i].imshow(im)\n",
124
+ " axs[i].set_title(np.round(lambdas[i], 2))\n",
125
+ " plt.show()"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "markdown",
130
+ "id": "4e0c7808",
131
+ "metadata": {},
132
+ "source": [
133
+ "## dimensionality reduction e vediamo dove finiscono i vari colori"
134
+ ]
135
+ },
136
+ {
137
+ "cell_type": "markdown",
138
+ "id": "833ed31f",
139
+ "metadata": {},
140
+ "source": [
141
+ "## clustering per vedere quali sono i centroid di questo spazio e se ci sono regioni determinate dai colori"
142
+ ]
143
+ },
144
+ {
145
+ "cell_type": "code",
146
+ "execution_count": null,
147
+ "id": "7c19e820",
148
+ "metadata": {},
149
+ "outputs": [],
150
+ "source": []
151
+ }
152
+ ],
153
+ "metadata": {
154
+ "kernelspec": {
155
+ "display_name": "Python 3",
156
+ "language": "python",
157
+ "name": "python3"
158
+ },
159
+ "language_info": {
160
+ "codemirror_mode": {
161
+ "name": "ipython",
162
+ "version": 3
163
+ },
164
+ "file_extension": ".py",
165
+ "mimetype": "text/x-python",
166
+ "name": "python",
167
+ "nbconvert_exporter": "python",
168
+ "pygments_lexer": "ipython3",
169
+ "version": "3.8.16"
170
+ }
171
+ },
172
+ "nbformat": 4,
173
+ "nbformat_minor": 5
174
+ }
interfacegan_colour_disentanglement.ipynb ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "3722712c",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "%matplotlib inline \n",
11
+ "\n",
12
+ "import pandas as pd\n",
13
+ "import pickle\n",
14
+ "import random\n",
15
+ "\n",
16
+ "from PIL import Image, ImageColor\n",
17
+ "import matplotlib.pyplot as plt\n",
18
+ "\n",
19
+ "import numpy as np\n",
20
+ "import torch\n",
21
+ "\n",
22
+ "from backend.disentangle_concepts import *\n",
23
+ "import dnnlib \n",
24
+ "import legacy\n",
25
+ "from backend.color_annotations import *\n",
26
+ "\n",
27
+ "import random\n",
28
+ "\n",
29
+ "from sklearn.linear_model import LinearRegression, LogisticRegression\n",
30
+ "\n",
31
+ "\n",
32
+ "%load_ext autoreload\n",
33
+ "%autoreload 2"
34
+ ]
35
+ },
36
+ {
37
+ "cell_type": "code",
38
+ "execution_count": null,
39
+ "id": "5630402a",
40
+ "metadata": {},
41
+ "outputs": [],
42
+ "source": [
43
+ "num_colors = 7"
44
+ ]
45
+ },
46
+ {
47
+ "cell_type": "code",
48
+ "execution_count": null,
49
+ "id": "00e57598",
50
+ "metadata": {},
51
+ "outputs": [],
52
+ "source": [
53
+ "values = [x*256/num_colors if x<num_colors else 256 for x in range(num_colors + 1)]\n",
54
+ "centers = [int((values[i-1]+values[i])/2) for i in range(len(values)) if i > 0]"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "id": "1550ecd7",
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "print(values)\n",
65
+ "print(centers)"
66
+ ]
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": null,
71
+ "id": "ab9be91e",
72
+ "metadata": {},
73
+ "outputs": [],
74
+ "source": [
75
+ "def create_color_image(hue, saturation, value, size=(20, 10)):\n",
76
+ " color_rgb = ImageColor.getrgb(\"hsv({}, {}%, {}%)\".format(hue, int(saturation * 100), int(value * 100)))\n",
77
+ " image = Image.new(\"RGB\", size, color_rgb)\n",
78
+ " return image"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "id": "bf1c8ab5",
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "def display_image(image, title=''):\n",
89
+ " plt.figure()\n",
90
+ " plt.suptitle(title)\n",
91
+ " plt.imshow(image)\n",
92
+ " plt.axis('off')\n",
93
+ " plt.show()"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "id": "519b16d4",
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "def to_256(val):\n",
104
+ " x = val*360/256\n",
105
+ " return x"
106
+ ]
107
+ },
108
+ {
109
+ "cell_type": "code",
110
+ "execution_count": null,
111
+ "id": "8f696758",
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "names = ['Red', 'Orange', 'Yellow', 'Yellow Green', 'Chartreuse Green',\n",
116
+ " 'Kelly Green', 'Green Blue Seafoam', 'Cyan Blue',\n",
117
+ " 'Warm Blue', 'Indigo', 'Purple Magenta', 'Magenta Pink']"
118
+ ]
119
+ },
120
+ {
121
+ "cell_type": "code",
122
+ "execution_count": null,
123
+ "id": "50825823",
124
+ "metadata": {},
125
+ "outputs": [],
126
+ "source": [
127
+ "saturation = 1 # Saturation value (0 to 1)\n",
128
+ "value = 1 # Value (brightness) value (0 to 1)\n",
129
+ "for hue, name in zip(centers, names[:num_colors]):\n",
130
+ " image = create_color_image(to_256(hue), saturation, value)\n",
131
+ " display_image(image, name) # Display the generated color image"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "id": "fe7acfaf-dc61-4211-9c78-8e4433bc9deb",
138
+ "metadata": {},
139
+ "outputs": [],
140
+ "source": [
141
+ "annotations_file = './data/textile_annotated_files/seeds0000-100000.pkl'\n",
142
+ "with open(annotations_file, 'rb') as f:\n",
143
+ " annotations = pickle.load(f)\n",
144
+ "\n",
145
+ "ann_df = pd.read_csv('./data/textile_annotated_files/top_three_colours.csv').fillna('#000000')\n",
146
+ "\n",
147
+ "with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:\n",
148
+ " model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore\n"
149
+ ]
150
+ },
151
+ {
152
+ "cell_type": "code",
153
+ "execution_count": null,
154
+ "id": "cd114cb1",
155
+ "metadata": {},
156
+ "outputs": [],
157
+ "source": [
158
+ "ann_df = tohsv(ann_df)\n",
159
+ "ann_df.head()"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "id": "feb64168",
166
+ "metadata": {},
167
+ "outputs": [],
168
+ "source": [
169
+ "X = np.array(annotations['w_vectors']).reshape((len(annotations['w_vectors']), 512))\n",
170
+ "print(X.shape)\n",
171
+ "y_h = np.array(ann_df['H1'].values)\n",
172
+ "y_s = np.array(ann_df['S1'].values)\n",
173
+ "y_v = np.array(ann_df['S1'].values)"
174
+ ]
175
+ },
176
+ {
177
+ "cell_type": "code",
178
+ "execution_count": null,
179
+ "id": "0ca08749",
180
+ "metadata": {},
181
+ "outputs": [],
182
+ "source": [
183
+ "np.unique(y_h)"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "markdown",
188
+ "id": "e8f33f14",
189
+ "metadata": {},
190
+ "source": [
191
+ "## Regression model"
192
+ ]
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "execution_count": null,
197
+ "id": "8da0a43d",
198
+ "metadata": {},
199
+ "outputs": [],
200
+ "source": [
201
+ "x_trainh, x_valh, y_trainh, y_valh = train_test_split(X, y_h, test_size=0.2)\n",
202
+ "x_trains, x_vals, y_trains, y_vals = train_test_split(X, y_s, test_size=0.2)\n",
203
+ "x_trainv, x_valv, y_trainv, y_valv = train_test_split(X, y_v, test_size=0.2)\n"
204
+ ]
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "execution_count": null,
209
+ "id": "8eddba20",
210
+ "metadata": {},
211
+ "outputs": [],
212
+ "source": [
213
+ "regh = LinearRegression().fit(x_trainh, y_trainh)\n",
214
+ "print('Val performance logistic regression', np.round(regh.score(x_valh, y_valh),2))\n",
215
+ "\n",
216
+ "separation_vectorh = regh.coef_ / np.linalg.norm(regh.coef_)\n",
217
+ "print(separation_vectorh.shape)\n",
218
+ "\n",
219
+ "regs = LinearRegression().fit(x_trains, y_trains)\n",
220
+ "print('Val performance logistic regression', np.round(regs.score(x_vals, y_vals),2))\n",
221
+ "\n",
222
+ "separation_vectors = regs.coef_ / np.linalg.norm(regs.coef_)\n",
223
+ "print(separation_vectors.shape)\n",
224
+ "\n",
225
+ "regv = LinearRegression().fit(x_trainv, y_trainv)\n",
226
+ "print('Val performance logistic regression', np.round(reg.score(x_valv, y_valv),2))\n",
227
+ "\n",
228
+ "separation_vectorv = regv.coef_ / np.linalg.norm(regv.coef_)\n",
229
+ "print(separation_vectorv.shape)\n"
230
+ ]
231
+ },
232
+ {
233
+ "cell_type": "code",
234
+ "execution_count": null,
235
+ "id": "c6a63345",
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": [
239
+ "seed = random.randint(0,100000)\n",
240
+ "original_image_vec = annotations['w_vectors'][seed]\n",
241
+ "img = generate_original_image(original_image_vec, model, latent_space='W')\n",
242
+ "img"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": null,
248
+ "id": "09f13e6a",
249
+ "metadata": {},
250
+ "outputs": [],
251
+ "source": [
252
+ "images, lambdas = regenerate_images(model, original_image_vec, separation_vectors, min_epsilon=-(int(5)), max_epsilon=int(5), count=7, latent_space='W')"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": null,
258
+ "id": "c66bcdde",
259
+ "metadata": {},
260
+ "outputs": [],
261
+ "source": [
262
+ "fig, axs = plt.subplots(1, len(images), figsize=(30,200))\n",
263
+ "for i,im in enumerate(images):\n",
264
+ " axs[i].imshow(im)\n",
265
+ " axs[i].set_title(np.round(lambdas[i], 2))"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "markdown",
270
+ "id": "4c44f0dd",
271
+ "metadata": {},
272
+ "source": [
273
+ "fourier per regolarità pattern\n",
274
+ "linear correlation con il colore\n",
275
+ "distribution dei colori original e non \n",
276
+ "neural network per vedere quanto riesce a classificare"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "markdown",
281
+ "id": "c2790c25",
282
+ "metadata": {},
283
+ "source": [
284
+ "## Multiclass model"
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": null,
290
+ "id": "afa0c100",
291
+ "metadata": {},
292
+ "outputs": [],
293
+ "source": [
294
+ "colors_list = ['Warm Pink Red', 'Red Orange', 'Orange Yellow', 'Gold Yellow', 'Chartreuse Green',\n",
295
+ " 'Kelly Green', 'Green Blue Seafoam', 'Blue Green Cyan',\n",
296
+ " 'Warm Blue', 'Indigo Blue Purple', 'Purple Magenta', 'Magenta Pink']"
297
+ ]
298
+ },
299
+ {
300
+ "cell_type": "markdown",
301
+ "id": "39a5668a",
302
+ "metadata": {},
303
+ "source": [
304
+ "double check colori"
305
+ ]
306
+ },
307
+ {
308
+ "cell_type": "code",
309
+ "execution_count": null,
310
+ "id": "5f2b48c0",
311
+ "metadata": {},
312
+ "outputs": [],
313
+ "source": [
314
+ "from sklearn import svm\n",
315
+ "\n",
316
+ "print([int(x*256/12) if x<12 else 255 for x in range(13)])\n",
317
+ "y_h_cat = pd.cut(y_h,bins=[x*256/12 if x<12 else 256 for x in range(13)],labels=colors_list).fillna('Warm Pink Red')\n",
318
+ "\n",
319
+ "print(y_h_cat.value_counts(dropna=False))\n",
320
+ "x_trainhc, x_valhc, y_trainhc, y_valhc = train_test_split(X, y_h_cat, test_size=0.2)"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "markdown",
325
+ "id": "67651454",
326
+ "metadata": {},
327
+ "source": [
328
+ "### SVR and LR"
329
+ ]
330
+ },
331
+ {
332
+ "cell_type": "code",
333
+ "execution_count": null,
334
+ "id": "7804f593",
335
+ "metadata": {},
336
+ "outputs": [],
337
+ "source": [
338
+ "clf = svm.LinearSVC().fit(x_trainhc, y_trainhc)\n",
339
+ "print('Val performance SVR regression', np.round(clf.score(x_valhc, y_valhc),2))"
340
+ ]
341
+ },
342
+ {
343
+ "cell_type": "code",
344
+ "execution_count": null,
345
+ "id": "e6e31b75",
346
+ "metadata": {},
347
+ "outputs": [],
348
+ "source": [
349
+ "clf_log = LogisticRegression(multi_class='ovr').fit(x_trainhc, y_trainhc)\n",
350
+ "print('Val performance logistic regression', np.round(clf_log.score(x_valhc, y_valhc),2))"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": null,
356
+ "id": "82e30f0c",
357
+ "metadata": {},
358
+ "outputs": [],
359
+ "source": [
360
+ "seed = random.randint(0,100000)\n",
361
+ "original_image_vec = annotations['w_vectors'][seed]\n",
362
+ "img = generate_original_image(original_image_vec, model, latent_space='W')\n",
363
+ "img"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "code",
368
+ "execution_count": null,
369
+ "id": "c8ce6086",
370
+ "metadata": {},
371
+ "outputs": [],
372
+ "source": [
373
+ "from sklearn.metrics import accuracy_score, confusion_matrix \n",
374
+ "\n",
375
+ "y_predhc = clf.predict(x_valhc)\n",
376
+ "print(y_predhc, y_valhc)\n",
377
+ "accuracy_score(y_valhc, y_predhc,)\n",
378
+ "\n",
379
+ "\n",
380
+ "#Get the confusion matrix\n",
381
+ "cm = confusion_matrix(y_valhc, y_predhc)\n",
382
+ "#array([[1, 0, 0],\n",
383
+ "# [1, 0, 0],\n",
384
+ "# [0, 1, 2]])\n",
385
+ "\n",
386
+ "#Now the normalize the diagonal entries\n",
387
+ "cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n",
388
+ "#array([[1. , 0. , 0. ],\n",
389
+ "# [1. , 0. , 0. ],\n",
390
+ "# [0. , 0.33333333, 0.66666667]])\n",
391
+ "\n",
392
+ "#The diagonal entries are the accuracies of each class\n",
393
+ "cm.diagonal()\n",
394
+ "#array([1. , 0. , 0.66666667])"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "id": "112f4b87",
401
+ "metadata": {},
402
+ "outputs": [],
403
+ "source": [
404
+ "print(clf.coef_, clf.coef_.shape)"
405
+ ]
406
+ },
407
+ {
408
+ "cell_type": "code",
409
+ "execution_count": null,
410
+ "id": "6241bce1",
411
+ "metadata": {},
412
+ "outputs": [],
413
+ "source": [
414
+ "warm_blue = clf.coef_[-3, :] / np.linalg.norm(clf.coef_[-3, :])\n",
415
+ "\n",
416
+ "images, lambdas = regenerate_images(model, original_image_vec, warm_blue, min_epsilon=-(int(5)), max_epsilon=int(5), count=7, latent_space='W')\n",
417
+ "fig, axs = plt.subplots(1, len(images), figsize=(30,200))\n",
418
+ "for i,im in enumerate(images):\n",
419
+ " axs[i].imshow(im)\n",
420
+ " axs[i].set_title(np.round(lambdas[i], 2))"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": null,
426
+ "id": "2fefcf0c",
427
+ "metadata": {},
428
+ "outputs": [],
429
+ "source": [
430
+ "warm_blue = clf.coef_[-4, :] / np.linalg.norm(clf.coef_[-4, :])\n",
431
+ "\n",
432
+ "images, lambdas = regenerate_images(model, original_image_vec, warm_blue, min_epsilon=-(int(50)), max_epsilon=int(50), count=2, latent_space='W')\n",
433
+ "fig, axs = plt.subplots(1, len(images), figsize=(30,200))\n",
434
+ "for i,im in enumerate(images):\n",
435
+ " axs[i].imshow(im)\n",
436
+ " axs[i].set_title(np.round(lambdas[i], 2))"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": null,
442
+ "id": "e0f31e0b",
443
+ "metadata": {},
444
+ "outputs": [],
445
+ "source": [
446
+ "from sklearn import svm\n",
447
+ "\n",
448
+ "y_h_cat = pd.cut(y_h,bins=[x*256/6 if x<6 else 256 for x in range(7)],labels=['Red', 'Yellow', 'Green', 'Blue',\n",
449
+ " 'Purple', 'Pink']).fillna('Red')\n",
450
+ "\n",
451
+ "print(y_h_cat.value_counts(dropna=False))\n",
452
+ "x_trainhc, x_valhc, y_trainhc, y_valhc = train_test_split(X, y_h_cat, test_size=0.2)\n",
453
+ "\n",
454
+ "clf6 = svm.LinearSVC().fit(x_trainhc, y_trainhc)\n",
455
+ "print('Val performance logistic regression', np.round(clf6.score(x_valhc, y_valhc),2))\n"
456
+ ]
457
+ },
458
+ {
459
+ "cell_type": "code",
460
+ "execution_count": null,
461
+ "id": "f5f28b41",
462
+ "metadata": {},
463
+ "outputs": [],
464
+ "source": [
465
+ "warm_blue = clf6.coef_[1, :] / np.linalg.norm(clf6.coef_[1, :])\n",
466
+ "\n",
467
+ "images, lambdas = regenerate_images(model, original_image_vec, warm_blue, min_epsilon=-(int(10)), max_epsilon=int(10), count=7, latent_space='W')\n",
468
+ "fig, axs = plt.subplots(1, len(images), figsize=(30,200))\n",
469
+ "for i,im in enumerate(images):\n",
470
+ " axs[i].imshow(im)\n",
471
+ " axs[i].set_title(np.round(lambdas[i], 2))"
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "markdown",
476
+ "id": "4e0c7808",
477
+ "metadata": {},
478
+ "source": [
479
+ "## dimensionality reduction e vediamo dove finiscono i vari colori"
480
+ ]
481
+ },
482
+ {
483
+ "cell_type": "markdown",
484
+ "id": "833ed31f",
485
+ "metadata": {},
486
+ "source": [
487
+ "## clustering per vedere quali sono i centroid di questo spazio e se ci sono regioni determinate dai colori"
488
+ ]
489
+ },
490
+ {
491
+ "cell_type": "code",
492
+ "execution_count": null,
493
+ "id": "7c19e820",
494
+ "metadata": {},
495
+ "outputs": [],
496
+ "source": []
497
+ }
498
+ ],
499
+ "metadata": {
500
+ "kernelspec": {
501
+ "display_name": "Python 3",
502
+ "language": "python",
503
+ "name": "python3"
504
+ },
505
+ "language_info": {
506
+ "codemirror_mode": {
507
+ "name": "ipython",
508
+ "version": 3
509
+ },
510
+ "file_extension": ".py",
511
+ "mimetype": "text/x-python",
512
+ "name": "python",
513
+ "nbconvert_exporter": "python",
514
+ "pygments_lexer": "ipython3",
515
+ "version": "3.8.16"
516
+ }
517
+ },
518
+ "nbformat": 4,
519
+ "nbformat_minor": 5
520
+ }
pages/1_Omniart_Disentanglement.py DELETED
@@ -1,202 +0,0 @@
1
- import streamlit as st
2
- import pickle
3
- import pandas as pd
4
- import numpy as np
5
- import random
6
- import torch
7
-
8
- from matplotlib.backends.backend_agg import RendererAgg
9
-
10
- from backend.disentangle_concepts import *
11
- import torch_utils
12
- import dnnlib
13
- import legacy
14
-
15
- _lock = RendererAgg.lock
16
-
17
-
18
- st.set_page_config(layout='wide')
19
- BACKGROUND_COLOR = '#bcd0e7'
20
- SECONDARY_COLOR = '#bce7db'
21
-
22
-
23
- st.title('Disentanglement studies')
24
- st.write('> **What concepts can be disentangled in the latent spae of a model?**')
25
- st.write("""Explanation on the functionalities to come.""")
26
-
27
- instruction_text = """Instruction to input:
28
- 1. Choosing concept:
29
- 2. Choosing image: Users can choose a specific image by entering **Image ID** and hit the _Choose the defined image_ button or can generate an image randomly by hitting the _Generate a random image_ button.
30
- 3. Choosing epsilon: **Epsilon** is the lambda amount of translation along the disentangled concept axis. A negative epsilon changes the image in the direction of the concept, a positive one pushes the image away from the concept.
31
- """
32
- st.write("To use the functionality below, users need to input the **concept** to disentangle, an **image** id and the **epsilon** of variation along the disentangled axis.")
33
- with st.expander("See more instruction", expanded=False):
34
- st.write(instruction_text)
35
-
36
-
37
- annotations_file = './data/annotated_files/seeds0000-50000.pkl'
38
- with open(annotations_file, 'rb') as f:
39
- annotations = pickle.load(f)
40
-
41
- ann_df = pd.read_csv('./data/annotated_files/sim_seeds0000-50000.csv')
42
- concepts = './data/concepts.txt'
43
-
44
- with open(concepts) as f:
45
- labels = [line.strip() for line in f.readlines()]
46
-
47
- if 'image_id' not in st.session_state:
48
- st.session_state.image_id = 0
49
- if 'projection' not in st.session_state:
50
- st.session_state.projection = False
51
- if 'concept_id' not in st.session_state:
52
- st.session_state.concept_id = 'Abstract'
53
- if 'space_id' not in st.session_state:
54
- st.session_state.space_id = 'Z'
55
-
56
- # def on_change_random_input():
57
- # st.session_state.image_id = st.session_state.image_id
58
-
59
- # ----------------------------- INPUT ----------------------------------
60
- st.header('Input')
61
- input_col_1, input_col_2, input_col_3 = st.columns(3)
62
- # --------------------------- INPUT column 1 ---------------------------
63
- with input_col_1:
64
- with st.form('text_form'):
65
-
66
- # image_id = st.number_input('Image ID: ', format='%d', step=1)
67
- st.write('**Choose a concept to disentangle**')
68
- # chosen_text_id_input = st.empty()
69
- # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
70
- concept_id = st.selectbox('Concept:', tuple(labels))
71
-
72
- st.write('**Choose a latent space to disentangle**')
73
- # chosen_text_id_input = st.empty()
74
- # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
75
- space_id = st.selectbox('Space:', tuple(['Z', 'W']))
76
-
77
- choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
78
- # random_text = st.form_submit_button('Select a random concept')
79
-
80
- # if random_text:
81
- # concept_id = random.choice(labels)
82
- # st.session_state.concept_id = concept_id
83
- # chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
84
-
85
- if choose_text_button:
86
- concept_id = str(concept_id)
87
- st.session_state.concept_id = concept_id
88
- space_id = str(space_id)
89
- st.session_state.space_id = space_id
90
- # st.write(image_id, st.session_state.image_id)
91
-
92
- # ---------------------------- SET UP OUTPUT ------------------------------
93
- epsilon_container = st.empty()
94
- st.header('Output')
95
- st.subheader('Concept vector')
96
-
97
- # perform attack container
98
- # header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
99
- # output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
100
- header_col_1, header_col_2 = st.columns([5,1])
101
- output_col_1, output_col_2 = st.columns([5,1])
102
-
103
- st.subheader('Derivations along the concept vector')
104
-
105
- # prediction error container
106
- error_container = st.empty()
107
- smoothgrad_header_container = st.empty()
108
-
109
- # smoothgrad container
110
- smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
111
- smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
112
-
113
- # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
114
- with output_col_1:
115
- separation_vector, number_important_features, imp_nodes, performance = get_separation_space(concept_id, annotations, ann_df, latent_space=st.session_state.space_id)
116
- # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
117
- st.write('Concept vector', separation_vector)
118
- header_col_1.write(f'Concept {concept_id} - Space {st.session_state.space_id} - Number of relevant nodes: {number_important_features} - Val classification performance: {performance}')# - Nodes {",".join(list(imp_nodes))}')
119
-
120
- # ----------------------------- INPUT column 2 & 3 ----------------------------
121
- with input_col_2:
122
- with st.form('image_form'):
123
-
124
- # image_id = st.number_input('Image ID: ', format='%d', step=1)
125
- st.write('**Choose or generate a random image to test the disentanglement**')
126
- chosen_image_id_input = st.empty()
127
- image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
128
-
129
- choose_image_button = st.form_submit_button('Choose the defined image')
130
- random_id = st.form_submit_button('Generate a random image')
131
- projection_id = st.form_submit_button('Generate an image on the boudary')
132
-
133
- if random_id or projection_id:
134
- image_id = random.randint(0, 50000)
135
- st.session_state.image_id = image_id
136
- chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
137
- st.session_state.projection = False
138
-
139
- if projection_id:
140
- st.session_state.projection = True
141
-
142
- if choose_image_button:
143
- image_id = int(image_id)
144
- st.session_state.image_id = int(image_id)
145
- # st.write(image_id, st.session_state.image_id)
146
-
147
- with input_col_3:
148
- with st.form('Variate along the disentangled concept'):
149
- st.write('**Set range of change**')
150
- chosen_epsilon_input = st.empty()
151
- epsilon = chosen_epsilon_input.number_input('Lambda:', min_value=1, step=1)
152
- epsilon_button = st.form_submit_button('Choose the defined lambda')
153
- st.write('**Select hierarchical levels to manipulate**')
154
- layers = st.multiselect('Layers:', tuple(range(14)))
155
- if len(layers) == 0:
156
- layers = None
157
- print(layers)
158
- layers_button = st.form_submit_button('Choose the defined layers')
159
-
160
-
161
- # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
162
-
163
- #model = torch.load('./data/model_files/pytorch_model.bin', map_location=torch.device('cpu'))
164
- with dnnlib.util.open_url('./data/model_files/network-snapshot-010600.pkl') as f:
165
- model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
166
-
167
- if st.session_state.space_id == 'Z':
168
- original_image_vec = annotations['z_vectors'][st.session_state.image_id]
169
- else:
170
- original_image_vec = annotations['w_vectors'][st.session_state.image_id]
171
-
172
- if st.session_state.projection:
173
- original_image_vec = original_image_vec - np.dot(original_image_vec, separation_vector.T) * separation_vector
174
- print(original_image_vec.shape)
175
-
176
- img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
177
- # input_image = original_image_dict['image']
178
- # input_label = original_image_dict['label']
179
- # input_id = original_image_dict['id']
180
-
181
- with smoothgrad_col_3:
182
- st.image(img)
183
- smooth_head_3.write(f'Base image')
184
-
185
-
186
- images, lambdas = regenerate_images(model, original_image_vec, separation_vector, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon), latent_space=st.session_state.space_id, layers=layers)
187
-
188
- with smoothgrad_col_1:
189
- st.image(images[0])
190
- smooth_head_1.write(f'Change of {np.round(lambdas[0], 2)}')
191
-
192
- with smoothgrad_col_2:
193
- st.image(images[1])
194
- smooth_head_2.write(f'Change of {np.round(lambdas[1], 2)}')
195
-
196
- with smoothgrad_col_4:
197
- st.image(images[3])
198
- smooth_head_4.write(f'Change of {np.round(lambdas[3], 2)}')
199
-
200
- with smoothgrad_col_5:
201
- st.image(images[4])
202
- smooth_head_5.write(f'Change of {np.round(lambdas[4], 2)}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/{5_Textiles_Disentanglement.py → 1_Textiles_Disentanglement.py} RENAMED
@@ -23,18 +23,18 @@ SECONDARY_COLOR = '#bce7db'
23
  st.title('Disentanglement studies on the Textile Dataset')
24
  st.markdown(
25
  """
26
- This is a demo of the Disentanglement studies on the [Oxford Vases Dataset](https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/).
27
  """,
28
  unsafe_allow_html=False,)
29
 
30
- annotations_file = './data/vase_annotated_files/seeds0000-20000.pkl'
31
  with open(annotations_file, 'rb') as f:
32
  annotations = pickle.load(f)
33
 
34
-
35
  if 'image_id' not in st.session_state:
36
  st.session_state.image_id = 0
37
- if 'concept_ids' not in st.session_state:
38
  st.session_state.concept_ids =['AMPHORA']
39
  if 'space_id' not in st.session_state:
40
  st.session_state.space_id = 'W'
 
23
  st.title('Disentanglement studies on the Textile Dataset')
24
  st.markdown(
25
  """
26
+ This is a demo of the Disentanglement studies on the [iMET Textiles Dataset](https://www.metmuseum.org/art/collection/search/85531).
27
  """,
28
  unsafe_allow_html=False,)
29
 
30
+ annotations_file = './data/textile_annotated_files/seeds0000-100000_S.pkl'
31
  with open(annotations_file, 'rb') as f:
32
  annotations = pickle.load(f)
33
 
34
+ COLORS_LIST = []
35
  if 'image_id' not in st.session_state:
36
  st.session_state.image_id = 0
37
+ if 'color_ids' not in st.session_state:
38
  st.session_state.concept_ids =['AMPHORA']
39
  if 'space_id' not in st.session_state:
40
  st.session_state.space_id = 'W'
pages/{2_Concepts_comparison.py → 2_Colours_comparison.py} RENAMED
File without changes
pages/3_Oxford_Vases_Disentanglement.py DELETED
@@ -1,178 +0,0 @@
1
- import streamlit as st
2
- import pickle
3
- import pandas as pd
4
- import numpy as np
5
- import random
6
- import torch
7
-
8
- from matplotlib.backends.backend_agg import RendererAgg
9
-
10
- from backend.disentangle_concepts import *
11
- import torch_utils
12
- import dnnlib
13
- import legacy
14
-
15
- _lock = RendererAgg.lock
16
-
17
-
18
- st.set_page_config(layout='wide')
19
- BACKGROUND_COLOR = '#bcd0e7'
20
- SECONDARY_COLOR = '#bce7db'
21
-
22
-
23
- st.title('Disentanglement studies on the Oxford Vases Dataset')
24
- st.markdown(
25
- """
26
- This is a demo of the Disentanglement studies on the [Oxford Vases Dataset](https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/).
27
- """,
28
- unsafe_allow_html=False,)
29
-
30
- annotations_file = './data/vase_annotated_files/seeds0000-20000.pkl'
31
- with open(annotations_file, 'rb') as f:
32
- annotations = pickle.load(f)
33
-
34
-
35
- if 'image_id' not in st.session_state:
36
- st.session_state.image_id = 0
37
- if 'concept_ids' not in st.session_state:
38
- st.session_state.concept_ids =['AMPHORA']
39
- if 'space_id' not in st.session_state:
40
- st.session_state.space_id = 'W'
41
-
42
- # def on_change_random_input():
43
- # st.session_state.image_id = st.session_state.image_id
44
-
45
- # ----------------------------- INPUT ----------------------------------
46
- st.header('Input')
47
- input_col_1, input_col_2, input_col_3 = st.columns(3)
48
- # --------------------------- INPUT column 1 ---------------------------
49
- with input_col_1:
50
- with st.form('text_form'):
51
-
52
- # image_id = st.number_input('Image ID: ', format='%d', step=1)
53
- st.write('**Choose two options to disentangle**')
54
- type_col = st.selectbox('Concept category:', tuple(['Provenance', 'Shape Name', 'Fabric', 'Technique']))
55
-
56
- ann_df = pd.read_csv(f'./data/vase_annotated_files/sim_{type_col}_seeds0000-20000.csv')
57
- labels = list(ann_df.columns)
58
- labels.remove('ID')
59
- labels.remove('Unnamed: 0')
60
-
61
- concept_ids = st.multiselect('Concepts:', tuple(labels), max_selections=2, default=[labels[2], labels[3]])
62
-
63
- st.write('**Choose a latent space to disentangle**')
64
- space_id = st.selectbox('Space:', tuple(['W', 'Z']))
65
-
66
- choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
67
-
68
- if choose_text_button:
69
- concept_ids = list(concept_ids)
70
- st.session_state.concept_ids = concept_ids
71
- space_id = str(space_id)
72
- st.session_state.space_id = space_id
73
- # st.write(image_id, st.session_state.image_id)
74
-
75
- # ---------------------------- SET UP OUTPUT ------------------------------
76
- epsilon_container = st.empty()
77
- st.header('Output')
78
- st.subheader('Concept vector')
79
-
80
- # perform attack container
81
- # header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
82
- # output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
83
- header_col_1, header_col_2 = st.columns([5,1])
84
- output_col_1, output_col_2 = st.columns([5,1])
85
-
86
- st.subheader('Derivations along the concept vector')
87
-
88
- # prediction error container
89
- error_container = st.empty()
90
- smoothgrad_header_container = st.empty()
91
-
92
- # smoothgrad container
93
- smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
94
- smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
95
-
96
- # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
97
- with output_col_1:
98
- separation_vector, number_important_features, imp_nodes, performance = get_separation_space(concept_ids, annotations, ann_df, latent_space=st.session_state.space_id, samples=150)
99
- # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
100
- st.write('Concept vector', separation_vector)
101
- header_col_1.write(f'Concept {st.session_state.concept_ids} - Space {st.session_state.space_id} - Number of relevant nodes: {number_important_features} - Val classification performance: {performance}')# - Nodes {",".join(list(imp_nodes))}')
102
-
103
- # ----------------------------- INPUT column 2 & 3 ----------------------------
104
- with input_col_2:
105
- with st.form('image_form'):
106
-
107
- # image_id = st.number_input('Image ID: ', format='%d', step=1)
108
- st.write('**Choose or generate a random image to test the disentanglement**')
109
- chosen_image_id_input = st.empty()
110
- image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
111
-
112
- choose_image_button = st.form_submit_button('Choose the defined image')
113
- random_id = st.form_submit_button('Generate a random image')
114
-
115
- if random_id:
116
- image_id = random.randint(0, 20000)
117
- st.session_state.image_id = image_id
118
- chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
119
-
120
- if choose_image_button:
121
- image_id = int(image_id)
122
- st.session_state.image_id = int(image_id)
123
- # st.write(image_id, st.session_state.image_id)
124
-
125
- with input_col_3:
126
- with st.form('Variate along the disentangled concept'):
127
- st.write('**Set range of change**')
128
- chosen_epsilon_input = st.empty()
129
- epsilon = chosen_epsilon_input.number_input('Lambda:', min_value=1, step=1)
130
- epsilon_button = st.form_submit_button('Choose the defined lambda')
131
- st.write('**Select hierarchical levels to manipulate**')
132
- layers = st.multiselect('Layers:', tuple(range(14)))
133
- if len(layers) == 0:
134
- layers = None
135
- print(layers)
136
- layers_button = st.form_submit_button('Choose the defined layers')
137
-
138
-
139
- # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
140
-
141
- #model = torch.load('./data/model_files/pytorch_model.bin', map_location=torch.device('cpu'))
142
- with dnnlib.util.open_url('./data/vase_model_files/network-snapshot-003800.pkl') as f:
143
- model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
144
-
145
- if st.session_state.space_id == 'Z':
146
- original_image_vec = annotations['z_vectors'][st.session_state.image_id]
147
- else:
148
- original_image_vec = annotations['w_vectors'][st.session_state.image_id]
149
-
150
- img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
151
-
152
- top_pred = ann_df.loc[st.session_state.image_id, labels].astype(float).idxmax()
153
- # input_image = original_image_dict['image']
154
- # input_label = original_image_dict['label']
155
- # input_id = original_image_dict['id']
156
-
157
- with smoothgrad_col_3:
158
- st.image(img)
159
- smooth_head_3.write(f'Base image, predicted as {top_pred}')
160
-
161
-
162
- images, lambdas = regenerate_images(model, original_image_vec, separation_vector, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon), latent_space=st.session_state.space_id, layers=layers)
163
-
164
- with smoothgrad_col_1:
165
- st.image(images[0])
166
- smooth_head_1.write(f'Change of {np.round(lambdas[0], 2)}')
167
-
168
- with smoothgrad_col_2:
169
- st.image(images[1])
170
- smooth_head_2.write(f'Change of {np.round(lambdas[1], 2)}')
171
-
172
- with smoothgrad_col_4:
173
- st.image(images[3])
174
- smooth_head_4.write(f'Change of {np.round(lambdas[3], 2)}')
175
-
176
- with smoothgrad_col_5:
177
- st.image(images[4])
178
- smooth_head_5.write(f'Change of {np.round(lambdas[4], 2)}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pages/4_Vase_Qualities_Comparison copy.py DELETED
@@ -1,268 +0,0 @@
1
- import streamlit as st
2
- import streamlit.components.v1 as components
3
-
4
- import dnnlib
5
- import legacy
6
-
7
- import pickle
8
- import pandas as pd
9
- import numpy as np
10
- from pyvis.network import Network
11
-
12
- import random
13
- from sklearn.metrics.pairwise import cosine_similarity
14
-
15
- from matplotlib.backends.backend_agg import RendererAgg
16
-
17
- from backend.disentangle_concepts import *
18
-
19
- _lock = RendererAgg.lock
20
-
21
- HIGHTLIGHT_COLOR = '#e7bcc5'
22
- st.set_page_config(layout='wide')
23
-
24
-
25
- st.title('Comparison among concept vectors')
26
- st.write('> **How do the concept vectors relate to each other?**')
27
- st.write('> **What is their join impact on the image?**')
28
- st.write("""Description to write""")
29
-
30
-
31
- annotations_file = './data/vase_annotated_files/seeds0000-20000.pkl'
32
- with open(annotations_file, 'rb') as f:
33
- annotations = pickle.load(f)
34
-
35
- if 'image_id' not in st.session_state:
36
- st.session_state.image_id = 0
37
- if 'concept_ids' not in st.session_state:
38
- st.session_state.concept_ids = ['Provenance ADRIA']
39
- if 'space_id' not in st.session_state:
40
- st.session_state.space_id = 'Z'
41
- if 'type_col' not in st.session_state:
42
- st.session_state.type_col = 'Provenance'
43
-
44
- # def on_change_random_input():
45
- # st.session_state.image_id = st.session_state.image_id
46
-
47
- # ----------------------------- INPUT ----------------------------------
48
- st.header('Input')
49
- input_col_1, input_col_2, input_col_3 = st.columns(3)
50
- # --------------------------- INPUT column 1 ---------------------------
51
- with input_col_1:
52
- with st.form('text_form'):
53
-
54
- # image_id = st.number_input('Image ID: ', format='%d', step=1)
55
- st.write('**Choose two options to disentangle**')
56
- type_col = st.selectbox('Concept category:', tuple(['Provenance', 'Shape Name', 'Fabric', 'Technique']))
57
-
58
- ann_df = pd.read_csv(f'./data/vase_annotated_files/sim_{type_col}_seeds0000-20000.csv')
59
- labels = list(ann_df.columns)
60
- labels.remove('ID')
61
- labels.remove('Unnamed: 0')
62
-
63
- concept_ids = st.multiselect('Concepts:', tuple(labels), default=[labels[2], labels[3]])
64
-
65
- st.write('**Choose a latent space to disentangle**')
66
- # chosen_text_id_input = st.empty()
67
- # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
68
- space_id = st.selectbox('Space:', tuple(['Z', 'W']))
69
-
70
- choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
71
-
72
- if choose_text_button:
73
- st.session_state.concept_ids = list(concept_ids)
74
- space_id = str(space_id)
75
- st.session_state.space_id = space_id
76
- st.session_state.type_col = type_col
77
- # st.write(image_id, st.session_state.image_id)
78
-
79
- # ---------------------------- SET UP OUTPUT ------------------------------
80
- epsilon_container = st.empty()
81
- st.header('Output')
82
- st.subheader('Concept vector')
83
-
84
- # perform attack container
85
- # header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
86
- # output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
87
- header_col_1, header_col_2 = st.columns([1,1])
88
- output_col_1, output_col_2 = st.columns([1,1])
89
-
90
- st.subheader('Derivations along the concept vector')
91
-
92
- # prediction error container
93
- error_container = st.empty()
94
- smoothgrad_header_container = st.empty()
95
-
96
- # smoothgrad container
97
- smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
98
- smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
99
-
100
- # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
101
- with output_col_1:
102
- vectors, nodes_in_common, performances = get_concepts_vectors(concept_ids, annotations, ann_df, latent_space=space_id)
103
- header_col_1.write(f'Concepts {", ".join(concept_ids)} - Latent space {space_id} - Relevant nodes in common: {nodes_in_common} - Performance of the concept vectors: {performances}')# - Nodes {",".join(list(imp_nodes))}')
104
-
105
- edges = []
106
- for i in range(len(concept_ids)):
107
- for j in range(len(concept_ids)):
108
- if i != j:
109
- print(f'Similarity between {concept_ids[i]} and {concept_ids[j]}')
110
- similarity = cosine_similarity(vectors[i,:].reshape(1, -1), vectors[j,:].reshape(1, -1))
111
- print(np.round(similarity[0][0], 3))
112
- edges.append((concept_ids[i], concept_ids[j], np.round(similarity[0][0], 3)))
113
-
114
-
115
- net = Network(height="750px", width="100%",)
116
- for e in edges:
117
- src = e[0]
118
- dst = e[1]
119
- w = e[2]
120
-
121
- net.add_node(src, src, title=src)
122
- net.add_node(dst, dst, title=dst)
123
- net.add_edge(src, dst, value=w, title=src + ' to ' + dst + ' similarity ' +str(w))
124
-
125
- # Generate network with specific layout settings
126
- net.repulsion(
127
- node_distance=420,
128
- central_gravity=0.33,
129
- spring_length=110,
130
- spring_strength=0.10,
131
- damping=0.95
132
- )
133
-
134
- # Save and read graph as HTML file (on Streamlit Sharing)
135
- try:
136
- path = '/tmp'
137
- net.save_graph(f'{path}/pyvis_graph.html')
138
- HtmlFile = open(f'{path}/pyvis_graph.html', 'r', encoding='utf-8')
139
-
140
- # Save and read graph as HTML file (locally)
141
- except:
142
- path = '/html_files'
143
- net.save_graph(f'{path}/pyvis_graph.html')
144
- HtmlFile = open(f'{path}/pyvis_graph.html', 'r', encoding='utf-8')
145
-
146
- # Load HTML file in HTML component for display on Streamlit page
147
- components.html(HtmlFile.read(), height=435)
148
-
149
- with output_col_2:
150
- with open('data/CLIP_vecs_vases.pkl', 'rb') as f:
151
- vectors_CLIP = pickle.load(f)
152
-
153
- # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
154
- #st.write('Concept vector', separation_vector)
155
- header_col_2.write(f'Concepts {", ".join(concept_ids)} - Latent space CLIP')# - Nodes {",".join(list(imp_nodes))}')
156
-
157
- edges_clip = []
158
- for c1 in concept_ids:
159
- for c2 in concept_ids:
160
- if c1 != c2:
161
-
162
- print(f'Similarity between {c1} and {c2}')
163
- similarity = cosine_similarity(vectors_CLIP[st.session_state.type_col + ' ' + c1].reshape(1, -1), vectors_CLIP[st.session_state.type_col + ' ' + c2].reshape(1, -1))
164
- print(np.round(similarity[0][0], 3))
165
- edges_clip.append((c1, c2, np.round(float(np.round(similarity[0][0], 3)), 3)))
166
-
167
-
168
- net_clip = Network(height="750px", width="100%",)
169
- for e in edges_clip:
170
- src = e[0]
171
- dst = e[1]
172
- w = e[2]
173
-
174
- net_clip.add_node(src, src, title=src)
175
- net_clip.add_node(dst, dst, title=dst)
176
- net_clip.add_edge(src, dst, value=w, title=src + ' to ' + dst + ' similarity ' +str(w))
177
-
178
- # Generate network with specific layout settings
179
- net_clip.repulsion(
180
- node_distance=420,
181
- central_gravity=0.33,
182
- spring_length=110,
183
- spring_strength=0.10,
184
- damping=0.95
185
- )
186
-
187
- # Save and read graph as HTML file (on Streamlit Sharing)
188
- try:
189
- path = '/tmp'
190
- net_clip.save_graph(f'{path}/pyvis_graph_clip.html')
191
- HtmlFile = open(f'{path}/pyvis_graph_clip.html', 'r', encoding='utf-8')
192
-
193
- # Save and read graph as HTML file (locally)
194
- except:
195
- path = '/html_files'
196
- net_clip.save_graph(f'{path}/pyvis_graph_clip.html')
197
- HtmlFile = open(f'{path}/pyvis_graph_clip.html', 'r', encoding='utf-8')
198
-
199
- # Load HTML file in HTML component for display on Streamlit page
200
- components.html(HtmlFile.read(), height=435)
201
-
202
- # ----------------------------- INPUT column 2 & 3 ----------------------------
203
- with input_col_2:
204
- with st.form('image_form'):
205
-
206
- # image_id = st.number_input('Image ID: ', format='%d', step=1)
207
- st.write('**Choose or generate a random image to test the disentanglement**')
208
- chosen_image_id_input = st.empty()
209
- image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
210
-
211
- choose_image_button = st.form_submit_button('Choose the defined image')
212
- random_id = st.form_submit_button('Generate a random image')
213
-
214
- if random_id:
215
- image_id = random.randint(0, 50000)
216
- st.session_state.image_id = image_id
217
- chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
218
-
219
- if choose_image_button:
220
- image_id = int(image_id)
221
- st.session_state.image_id = int(image_id)
222
- # st.write(image_id, st.session_state.image_id)
223
-
224
- with input_col_3:
225
- with st.form('Variate along the disentangled concepts'):
226
- st.write('**Set range of change**')
227
- chosen_epsilon_input = st.empty()
228
- epsilon = chosen_epsilon_input.number_input('Epsilon:', min_value=1, step=1)
229
- epsilon_button = st.form_submit_button('Choose the defined epsilon')
230
-
231
- # # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
232
-
233
-
234
- with dnnlib.util.open_url('./data/vase_model_files/network-snapshot-003800.pkl') as f:
235
- model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
236
-
237
- if st.session_state.space_id == 'Z':
238
- original_image_vec = annotations['z_vectors'][st.session_state.image_id]
239
- else:
240
- original_image_vec = annotations['w_vectors'][st.session_state.image_id]
241
-
242
- img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
243
- # input_image = original_image_dict['image']
244
- # input_label = original_image_dict['label']
245
- # input_id = original_image_dict['id']
246
-
247
- with smoothgrad_col_3:
248
- st.image(img)
249
- smooth_head_3.write(f'Base image')
250
-
251
-
252
- images, lambdas = generate_joint_effect(model, original_image_vec, vectors, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon), latent_space=st.session_state.space_id)
253
-
254
- with smoothgrad_col_1:
255
- st.image(images[0])
256
- smooth_head_1.write(f'Change of {np.round(lambdas[0], 2)}')
257
-
258
- with smoothgrad_col_2:
259
- st.image(images[1])
260
- smooth_head_2.write(f'Change of {np.round(lambdas[1], 2)}')
261
-
262
- with smoothgrad_col_4:
263
- st.image(images[3])
264
- smooth_head_4.write(f'Change of {np.round(lambdas[3], 2)}')
265
-
266
- with smoothgrad_col_5:
267
- st.image(images[4])
268
- smooth_head_5.write(f'Change of {np.round(lambdas[4], 2)}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
structure_annotations.ipynb ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os \n",
10
+ "from glob import glob \n",
11
+ "import pandas as pd\n",
12
+ "import numpy as np\n",
13
+ "\n",
14
+ "from PIL import Image, ImageColor\n",
15
+ "import matplotlib.pyplot as plt\n",
16
+ "\n",
17
+ "import torch\n",
18
+ "\n",
19
+ "from backend.disentangle_concepts import *\n",
20
+ "import dnnlib \n",
21
+ "import legacy\n",
22
+ "from backend.color_annotations import *\n",
23
+ "\n",
24
+ "\n",
25
+ "%load_ext autoreload\n",
26
+ "%autoreload 2"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "images_textiles = glob('/Users/ludovicaschaerf/Desktop/TextAIles/TextileGAN/Original Textiles/*')"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": null,
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "images = []\n",
45
+ "\n",
46
+ "for step in [10, 20]:\n",
47
+ " imh = np.zeros((200, 200))\n",
48
+ " imv = np.zeros((200, 200))\n",
49
+ " imb = np.ones((200, 200))\n",
50
+ " imb2 = np.ones((200, 200))\n",
51
+ " \n",
52
+ " for x,y in zip(range(0,200, step*2),range(step, 200, step*2)):\n",
53
+ " imh[x:y, :] = 255\n",
54
+ " imv[:, x:y] = 255\n",
55
+ " imb[x:y, :] = 0\n",
56
+ " imb[:, x:y] = 0\n",
57
+ " imb2[x:y, :] = 0\n",
58
+ " imb2[:, x*2:y*2] = 0\n",
59
+ " \n",
60
+ " images.append(imh) \n",
61
+ " images.append(imb) \n",
62
+ " images.append(imb2) \n",
63
+ " images.append(imv) \n",
64
+ "\n",
65
+ "for im in images:\n",
66
+ " plt.imshow(im, cmap='gray')\n",
67
+ " plt.title('Original Image')\n",
68
+ " plt.show()"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "metadata": {},
75
+ "outputs": [],
76
+ "source": [
77
+ "def get_freqs(image):\n",
78
+ " \n",
79
+ " # Load the image\n",
80
+ " # image = cv2.imread(image, cv2.IMREAD_GRAYSCALE)\n",
81
+ " # Center crop the image to remove black borders\n",
82
+ " height, width = image.shape\n",
83
+ " short_side = min(height, width)\n",
84
+ " crop_size = min(200, short_side)\n",
85
+ " center_x = width // 2\n",
86
+ " center_y = height // 2\n",
87
+ " crop_half_size = crop_size // 2\n",
88
+ " image = image[center_y - crop_half_size:center_y + crop_half_size,\n",
89
+ " center_x - crop_half_size:center_x + crop_half_size]\n",
90
+ "\n",
91
+ " # kernel = np.ones((5,5), np.float32) / 25\n",
92
+ " # image = cv2.filter2D( image, -1, kernel)\n",
93
+ " # image = cv2.GaussianBlur(image,(5,5),0)\n",
94
+ " # image = image - cv2.GaussianBlur(image, (21, 21), 1) + 127\n",
95
+ "\n",
96
+ " # print(np.unique(image))\n",
97
+ " # Perform 1D Fourier Transforms\n",
98
+ " horizontal_freq = np.fft.fftshift(np.fft.fft(image, axis=1))\n",
99
+ " vertical_freq = np.fft.fftshift(np.fft.fft(image, axis=0))\n",
100
+ " twod = np.fft.fftshift(np.fft.fft2(image))\n",
101
+ " \n",
102
+ " # Calculate corresponding frequencies\n",
103
+ " num_cols = image.shape[1]\n",
104
+ " num_rows = image.shape[0]\n",
105
+ " # horizontal_freqs = np.fft.fftshift(np.fft.fftfreq(num_cols))\n",
106
+ " # vertical_freqs = np.fft.fftshift(np.fft.fftfreq(num_rows))\n",
107
+ " \n",
108
+ " horizontal_freqs = np.fft.fftfreq(num_cols)\n",
109
+ " vertical_freqs = np.fft.fftfreq(num_rows)\n",
110
+ " \n",
111
+ " # # Sum power along the second axis\n",
112
+ " # twod = twod.real*twod.real + twod.imag*twod.imag\n",
113
+ " # twod = twod.sum(axis=1)/twod.shape[1]\n",
114
+ "\n",
115
+ " # # Round up the size along this axis to an even number\n",
116
+ " # n = int(math.ceil(image.shape[0] / 2.) * 2 )\n",
117
+ "\n",
118
+ " # # Generate a list of frequencies\n",
119
+ " # f = np.fft.fftfreq(n)\n",
120
+ "\n",
121
+ " # # Graph it\n",
122
+ " # plt.plot(f[1:],a[1:], label = 'sum of amplitudes over y vs f_x')\n",
123
+ "\n",
124
+ " \n",
125
+ "\n",
126
+ " # Calculate magnitude spectra\n",
127
+ " horizontal_magnitudes = np.abs(horizontal_freq)[0]\n",
128
+ " vertical_magnitudes = np.abs(vertical_freq)[0]\n",
129
+ " \n",
130
+ " # # Find peaks in the magnitudes\n",
131
+ " # horizontal_peaks, _ = find_peaks(horizontal_magnitudes, height=100) # Adjust height threshold as needed\n",
132
+ " # vertical_peaks, _ = find_peaks(vertical_magnitudes, height=10) # Adjust height threshold as needed\n",
133
+ "\n",
134
+ " print('Median horizontal frequency', np.median(horizontal_magnitudes), ', median vertical frequency', np.median(vertical_magnitudes))\n",
135
+ " # Plot frequency analysis with peaks\n",
136
+ " plt.figure(figsize=(25, 5))\n",
137
+ "\n",
138
+ " # Plot frequency analysis\n",
139
+ " plt.subplot(1, 6, 1)\n",
140
+ " plt.imshow(image, cmap='gray')\n",
141
+ " plt.title('Original Image')\n",
142
+ "\n",
143
+ " plt.subplot(1, 6, 2)\n",
144
+ " plt.imshow(np.log(1 + np.abs(horizontal_freq)), cmap='gray')\n",
145
+ " plt.title('Horizontal Frequency Analysis')\n",
146
+ "\n",
147
+ " plt.subplot(1, 6, 3)\n",
148
+ " plt.imshow(np.log(1 + np.abs(vertical_freq)), cmap='gray')\n",
149
+ " plt.title('Vertical Frequency Analysis')\n",
150
+ "\n",
151
+ " plt.subplot(1, 6, 4)\n",
152
+ " plt.imshow(np.log(1 + np.abs(twod)), cmap='gray')\n",
153
+ " plt.title('2D Frequency Analysis')\n",
154
+ "\n",
155
+ " plt.subplot(1, 6, 5)\n",
156
+ " plt.scatter(horizontal_freqs[1:], horizontal_magnitudes[1:], s=5)\n",
157
+ " # plt.plot(horizontal_freqs[horizontal_peaks], horizontal_magnitudes[horizontal_peaks], 'ro', markersize=5)\n",
158
+ " plt.xlabel('Horizontal Frequency')\n",
159
+ " plt.ylabel('Magnitude')\n",
160
+ " plt.title('Horizontal Frequency Analysis with Peaks')\n",
161
+ "\n",
162
+ " plt.subplot(1, 6, 6)\n",
163
+ " plt.scatter(vertical_freqs[1:], vertical_magnitudes[1:], s=5)\n",
164
+ " # plt.plot(vertical_freqs[vertical_peaks], vertical_magnitudes[vertical_peaks], 'ro', markersize=5)\n",
165
+ " plt.xlabel('Vertical Frequency')\n",
166
+ " plt.ylabel('Magnitude')\n",
167
+ " plt.title('Vertical Frequency Analysis with Peaks')\n",
168
+ "\n",
169
+ "\n",
170
+ " plt.show()\n"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "markdown",
175
+ "metadata": {},
176
+ "source": [
177
+ "When the absolute value of the frequency is high in the context of Fourier analysis, it indicates the presence of a rapidly changing or repeating pattern in the image. The frequency component's absolute value reflects how many cycles of the pattern occur within a fixed interval.\n",
178
+ "\n",
179
+ "In Fourier analysis:\n",
180
+ "\n",
181
+ "Higher Frequency Components: Higher absolute frequency values correspond to faster changes or repetitions in the image data. This means that the pattern or feature represented by that frequency component oscillates more rapidly across the image.\n",
182
+ "\n",
183
+ "Spatial Frequency: In image analysis, frequency is often associated with how quickly the intensity or color changes as you move across the image. High absolute frequency values indicate rapid changes or transitions in the image content.\n",
184
+ "\n",
185
+ "Fine Details and Textures: Rapidly oscillating frequency components are associated with fine details and textures in the image. For example, in textiles, high-frequency components might capture the intricate weave patterns or small-scale textures.\n",
186
+ "\n",
187
+ "Small-Scale Features: Patterns that repeat at small scales, such as fine lines or tiny structures, tend to result in high-frequency components with high absolute values.\n",
188
+ "\n",
189
+ "Edges and Transitions: Edges and sharp transitions between different colors or intensities are also associated with high-frequency components. These transitions involve rapid changes in intensity, leading to higher absolute frequency values."
190
+ ]
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "execution_count": null,
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "import cv2\n",
199
+ "from scipy.signal import find_peaks\n",
200
+ "\n",
201
+ "import numpy as np\n",
202
+ "import matplotlib.pyplot as plt\n",
203
+ "\n",
204
+ "\n",
205
+ "for image in images:\n",
206
+ " get_freqs(image)"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": null,
212
+ "metadata": {},
213
+ "outputs": [],
214
+ "source": [
215
+ "import cv2\n",
216
+ "import numpy as np\n",
217
+ "import matplotlib.pyplot as plt\n",
218
+ "from scipy.signal import find_peaks\n",
219
+ "from collections import Counter\n",
220
+ "import seaborn as sns\n",
221
+ "# Load the image\n",
222
+ "image = cv2.imread(images_textiles[0], cv2.IMREAD_GRAYSCALE)\n",
223
+ "\n",
224
+ "# Perform 1D Fourier Transforms\n",
225
+ "horizontal_freq = np.fft.fftshift(np.fft.fft(image, axis=1))\n",
226
+ "vertical_freq = np.fft.fftshift(np.fft.fft(image, axis=0))\n",
227
+ "\n",
228
+ "# Calculate magnitude spectra\n",
229
+ "horizontal_magnitudes = np.abs(horizontal_freq)[0]\n",
230
+ "vertical_magnitudes = np.abs(vertical_freq)[0]\n",
231
+ "\n",
232
+ "# Create histograms of magnitude recurrence\n",
233
+ "horizontal_magnitude_counter = Counter(np.round(horizontal_magnitudes).astype(int))\n",
234
+ "vertical_magnitude_counter = Counter(np.round(vertical_magnitudes).astype(int))\n",
235
+ "\n",
236
+ "# Plot magnitude recurrence histograms\n",
237
+ "plt.figure(figsize=(12, 5))\n",
238
+ "\n",
239
+ "plt.subplot(1, 2, 1)\n",
240
+ "sns.histplot(list(horizontal_magnitude_counter.elements()), kde=True, color='blue')\n",
241
+ "plt.xlim(0, 1000) # Limit x-axis range\n",
242
+ "plt.xlabel('Magnitude')\n",
243
+ "plt.ylabel('Recurrence')\n",
244
+ "plt.title('Horizontal Magnitude Recurrence')\n",
245
+ "\n",
246
+ "plt.subplot(1, 2, 2)\n",
247
+ "sns.histplot(list(vertical_magnitude_counter.elements()), kde=True, color='green')\n",
248
+ "plt.xlim(0, 1000) # Limit x-axis range\n",
249
+ "plt.xlabel('Magnitude')\n",
250
+ "plt.ylabel('Recurrence')\n",
251
+ "plt.title('Vertical Magnitude Recurrence')\n",
252
+ "\n",
253
+ "plt.tight_layout()\n",
254
+ "plt.show()\n"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": [
263
+ "import cv2\n",
264
+ "import numpy as np\n",
265
+ "import matplotlib.pyplot as plt\n",
266
+ "from skimage.feature import hog\n",
267
+ "from skimage import exposure\n",
268
+ "\n",
269
+ "# Load the image\n",
270
+ "image = cv2.imread(images_textiles[4], cv2.IMREAD_GRAYSCALE)\n",
271
+ "\n",
272
+ "# Center crop the image to remove black borders\n",
273
+ "height, width = image.shape\n",
274
+ "crop_size = 200\n",
275
+ "center_x = width // 2\n",
276
+ "center_y = height // 2\n",
277
+ "crop_half_size = crop_size // 2\n",
278
+ "cropped_image = image[center_y - crop_half_size:center_y + crop_half_size,\n",
279
+ " center_x - crop_half_size:center_x + crop_half_size]\n",
280
+ "\n",
281
+ "# Compute HOG features\n",
282
+ "orientations = 9\n",
283
+ "pixels_per_cell = (5, 5)\n",
284
+ "cells_per_block = (1,1)\n",
285
+ "hog_features, hog_image = hog(cropped_image, orientations=orientations,\n",
286
+ " pixels_per_cell=pixels_per_cell,\n",
287
+ " cells_per_block=cells_per_block,\n",
288
+ " block_norm='L2-Hys',\n",
289
+ " visualize=True)\n",
290
+ "\n",
291
+ "# Plot the HOG image\n",
292
+ "plt.figure(figsize=(8, 4))\n",
293
+ "plt.subplot(1, 2, 1)\n",
294
+ "plt.imshow(cropped_image, cmap='gray')\n",
295
+ "plt.title('Original Image')\n",
296
+ "\n",
297
+ "plt.subplot(1, 2, 2)\n",
298
+ "plt.imshow(hog_image, cmap='gray')\n",
299
+ "plt.title('HOG Image')\n",
300
+ "plt.axis('off')\n",
301
+ "\n",
302
+ "plt.tight_layout()\n",
303
+ "plt.show()\n",
304
+ "\n",
305
+ "# Plot the orientation distribution\n",
306
+ "hog_histogram, _ = np.histogram(hog_features, bins=orientations)\n",
307
+ "angles_deg = np.arange(0, 180, 180/orientations)\n",
308
+ "plt.figure(figsize=(6, 4))\n",
309
+ "plt.bar(angles_deg, hog_histogram, width=180/orientations)\n",
310
+ "plt.xlabel('Orientation (degrees)')\n",
311
+ "plt.ylabel('Frequency')\n",
312
+ "plt.title('Orientation Distribution')\n",
313
+ "plt.xticks(angles_deg)\n",
314
+ "plt.show()"
315
+ ]
316
+ },
317
+ {
318
+ "cell_type": "code",
319
+ "execution_count": null,
320
+ "metadata": {},
321
+ "outputs": [],
322
+ "source": [
323
+ "import cv2\n",
324
+ "import numpy as np\n",
325
+ "import matplotlib.pyplot as plt\n",
326
+ "import pywt\n",
327
+ "\n",
328
+ "# Load the image\n",
329
+ "image = cv2.imread(images_textiles[0], cv2.IMREAD_GRAYSCALE)\n",
330
+ "\n",
331
+ "# Perform Continuous Wavelet Transform (CWT) on each row of the image\n",
332
+ "wavelet = 'morl' # You can choose a different wavelet basis\n",
333
+ "scales = np.arange(1, 20) # Scales to analyze\n",
334
+ "\n",
335
+ "cwt_coeffs = []\n",
336
+ "for row in image:\n",
337
+ " coeffs, _ = pywt.cwt(row, scales, wavelet)\n",
338
+ " cwt_coeffs.append(coeffs)\n",
339
+ "\n",
340
+ "cwt_coeffs = np.array(cwt_coeffs)\n",
341
+ "\n",
342
+ "# Plot scaleograms\n",
343
+ "plt.figure(figsize=(10, 6))\n",
344
+ "plt.imshow(np.abs(cwt_coeffs[:, 0, :]), extent=[0, image.shape[1], scales[-1], scales[0]], cmap='viridis', aspect='auto')\n",
345
+ "plt.colorbar(label='Magnitude')\n",
346
+ "plt.title('Wavelet Scaleogram')\n",
347
+ "plt.xlabel('Pixel')\n",
348
+ "plt.ylabel('Scale')\n",
349
+ "plt.show()\n"
350
+ ]
351
+ },
352
+ {
353
+ "cell_type": "code",
354
+ "execution_count": null,
355
+ "metadata": {},
356
+ "outputs": [],
357
+ "source": []
358
+ }
359
+ ],
360
+ "metadata": {
361
+ "kernelspec": {
362
+ "display_name": "art-reco_x86",
363
+ "language": "python",
364
+ "name": "python3"
365
+ },
366
+ "language_info": {
367
+ "codemirror_mode": {
368
+ "name": "ipython",
369
+ "version": 3
370
+ },
371
+ "file_extension": ".py",
372
+ "mimetype": "text/x-python",
373
+ "name": "python",
374
+ "nbconvert_exporter": "python",
375
+ "pygments_lexer": "ipython3",
376
+ "version": "3.8.16"
377
+ },
378
+ "orig_nbformat": 4
379
+ },
380
+ "nbformat": 4,
381
+ "nbformat_minor": 2
382
+ }
stylespace_colour_disentanglement.ipynb ADDED
@@ -0,0 +1,580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "fe83bcc2",
6
+ "metadata": {},
7
+ "source": [
8
+ "![image](/Users/ludovicaschaerf/Desktop/latent-space-theories/data/stylegan3.webp)"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "3722712c",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "%matplotlib inline \n",
19
+ "\n",
20
+ "import pandas as pd\n",
21
+ "import pickle\n",
22
+ "import random\n",
23
+ "\n",
24
+ "from PIL import Image, ImageColor\n",
25
+ "import matplotlib.pyplot as plt\n",
26
+ "\n",
27
+ "import numpy as np\n",
28
+ "import torch\n",
29
+ "\n",
30
+ "from backend.disentangle_concepts import *\n",
31
+ "from backend.color_annotations import *\n",
32
+ "from backend.networks_stylegan3 import *\n",
33
+ "import dnnlib \n",
34
+ "import legacy\n",
35
+ "\n",
36
+ "import random\n",
37
+ "\n",
38
+ "from sklearn.linear_model import LinearRegression, LogisticRegression\n",
39
+ "\n",
40
+ "\n",
41
+ "%load_ext autoreload\n",
42
+ "%autoreload 2"
43
+ ]
44
+ },
45
+ {
46
+ "cell_type": "code",
47
+ "execution_count": null,
48
+ "id": "fe7acfaf-dc61-4211-9c78-8e4433bc9deb",
49
+ "metadata": {},
50
+ "outputs": [],
51
+ "source": [
52
+ "annotations_file = './data/textile_annotated_files/seeds0000-100000.pkl'\n",
53
+ "with open(annotations_file, 'rb') as f:\n",
54
+ " annotations = pickle.load(f)\n",
55
+ "\n",
56
+ "ann_df = pd.read_csv('./data/textile_annotated_files/top_three_colours.csv').fillna('#000000')\n",
57
+ "\n",
58
+ "with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:\n",
59
+ " model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore\n"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "id": "0e4b656e",
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "z = torch.from_numpy(annotations['w_vectors'][0].copy()).to('cpu')\n",
70
+ "W = z.expand((16, -1)).unsqueeze(0)\n",
71
+ "img = model.synthesis(W, noise_mode='const')\n",
72
+ "img.shape"
73
+ ]
74
+ },
75
+ {
76
+ "cell_type": "code",
77
+ "execution_count": null,
78
+ "id": "1259f950",
79
+ "metadata": {},
80
+ "outputs": [],
81
+ "source": [
82
+ "in_ = model.synthesis.input(W[0, 0].unsqueeze(0))\n",
83
+ "l1 = model.synthesis.L0_36_512(in_, W[0, 1].unsqueeze(0))\n",
84
+ "l1.shape"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "id": "918feb0e",
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "a = 'L0_36_512'\n",
95
+ "getattr(model.synthesis, a)"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "execution_count": null,
101
+ "id": "1bf7bfa4",
102
+ "metadata": {},
103
+ "outputs": [],
104
+ "source": [
105
+ "def rest_from_style(x, styles, layer):\n",
106
+ " dtype = torch.float16 if (getattr(model.synthesis, layer).use_fp16 and device=='cuda') else torch.float32\n",
107
+ " if getattr(model.synthesis, layer).is_torgb:\n",
108
+ " print(layer, getattr(model.synthesis, layer).is_torgb)\n",
109
+ " weight_gain = 1 / np.sqrt(getattr(model.synthesis, layer).in_channels * (getattr(model.synthesis, layer).conv_kernel ** 2))\n",
110
+ " styles = styles * weight_gain\n",
111
+ " input_gain = getattr(model.synthesis, layer).magnitude_ema.rsqrt().to(dtype)\n",
112
+ " # Execute modulated conv2d.\n",
113
+ " x = modulated_conv2d(x=x.to(dtype), w=getattr(model.synthesis, layer).weight.to(dtype), s=styles.to(dtype),\n",
114
+ " padding=getattr(model.synthesis, layer).conv_kernel-1, demodulate=(not getattr(model.synthesis, layer).is_torgb), input_gain=input_gain.to(dtype))\n",
115
+ " # Execute bias, filtered leaky ReLU, and clamping.\n",
116
+ " gain = 1 if getattr(model.synthesis, layer).is_torgb else np.sqrt(2)\n",
117
+ " slope = 1 if getattr(model.synthesis, layer).is_torgb else 0.2\n",
118
+ " x = filtered_lrelu.filtered_lrelu(x=x, fu=getattr(model.synthesis, layer).up_filter, fd=getattr(model.synthesis, layer).down_filter, \n",
119
+ " b=getattr(model.synthesis, layer).bias.to(x.dtype),\n",
120
+ " up=getattr(model.synthesis, layer).up_factor, down=getattr(model.synthesis, layer).down_factor, \n",
121
+ " padding=getattr(model.synthesis, layer).padding,\n",
122
+ " gain=gain, slope=slope, clamp=getattr(model.synthesis, layer).conv_clamp)\n",
123
+ " return x"
124
+ ]
125
+ },
126
+ {
127
+ "cell_type": "code",
128
+ "execution_count": null,
129
+ "id": "c674780d",
130
+ "metadata": {},
131
+ "outputs": [],
132
+ "source": [
133
+ "x1 = rest_from_style(in_, model.synthesis.L0_36_512.affine(W[0, 1].unsqueeze(0)), 'L0_36_512')\n",
134
+ "x1.shape"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": null,
140
+ "id": "0305ce16",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "def getS(w):\n",
145
+ " w_torch = torch.from_numpy(w).to('cpu')\n",
146
+ " W = w_torch.expand((16, -1)).unsqueeze(0)\n",
147
+ " s = []\n",
148
+ " s.append(model.synthesis.input.affine(W[0, 0].unsqueeze(0)).numpy())\n",
149
+ " s.append(model.synthesis.L0_36_512.affine(W[0, 1].unsqueeze(0)).numpy())\n",
150
+ " s.append(model.synthesis.L1_36_512.affine(W[0, 2].unsqueeze(0)).numpy())\n",
151
+ " s.append(model.synthesis.L2_36_512.affine(W[0, 3].unsqueeze(0)).numpy())\n",
152
+ " s.append(model.synthesis.L3_52_512.affine(W[0, 4].unsqueeze(0)).numpy())\n",
153
+ " s.append(model.synthesis.L4_52_512.affine(W[0, 5].unsqueeze(0)).numpy())\n",
154
+ " s.append(model.synthesis.L5_84_512.affine(W[0, 6].unsqueeze(0)).numpy())\n",
155
+ " s.append(model.synthesis.L6_84_512.affine(W[0, 7].unsqueeze(0)).numpy())\n",
156
+ " s.append(model.synthesis.L7_148_512.affine(W[0, 8].unsqueeze(0)).numpy())\n",
157
+ " s.append(model.synthesis.L8_148_512.affine(W[0, 9].unsqueeze(0)).numpy())\n",
158
+ " s.append(model.synthesis.L9_148_362.affine(W[0, 10].unsqueeze(0)).numpy())\n",
159
+ " s.append(model.synthesis.L10_276_256.affine(W[0, 11].unsqueeze(0)).numpy())\n",
160
+ " s.append(model.synthesis.L11_276_181.affine(W[0, 12].unsqueeze(0)).numpy())\n",
161
+ " s.append(model.synthesis.L12_276_128.affine(W[0, 13].unsqueeze(0)).numpy())\n",
162
+ " s.append(model.synthesis.L13_256_128.affine(W[0, 14].unsqueeze(0)).numpy())\n",
163
+ " s.append(model.synthesis.L14_256_3.affine(W[0, 15].unsqueeze(0)).numpy())\n",
164
+ " return s"
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "id": "981f5215",
171
+ "metadata": {},
172
+ "outputs": [],
173
+ "source": [
174
+ "s = getS(annotations['w_vectors'][0])"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": null,
180
+ "id": "389ad35a",
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "shapes = [512] + [x.shape[1] for x in s]\n",
185
+ "layers = ['w', 'input', 'L0_36_512', 'L1_36_512', 'L2_36_512', 'L3_52_512', 'L4_52_512', 'L5_84_512', 'L6_84_512',\n",
186
+ " 'L7_148_512', 'L8_148_512', 'L9_148_362', 'L10_276_256', 'L11_276_181', 'L12_276_128', 'L13_256_128',\n",
187
+ " 'L14_256_3']\n",
188
+ "sum(shapes), shapes"
189
+ ]
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "execution_count": null,
194
+ "id": "3c143e86",
195
+ "metadata": {},
196
+ "outputs": [],
197
+ "source": [
198
+ "def generate_flexible_images(w, change_vectors, lambdas=1, device='cpu'):\n",
199
+ " w_torch = torch.from_numpy(w).to('cpu')\n",
200
+ " # w_torch = w_torch + lambdas * change_vectors[0]\n",
201
+ " W = w_torch.expand((16, -1)).unsqueeze(0)\n",
202
+ " \n",
203
+ " x = model.synthesis.input(W[0,0].unsqueeze(0))\n",
204
+ " for i, layer in enumerate(layers):\n",
205
+ " if i < 2:\n",
206
+ " continue\n",
207
+ " style = getattr(model.synthesis, layer).affine(W[0, i-1].unsqueeze(0))\n",
208
+ " change = torch.from_numpy(change_vectors[i].copy()).unsqueeze(0).to(device)\n",
209
+ " style = torch.add(style, change, alpha=lambdas)\n",
210
+ " x = rest_from_style(x, style, layer)\n",
211
+ " \n",
212
+ " if model.synthesis.output_scale != 1:\n",
213
+ " x = x * model.synthesis.output_scale\n",
214
+ "\n",
215
+ " img = (x.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)\n",
216
+ " img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')\n",
217
+ " \n",
218
+ " return img"
219
+ ]
220
+ },
221
+ {
222
+ "cell_type": "code",
223
+ "execution_count": null,
224
+ "id": "f03915ff",
225
+ "metadata": {},
226
+ "outputs": [],
227
+ "source": [
228
+ "def get_original_pos(top_positions):\n",
229
+ " current_idx = 0\n",
230
+ " vectors = []\n",
231
+ " for i, (leng, layer) in enumerate(zip(shapes, layers)):\n",
232
+ " arr = np.zeros(leng)\n",
233
+ " for top_position in top_positions:\n",
234
+ " if top_position >= current_idx and top_position < current_idx + leng:\n",
235
+ " arr[top_position - current_idx] = 1\n",
236
+ " arr = arr / (np.linalg.norm(arr) + 0.000001)\n",
237
+ " vectors.append(arr)\n",
238
+ " current_idx += leng\n",
239
+ " return vectors \n"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": null,
245
+ "id": "e76d836d",
246
+ "metadata": {},
247
+ "outputs": [],
248
+ "source": [
249
+ "ss = []\n",
250
+ "for i in tqdm(range(len(annotations['w_vectors']))):\n",
251
+ " ss.append(getS(annotations['w_vectors'][i]))\n",
252
+ " \n",
253
+ "annotations['s_vectors'] = ss"
254
+ ]
255
+ },
256
+ {
257
+ "cell_type": "code",
258
+ "execution_count": null,
259
+ "id": "6ea1ca59",
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": [
263
+ "annotations_file = './data/textile_annotated_files/seeds0000-100000_S.pkl'\n",
264
+ "with open(annotations_file, 'wb') as f:\n",
265
+ " pickle.dump(annotations, f)\n"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": null,
271
+ "id": "12f78bdb",
272
+ "metadata": {},
273
+ "outputs": [],
274
+ "source": [
275
+ "len(ss)"
276
+ ]
277
+ },
278
+ {
279
+ "cell_type": "code",
280
+ "execution_count": null,
281
+ "id": "cd114cb1",
282
+ "metadata": {},
283
+ "outputs": [],
284
+ "source": [
285
+ "ann_df = tohsv(ann_df)\n",
286
+ "ann_df.head()"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "code",
291
+ "execution_count": null,
292
+ "id": "0d470f83",
293
+ "metadata": {},
294
+ "outputs": [],
295
+ "source": [
296
+ "def getX(annotations, space='s'):\n",
297
+ " if space == 'x':\n",
298
+ " X = np.array(annotations['w_vectors']).reshape((len(annotations['w_vectors']), 512))\n",
299
+ " elif space == 's':\n",
300
+ " concat_v = []\n",
301
+ " for i in range(len(annotations['w_vectors'])):\n",
302
+ " concat_v.append(np.concatenate([annotations['w_vectors'][i]] + annotations['s_vectors'][i], axis=1))\n",
303
+ " \n",
304
+ " X = np.array(concat_v)\n",
305
+ " X = X[:, 0, :]\n",
306
+ " print(X.shape)\n",
307
+ " \n",
308
+ " return X\n",
309
+ " \n",
310
+ " "
311
+ ]
312
+ },
313
+ {
314
+ "cell_type": "code",
315
+ "execution_count": null,
316
+ "id": "feb64168",
317
+ "metadata": {},
318
+ "outputs": [],
319
+ "source": [
320
+ "X = getX(annotations)\n",
321
+ "print(X.shape)\n",
322
+ "y_h = np.array(ann_df['H1'].values)\n",
323
+ "y_s = np.array(ann_df['S1'].values)\n",
324
+ "y_v = np.array(ann_df['S1'].values)"
325
+ ]
326
+ },
327
+ {
328
+ "cell_type": "code",
329
+ "execution_count": null,
330
+ "id": "afa0c100",
331
+ "metadata": {},
332
+ "outputs": [],
333
+ "source": [
334
+ "colors_list = ['Warm Pink Red', 'Red Orange', 'Orange Yellow', 'Gold Yellow', 'Chartreuse Green',\n",
335
+ " 'Kelly Green', 'Green Blue Seafoam', 'Blue Green Cyan',\n",
336
+ " 'Warm Blue', 'Indigo Blue Purple', 'Purple Magenta', 'Magenta Pink']"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "markdown",
341
+ "id": "39a5668a",
342
+ "metadata": {},
343
+ "source": [
344
+ "double check colori"
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "execution_count": null,
350
+ "id": "5f2b48c0",
351
+ "metadata": {},
352
+ "outputs": [],
353
+ "source": [
354
+ "print([int(x*256/12) if x<12 else 255 for x in range(13)])\n",
355
+ "y_h_cat = pd.cut(y_h,bins=[x*256/12 if x<12 else 256 for x in range(13)],labels=colors_list).fillna('Warm Pink Red')\n",
356
+ "\n",
357
+ "print(y_h_cat.value_counts(dropna=False))\n",
358
+ "x_trainhc, x_valhc, y_trainhc, y_valhc = train_test_split(X, y_h_cat, test_size=0.2)"
359
+ ]
360
+ },
361
+ {
362
+ "cell_type": "markdown",
363
+ "id": "6c2a1765",
364
+ "metadata": {},
365
+ "source": [
366
+ "### Variance based"
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "code",
371
+ "execution_count": null,
372
+ "id": "2be4202e",
373
+ "metadata": {},
374
+ "outputs": [],
375
+ "source": [
376
+ "positives = x_trainhc[np.where(y_trainhc == 'Warm Blue')]\n",
377
+ "print(positives.shape, x_trainhc.shape)\n",
378
+ "variations = detect_attribute_specific_channels(positives, x_trainhc, sign=True)\n",
379
+ "print(variations.shape, np.argmax(variations))"
380
+ ]
381
+ },
382
+ {
383
+ "cell_type": "code",
384
+ "execution_count": null,
385
+ "id": "7d0c129d",
386
+ "metadata": {},
387
+ "outputs": [],
388
+ "source": [
389
+ "argsorted_vars = np.argsort(variations)[-5:]\n",
390
+ "sorted_vars = np.sort(variations)[-5:]\n",
391
+ "argsorted_vars, sorted_vars"
392
+ ]
393
+ },
394
+ {
395
+ "cell_type": "code",
396
+ "execution_count": null,
397
+ "id": "e2c2ed49",
398
+ "metadata": {},
399
+ "outputs": [],
400
+ "source": [
401
+ "original_pos = get_original_pos(argsorted_vars)"
402
+ ]
403
+ },
404
+ {
405
+ "cell_type": "code",
406
+ "execution_count": null,
407
+ "id": "82e30f0c",
408
+ "metadata": {},
409
+ "outputs": [],
410
+ "source": [
411
+ "seed = random.randint(0,100000)\n",
412
+ "seed = 52722\n",
413
+ "original_image_vec = annotations['w_vectors'][seed]\n",
414
+ "img = generate_original_image(original_image_vec, model, latent_space='W')\n",
415
+ "img"
416
+ ]
417
+ },
418
+ {
419
+ "cell_type": "code",
420
+ "execution_count": null,
421
+ "id": "cd71f2c8",
422
+ "metadata": {},
423
+ "outputs": [],
424
+ "source": [
425
+ "device = 'cpu'\n",
426
+ "img1 = generate_flexible_images(original_image_vec, original_pos, lambdas=-1)\n",
427
+ "img1"
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "code",
432
+ "execution_count": null,
433
+ "id": "abc5ac3f",
434
+ "metadata": {},
435
+ "outputs": [],
436
+ "source": [
437
+ "img1 = generate_flexible_images(original_image_vec, original_pos, lambdas=-2)\n",
438
+ "img1"
439
+ ]
440
+ },
441
+ {
442
+ "cell_type": "code",
443
+ "execution_count": null,
444
+ "id": "0602dcab",
445
+ "metadata": {},
446
+ "outputs": [],
447
+ "source": [
448
+ "len(original_pos)"
449
+ ]
450
+ },
451
+ {
452
+ "cell_type": "code",
453
+ "execution_count": null,
454
+ "id": "d7eb412f",
455
+ "metadata": {},
456
+ "outputs": [],
457
+ "source": [
458
+ "img1 = generate_flexible_images(original_image_vec, original_pos, lambdas=1)\n",
459
+ "img1"
460
+ ]
461
+ },
462
+ {
463
+ "cell_type": "code",
464
+ "execution_count": null,
465
+ "id": "03161270",
466
+ "metadata": {},
467
+ "outputs": [],
468
+ "source": [
469
+ "seps, vals = all_variance_based_disentanglements(colors_list, x_trainhc, y_trainhc, k=10, sign=True, space='s')\n",
470
+ "vals[2].shape"
471
+ ]
472
+ },
473
+ {
474
+ "cell_type": "code",
475
+ "execution_count": null,
476
+ "id": "ae1016d6",
477
+ "metadata": {},
478
+ "outputs": [],
479
+ "source": [
480
+ "warm_pink_val = get_verification_score(0, seps[0], model, annotations, samples=10, latent_space='W')\n",
481
+ "warm_pink_val"
482
+ ]
483
+ },
484
+ {
485
+ "cell_type": "code",
486
+ "execution_count": null,
487
+ "id": "b412cb25",
488
+ "metadata": {},
489
+ "outputs": [],
490
+ "source": [
491
+ "warm_blue_val = get_verification_score(8, seps[8], model, annotations, samples=10, latent_space='W')\n",
492
+ "warm_blue_val"
493
+ ]
494
+ },
495
+ {
496
+ "cell_type": "code",
497
+ "execution_count": null,
498
+ "id": "6812cb6b",
499
+ "metadata": {},
500
+ "outputs": [],
501
+ "source": [
502
+ "seps, _ = all_variance_based_disentanglements(colors_list, x_trainhc, y_trainhc, k=10, sign=True)\n",
503
+ "\n",
504
+ "for sep, color in zip(seps, colors_list):\n",
505
+ " images, lambdas = regenerate_images(model, original_image_vec, sep, min_epsilon=-(int(4)), max_epsilon=int(4), count=5, latent_space='W')\n",
506
+ " fig, axs = plt.subplots(1, len(images), figsize=(50,10))\n",
507
+ " fig.suptitle(color, fontsize=20)\n",
508
+ " for i,im in enumerate(images):\n",
509
+ " axs[i].imshow(im)\n",
510
+ " axs[i].set_title(np.round(lambdas[i], 2))\n",
511
+ " plt.show()"
512
+ ]
513
+ },
514
+ {
515
+ "cell_type": "code",
516
+ "execution_count": null,
517
+ "id": "24b8f275",
518
+ "metadata": {},
519
+ "outputs": [],
520
+ "source": [
521
+ "seps = all_variance_based_disentanglements(colors_list, x_trainhc, y_trainhc, k=10, sign=True)\n",
522
+ "\n",
523
+ "for sep, color in zip(seps, colors_list):\n",
524
+ " images, lambdas = regenerate_images(model, original_image_vec, sep, min_epsilon=-(int(4)), max_epsilon=int(4), count=5, latent_space='W')\n",
525
+ " fig, axs = plt.subplots(1, len(images), figsize=(50,10))\n",
526
+ " fig.suptitle(color, fontsize=20)\n",
527
+ " for i,im in enumerate(images):\n",
528
+ " axs[i].imshow(im)\n",
529
+ " axs[i].set_title(np.round(lambdas[i], 2))\n",
530
+ " plt.show()"
531
+ ]
532
+ },
533
+ {
534
+ "cell_type": "code",
535
+ "execution_count": null,
536
+ "id": "b6c61fbb",
537
+ "metadata": {},
538
+ "outputs": [],
539
+ "source": [
540
+ "separation_vector_onehot = np.zeros(512)\n",
541
+ "separation_vector_onehot[argsorted_vars] = 1\n",
542
+ "\n",
543
+ "images, lambdas = regenerate_images(model, original_image_vec, separation_vector_onehot, min_epsilon=-(int(10)), max_epsilon=int(10), count=7, latent_space='W')\n",
544
+ "fig, axs = plt.subplots(1, len(images), figsize=(30,200))\n",
545
+ "for i,im in enumerate(images):\n",
546
+ " axs[i].imshow(im)\n",
547
+ " axs[i].set_title(np.round(lambdas[i], 2))"
548
+ ]
549
+ },
550
+ {
551
+ "cell_type": "code",
552
+ "execution_count": null,
553
+ "id": "7c19e820",
554
+ "metadata": {},
555
+ "outputs": [],
556
+ "source": []
557
+ }
558
+ ],
559
+ "metadata": {
560
+ "kernelspec": {
561
+ "display_name": "Python 3",
562
+ "language": "python",
563
+ "name": "python3"
564
+ },
565
+ "language_info": {
566
+ "codemirror_mode": {
567
+ "name": "ipython",
568
+ "version": 3
569
+ },
570
+ "file_extension": ".py",
571
+ "mimetype": "text/x-python",
572
+ "name": "python",
573
+ "nbconvert_exporter": "python",
574
+ "pygments_lexer": "ipython3",
575
+ "version": "3.8.16"
576
+ }
577
+ },
578
+ "nbformat": 4,
579
+ "nbformat_minor": 5
580
+ }
view_predictions.ipynb CHANGED
@@ -2,31 +2,10 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
- "execution_count": 1,
6
  "id": "fe7acfaf-dc61-4211-9c78-8e4433bc9deb",
7
  "metadata": {},
8
- "outputs": [
9
- {
10
- "name": "stdout",
11
- "output_type": "stream",
12
- "text": [
13
- "Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.\n",
14
- "Intel MKL WARNING: Support of Intel(R) Streaming SIMD Extensions 4.2 (Intel(R) SSE4.2) enabled only processors has been deprecated. Intel oneAPI Math Kernel Library 2025.0 will require Intel(R) Advanced Vector Extensions (Intel(R) AVX) instructions.\n"
15
- ]
16
- },
17
- {
18
- "ename": "ModuleNotFoundError",
19
- "evalue": "No module named 'sklearn'",
20
- "output_type": "error",
21
- "traceback": [
22
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
23
- "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
24
- "Cell \u001b[0;32mIn[1], line 8\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mtorch\u001b[39;00m\n\u001b[1;32m 7\u001b[0m \u001b[39m#import ninja\u001b[39;00m\n\u001b[0;32m----> 8\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mbackend\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mdisentangle_concepts\u001b[39;00m \u001b[39mimport\u001b[39;00m \u001b[39m*\u001b[39m\n\u001b[1;32m 9\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mdnnlib\u001b[39;00m \n\u001b[1;32m 10\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mlegacy\u001b[39;00m\n",
25
- "File \u001b[0;32m~/Desktop/latent-space-theories/backend/disentangle_concepts.py:2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mnumpy\u001b[39;00m \u001b[39mas\u001b[39;00m \u001b[39mnp\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39msklearn\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39msvm\u001b[39;00m \u001b[39mimport\u001b[39;00m SVC\n\u001b[1;32m 3\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39msklearn\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mlinear_model\u001b[39;00m \u001b[39mimport\u001b[39;00m LogisticRegression\n\u001b[1;32m 4\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39msklearn\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmodel_selection\u001b[39;00m \u001b[39mimport\u001b[39;00m train_test_split\n",
26
- "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'sklearn'"
27
- ]
28
- }
29
- ],
30
  "source": [
31
  "import pandas as pd\n",
32
  "import pickle\n",
@@ -171,19 +150,10 @@
171
  },
172
  {
173
  "cell_type": "code",
174
- "execution_count": 1,
175
  "id": "0eae840f",
176
  "metadata": {},
177
- "outputs": [
178
- {
179
- "name": "stderr",
180
- "output_type": "stream",
181
- "text": [
182
- "/Users/ludovicaschaerf/anaconda3/envs/torch_arm/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
183
- " from .autonotebook import tqdm as notebook_tqdm\n"
184
- ]
185
- }
186
- ],
187
  "source": [
188
  "import open_clip\n",
189
  "import os\n",
@@ -194,7 +164,7 @@
194
  },
195
  {
196
  "cell_type": "code",
197
- "execution_count": 2,
198
  "id": "4d776015",
199
  "metadata": {},
200
  "outputs": [],
@@ -205,33 +175,10 @@
205
  },
206
  {
207
  "cell_type": "code",
208
- "execution_count": 3,
209
  "id": "e3f917a7",
210
  "metadata": {},
211
- "outputs": [
212
- {
213
- "name": "stdout",
214
- "output_type": "stream",
215
- "text": [
216
- "['Provenance ADRIA', 'Provenance AEGINA', 'Provenance AL MINA', 'Provenance ALICANTE', 'Provenance AMATHUS', 'Provenance AMPURIAS', 'Provenance APOLLONIA PONTICA', 'Provenance APULIA', 'Provenance ARGOLIS', 'Provenance ARGOS', 'Provenance ATHENS', 'Provenance ATHENS (?)', 'Provenance ATTICA', 'Provenance BARI', 'Provenance BENGHAZI', 'Provenance BEREZAN', 'Provenance BLACK SEA', 'Provenance BOEOTIA', 'Provenance BOLOGNA', 'Provenance CAPUA', 'Provenance CARIA', 'Provenance CARTHAGE', 'Provenance CHIUSI', 'Provenance CIVITAVECCHIA', 'Provenance CORINTH', 'Provenance CORSICA', 'Provenance CRETE', 'Provenance CRIMEA', 'Provenance CUMAE', 'Provenance CYCLADES', 'Provenance CYPRUS', 'Provenance CYRENAICA', 'Provenance CYRENE', 'Provenance EGYPT', 'Provenance ELIS', 'Provenance ENSERUNE', 'Provenance ETRURIA', 'Provenance EUBOEA', 'Provenance FALERII', 'Provenance GRANADA', 'Provenance GREECE', 'Provenance GREECE (?)', 'Provenance HISTRIA', 'Provenance ITALY', 'Provenance JAEN', 'Provenance KITION', 'Provenance LESBOS', 'Provenance LOCRI', 'Provenance LOCRIS', 'Provenance LYCIA', 'Provenance MACEDONIA', 'Provenance MARION', 'Provenance MARSEILLES', 'Provenance MARZABOTTO', 'Provenance METAPONTUM', 'Provenance MILETUS', 'Provenance NAPLES', 'Provenance NAUCRATIS', 'Provenance NOLA', 'Provenance NUMANA', 'Provenance OLD SMYRNA', 'Provenance ORVIETO', 'Provenance PAESTUM', 'Provenance PHOCIS', 'Provenance PITANE', 'Provenance POPULONIA', 'Provenance RHODES', 'Provenance ROME', 'Provenance RUVO', 'Provenance SALERNO', 'Provenance SAMOS', 'Provenance SARDINIA', 'Provenance SARDIS', 'Provenance SICILY', 'Provenance SMYRNA', 'Provenance SOUTH', 'Provenance SPARTA', 'Provenance SPINA', 'Provenance SUESSULA', 'Provenance TARANTO', 'Provenance TELL DEFENNEH', 'Provenance THASOS', 'Provenance THERA', 'Provenance THESSALY', 'Provenance THRACE', 'Provenance TOCRA', 'Provenance TODI', 'Provenance ULLASTRET', 'Provenance VALENCIA', 'Shape Name ALABASTRON', 'Shape Name ALABASTRON FRAGMENT', 'Shape Name AMPHORA', 'Shape Name AMPHORA (?) FRAGMENT', 'Shape Name AMPHORISKOS', 'Shape Name ARYBALLOS', 'Shape Name ASKOS', 'Shape Name ASKOS FRAGMENT', 'Shape Name BOTTLE', 'Shape Name BOWL', 'Shape Name BOWL FRAGMENT', 'Shape Name CHALICE', 'Shape Name CHALICE FRAGMENT', 'Shape Name CHOUS', 'Shape Name CHOUS FRAGMENT', 'Shape Name CUP', 'Shape Name CUP (?) FRAGMENT', 'Shape Name CUP A', 'Shape Name CUP A FRAGMENT', 'Shape Name CUP A FRAGMENTS', 'Shape Name CUP B', 'Shape Name CUP B FRAGMENT', 'Shape Name CUP B FRAGMENTS', 'Shape Name CUP C', 'Shape Name CUP C FRAGMENT', 'Shape Name CUP C FRAGMENTS', 'Shape Name CUP DROOP', 'Shape Name CUP DROOP FRAGMENT', 'Shape Name CUP FRAGMENT', 'Shape Name CUP FRAGMENTS', 'Shape Name CUP KASSEL', 'Shape Name CUP KASSEL FRAGMENT', 'Shape Name CUP LITTLE MASTER BAND', 'Shape Name CUP LITTLE MASTER BAND (?) FRAGMENT', 'Shape Name CUP LITTLE MASTER BAND FRAGMENT', 'Shape Name CUP LITTLE MASTER BAND FRAGMENTS', 'Shape Name CUP LITTLE MASTER FRAGMENT', 'Shape Name CUP LITTLE MASTER LIP', 'Shape Name CUP LITTLE MASTER LIP FRAGMENT', 'Shape Name CUP LITTLE MASTER LIP FRAGMENTS', 'Shape Name CUP SIANA', 'Shape Name CUP SIANA FRAGMENT', 'Shape Name CUP SIANA FRAGMENTS', 'Shape Name CUP SKYPHOS', 'Shape Name CUP SKYPHOS FRAGMENT', 'Shape Name CUP STEMLESS', 'Shape Name CUP STEMLESS FRAGMENT', 'Shape Name CUP STEMLESS FRAGMENTS', 'Shape Name DINOS', 'Shape Name DINOS FRAGMENT', 'Shape Name DISH', 'Shape Name EPICHYSIS', 'Shape Name EPINETRON FRAGMENT', 'Shape Name FEEDER', 'Shape Name FIGURE VASE', 'Shape Name FIGURE VASE ARYBALLOS', 'Shape Name FIGURE VASE ASKOS', 'Shape Name FIGURE VASE FRAGMENT', 'Shape Name FIGURE VASE KANTHAROS', 'Shape Name FIGURE VASE LEKYTHOS', 'Shape Name FIGURE VASE OINOCHOE', 'Shape Name FISH-PLATE', 'Shape Name FLASK', 'Shape Name FRAGMENT', 'Shape Name FRAGMENTS', 'Shape Name GUTTUS', 'Shape Name HYDRIA', 'Shape Name HYDRIA (?) FRAGMENT', 'Shape Name HYDRIA FRAGMENT', 'Shape Name HYDRIA FRAGMENTS', 'Shape Name INCENSE BURNER', 'Shape Name JAR', 'Shape Name JUG', 'Shape Name KALATHOS', 'Shape Name KANTHAROS', 'Shape Name KANTHAROS FRAGMENT', 'Shape Name KERNOS', 'Shape Name KOTYLE', 'Shape Name KOTYLE FRAGMENT', 'Shape Name KRATER', 'Shape Name KRATER (?) FRAGMENT', 'Shape Name KRATER FRAGMENT', 'Shape Name KRATER FRAGMENTS', 'Shape Name KYATHOS', 'Shape Name KYATHOS FRAGMENT', 'Shape Name LEBES', 'Shape Name LEBES FRAGMENT', 'Shape Name LEKANIS', 'Shape Name LEKANIS FRAGMENT', 'Shape Name LEKANIS FRAGMENTS', 'Shape Name LEKANIS LID', 'Shape Name LEKANIS LID FRAGMENT', 'Shape Name LEKYTHOS', 'Shape Name LEKYTHOS FRAGMENT', 'Shape Name LEKYTHOS FRAGMENTS', 'Shape Name LID', 'Shape Name LID FRAGMENT', 'Shape Name LOUTROPHOROS', 'Shape Name LOUTROPHOROS FRAGMENT', 'Shape Name LOUTROPHOROS FRAGMENTS', 'Shape Name LYDION', 'Shape Name MASTOID', 'Shape Name MINIATURE PANATHENAIC AMPHORA', 'Shape Name MUG', 'Shape Name MUG FRAGMENT', 'Shape Name NESTORIS', 'Shape Name OINOCHOE', 'Shape Name OINOCHOE (?) FRAGMENT', 'Shape Name OINOCHOE FRAGMENT', 'Shape Name OINOCHOE FRAGMENTS', 'Shape Name OLLA', 'Shape Name OLPE', 'Shape Name OLPE FRAGMENT', 'Shape Name OLPE FRAGMENTS', 'Shape Name PELIKE', 'Shape Name PELIKE FRAGMENT', 'Shape Name PELIKE FRAGMENTS', 'Shape Name PHIALE', 'Shape Name PHIALE FRAGMENT', 'Shape Name PITCHER', 'Shape Name PLAQUE FRAGMENT', 'Shape Name PLAQUE FRAGMENTS', 'Shape Name PLATE', 'Shape Name PLATE FRAGMENT', 'Shape Name PLATE FRAGMENTS', 'Shape Name PLEMOCHOE', 'Shape Name PSEUDO-PANATHENAIC AMPHORA', 'Shape Name PSEUDO-PANATHENAIC AMPHORA FRAGMENT', 'Shape Name PSYKTER', 'Shape Name PYXIS', 'Shape Name PYXIS FRAGMENT', 'Shape Name PYXIS FRAGMENTS', 'Shape Name PYXIS LID', 'Shape Name PYXIS LID FRAGMENT', 'Shape Name RHYTON', 'Shape Name SKYPHOS', 'Shape Name SKYPHOS (?) FRAGMENT', 'Shape Name SKYPHOS FRAGMENT', 'Shape Name SKYPHOS FRAGMENTS', 'Shape Name STAMNOS', 'Shape Name STAMNOS FRAGMENT', 'Shape Name STAMNOS FRAGMENTS', 'Shape Name STAND', 'Shape Name STAND FRAGMENT', 'Shape Name STEMLESS CUP FRAGMENT', 'Shape Name STIRRUP JAR', 'Shape Name TANKARD', 'Shape Name UNKNOWN', 'Shape Name URN', 'Shape Name VARIOUS', 'Shape Name VASE', 'Fabric ARGIVE GEOMETRIC', 'Fabric ARRETINE', 'Fabric ATHENIAN', 'Fabric ATHENIAN (?)', 'Fabric ATHENIAN GEOMETRIC', 'Fabric ATHENIAN PROTOGEOMETRIC', 'Fabric BOEOTIAN', 'Fabric BOEOTIAN GEOMETRIC', 'Fabric CALENIAN', 'Fabric CAMPANA', 'Fabric CANOSAN', 'Fabric CHALCIDIAN', 'Fabric CORINTHIAN', 'Fabric CRETAN', 'Fabric CYCLADIC', 'Fabric CYPRIOT', 'Fabric CYPRIOT, BRONZE AGE', 'Fabric CYPRIOT, IRON AGE', 'Fabric CYPRIOT, MYCENAEAN STYLE', 'Fabric DAUNIAN', 'Fabric EAST GREEK', 'Fabric EAST GREEK GEOMETRIC', 'Fabric EAST GREEK, CLAZOMENIAN', 'Fabric EAST GREEK, FIKELLURA', 'Fabric EAST GREEK, NAUCRATITE', 'Fabric EGYPTIAN', 'Fabric ETRUSCAN', 'Fabric ETRUSCO-CORINTHIAN', 'Fabric FALISCAN', 'Fabric GALLIC', 'Fabric GALLO-ROMAN', 'Fabric GREEK', 'Fabric HELLADIC', 'Fabric HELLENISTIC', 'Fabric IBERIAN', 'Fabric IONIAN', 'Fabric ITALIOTE', 'Fabric ITALO-CORINTHIAN', 'Fabric ITALO-GEOMETRIC', 'Fabric LACONIAN', 'Fabric LACONIAN GEOMETRIC', 'Fabric MESSAPIAN', 'Fabric MINOAN', 'Fabric MYCENEAN', 'Fabric PONTIC', 'Fabric PROTO-ELAMITE', 'Fabric PROTOATTIC', 'Fabric PROTOCORINTHIAN', 'Fabric PROTOCORINTHIAN, TRANSITIONAL', 'Fabric ROMAN', 'Fabric SOUTH ITALIAN', 'Fabric SOUTH ITALIAN, APULIAN', 'Fabric SOUTH ITALIAN, CAMPANIAN', 'Fabric SOUTH ITALIAN, GNATHIAN', 'Fabric SOUTH ITALIAN, LUCANIAN', 'Fabric SOUTH ITALIAN, PAESTAN', 'Fabric SOUTH ITALIAN, SICILIAN', 'Fabric TERRA SIGILLATA', 'Fabric UNCERTAIN', 'Fabric VEIAN', 'Fabric VILLA NOVA', 'Technique ADDED COLOUR', 'Technique BLACK GLAZE', 'Technique BLACK PATTERN', 'Technique BLACK-FIGURE', 'Technique BROWN GLAZE', 'Technique BUCCHERO', 'Technique IMPASTO', 'Technique OUTLINE', 'Technique PATTERN', 'Technique PLAIN', 'Technique PSEUDO RED-FIGURE', 'Technique RED POLISHED WARE', 'Technique RED-FIGURE', 'Technique RELIEF', 'Technique RESERVING', 'Technique SILHOUETTE', 'Technique WHITE PAINTED WARE']\n"
217
- ]
218
- },
219
- {
220
- "name": "stderr",
221
- "output_type": "stream",
222
- "text": [
223
- "/Users/ludovicaschaerf/anaconda3/envs/torch_arm/lib/python3.11/site-packages/torch/amp/autocast_mode.py:221: UserWarning: User provided device_type of 'cuda', but CUDA is not available. Disabling\n",
224
- " warnings.warn('User provided device_type of \\'cuda\\', but CUDA is not available. Disabling')\n"
225
- ]
226
- },
227
- {
228
- "name": "stdout",
229
- "output_type": "stream",
230
- "text": [
231
- "(318, 768)\n"
232
- ]
233
- }
234
- ],
235
  "source": [
236
  "model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')\n",
237
  "tokenizer = open_clip.get_tokenizer('ViT-L-14')\n",
@@ -257,7 +204,7 @@
257
  },
258
  {
259
  "cell_type": "code",
260
- "execution_count": 4,
261
  "id": "f7858bbf",
262
  "metadata": {},
263
  "outputs": [],
@@ -267,7 +214,7 @@
267
  },
268
  {
269
  "cell_type": "code",
270
- "execution_count": 5,
271
  "id": "de6bd428",
272
  "metadata": {},
273
  "outputs": [],
@@ -277,7 +224,7 @@
277
  },
278
  {
279
  "cell_type": "code",
280
- "execution_count": 6,
281
  "id": "d19c8e4c",
282
  "metadata": {},
283
  "outputs": [],
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": null,
6
  "id": "fe7acfaf-dc61-4211-9c78-8e4433bc9deb",
7
  "metadata": {},
8
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  "source": [
10
  "import pandas as pd\n",
11
  "import pickle\n",
 
150
  },
151
  {
152
  "cell_type": "code",
153
+ "execution_count": null,
154
  "id": "0eae840f",
155
  "metadata": {},
156
+ "outputs": [],
 
 
 
 
 
 
 
 
 
157
  "source": [
158
  "import open_clip\n",
159
  "import os\n",
 
164
  },
165
  {
166
  "cell_type": "code",
167
+ "execution_count": null,
168
  "id": "4d776015",
169
  "metadata": {},
170
  "outputs": [],
 
175
  },
176
  {
177
  "cell_type": "code",
178
+ "execution_count": null,
179
  "id": "e3f917a7",
180
  "metadata": {},
181
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  "source": [
183
  "model_clip, _, preprocess = open_clip.create_model_and_transforms('ViT-L-14', pretrained='laion2b_s32b_b82k')\n",
184
  "tokenizer = open_clip.get_tokenizer('ViT-L-14')\n",
 
204
  },
205
  {
206
  "cell_type": "code",
207
+ "execution_count": null,
208
  "id": "f7858bbf",
209
  "metadata": {},
210
  "outputs": [],
 
214
  },
215
  {
216
  "cell_type": "code",
217
+ "execution_count": null,
218
  "id": "de6bd428",
219
  "metadata": {},
220
  "outputs": [],
 
224
  },
225
  {
226
  "cell_type": "code",
227
+ "execution_count": null,
228
  "id": "d19c8e4c",
229
  "metadata": {},
230
  "outputs": [],
view_segmentations.ipynb ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from PIL import Image\n",
10
+ "from lang_sam import LangSAM"
11
+ ]
12
+ },
13
+ {
14
+ "cell_type": "code",
15
+ "execution_count": null,
16
+ "metadata": {},
17
+ "outputs": [],
18
+ "source": [
19
+ "\n",
20
+ "def save_mask_as_image(mask, output_path):\n",
21
+ " # Create a blank image with the same dimensions as the mask\n",
22
+ " width, height = mask.shape[1], mask.shape[0]\n",
23
+ " image = Image.new(\"L\", (width, height))\n",
24
+ "\n",
25
+ " # Set the pixel values based on the mask\n",
26
+ " for y in range(height):\n",
27
+ " for x in range(width):\n",
28
+ " pixel_value = 255 if mask[y, x] == 1 else 0\n",
29
+ " image.putpixel((x, y), pixel_value)\n",
30
+ "\n",
31
+ " # Save the image as a PNG\n",
32
+ " image.save(output_path)\n",
33
+ "\n",
34
+ "\n",
35
+ "\n",
36
+ "model = LangSAM()\n",
37
+ "TEST_DIR = '/Users/ludovicaschaerf/Desktop/Data/VA_textiles_masks/'\n",
38
+ "OUT_DIR = '/Users/ludovicaschaerf/Desktop/Data/VA_textiles/'\n",
39
+ "image_pil = Image.open(TEST_DIR + \"O25495.jpg\").convert(\"RGB\")\n",
40
+ "text_prompt = \"textile\"\n",
41
+ "masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)\n"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "\n",
51
+ "save_mask_as_image(masks[0], OUT_DIR + \"O25495.jpg\")"
52
+ ]
53
+ }
54
+ ],
55
+ "metadata": {
56
+ "kernelspec": {
57
+ "display_name": "art-reco_x86",
58
+ "language": "python",
59
+ "name": "python3"
60
+ },
61
+ "language_info": {
62
+ "codemirror_mode": {
63
+ "name": "ipython",
64
+ "version": 3
65
+ },
66
+ "file_extension": ".py",
67
+ "mimetype": "text/x-python",
68
+ "name": "python",
69
+ "nbconvert_exporter": "python",
70
+ "pygments_lexer": "ipython3",
71
+ "version": "3.8.16"
72
+ },
73
+ "orig_nbformat": 4
74
+ },
75
+ "nbformat": 4,
76
+ "nbformat_minor": 2
77
+ }