ludusc commited on
Commit
9a76dcd
·
2 Parent(s): a5df893 3826816

solved merge

Browse files
.gitignore CHANGED
@@ -189,4 +189,5 @@ cython_debug/
189
  data/images/
190
  tmp/
191
  figures/
192
- archive/
 
 
189
  data/images/
190
  tmp/
191
  figures/
192
+ archive/
193
+ segment-anything/
DisentanglementBase.py CHANGED
@@ -181,12 +181,20 @@ class DisentanglementBase:
181
  bins = [(x-1) * 360 / (len(self.colors_list) - 1) if x != 1
182
  else 1 for x in range(len(self.colors_list) + 1)]
183
  bins[0] = 0
 
184
  y_cat = pd.cut(y,
185
  bins=bins,
186
  labels=self.colors_list,
187
  include_lowest=True
188
  )
189
  print(y_cat.value_counts())
 
 
 
 
 
 
 
190
  x_train, x_val, y_train, y_val = train_test_split(X, y_cat, test_size=0.2)
191
  else:
192
  if extremes:
@@ -567,11 +575,14 @@ def main():
567
  with dnnlib.util.open_url(model_file) as f:
568
  model = legacy.load_network_pkl(f)['G_ema'] # type: ignore
569
 
570
- colors_list = ['Red', 'Orange', 'Yellow', 'Yellow Green', 'Chartreuse Green',
571
- 'Kelly Green', 'Green Blue Seafoam', 'Cyan Blue',
572
- 'Warm Blue', 'Indigo', 'Purple Magenta', 'Magenta Pink']
573
- colors_list = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue',
574
- 'Blue', 'Violet', 'Pink']
 
 
 
575
 
576
  scores = []
577
  kwargs = {'CL method':['LR', 'SVM'], 'C':[0.1, 1], 'sign':[True, False],
 
181
  bins = [(x-1) * 360 / (len(self.colors_list) - 1) if x != 1
182
  else 1 for x in range(len(self.colors_list) + 1)]
183
  bins[0] = 0
184
+
185
  y_cat = pd.cut(y,
186
  bins=bins,
187
  labels=self.colors_list,
188
  include_lowest=True
189
  )
190
  print(y_cat.value_counts())
191
+
192
+ y_h_cat[y_s == 0] = 'Gray'
193
+ y_h_cat[y_s == 100] = 'Gray'
194
+ y_h_cat[y_v == 0] = 'Gray'
195
+ y_h_cat[y_v == 100] = 'Gray'
196
+
197
+ print(y_cat.value_counts())
198
  x_train, x_val, y_train, y_val = train_test_split(X, y_cat, test_size=0.2)
199
  else:
200
  if extremes:
 
575
  with dnnlib.util.open_url(model_file) as f:
576
  model = legacy.load_network_pkl(f)['G_ema'] # type: ignore
577
 
578
+
579
+ # colors_list = ['Red', 'Orange', 'Yellow', 'Yellow Green', 'Chartreuse Green',
580
+ # 'Kelly Green', 'Green Blue Seafoam', 'Cyan Blue',
581
+ # 'Warm Blue', 'Indigo', 'Purple Magenta', 'Magenta Pink']
582
+ # colors_list = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue',
583
+ # 'Blue', 'Purple', 'Pink']
584
+ colors_list = ['Gray', 'Red', 'Yellow', 'Green', 'Cyan',
585
+ 'Blue', 'Magenta']
586
 
587
  scores = []
588
  kwargs = {'CL method':['LR', 'SVM'], 'C':[0.1, 1], 'sign':[True, False],
Home.py CHANGED
@@ -6,24 +6,28 @@ st.set_page_config(layout='wide')
6
 
7
  st.title('About')
8
 
 
9
  # INTRO
10
- intro_text = """This project investigates the nature and nurture of latent spaces, with the aim of formulating a theory of this particular vectorial space. It draws together reflections on the inherent constraints of latent spaces in particular architectures and considers the learning-specific features that emerge.
 
11
  The thesis concentrates mostly on the second part, exploring different avenues for understanding the space. Using a multitude of vision generative models, it discusses possibilities for the systematic exploration of space, including disentanglement properties and coverage of various guidance methods.
12
  It also explores the possibility of comparison across latent spaces and investigates the differences and commonalities across different learning experiments. Furthermore, the thesis investigates the role of stochasticity in newer models.
13
  As a case study, this thesis adopts art historical data, spanning classic art, photography, and modern and contemporary art.
14
-
15
- The project aims to interpret the StyleGAN2 model by several techniques.
16
- > “What concepts are disentangled in the latent space of StyleGAN2”\n
17
- > “Can we quantify the complexity of such concepts?”.
18
-
19
  """
20
  st.write(intro_text)
 
 
 
 
 
 
21
 
22
  # 4 PAGES
23
  st.subheader('Pages')
24
- sections_text = """Overall, there are x features in this web app:
25
- 1) Disentanglement visualizer
26
- 2) Concept vectors comparison and aggregation
 
27
  ...
28
  """
29
  st.write(sections_text)
 
6
 
7
  st.title('About')
8
 
9
+ st.subheader('General aim of the Ph.D. (to be updated)')
10
  # INTRO
11
+ intro_text = """
12
+ This project investigates the nature and nurture of latent spaces, with the aim of formulating a theory of this particular vectorial space. It draws together reflections on the inherent constraints of latent spaces in particular architectures and considers the learning-specific features that emerge.
13
  The thesis concentrates mostly on the second part, exploring different avenues for understanding the space. Using a multitude of vision generative models, it discusses possibilities for the systematic exploration of space, including disentanglement properties and coverage of various guidance methods.
14
  It also explores the possibility of comparison across latent spaces and investigates the differences and commonalities across different learning experiments. Furthermore, the thesis investigates the role of stochasticity in newer models.
15
  As a case study, this thesis adopts art historical data, spanning classic art, photography, and modern and contemporary art.
 
 
 
 
 
16
  """
17
  st.write(intro_text)
18
+ st.subheader('On this experiment')
19
+ st.write(
20
+ """The project aims to interpret the StyleGAN3 model trained on Textiles using disentanglement methods.
21
+ > “What features are disentangled in the latent space of StyleGAN3”\n
22
+ > “Can we quantify the complexity, quality and relations of such features?”.
23
+ """)
24
 
25
  # 4 PAGES
26
  st.subheader('Pages')
27
+ sections_text = """Overall, there are 3 features in this web app:
28
+ 1) Textiles manipulation
29
+ 2) Features comparison
30
+ 3) Vectors algebra manipulation
31
  ...
32
  """
33
  st.write(sections_text)
README.md CHANGED
@@ -13,4 +13,5 @@ pinned: false
13
 
14
  To be change name: latent-space-theory
15
 
16
- This app was built with Streamlit. To run the app, `streamlit run Home.py` in the terminal.
 
 
13
 
14
  To be change name: latent-space-theory
15
 
16
+ This app was built with Streamlit. To run the app, `streamlit run Home.py` in the terminal.
17
+ python -m streamlit run Home.py
backend/disentangle_concepts.py CHANGED
@@ -7,7 +7,7 @@ from PIL import Image
7
 
8
 
9
 
10
- def generate_composite_images(model, z, decision_boundaries, lambdas, latent_space='W'):
11
  """
12
  The regenerate_images function takes a model, z, and decision_boundary as input. It then
13
  constructs an inverse rotation/translation matrix and passes it to the generator. The generator
@@ -33,9 +33,19 @@ def generate_composite_images(model, z, decision_boundaries, lambdas, latent_spa
33
  repetitions = 16
34
  z_0 = z
35
 
36
- for decision_boundary, lmbd in zip(decision_boundaries, lambdas):
37
- decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
38
- z_0 = z_0 + int(lmbd) * decision_boundary
 
 
 
 
 
 
 
 
 
 
39
 
40
 
41
  if latent_space == 'Z':
 
7
 
8
 
9
 
10
+ def generate_composite_images(model, z, decision_boundaries, lambdas, latent_space='W', negative_colors=None):
11
  """
12
  The regenerate_images function takes a model, z, and decision_boundary as input. It then
13
  constructs an inverse rotation/translation matrix and passes it to the generator. The generator
 
33
  repetitions = 16
34
  z_0 = z
35
 
36
+ if negative_colors:
37
+ for decision_boundary, lmbd, neg_boundary in zip(decision_boundaries, lambdas, negative_colors):
38
+ decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
39
+ if neg_boundary != 'None':
40
+ neg_boundary = torch.from_numpy(neg_boundary.copy()).to(device)
41
+
42
+ z_0 = z_0 + int(lmbd) * (decision_boundary - (neg_boundary.T * decision_boundary) * neg_boundary)
43
+ else:
44
+ z_0 = z_0 + int(lmbd) * decision_boundary
45
+ else:
46
+ for decision_boundary, lmbd in zip(decision_boundaries, lambdas):
47
+ decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
48
+ z_0 = z_0 + int(lmbd) * decision_boundary
49
 
50
 
51
  if latent_space == 'Z':
config.toml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [server]
2
+ enableCORS = false
3
+ headless = true
4
+
5
+ [browser]
6
+ serverAddress = 0.0.0.0
7
+ gatherUsageStats = false
8
+ serverPort = 8501
data/grammar_ornaments/1_colors_generally.png ADDED

Git LFS Details

  • SHA256: 8a486eb08408a776bf685267727e07012e03e9d58d2cd00cc52a17a0f4583deb
  • Pointer size: 131 Bytes
  • Size of remote file: 549 kB
data/grammar_ornaments/2_proportions_and_contrasts.png ADDED

Git LFS Details

  • SHA256: b70781936b5a79e4889de59c2012973fe72bcfa0576a6df15658f8c4dd3c8f29
  • Pointer size: 131 Bytes
  • Size of remote file: 608 kB
data/grammar_ornaments/3_positions_simultanous.png ADDED

Git LFS Details

  • SHA256: e9c7a4c0cf7c9e0de1e2c7a39793e518ef453fcb57ca52802c2c04e8b525c68a
  • Pointer size: 131 Bytes
  • Size of remote file: 582 kB
data/grammar_ornaments/4_juxtapositions.png ADDED

Git LFS Details

  • SHA256: 99ba0ce71ab96aa76dd924dd3d384112d7aa23ca858f4ec48474f5c3c67485cd
  • Pointer size: 131 Bytes
  • Size of remote file: 478 kB
grammar_ornament_test.ipynb ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os \n",
10
+ "from glob import glob \n",
11
+ "import pandas as pd\n",
12
+ "import numpy as np\n",
13
+ "\n",
14
+ "from PIL import Image, ImageColor\n",
15
+ "import extcolors\n",
16
+ "\n",
17
+ "import matplotlib.pyplot as plt\n",
18
+ "\n",
19
+ "import torch\n",
20
+ "\n",
21
+ "import dnnlib \n",
22
+ "import legacy\n",
23
+ "\n",
24
+ "\n",
25
+ "%load_ext autoreload\n",
26
+ "%autoreload 2"
27
+ ]
28
+ },
29
+ {
30
+ "cell_type": "code",
31
+ "execution_count": null,
32
+ "metadata": {},
33
+ "outputs": [],
34
+ "source": [
35
+ "images_textiles = glob('/Users/ludovicaschaerf/Desktop/TextAIles/TextileGAN/Original Textiles/*')"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "markdown",
40
+ "metadata": {},
41
+ "source": [
42
+ "### LAWS\n",
43
+ "\n",
44
+ "1. primary colours on small surfaces and secondary or tertiary colors on large backgrounds\n",
45
+ "2. primary in upper portions and sec/third in lower portions of objects\n",
46
+ "3. primaries of equal intensities harmonize, secondaries harmonized by opposite primary in equal intensity, tertiary by remaining secondary\n",
47
+ "4. a full colors contrasted by a lower tone color should have the latter in larger proportion\n",
48
+ "5. when a primary has a hue (second coloration) of another primary, the secondary must have the hue of the third primary\n",
49
+ "6. blue in concave surfaces, yellow in convex, red in undersites\n",
50
+ "7. if too much of a color, the other colors should have the hue version without that color\n",
51
+ "8. all three primaries should be present\n",
52
+ "9. ..."
53
+ ]
54
+ },
55
+ {
56
+ "cell_type": "markdown",
57
+ "metadata": {},
58
+ "source": [
59
+ "Test 1\n",
60
+ "\n",
61
+ "primary - secondary - tertiary "
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "execution_count": null,
67
+ "metadata": {},
68
+ "outputs": [],
69
+ "source": [
70
+ "def get_color_rank(hue, saturation, value):\n",
71
+ " if value < 5:\n",
72
+ " color = 'Black'\n",
73
+ " rank = 'None'\n",
74
+ " elif saturation < 3:\n",
75
+ " color = 'White'\n",
76
+ " rank = 'None'\n",
77
+ " elif saturation < 15:\n",
78
+ " color = 'Gray'\n",
79
+ " rank = 'None'\n",
80
+ " elif hue == 0:\n",
81
+ " color = 'Gray'\n",
82
+ " rank = 'None'\n",
83
+ " \n",
84
+ " elif hue >= 330 or hue <= 15:\n",
85
+ " color = 'Red'\n",
86
+ " rank = 'Primary'\n",
87
+ " elif hue > 15 and hue < 25:\n",
88
+ " color = 'Red Orange'\n",
89
+ " rank = 'Tertiary'\n",
90
+ " elif hue >= 25 and hue <= 40:\n",
91
+ " color = 'Orange'\n",
92
+ " rank = 'Secondary'\n",
93
+ " elif hue > 40 and hue < 50:\n",
94
+ " color = 'Orange Yellow'\n",
95
+ " rank = 'Tertiary'\n",
96
+ " elif hue >= 50 and hue <= 85:\n",
97
+ " color = 'Yellow'\n",
98
+ " rank = 'Primary'\n",
99
+ " elif hue > 85 and hue < 95:\n",
100
+ " color = 'Yellow Green'\n",
101
+ " rank = 'Tertiary'\n",
102
+ " elif hue >= 95 and hue <= 145:\n",
103
+ " color = 'Green'\n",
104
+ " rank = 'Secondary'\n",
105
+ " elif hue >= 145 and hue < 180:\n",
106
+ " color = 'Green Blue'\n",
107
+ " rank = 'Tertiary'\n",
108
+ " elif hue >= 180 and hue <= 245:\n",
109
+ " color = 'Blue'\n",
110
+ " rank = 'Primary'\n",
111
+ " elif hue > 245 and hue < 265:\n",
112
+ " color = 'Blue Violet'\n",
113
+ " rank = 'Tertiary'\n",
114
+ " elif hue >= 265 and hue <= 290:\n",
115
+ " color = 'Violet'\n",
116
+ " rank = 'Secondary'\n",
117
+ " elif hue > 290 and hue < 330:\n",
118
+ " color = 'Violet Red'\n",
119
+ " rank = 'Tertiary'\n",
120
+ " return color, rank"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "def rgb2hsv(r, g, b):\n",
130
+ " # Normalize R, G, B values\n",
131
+ " r, g, b = r / 255.0, g / 255.0, b / 255.0\n",
132
+ " \n",
133
+ " # h, s, v = hue, saturation, value\n",
134
+ " max_rgb = max(r, g, b) \n",
135
+ " min_rgb = min(r, g, b) \n",
136
+ " difference = max_rgb-min_rgb \n",
137
+ " \n",
138
+ " # if max_rgb and max_rgb are equal then h = 0\n",
139
+ " if max_rgb == min_rgb:\n",
140
+ " h = 0\n",
141
+ " \n",
142
+ " # if max_rgb==r then h is computed as follows\n",
143
+ " elif max_rgb == r:\n",
144
+ " h = (60 * ((g - b) / difference) + 360) % 360\n",
145
+ " \n",
146
+ " # if max_rgb==g then compute h as follows\n",
147
+ " elif max_rgb == g:\n",
148
+ " h = (60 * ((b - r) / difference) + 120) % 360\n",
149
+ " \n",
150
+ " # if max_rgb=b then compute h\n",
151
+ " elif max_rgb == b:\n",
152
+ " h = (60 * ((r - g) / difference) + 240) % 360\n",
153
+ " \n",
154
+ " # if max_rgb==zero then s=0\n",
155
+ " if max_rgb == 0:\n",
156
+ " s = 0\n",
157
+ " else:\n",
158
+ " s = (difference / max_rgb) * 100\n",
159
+ " \n",
160
+ " # compute v\n",
161
+ " v = max_rgb * 100\n",
162
+ " # return rounded values of H, S and V\n",
163
+ " return tuple(map(round, (h, s, v)))\n",
164
+ " "
165
+ ]
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "execution_count": null,
170
+ "metadata": {},
171
+ "outputs": [],
172
+ "source": [
173
+ "def obtain_hsv_colors(img):\n",
174
+ " colors = extcolors.extract_from_path(img, tolerance=7, limit=7)\n",
175
+ " colors = [(rgb2hsv(h[0][0], h[0][1], h[0][2]), h[1]) for h in colors[0] if h[0] != (0,0,0)]\n",
176
+ " return colors"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "colors = obtain_hsv_colors(images_textiles[0])\n",
186
+ "print(colors)"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "for col in colors:\n",
196
+ " print(get_color_rank(*col[0]))"
197
+ ]
198
+ },
199
+ {
200
+ "cell_type": "code",
201
+ "execution_count": null,
202
+ "metadata": {},
203
+ "outputs": [],
204
+ "source": [
205
+ "for img in images_textiles[:30]:\n",
206
+ " colors = obtain_hsv_colors(img)\n",
207
+ " plt.imshow(plt.imread(img))\n",
208
+ " plt.show()\n",
209
+ " for col in colors:\n",
210
+ " print(col[0])\n",
211
+ " print(get_color_rank(*col[0]))\n",
212
+ " \n",
213
+ " print()"
214
+ ]
215
+ },
216
+ {
217
+ "cell_type": "markdown",
218
+ "metadata": {},
219
+ "source": [
220
+ "### use for training only images with medium saturation and value\n",
221
+ "\n",
222
+ "use codes and not only hue for color categorization\n",
223
+ "or remove colors that are creater with black and whites"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": null,
229
+ "metadata": {},
230
+ "outputs": [],
231
+ "source": []
232
+ }
233
+ ],
234
+ "metadata": {
235
+ "kernelspec": {
236
+ "display_name": "art-reco_x86",
237
+ "language": "python",
238
+ "name": "python3"
239
+ },
240
+ "language_info": {
241
+ "codemirror_mode": {
242
+ "name": "ipython",
243
+ "version": 3
244
+ },
245
+ "file_extension": ".py",
246
+ "mimetype": "text/x-python",
247
+ "name": "python",
248
+ "nbconvert_exporter": "python",
249
+ "pygments_lexer": "ipython3",
250
+ "version": "3.8.16"
251
+ },
252
+ "orig_nbformat": 4
253
+ },
254
+ "nbformat": 4,
255
+ "nbformat_minor": 2
256
+ }
interfacegan_colour_disentanglement.ipynb CHANGED
@@ -15,6 +15,7 @@
15
  "\n",
16
  "from PIL import Image, ImageColor\n",
17
  "import matplotlib.pyplot as plt\n",
 
18
  "\n",
19
  "import numpy as np\n",
20
  "import torch\n",
@@ -33,6 +34,87 @@
33
  "%autoreload 2"
34
  ]
35
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  {
37
  "cell_type": "code",
38
  "execution_count": null,
@@ -43,6 +125,21 @@
43
  "num_colors = 7"
44
  ]
45
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  {
47
  "cell_type": "code",
48
  "execution_count": null,
@@ -50,8 +147,7 @@
50
  "metadata": {},
51
  "outputs": [],
52
  "source": [
53
- "values = [x*256/num_colors if x<num_colors else 256 for x in range(num_colors + 1)]\n",
54
- "centers = [int((values[i-1]+values[i])/2) for i in range(len(values)) if i > 0]"
55
  ]
56
  },
57
  {
@@ -61,7 +157,7 @@
61
  "metadata": {},
62
  "outputs": [],
63
  "source": [
64
- "print(values)\n",
65
  "print(centers)"
66
  ]
67
  },
@@ -100,9 +196,9 @@
100
  "metadata": {},
101
  "outputs": [],
102
  "source": [
103
- "def to_256(val):\n",
104
- " x = val*360/256\n",
105
- " return x"
106
  ]
107
  },
108
  {
@@ -112,9 +208,7 @@
112
  "metadata": {},
113
  "outputs": [],
114
  "source": [
115
- "names = ['Red', 'Orange', 'Yellow', 'Yellow Green', 'Chartreuse Green',\n",
116
- " 'Kelly Green', 'Green Blue Seafoam', 'Cyan Blue',\n",
117
- " 'Warm Blue', 'Indigo', 'Purple Magenta', 'Magenta Pink']"
118
  ]
119
  },
120
  {
@@ -127,7 +221,7 @@
127
  "saturation = 1 # Saturation value (0 to 1)\n",
128
  "value = 1 # Value (brightness) value (0 to 1)\n",
129
  "for hue, name in zip(centers, names[:num_colors]):\n",
130
- " image = create_color_image(to_256(hue), saturation, value)\n",
131
  " display_image(image, name) # Display the generated color image"
132
  ]
133
  },
@@ -148,6 +242,37 @@
148
  " model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore\n"
149
  ]
150
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  {
152
  "cell_type": "code",
153
  "execution_count": null,
@@ -155,7 +280,6 @@
155
  "metadata": {},
156
  "outputs": [],
157
  "source": [
158
- "ann_df = tohsv(ann_df)\n",
159
  "ann_df.head()"
160
  ]
161
  },
@@ -291,9 +415,7 @@
291
  "metadata": {},
292
  "outputs": [],
293
  "source": [
294
- "colors_list = ['Warm Pink Red', 'Red Orange', 'Orange Yellow', 'Gold Yellow', 'Chartreuse Green',\n",
295
- " 'Kelly Green', 'Green Blue Seafoam', 'Blue Green Cyan',\n",
296
- " 'Warm Blue', 'Indigo Blue Purple', 'Purple Magenta', 'Magenta Pink']"
297
  ]
298
  },
299
  {
@@ -313,10 +435,17 @@
313
  "source": [
314
  "from sklearn import svm\n",
315
  "\n",
316
- "print([int(x*256/12) if x<12 else 255 for x in range(13)])\n",
317
- "y_h_cat = pd.cut(y_h,bins=[x*256/12 if x<12 else 256 for x in range(13)],labels=colors_list).fillna('Warm Pink Red')\n",
318
  "\n",
319
  "print(y_h_cat.value_counts(dropna=False))\n",
 
 
 
 
 
 
 
 
320
  "x_trainhc, x_valhc, y_trainhc, y_valhc = train_test_split(X, y_h_cat, test_size=0.2)"
321
  ]
322
  },
 
15
  "\n",
16
  "from PIL import Image, ImageColor\n",
17
  "import matplotlib.pyplot as plt\n",
18
+ "from sklearn.model_selection import train_test_split\n",
19
  "\n",
20
  "import numpy as np\n",
21
  "import torch\n",
 
34
  "%autoreload 2"
35
  ]
36
  },
37
+ {
38
+ "cell_type": "markdown",
39
+ "id": "03efb8c0",
40
+ "metadata": {},
41
+ "source": [
42
+ "0-60\n",
43
+ "\n",
44
+ "Red\n",
45
+ "\n",
46
+ "60-120\n",
47
+ "\n",
48
+ "Yellow\n",
49
+ "\n",
50
+ "120-180\n",
51
+ "\n",
52
+ "Green\n",
53
+ "\n",
54
+ "180-240\n",
55
+ "\n",
56
+ "Cyan\n",
57
+ "\n",
58
+ "240-300\n",
59
+ "\n",
60
+ "Blue\n",
61
+ "\n",
62
+ "300-360\n",
63
+ "\n",
64
+ "Magenta\n",
65
+ "\n",
66
+ "Standard classification"
67
+ ]
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "execution_count": null,
72
+ "id": "00a35126",
73
+ "metadata": {},
74
+ "outputs": [],
75
+ "source": [
76
+ "def hex2rgb(hex_value):\n",
77
+ " h = hex_value.strip(\"#\") \n",
78
+ " rgb = tuple(int(h[i:i+2], 16) for i in (0, 2, 4))\n",
79
+ " return rgb\n",
80
+ "\n",
81
+ "def rgb2hsv(r, g, b):\n",
82
+ " # Normalize R, G, B values\n",
83
+ " r, g, b = r / 255.0, g / 255.0, b / 255.0\n",
84
+ " \n",
85
+ " # h, s, v = hue, saturation, value\n",
86
+ " max_rgb = max(r, g, b) \n",
87
+ " min_rgb = min(r, g, b) \n",
88
+ " difference = max_rgb-min_rgb \n",
89
+ " \n",
90
+ " # if max_rgb and max_rgb are equal then h = 0\n",
91
+ " if max_rgb == min_rgb:\n",
92
+ " h = 0\n",
93
+ " \n",
94
+ " # if max_rgb==r then h is computed as follows\n",
95
+ " elif max_rgb == r:\n",
96
+ " h = (60 * ((g - b) / difference) + 360) % 360\n",
97
+ " \n",
98
+ " # if max_rgb==g then compute h as follows\n",
99
+ " elif max_rgb == g:\n",
100
+ " h = (60 * ((b - r) / difference) + 120) % 360\n",
101
+ " \n",
102
+ " # if max_rgb=b then compute h\n",
103
+ " elif max_rgb == b:\n",
104
+ " h = (60 * ((r - g) / difference) + 240) % 360\n",
105
+ " \n",
106
+ " # if max_rgb==zero then s=0\n",
107
+ " if max_rgb == 0:\n",
108
+ " s = 0\n",
109
+ " else:\n",
110
+ " s = (difference / max_rgb) * 100\n",
111
+ " \n",
112
+ " # compute v\n",
113
+ " v = max_rgb * 100\n",
114
+ " # return rounded values of H, S and V\n",
115
+ " return tuple(map(round, (h, s, v)))"
116
+ ]
117
+ },
118
  {
119
  "cell_type": "code",
120
  "execution_count": null,
 
125
  "num_colors = 7"
126
  ]
127
  },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "id": "c8428918",
132
+ "metadata": {},
133
+ "outputs": [],
134
+ "source": [
135
+ "bins = [(x-1) * 360 / (num_colors - 1) if x != 1 \n",
136
+ " else 1 for x in range(num_colors + 1)]\n",
137
+ "bins[0] = 0\n",
138
+ "\n",
139
+ "bins\n",
140
+ " "
141
+ ]
142
+ },
143
  {
144
  "cell_type": "code",
145
  "execution_count": null,
 
147
  "metadata": {},
148
  "outputs": [],
149
  "source": [
150
+ "centers = [int((bins[i-1]+bins[i])/2) for i in range(len(bins)) if i > 0]"
 
151
  ]
152
  },
153
  {
 
157
  "metadata": {},
158
  "outputs": [],
159
  "source": [
160
+ "print(bins)\n",
161
  "print(centers)"
162
  ]
163
  },
 
196
  "metadata": {},
197
  "outputs": [],
198
  "source": [
199
+ "# def to_256(val):\n",
200
+ "# x = val*360/256\n",
201
+ "# return int(x)"
202
  ]
203
  },
204
  {
 
208
  "metadata": {},
209
  "outputs": [],
210
  "source": [
211
+ "names = ['Gray', 'Red', 'Yellow', 'Green', 'Cyan', 'Blue','Magenta']"
 
 
212
  ]
213
  },
214
  {
 
221
  "saturation = 1 # Saturation value (0 to 1)\n",
222
  "value = 1 # Value (brightness) value (0 to 1)\n",
223
  "for hue, name in zip(centers, names[:num_colors]):\n",
224
+ " image = create_color_image(hue, saturation, value)\n",
225
  " display_image(image, name) # Display the generated color image"
226
  ]
227
  },
 
242
  " model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore\n"
243
  ]
244
  },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": null,
248
+ "id": "065cd656",
249
+ "metadata": {},
250
+ "outputs": [],
251
+ "source": [
252
+ "from DisentanglementBase import DisentanglementBase"
253
+ ]
254
+ },
255
+ {
256
+ "cell_type": "code",
257
+ "execution_count": null,
258
+ "id": "afb8a611",
259
+ "metadata": {},
260
+ "outputs": [],
261
+ "source": [
262
+ "variable = 'H1'\n",
263
+ "disentanglemnet_exp = DisentanglementBase('.', model, annotations, ann_df, space='W', colors_list=names, compute_s=False, variable=variable)\n"
264
+ ]
265
+ },
266
+ {
267
+ "cell_type": "code",
268
+ "execution_count": null,
269
+ "id": "a7398217",
270
+ "metadata": {},
271
+ "outputs": [],
272
+ "source": [
273
+ "ann_df = disentanglemnet_exp.df"
274
+ ]
275
+ },
276
  {
277
  "cell_type": "code",
278
  "execution_count": null,
 
280
  "metadata": {},
281
  "outputs": [],
282
  "source": [
 
283
  "ann_df.head()"
284
  ]
285
  },
 
415
  "metadata": {},
416
  "outputs": [],
417
  "source": [
418
+ "colors_list = names"
 
 
419
  ]
420
  },
421
  {
 
435
  "source": [
436
  "from sklearn import svm\n",
437
  "\n",
438
+ "y_h_cat = pd.cut(y_h,bins=bins,labels=colors_list, include_lowest=True)\n",
 
439
  "\n",
440
  "print(y_h_cat.value_counts(dropna=False))\n",
441
+ "\n",
442
+ "y_h_cat[y_s == 0] = 'Gray'\n",
443
+ "y_h_cat[y_s == 100] = 'Gray'\n",
444
+ "y_h_cat[y_v == 0] = 'Gray'\n",
445
+ "y_h_cat[y_v == 100] = 'Gray'\n",
446
+ "\n",
447
+ "print(y_h_cat.value_counts(dropna=False))\n",
448
+ "\n",
449
  "x_trainhc, x_valhc, y_trainhc, y_valhc = train_test_split(X, y_h_cat, test_size=0.2)"
450
  ]
451
  },
pages/{1_Textiles_Disentanglement.py → 1_Textiles_Manipulation.py} RENAMED
@@ -20,10 +20,14 @@ BACKGROUND_COLOR = '#bcd0e7'
20
  SECONDARY_COLOR = '#bce7db'
21
 
22
 
23
- st.title('Disentanglement studies on the Textile Dataset')
24
  st.markdown(
25
  """
26
- This is a demo of the Disentanglement studies on the [iMET Textiles Dataset](https://www.metmuseum.org/art/collection/search/85531).
 
 
 
 
27
  """,
28
  unsafe_allow_html=False,)
29
 
@@ -49,7 +53,7 @@ with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pk
49
  COLORS_LIST = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue', 'Blue', 'Purple', 'Pink']
50
 
51
  if 'image_id' not in st.session_state:
52
- st.session_state.image_id = 0
53
  if 'color_ids' not in st.session_state:
54
  st.session_state.concept_ids = COLORS_LIST[-1]
55
  if 'space_id' not in st.session_state:
@@ -73,9 +77,6 @@ if 'num_factors' not in st.session_state:
73
  if 'best' not in st.session_state:
74
  st.session_state.best = True
75
 
76
- # def on_change_random_input():
77
- # st.session_state.image_id = st.session_state.image_id
78
-
79
  # ----------------------------- INPUT ----------------------------------
80
  st.header('Input')
81
  input_col_1, input_col_2, input_col_3, input_col_4 = st.columns(4)
@@ -83,8 +84,7 @@ input_col_1, input_col_2, input_col_3, input_col_4 = st.columns(4)
83
  with input_col_1:
84
  with st.form('image_form'):
85
 
86
- # image_id = st.number_input('Image ID: ', format='%d', step=1)
87
- st.write('**Choose or generate a random image to test the disentanglement**')
88
  chosen_image_id_input = st.empty()
89
  image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
90
 
@@ -103,16 +103,15 @@ with input_col_1:
103
  with input_col_2:
104
  with st.form('text_form_1'):
105
 
106
- st.write('**Choose color to vary**')
107
- type_col = st.selectbox('Color:', tuple(COLORS_LIST), index=7)
108
- colors_button = st.form_submit_button('Choose the defined color')
109
 
110
  st.write('**Set range of change**')
111
  chosen_color_lambda_input = st.empty()
112
  color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=-100, step=1, value=7)
113
- color_lambda_button = st.form_submit_button('Choose the defined lambda for color')
114
 
115
- if colors_button or color_lambda_button:
116
  st.session_state.image_id = image_id
117
  st.session_state.concept_ids = type_col
118
  st.session_state.color_lambda = color_lambda
@@ -121,45 +120,48 @@ with input_col_2:
121
  with input_col_3:
122
  with st.form('text_form'):
123
 
124
- st.write('**Saturation variation**')
125
  chosen_saturation_lambda_input = st.empty()
126
  saturation_lambda = chosen_saturation_lambda_input.number_input('Lambda:', min_value=-100, step=1, key=0, value=0)
127
- saturation_lambda_button = st.form_submit_button('Choose the defined lambda for saturation')
128
 
129
- st.write('**Value variation**')
130
  chosen_value_lambda_input = st.empty()
131
  value_lambda = chosen_value_lambda_input.number_input('Lambda:', min_value=-100, step=1, key=1, value=0)
132
- value_lambda_button = st.form_submit_button('Choose the defined lambda for salue')
133
 
134
- if saturation_lambda_button or value_lambda_button:
135
  st.session_state.saturation_lambda = int(saturation_lambda)
136
  st.session_state.value_lambda = int(value_lambda)
137
 
138
  with input_col_4:
139
  with st.form('text_form_2'):
140
- st.write('Use best options')
141
  best = st.selectbox('Option:', tuple([True, False]), index=0)
142
- st.write('Options for StyleSpace (not available for Saturation and Value)')
143
- sign = st.selectbox('Sign option:', tuple([True, False]), index=1)
144
- num_factors = st.selectbox('Number of factors option:', tuple([1, 5, 10, 20, False]), index=4)
145
- st.write('Options for InterFaceGAN (not available for Saturation and Value)')
146
- cl_method = st.selectbox('Classification method option:', tuple(['LR', 'SVM', False]), index=2)
147
- regularization = st.selectbox('Regularization option:', tuple([0.1, 1.0, False]), index=2)
148
- st.write('Options for InterFaceGAN (only for Saturation and Value)')
149
- extremes = st.selectbox('Extremes option:', tuple([True, False]), index=1)
150
-
 
 
 
 
 
 
151
  choose_options_button = st.form_submit_button('Choose the defined options')
152
- # st.write('**Choose a latent space to disentangle**')
153
- # # chosen_text_id_input = st.empty()
154
- # # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
155
- # space_id = st.selectbox('Space:', tuple(['Z', 'W']))
156
  if choose_options_button:
157
- st.session_state.sign = sign
158
- st.session_state.num_factors = num_factors
159
- st.session_state.cl_method = cl_method
160
- st.session_state.regularization = regularization
161
- st.session_state.extremes = extremes
162
  st.session_state.best = best
 
 
 
 
 
 
 
163
 
164
  # with input_col_4:
165
  # with st.form('Network specifics:'):
@@ -178,7 +180,7 @@ with input_col_4:
178
  # ---------------------------- SET UP OUTPUT ------------------------------
179
  epsilon_container = st.empty()
180
  st.header('Image Manipulation')
181
- st.subheader('Using selected directions')
182
 
183
  header_col_1, header_col_2 = st.columns([1,1])
184
  output_col_1, output_col_2 = st.columns([1,1])
@@ -193,7 +195,7 @@ output_col_1, output_col_2 = st.columns([1,1])
193
 
194
  # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
195
  with header_col_1:
196
- st.write(f'Original image')
197
 
198
  with header_col_2:
199
  if st.session_state.best:
@@ -209,8 +211,15 @@ with header_col_2:
209
  tmp_sat = concept_vectors[concept_vectors['color'] == 'Saturation'][concept_vectors['extremes'] == st.session_state.extremes]
210
  saturation_separation_vector, performance_saturation = tmp_sat.reset_index().loc[0, ['vector', 'score']]
211
 
212
- st.write(f'Change in {st.session_state.concept_ids} of {np.round(st.session_state.color_lambda, 2)}, in saturation of {np.round(st.session_state.saturation_lambda, 2)}, in value of {np.round(st.session_state.value_lambda, 2)}. - Performance color vector: {performance_color}, saturation vector: {performance_saturation/100}, value vector: {performance_value/100}')
213
-
 
 
 
 
 
 
 
214
  # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
215
 
216
  if st.session_state.space_id == 'Z':
@@ -226,3 +235,4 @@ with output_col_1:
226
  with output_col_2:
227
  image_updated = generate_composite_images(model, original_image_vec, [color_separation_vector, saturation_separation_vector, value_separation_vector], lambdas=[st.session_state.color_lambda, st.session_state.saturation_lambda, st.session_state.value_lambda])
228
  st.image(image_updated)
 
 
20
  SECONDARY_COLOR = '#bce7db'
21
 
22
 
23
+ st.title('Disentanglement on Textile Datasets')
24
  st.markdown(
25
  """
26
+ This is a demo of the Disentanglement experiment on the [iMET Textiles Dataset](https://www.metmuseum.org/art/collection/search/85531).
27
+
28
+ In this page, the user can adjust the colors of textile images generated by an AI by simply traversing the latent space of the AI.
29
+ The colors can be adjusted following the human-intuitive encoding of HSV, adjusting the main Hue of the image with an option of 7 colors + Gray,
30
+ the saturation (the amount of Gray) and the value of the image (the amount of Black).
31
  """,
32
  unsafe_allow_html=False,)
33
 
 
53
  COLORS_LIST = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue', 'Blue', 'Purple', 'Pink']
54
 
55
  if 'image_id' not in st.session_state:
56
+ st.session_state.image_id = 52921
57
  if 'color_ids' not in st.session_state:
58
  st.session_state.concept_ids = COLORS_LIST[-1]
59
  if 'space_id' not in st.session_state:
 
77
  if 'best' not in st.session_state:
78
  st.session_state.best = True
79
 
 
 
 
80
  # ----------------------------- INPUT ----------------------------------
81
  st.header('Input')
82
  input_col_1, input_col_2, input_col_3, input_col_4 = st.columns(4)
 
84
  with input_col_1:
85
  with st.form('image_form'):
86
 
87
+ st.write('**Choose or generate a random base image**')
 
88
  chosen_image_id_input = st.empty()
89
  image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
90
 
 
103
  with input_col_2:
104
  with st.form('text_form_1'):
105
 
106
+ st.write('**Choose hue to vary**')
107
+ type_col = st.selectbox('Hue:', tuple(COLORS_LIST), index=7)
 
108
 
109
  st.write('**Set range of change**')
110
  chosen_color_lambda_input = st.empty()
111
  color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=-100, step=1, value=7)
112
+ color_lambda_button = st.form_submit_button('Choose the defined hue and lambda')
113
 
114
+ if color_lambda_button:
115
  st.session_state.image_id = image_id
116
  st.session_state.concept_ids = type_col
117
  st.session_state.color_lambda = color_lambda
 
120
  with input_col_3:
121
  with st.form('text_form'):
122
 
123
+ st.write('**Choose saturation variation**')
124
  chosen_saturation_lambda_input = st.empty()
125
  saturation_lambda = chosen_saturation_lambda_input.number_input('Lambda:', min_value=-100, step=1, key=0, value=0)
 
126
 
127
+ st.write('**Choose value variation**')
128
  chosen_value_lambda_input = st.empty()
129
  value_lambda = chosen_value_lambda_input.number_input('Lambda:', min_value=-100, step=1, key=1, value=0)
130
+ value_lambda_button = st.form_submit_button('Choose the defined lambda for value and saturation')
131
 
132
+ if value_lambda_button:
133
  st.session_state.saturation_lambda = int(saturation_lambda)
134
  st.session_state.value_lambda = int(value_lambda)
135
 
136
  with input_col_4:
137
  with st.form('text_form_2'):
138
+ st.write('Use the best vectors (after hyperparameter tuning)')
139
  best = st.selectbox('Option:', tuple([True, False]), index=0)
140
+ sign = True
141
+ num_factors=10
142
+ cl_method='LR'
143
+ regularization=0.1
144
+ extremes=True
145
+ if st.session_state.best is False:
146
+ st.write('Options for StyleSpace (not available for Saturation and Value)')
147
+ sign = st.selectbox('Sign option:', tuple([True, False]), index=1)
148
+ num_factors = st.selectbox('Number of factors option:', tuple([1, 5, 10, 20, False]), index=4)
149
+ st.write('Options for InterFaceGAN (not available for Saturation and Value)')
150
+ cl_method = st.selectbox('Classification method option:', tuple(['LR', 'SVM', False]), index=2)
151
+ regularization = st.selectbox('Regularization option:', tuple([0.1, 1.0, False]), index=2)
152
+ st.write('Options for InterFaceGAN (only for Saturation and Value)')
153
+ extremes = st.selectbox('Extremes option:', tuple([True, False]), index=1)
154
+
155
  choose_options_button = st.form_submit_button('Choose the defined options')
 
 
 
 
156
  if choose_options_button:
 
 
 
 
 
157
  st.session_state.best = best
158
+ if st.session_state.best is False:
159
+ st.session_state.sign = sign
160
+ st.session_state.num_factors = num_factors
161
+ st.session_state.cl_method = cl_method
162
+ st.session_state.regularization = regularization
163
+ st.session_state.extremes = extremes
164
+
165
 
166
  # with input_col_4:
167
  # with st.form('Network specifics:'):
 
180
  # ---------------------------- SET UP OUTPUT ------------------------------
181
  epsilon_container = st.empty()
182
  st.header('Image Manipulation')
183
+ st.write('Using selected vectors to modify the original image...')
184
 
185
  header_col_1, header_col_2 = st.columns([1,1])
186
  output_col_1, output_col_2 = st.columns([1,1])
 
195
 
196
  # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
197
  with header_col_1:
198
+ st.write(f'### Original image')
199
 
200
  with header_col_2:
201
  if st.session_state.best:
 
211
  tmp_sat = concept_vectors[concept_vectors['color'] == 'Saturation'][concept_vectors['extremes'] == st.session_state.extremes]
212
  saturation_separation_vector, performance_saturation = tmp_sat.reset_index().loc[0, ['vector', 'score']]
213
 
214
+ st.write('### Modified image')
215
+ st.write(f"""
216
+ Change in hue: {st.session_state.concept_ids} of amount: {np.round(st.session_state.color_lambda, 2)},
217
+ in: saturation of amount: {np.round(st.session_state.saturation_lambda, 2)},
218
+ in: value of amount: {np.round(st.session_state.value_lambda, 2)}.\
219
+ Verification performance of hue vector: {performance_color},
220
+ saturation vector: {performance_saturation/100},
221
+ value vector: {performance_value/100}""")
222
+
223
  # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
224
 
225
  if st.session_state.space_id == 'Z':
 
235
  with output_col_2:
236
  image_updated = generate_composite_images(model, original_image_vec, [color_separation_vector, saturation_separation_vector, value_separation_vector], lambdas=[st.session_state.color_lambda, st.session_state.saturation_lambda, st.session_state.value_lambda])
237
  st.image(image_updated)
238
+
pages/{2_Colours_comparison.py → 2_Network_comparison.py} RENAMED
@@ -24,7 +24,11 @@ st.set_page_config(layout='wide')
24
 
25
  st.title('Comparison among color directions')
26
  st.write('> **How do the color directions relate to each other?**')
27
- st.write('> **What is their joint impact on the image?**')
 
 
 
 
28
 
29
 
30
  annotations_file = './data/textile_annotated_files/seeds0000-100000_S.pkl'
@@ -46,10 +50,8 @@ with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pk
46
 
47
  COLORS_LIST = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue', 'Blue', 'Purple', 'Pink', 'Saturation', 'Value']
48
 
49
- if 'image_id' not in st.session_state:
50
- st.session_state.image_id = 0
51
  if 'concept_ids' not in st.session_state:
52
- st.session_state.concept_ids = [COLORS_LIST[-1], COLORS_LIST[-2], ]
53
  if 'sign' not in st.session_state:
54
  st.session_state.sign = False
55
  if 'extremes' not in st.session_state:
@@ -60,10 +62,8 @@ if 'cl_method' not in st.session_state:
60
  st.session_state.cl_method = False
61
  if 'num_factors' not in st.session_state:
62
  st.session_state.num_factors = False
63
-
64
-
65
- if 'space_id' not in st.session_state:
66
- st.session_state.space_id = 'W'
67
 
68
  # ----------------------------- INPUT ----------------------------------
69
  st.header('Input')
@@ -76,7 +76,7 @@ with input_col_1:
76
  st.write('**Choose a series of colors to compare**')
77
  # chosen_text_id_input = st.empty()
78
  # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
79
- concept_ids = st.multiselect('Color (including Saturation and Value):', tuple(COLORS_LIST), default=[COLORS_LIST[-1], COLORS_LIST[-2], ])
80
  choose_text_button = st.form_submit_button('Choose the defined colors')
81
 
82
  if choose_text_button:
@@ -85,27 +85,33 @@ with input_col_1:
85
 
86
  with input_col_2:
87
  with st.form('text_form_1'):
88
- st.write('Options for StyleSpace (not available for Saturation and Value)')
89
- sign = st.selectbox('Sign option:', tuple([True, False]), index=1)
90
- num_factors = st.selectbox('Number of factors option:', tuple([1, 5, 10, 20, False]), index=4)
91
- st.write('Options for InterFaceGAN (not available for Saturation and Value)')
92
- cl_method = st.selectbox('Classification method option:', tuple(['LR', 'SVM', False]), index=2)
93
- regularization = st.selectbox('Regularization option:', tuple([0.1, 1.0, False]), index=2)
94
- st.write('Options for InterFaceGAN (only for Saturation and Value)')
95
- extremes = st.selectbox('Extremes option:', tuple([True, False]), index=1)
96
-
 
 
 
 
 
 
 
 
97
  choose_options_button = st.form_submit_button('Choose the defined options')
98
- # st.write('**Choose a latent space to disentangle**')
99
- # # chosen_text_id_input = st.empty()
100
- # # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
101
- # space_id = st.selectbox('Space:', tuple(['Z', 'W']))
102
  if choose_options_button:
103
- st.session_state.sign = sign
104
- st.session_state.num_factors = num_factors
105
- st.session_state.cl_method = cl_method
106
- st.session_state.regularization = regularization
107
- st.session_state.extremes = extremes
108
-
 
 
109
  # ---------------------------- SET UP OUTPUT ------------------------------
110
  epsilon_container = st.empty()
111
  st.header('Comparison')
@@ -115,23 +121,28 @@ header_col_1, header_col_2 = st.columns([3,1])
115
  output_col_1, output_col_2 = st.columns([3,1])
116
 
117
  # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
118
- tmp = concept_vectors[concept_vectors['color'].isin(st.session_state.concept_ids)]
119
- tmp = tmp[tmp['sign'] == st.session_state.sign][tmp['extremes'] == st.session_state.extremes][tmp['num_factors'] == st.session_state.num_factors][tmp['cl_method'] == st.session_state.cl_method][tmp['regularization'] == st.session_state.regularization]
 
 
 
 
120
  info = tmp.loc[:, ['vector', 'score', 'color', 'kwargs']].values
121
  concept_ids = [i[2] for i in info] #+ ' ' + i[3]
122
 
123
  with header_col_1:
124
- st.write('Similarity graph')
125
 
126
  with header_col_2:
127
- st.write('Information')
128
 
129
  with output_col_2:
130
  for i,concept_id in enumerate(concept_ids):
131
- st.write(f'Color {info[i][2]} - Settings: {info[i][3]} Performance of the color vector: {info[i][1]}')# - Nodes {",".join(list(imp_nodes))}')
 
 
132
 
133
  with output_col_1:
134
-
135
  edges = []
136
  for i in range(len(concept_ids)):
137
  for j in range(len(concept_ids)):
 
24
 
25
  st.title('Comparison among color directions')
26
  st.write('> **How do the color directions relate to each other?**')
27
+ st.write("""
28
+ This page provides a simple network-based framework to inspect the vector similarity (cosine similarity) among the found color vectors.
29
+ The nodes are the colors chosen for comparison and the strength of the edge represents the similarity.
30
+
31
+ """)
32
 
33
 
34
  annotations_file = './data/textile_annotated_files/seeds0000-100000_S.pkl'
 
50
 
51
  COLORS_LIST = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue', 'Blue', 'Purple', 'Pink', 'Saturation', 'Value']
52
 
 
 
53
  if 'concept_ids' not in st.session_state:
54
+ st.session_state.concept_ids = COLORS_LIST
55
  if 'sign' not in st.session_state:
56
  st.session_state.sign = False
57
  if 'extremes' not in st.session_state:
 
62
  st.session_state.cl_method = False
63
  if 'num_factors' not in st.session_state:
64
  st.session_state.num_factors = False
65
+ if 'best' not in st.session_state:
66
+ st.session_state.best = True
 
 
67
 
68
  # ----------------------------- INPUT ----------------------------------
69
  st.header('Input')
 
76
  st.write('**Choose a series of colors to compare**')
77
  # chosen_text_id_input = st.empty()
78
  # concept_id = chosen_text_id_input.text_input('Concept:', value=st.session_state.concept_id)
79
+ concept_ids = st.multiselect('Color (including Saturation and Value):', tuple(COLORS_LIST), default=COLORS_LIST)
80
  choose_text_button = st.form_submit_button('Choose the defined colors')
81
 
82
  if choose_text_button:
 
85
 
86
  with input_col_2:
87
  with st.form('text_form_1'):
88
+ st.write('Use the best vectors (after hyperparameter tuning)')
89
+ best = st.selectbox('Option:', tuple([True, False]), index=0)
90
+ sign = True
91
+ num_factors=10
92
+ cl_method='LR'
93
+ regularization=0.1
94
+ extremes=True
95
+ if st.session_state.best is False:
96
+ st.write('Options for StyleSpace (not available for Saturation and Value)')
97
+ sign = st.selectbox('Sign option:', tuple([True, False]), index=1)
98
+ num_factors = st.selectbox('Number of factors option:', tuple([1, 5, 10, 20, False]), index=4)
99
+ st.write('Options for InterFaceGAN (not available for Saturation and Value)')
100
+ cl_method = st.selectbox('Classification method option:', tuple(['LR', 'SVM', False]), index=2)
101
+ regularization = st.selectbox('Regularization option:', tuple([0.1, 1.0, False]), index=2)
102
+ st.write('Options for InterFaceGAN (only for Saturation and Value)')
103
+ extremes = st.selectbox('Extremes option:', tuple([True, False]), index=1)
104
+
105
  choose_options_button = st.form_submit_button('Choose the defined options')
 
 
 
 
106
  if choose_options_button:
107
+ st.session_state.best = best
108
+ if st.session_state.best is False:
109
+ st.session_state.sign = sign
110
+ st.session_state.num_factors = num_factors
111
+ st.session_state.cl_method = cl_method
112
+ st.session_state.regularization = regularization
113
+ st.session_state.extremes = extremes
114
+
115
  # ---------------------------- SET UP OUTPUT ------------------------------
116
  epsilon_container = st.empty()
117
  st.header('Comparison')
 
121
  output_col_1, output_col_2 = st.columns([3,1])
122
 
123
  # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
124
+ if st.session_state.best:
125
+ tmp = concept_vectors[concept_vectors['color'].isin(st.session_state.concept_ids)].groupby('color').first().reset_index()
126
+ else:
127
+ tmp = concept_vectors[concept_vectors['color'].isin(st.session_state.concept_ids)]
128
+ tmp = tmp[tmp['sign'] == st.session_state.sign][tmp['extremes'] == st.session_state.extremes][tmp['num_factors'] == st.session_state.num_factors][tmp['cl_method'] == st.session_state.cl_method][tmp['regularization'] == st.session_state.regularization]
129
+
130
  info = tmp.loc[:, ['vector', 'score', 'color', 'kwargs']].values
131
  concept_ids = [i[2] for i in info] #+ ' ' + i[3]
132
 
133
  with header_col_1:
134
+ st.write('### Similarity graph')
135
 
136
  with header_col_2:
137
+ st.write('### Information')
138
 
139
  with output_col_2:
140
  for i,concept_id in enumerate(concept_ids):
141
+ st.write(f'''Color: {info[i][2]}.\
142
+ Settings: {info[i][3]}\
143
+ ''')
144
 
145
  with output_col_1:
 
146
  edges = []
147
  for i in range(len(concept_ids)):
148
  for j in range(len(concept_ids)):
pages/3_Vectors_algebra.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pickle
3
+ import pandas as pd
4
+ import numpy as np
5
+ import random
6
+ import torch
7
+
8
+ from matplotlib.backends.backend_agg import RendererAgg
9
+
10
+ from backend.disentangle_concepts import *
11
+ import torch_utils
12
+ import dnnlib
13
+ import legacy
14
+
15
+ _lock = RendererAgg.lock
16
+
17
+
18
+ st.set_page_config(layout='wide')
19
+ BACKGROUND_COLOR = '#bcd0e7'
20
+ SECONDARY_COLOR = '#bce7db'
21
+
22
+
23
+ st.title('Vector algebra using disentangled vectors')
24
+ st.markdown(
25
+ """
26
+ This page offers the possibility to edit the colors of a given textile image using vector algebra and projections.
27
+ It allows to select several colors to move towards and against (selecting a positive or negative lambda).
28
+ Furthermore, it offers the possibility of conditional manipulation, by moving in the direction of a color n1 without affecting the color n2.
29
+ This is done using a projected direction n1 - (n1.T n2) n2.
30
+ """,
31
+ unsafe_allow_html=False,)
32
+
33
+ annotations_file = './data/textile_annotated_files/seeds0000-100000_S.pkl'
34
+ with open(annotations_file, 'rb') as f:
35
+ annotations = pickle.load(f)
36
+
37
+ concept_vectors = pd.read_csv('./data/stored_vectors/scores_colors_hsv.csv')
38
+ concept_vectors['vector'] = [np.array([float(xx) for xx in x]) for x in concept_vectors['vector'].str.split(', ')]
39
+ concept_vectors['score'] = concept_vectors['score'].astype(float)
40
+
41
+ concept_vectors = concept_vectors.sort_values('score', ascending=False).reset_index()
42
+
43
+ with dnnlib.util.open_url('./data/textile_model_files/network-snapshot-005000.pkl') as f:
44
+ model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
45
+
46
+ COLORS_LIST = ['Gray', 'Red Orange', 'Yellow', 'Green', 'Light Blue', 'Blue', 'Purple', 'Pink', 'Saturation', 'Value']
47
+ COLORS_NEGATIVE = COLORS_LIST + ['None']
48
+
49
+ if 'image_id' not in st.session_state:
50
+ st.session_state.image_id = 52921
51
+ if 'colors' not in st.session_state:
52
+ st.session_state.colors = [COLORS_LIST[5], COLORS_LIST[7]]
53
+ if 'non_colors' not in st.session_state:
54
+ st.session_state.non_colors = ['None']
55
+ if 'color_lambda' not in st.session_state:
56
+ st.session_state.color_lambda = [5]
57
+
58
+ # ----------------------------- INPUT ----------------------------------
59
+ epsilon_container = st.empty()
60
+ st.header('Image Manipulation with Vector Algebra')
61
+
62
+ header_col_1, header_col_2, header_col_3, header_col_4 = st.columns([1,1,1,1])
63
+ input_col_1, output_col_2, output_col_3, input_col_4 = st.columns([1,1,1,1])
64
+
65
+ # --------------------------- INPUT column 1 ---------------------------
66
+ with input_col_1:
67
+ with st.form('image_form'):
68
+
69
+ # image_id = st.number_input('Image ID: ', format='%d', step=1)
70
+ st.write('**Choose or generate a random image**')
71
+ chosen_image_id_input = st.empty()
72
+ image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
73
+
74
+ choose_image_button = st.form_submit_button('Choose the defined image')
75
+ random_id = st.form_submit_button('Generate a random image')
76
+
77
+ if random_id:
78
+ image_id = random.randint(0, 100000)
79
+ st.session_state.image_id = image_id
80
+ chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
81
+
82
+ if choose_image_button:
83
+ image_id = int(image_id)
84
+ st.session_state.image_id = image_id
85
+
86
+ with header_col_1:
87
+ st.write('### Input image selection')
88
+
89
+ original_image_vec = annotations['w_vectors'][st.session_state.image_id]
90
+ img = generate_original_image(original_image_vec, model)
91
+
92
+ with output_col_2:
93
+ st.image(img)
94
+
95
+ with header_col_2:
96
+ st.write('### Original image')
97
+
98
+ with input_col_4:
99
+ with st.form('text_form_1'):
100
+
101
+ st.write('**Colors to vary (including Saturation and Value)**')
102
+ colors = st.multiselect('Color:', tuple(COLORS_LIST), default=[COLORS_LIST[5], COLORS_LIST[7]])
103
+ colors_button = st.form_submit_button('Choose the defined colors')
104
+
105
+ st.session_state.image_id = image_id
106
+ st.session_state.colors = colors
107
+ st.session_state.color_lambda = [5]*len(colors)
108
+ st.session_state.non_colors = ['None']*len(colors)
109
+
110
+ lambdas = []
111
+ negative_cols = []
112
+ for color in colors:
113
+ st.write('### '+color )
114
+ st.write('**Set range of change (can be negative)**')
115
+ chosen_color_lambda_input = st.empty()
116
+ color_lambda = chosen_color_lambda_input.number_input('Lambda:', min_value=-100, step=1, value=5, key=color+'_number')
117
+ lambdas.append(color_lambda)
118
+
119
+ st.write('**Set dimensions of change to not consider**')
120
+ chosen_color_negative_input = st.empty()
121
+ color_negative = chosen_color_negative_input.selectbox('Color:', tuple(COLORS_NEGATIVE), index=len(COLORS_NEGATIVE)-1, key=color+'_noncolor')
122
+ negative_cols.append(color_negative)
123
+
124
+ lambdas_button = st.form_submit_button('Submit options')
125
+ if lambdas_button:
126
+ st.session_state.color_lambda = lambdas
127
+ st.session_state.non_colors = negative_cols
128
+
129
+
130
+ with header_col_4:
131
+ st.write('### Color settings')
132
+ # print(st.session_state.colors)
133
+ # print(st.session_state.color_lambda)
134
+ # print(st.session_state.non_colors)
135
+
136
+ # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
137
+
138
+ with header_col_3:
139
+ separation_vectors = []
140
+ for col in st.session_state.colors:
141
+ separation_vector, score_1 = concept_vectors[concept_vectors['color'] == col].reset_index().loc[0, ['vector', 'score']]
142
+ separation_vectors.append(separation_vector)
143
+
144
+ negative_separation_vectors = []
145
+ for non_col in st.session_state.non_colors:
146
+ if non_col != 'None':
147
+ negative_separation_vector, score_2 = concept_vectors[concept_vectors['color'] == non_col].reset_index().loc[0, ['vector', 'score']]
148
+ negative_separation_vectors.append(negative_separation_vector)
149
+ else:
150
+ negative_separation_vectors.append('None')
151
+ ## n1 − (n1T n2)n2
152
+ # print(negative_separation_vectors, separation_vectors)
153
+ st.write('### Output Image')
154
+ st.write(f'''Change in colors: {str(st.session_state.colors)},\
155
+ without affecting colors {str(st.session_state.non_colors)}''')
156
+
157
+ # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
158
+
159
+ with output_col_3:
160
+ image_updated = generate_composite_images(model, original_image_vec, separation_vectors,
161
+ lambdas=st.session_state.color_lambda,
162
+ negative_colors=negative_separation_vectors)
163
+ st.image(image_updated)
pyproject.toml ADDED
File without changes
structure_annotations.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
test-docker.sh ADDED
@@ -0,0 +1,743 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/sh
2
+ set -e
3
+ # Docker Engine for Linux installation script.
4
+ #
5
+ # This script is intended as a convenient way to configure docker's package
6
+ # repositories and to install Docker Engine, This script is not recommended
7
+ # for production environments. Before running this script, make yourself familiar
8
+ # with potential risks and limitations, and refer to the installation manual
9
+ # at https://docs.docker.com/engine/install/ for alternative installation methods.
10
+ #
11
+ # The script:
12
+ #
13
+ # - Requires `root` or `sudo` privileges to run.
14
+ # - Attempts to detect your Linux distribution and version and configure your
15
+ # package management system for you.
16
+ # - Doesn't allow you to customize most installation parameters.
17
+ # - Installs dependencies and recommendations without asking for confirmation.
18
+ # - Installs the latest stable release (by default) of Docker CLI, Docker Engine,
19
+ # Docker Buildx, Docker Compose, containerd, and runc. When using this script
20
+ # to provision a machine, this may result in unexpected major version upgrades
21
+ # of these packages. Always test upgrades in a test environment before
22
+ # deploying to your production systems.
23
+ # - Isn't designed to upgrade an existing Docker installation. When using the
24
+ # script to update an existing installation, dependencies may not be updated
25
+ # to the expected version, resulting in outdated versions.
26
+ #
27
+ # Source code is available at https://github.com/docker/docker-install/
28
+ #
29
+ # Usage
30
+ # ==============================================================================
31
+ #
32
+ # To install the latest stable versions of Docker CLI, Docker Engine, and their
33
+ # dependencies:
34
+ #
35
+ # 1. download the script
36
+ #
37
+ # $ curl -fsSL https://get.docker.com -o install-docker.sh
38
+ #
39
+ # 2. verify the script's content
40
+ #
41
+ # $ cat install-docker.sh
42
+ #
43
+ # 3. run the script with --dry-run to verify the steps it executes
44
+ #
45
+ # $ sh install-docker.sh --dry-run
46
+ #
47
+ # 4. run the script either as root, or using sudo to perform the installation.
48
+ #
49
+ # $ sudo sh install-docker.sh
50
+ #
51
+ # Command-line options
52
+ # ==============================================================================
53
+ #
54
+ # --version <VERSION>
55
+ # Use the --version option to install a specific version, for example:
56
+ #
57
+ # $ sudo sh install-docker.sh --version 23.0
58
+ #
59
+ # --channel <stable|test>
60
+ #
61
+ # Use the --channel option to install from an alternative installation channel.
62
+ # The following example installs the latest versions from the "test" channel,
63
+ # which includes pre-releases (alpha, beta, rc):
64
+ #
65
+ # $ sudo sh install-docker.sh --channel test
66
+ #
67
+ # Alternatively, use the script at https://test.docker.com, which uses the test
68
+ # channel as default.
69
+ #
70
+ # --mirror <Aliyun|AzureChinaCloud>
71
+ #
72
+ # Use the --mirror option to install from a mirror supported by this script.
73
+ # Available mirrors are "Aliyun" (https://mirrors.aliyun.com/docker-ce), and
74
+ # "AzureChinaCloud" (https://mirror.azure.cn/docker-ce), for example:
75
+ #
76
+ # $ sudo sh install-docker.sh --mirror AzureChinaCloud
77
+ #
78
+ # ==============================================================================
79
+
80
+
81
+ # Git commit from https://github.com/docker/docker-install when
82
+ # the script was uploaded (Should only be modified by upload job):
83
+ SCRIPT_COMMIT_SHA="e5543d473431b782227f8908005543bb4389b8de"
84
+
85
+ # strip "v" prefix if present
86
+ VERSION="${VERSION#v}"
87
+
88
+ # The channel to install from:
89
+ # * stable
90
+ # * test
91
+ # * edge (deprecated)
92
+ # * nightly (deprecated)
93
+ DEFAULT_CHANNEL_VALUE="test"
94
+ if [ -z "$CHANNEL" ]; then
95
+ CHANNEL=$DEFAULT_CHANNEL_VALUE
96
+ fi
97
+
98
+ DEFAULT_DOWNLOAD_URL="https://download.docker.com"
99
+ if [ -z "$DOWNLOAD_URL" ]; then
100
+ DOWNLOAD_URL=$DEFAULT_DOWNLOAD_URL
101
+ fi
102
+
103
+ DEFAULT_REPO_FILE="docker-ce.repo"
104
+ if [ -z "$REPO_FILE" ]; then
105
+ REPO_FILE="$DEFAULT_REPO_FILE"
106
+ fi
107
+
108
+ mirror=''
109
+ DRY_RUN=${DRY_RUN:-}
110
+ while [ $# -gt 0 ]; do
111
+ case "$1" in
112
+ --channel)
113
+ CHANNEL="$2"
114
+ shift
115
+ ;;
116
+ --dry-run)
117
+ DRY_RUN=1
118
+ ;;
119
+ --mirror)
120
+ mirror="$2"
121
+ shift
122
+ ;;
123
+ --version)
124
+ VERSION="${2#v}"
125
+ shift
126
+ ;;
127
+ --*)
128
+ echo "Illegal option $1"
129
+ ;;
130
+ esac
131
+ shift $(( $# > 0 ? 1 : 0 ))
132
+ done
133
+
134
+ case "$mirror" in
135
+ Aliyun)
136
+ DOWNLOAD_URL="https://mirrors.aliyun.com/docker-ce"
137
+ ;;
138
+ AzureChinaCloud)
139
+ DOWNLOAD_URL="https://mirror.azure.cn/docker-ce"
140
+ ;;
141
+ "")
142
+ ;;
143
+ *)
144
+ >&2 echo "unknown mirror '$mirror': use either 'Aliyun', or 'AzureChinaCloud'."
145
+ exit 1
146
+ ;;
147
+ esac
148
+
149
+ case "$CHANNEL" in
150
+ stable|test)
151
+ ;;
152
+ edge|nightly)
153
+ >&2 echo "DEPRECATED: the $CHANNEL channel has been deprecated and is no longer supported by this script."
154
+ exit 1
155
+ ;;
156
+ *)
157
+ >&2 echo "unknown CHANNEL '$CHANNEL': use either stable or test."
158
+ exit 1
159
+ ;;
160
+ esac
161
+
162
+ command_exists() {
163
+ command -v "$@" > /dev/null 2>&1
164
+ }
165
+
166
+ # version_gte checks if the version specified in $VERSION is at least the given
167
+ # SemVer (Maj.Minor[.Patch]), or CalVer (YY.MM) version.It returns 0 (success)
168
+ # if $VERSION is either unset (=latest) or newer or equal than the specified
169
+ # version, or returns 1 (fail) otherwise.
170
+ #
171
+ # examples:
172
+ #
173
+ # VERSION=23.0
174
+ # version_gte 23.0 // 0 (success)
175
+ # version_gte 20.10 // 0 (success)
176
+ # version_gte 19.03 // 0 (success)
177
+ # version_gte 21.10 // 1 (fail)
178
+ version_gte() {
179
+ if [ -z "$VERSION" ]; then
180
+ return 0
181
+ fi
182
+ eval version_compare "$VERSION" "$1"
183
+ }
184
+
185
+ # version_compare compares two version strings (either SemVer (Major.Minor.Path),
186
+ # or CalVer (YY.MM) version strings. It returns 0 (success) if version A is newer
187
+ # or equal than version B, or 1 (fail) otherwise. Patch releases and pre-release
188
+ # (-alpha/-beta) are not taken into account
189
+ #
190
+ # examples:
191
+ #
192
+ # version_compare 23.0.0 20.10 // 0 (success)
193
+ # version_compare 23.0 20.10 // 0 (success)
194
+ # version_compare 20.10 19.03 // 0 (success)
195
+ # version_compare 20.10 20.10 // 0 (success)
196
+ # version_compare 19.03 20.10 // 1 (fail)
197
+ version_compare() (
198
+ set +x
199
+
200
+ yy_a="$(echo "$1" | cut -d'.' -f1)"
201
+ yy_b="$(echo "$2" | cut -d'.' -f1)"
202
+ if [ "$yy_a" -lt "$yy_b" ]; then
203
+ return 1
204
+ fi
205
+ if [ "$yy_a" -gt "$yy_b" ]; then
206
+ return 0
207
+ fi
208
+ mm_a="$(echo "$1" | cut -d'.' -f2)"
209
+ mm_b="$(echo "$2" | cut -d'.' -f2)"
210
+
211
+ # trim leading zeros to accommodate CalVer
212
+ mm_a="${mm_a#0}"
213
+ mm_b="${mm_b#0}"
214
+
215
+ if [ "${mm_a:-0}" -lt "${mm_b:-0}" ]; then
216
+ return 1
217
+ fi
218
+
219
+ return 0
220
+ )
221
+
222
+ is_dry_run() {
223
+ if [ -z "$DRY_RUN" ]; then
224
+ return 1
225
+ else
226
+ return 0
227
+ fi
228
+ }
229
+
230
+ is_wsl() {
231
+ case "$(uname -r)" in
232
+ *microsoft* ) true ;; # WSL 2
233
+ *Microsoft* ) true ;; # WSL 1
234
+ * ) false;;
235
+ esac
236
+ }
237
+
238
+ is_darwin() {
239
+ case "$(uname -s)" in
240
+ *darwin* ) true ;;
241
+ *Darwin* ) true ;;
242
+ * ) false;;
243
+ esac
244
+ }
245
+
246
+ deprecation_notice() {
247
+ distro=$1
248
+ distro_version=$2
249
+ echo
250
+ printf "\033[91;1mDEPRECATION WARNING\033[0m\n"
251
+ printf " This Linux distribution (\033[1m%s %s\033[0m) reached end-of-life and is no longer supported by this script.\n" "$distro" "$distro_version"
252
+ echo " No updates or security fixes will be released for this distribution, and users are recommended"
253
+ echo " to upgrade to a currently maintained version of $distro."
254
+ echo
255
+ printf "Press \033[1mCtrl+C\033[0m now to abort this script, or wait for the installation to continue."
256
+ echo
257
+ sleep 10
258
+ }
259
+
260
+ get_distribution() {
261
+ lsb_dist=""
262
+ # Every system that we officially support has /etc/os-release
263
+ if [ -r /etc/os-release ]; then
264
+ lsb_dist="$(. /etc/os-release && echo "$ID")"
265
+ fi
266
+ # Returning an empty string here should be alright since the
267
+ # case statements don't act unless you provide an actual value
268
+ echo "$lsb_dist"
269
+ }
270
+
271
+ echo_docker_as_nonroot() {
272
+ if is_dry_run; then
273
+ return
274
+ fi
275
+ if command_exists docker && [ -e /var/run/docker.sock ]; then
276
+ (
277
+ set -x
278
+ $sh_c 'docker version'
279
+ ) || true
280
+ fi
281
+
282
+ # intentionally mixed spaces and tabs here -- tabs are stripped by "<<-EOF", spaces are kept in the output
283
+ echo
284
+ echo "================================================================================"
285
+ echo
286
+ if version_gte "20.10"; then
287
+ echo "To run Docker as a non-privileged user, consider setting up the"
288
+ echo "Docker daemon in rootless mode for your user:"
289
+ echo
290
+ echo " dockerd-rootless-setuptool.sh install"
291
+ echo
292
+ echo "Visit https://docs.docker.com/go/rootless/ to learn about rootless mode."
293
+ echo
294
+ fi
295
+ echo
296
+ echo "To run the Docker daemon as a fully privileged service, but granting non-root"
297
+ echo "users access, refer to https://docs.docker.com/go/daemon-access/"
298
+ echo
299
+ echo "WARNING: Access to the remote API on a privileged Docker daemon is equivalent"
300
+ echo " to root access on the host. Refer to the 'Docker daemon attack surface'"
301
+ echo " documentation for details: https://docs.docker.com/go/attack-surface/"
302
+ echo
303
+ echo "================================================================================"
304
+ echo
305
+ }
306
+
307
+ # Check if this is a forked Linux distro
308
+ check_forked() {
309
+
310
+ # Check for lsb_release command existence, it usually exists in forked distros
311
+ if command_exists lsb_release; then
312
+ # Check if the `-u` option is supported
313
+ set +e
314
+ lsb_release -a -u > /dev/null 2>&1
315
+ lsb_release_exit_code=$?
316
+ set -e
317
+
318
+ # Check if the command has exited successfully, it means we're in a forked distro
319
+ if [ "$lsb_release_exit_code" = "0" ]; then
320
+ # Print info about current distro
321
+ cat <<-EOF
322
+ You're using '$lsb_dist' version '$dist_version'.
323
+ EOF
324
+
325
+ # Get the upstream release info
326
+ lsb_dist=$(lsb_release -a -u 2>&1 | tr '[:upper:]' '[:lower:]' | grep -E 'id' | cut -d ':' -f 2 | tr -d '[:space:]')
327
+ dist_version=$(lsb_release -a -u 2>&1 | tr '[:upper:]' '[:lower:]' | grep -E 'codename' | cut -d ':' -f 2 | tr -d '[:space:]')
328
+
329
+ # Print info about upstream distro
330
+ cat <<-EOF
331
+ Upstream release is '$lsb_dist' version '$dist_version'.
332
+ EOF
333
+ else
334
+ if [ -r /etc/debian_version ] && [ "$lsb_dist" != "ubuntu" ] && [ "$lsb_dist" != "raspbian" ]; then
335
+ if [ "$lsb_dist" = "osmc" ]; then
336
+ # OSMC runs Raspbian
337
+ lsb_dist=raspbian
338
+ else
339
+ # We're Debian and don't even know it!
340
+ lsb_dist=debian
341
+ fi
342
+ dist_version="$(sed 's/\/.*//' /etc/debian_version | sed 's/\..*//')"
343
+ case "$dist_version" in
344
+ 12)
345
+ dist_version="bookworm"
346
+ ;;
347
+ 11)
348
+ dist_version="bullseye"
349
+ ;;
350
+ 10)
351
+ dist_version="buster"
352
+ ;;
353
+ 9)
354
+ dist_version="stretch"
355
+ ;;
356
+ 8)
357
+ dist_version="jessie"
358
+ ;;
359
+ esac
360
+ fi
361
+ fi
362
+ fi
363
+ }
364
+
365
+ do_install() {
366
+ echo "# Executing docker install script, commit: $SCRIPT_COMMIT_SHA"
367
+
368
+ if command_exists docker; then
369
+ cat >&2 <<-'EOF'
370
+ Warning: the "docker" command appears to already exist on this system.
371
+
372
+ If you already have Docker installed, this script can cause trouble, which is
373
+ why we're displaying this warning and provide the opportunity to cancel the
374
+ installation.
375
+
376
+ If you installed the current Docker package using this script and are using it
377
+ again to update Docker, you can safely ignore this message.
378
+
379
+ You may press Ctrl+C now to abort this script.
380
+ EOF
381
+ ( set -x; sleep 20 )
382
+ fi
383
+
384
+ user="$(id -un 2>/dev/null || true)"
385
+
386
+ sh_c='sh -c'
387
+ if [ "$user" != 'root' ]; then
388
+ if command_exists sudo; then
389
+ sh_c='sudo -E sh -c'
390
+ elif command_exists su; then
391
+ sh_c='su -c'
392
+ else
393
+ cat >&2 <<-'EOF'
394
+ Error: this installer needs the ability to run commands as root.
395
+ We are unable to find either "sudo" or "su" available to make this happen.
396
+ EOF
397
+ exit 1
398
+ fi
399
+ fi
400
+
401
+ if is_dry_run; then
402
+ sh_c="echo"
403
+ fi
404
+
405
+ # perform some very rudimentary platform detection
406
+ lsb_dist=$( get_distribution )
407
+ lsb_dist="$(echo "$lsb_dist" | tr '[:upper:]' '[:lower:]')"
408
+
409
+ if is_wsl; then
410
+ echo
411
+ echo "WSL DETECTED: We recommend using Docker Desktop for Windows."
412
+ echo "Please get Docker Desktop from https://www.docker.com/products/docker-desktop/"
413
+ echo
414
+ cat >&2 <<-'EOF'
415
+
416
+ You may press Ctrl+C now to abort this script.
417
+ EOF
418
+ ( set -x; sleep 20 )
419
+ fi
420
+
421
+ case "$lsb_dist" in
422
+
423
+ ubuntu)
424
+ if command_exists lsb_release; then
425
+ dist_version="$(lsb_release --codename | cut -f2)"
426
+ fi
427
+ if [ -z "$dist_version" ] && [ -r /etc/lsb-release ]; then
428
+ dist_version="$(. /etc/lsb-release && echo "$DISTRIB_CODENAME")"
429
+ fi
430
+ ;;
431
+
432
+ debian|raspbian)
433
+ dist_version="$(sed 's/\/.*//' /etc/debian_version | sed 's/\..*//')"
434
+ case "$dist_version" in
435
+ 12)
436
+ dist_version="bookworm"
437
+ ;;
438
+ 11)
439
+ dist_version="bullseye"
440
+ ;;
441
+ 10)
442
+ dist_version="buster"
443
+ ;;
444
+ 9)
445
+ dist_version="stretch"
446
+ ;;
447
+ 8)
448
+ dist_version="jessie"
449
+ ;;
450
+ esac
451
+ ;;
452
+
453
+ centos|rhel|sles)
454
+ if [ -z "$dist_version" ] && [ -r /etc/os-release ]; then
455
+ dist_version="$(. /etc/os-release && echo "$VERSION_ID")"
456
+ fi
457
+ ;;
458
+
459
+ *)
460
+ if command_exists lsb_release; then
461
+ dist_version="$(lsb_release --release | cut -f2)"
462
+ fi
463
+ if [ -z "$dist_version" ] && [ -r /etc/os-release ]; then
464
+ dist_version="$(. /etc/os-release && echo "$VERSION_ID")"
465
+ fi
466
+ ;;
467
+
468
+ esac
469
+
470
+ # Check if this is a forked Linux distro
471
+ check_forked
472
+
473
+ # Print deprecation warnings for distro versions that recently reached EOL,
474
+ # but may still be commonly used (especially LTS versions).
475
+ case "$lsb_dist.$dist_version" in
476
+ debian.stretch|debian.jessie)
477
+ deprecation_notice "$lsb_dist" "$dist_version"
478
+ ;;
479
+ raspbian.stretch|raspbian.jessie)
480
+ deprecation_notice "$lsb_dist" "$dist_version"
481
+ ;;
482
+ ubuntu.xenial|ubuntu.trusty)
483
+ deprecation_notice "$lsb_dist" "$dist_version"
484
+ ;;
485
+ ubuntu.impish|ubuntu.hirsute|ubuntu.groovy|ubuntu.eoan|ubuntu.disco|ubuntu.cosmic)
486
+ deprecation_notice "$lsb_dist" "$dist_version"
487
+ ;;
488
+ fedora.*)
489
+ if [ "$dist_version" -lt 36 ]; then
490
+ deprecation_notice "$lsb_dist" "$dist_version"
491
+ fi
492
+ ;;
493
+ esac
494
+
495
+ # Run setup for each distro accordingly
496
+ case "$lsb_dist" in
497
+ ubuntu|debian|raspbian)
498
+ pre_reqs="apt-transport-https ca-certificates curl"
499
+ if ! command -v gpg > /dev/null; then
500
+ pre_reqs="$pre_reqs gnupg"
501
+ fi
502
+ apt_repo="deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.gpg] $DOWNLOAD_URL/linux/$lsb_dist $dist_version $CHANNEL"
503
+ (
504
+ if ! is_dry_run; then
505
+ set -x
506
+ fi
507
+ $sh_c 'apt-get update -qq >/dev/null'
508
+ $sh_c "DEBIAN_FRONTEND=noninteractive apt-get install -y -qq $pre_reqs >/dev/null"
509
+ $sh_c 'install -m 0755 -d /etc/apt/keyrings'
510
+ $sh_c "curl -fsSL \"$DOWNLOAD_URL/linux/$lsb_dist/gpg\" | gpg --dearmor --yes -o /etc/apt/keyrings/docker.gpg"
511
+ $sh_c "chmod a+r /etc/apt/keyrings/docker.gpg"
512
+ $sh_c "echo \"$apt_repo\" > /etc/apt/sources.list.d/docker.list"
513
+ $sh_c 'apt-get update -qq >/dev/null'
514
+ )
515
+ pkg_version=""
516
+ if [ -n "$VERSION" ]; then
517
+ if is_dry_run; then
518
+ echo "# WARNING: VERSION pinning is not supported in DRY_RUN"
519
+ else
520
+ # Will work for incomplete versions IE (17.12), but may not actually grab the "latest" if in the test channel
521
+ pkg_pattern="$(echo "$VERSION" | sed 's/-ce-/~ce~.*/g' | sed 's/-/.*/g')"
522
+ search_command="apt-cache madison docker-ce | grep '$pkg_pattern' | head -1 | awk '{\$1=\$1};1' | cut -d' ' -f 3"
523
+ pkg_version="$($sh_c "$search_command")"
524
+ echo "INFO: Searching repository for VERSION '$VERSION'"
525
+ echo "INFO: $search_command"
526
+ if [ -z "$pkg_version" ]; then
527
+ echo
528
+ echo "ERROR: '$VERSION' not found amongst apt-cache madison results"
529
+ echo
530
+ exit 1
531
+ fi
532
+ if version_gte "18.09"; then
533
+ search_command="apt-cache madison docker-ce-cli | grep '$pkg_pattern' | head -1 | awk '{\$1=\$1};1' | cut -d' ' -f 3"
534
+ echo "INFO: $search_command"
535
+ cli_pkg_version="=$($sh_c "$search_command")"
536
+ fi
537
+ pkg_version="=$pkg_version"
538
+ fi
539
+ fi
540
+ (
541
+ pkgs="docker-ce${pkg_version%=}"
542
+ if version_gte "18.09"; then
543
+ # older versions didn't ship the cli and containerd as separate packages
544
+ pkgs="$pkgs docker-ce-cli${cli_pkg_version%=} containerd.io"
545
+ fi
546
+ if version_gte "20.10"; then
547
+ pkgs="$pkgs docker-compose-plugin docker-ce-rootless-extras$pkg_version"
548
+ fi
549
+ if version_gte "23.0"; then
550
+ pkgs="$pkgs docker-buildx-plugin"
551
+ fi
552
+ if ! is_dry_run; then
553
+ set -x
554
+ fi
555
+ $sh_c "DEBIAN_FRONTEND=noninteractive apt-get install -y -qq $pkgs >/dev/null"
556
+ )
557
+ echo_docker_as_nonroot
558
+ exit 0
559
+ ;;
560
+ centos|fedora|rhel)
561
+ if [ "$(uname -m)" != "s390x" ] && [ "$lsb_dist" = "rhel" ]; then
562
+ echo "Packages for RHEL are currently only available for s390x."
563
+ exit 1
564
+ fi
565
+ if [ "$lsb_dist" = "fedora" ]; then
566
+ pkg_manager="dnf"
567
+ config_manager="dnf config-manager"
568
+ enable_channel_flag="--set-enabled"
569
+ disable_channel_flag="--set-disabled"
570
+ pre_reqs="dnf-plugins-core"
571
+ pkg_suffix="fc$dist_version"
572
+ else
573
+ pkg_manager="yum"
574
+ config_manager="yum-config-manager"
575
+ enable_channel_flag="--enable"
576
+ disable_channel_flag="--disable"
577
+ pre_reqs="yum-utils"
578
+ pkg_suffix="el"
579
+ fi
580
+ repo_file_url="$DOWNLOAD_URL/linux/$lsb_dist/$REPO_FILE"
581
+ (
582
+ if ! is_dry_run; then
583
+ set -x
584
+ fi
585
+ $sh_c "$pkg_manager install -y -q $pre_reqs"
586
+ $sh_c "$config_manager --add-repo $repo_file_url"
587
+
588
+ if [ "$CHANNEL" != "stable" ]; then
589
+ $sh_c "$config_manager $disable_channel_flag 'docker-ce-*'"
590
+ $sh_c "$config_manager $enable_channel_flag 'docker-ce-$CHANNEL'"
591
+ fi
592
+ $sh_c "$pkg_manager makecache"
593
+ )
594
+ pkg_version=""
595
+ if [ -n "$VERSION" ]; then
596
+ if is_dry_run; then
597
+ echo "# WARNING: VERSION pinning is not supported in DRY_RUN"
598
+ else
599
+ pkg_pattern="$(echo "$VERSION" | sed 's/-ce-/\\\\.ce.*/g' | sed 's/-/.*/g').*$pkg_suffix"
600
+ search_command="$pkg_manager list --showduplicates docker-ce | grep '$pkg_pattern' | tail -1 | awk '{print \$2}'"
601
+ pkg_version="$($sh_c "$search_command")"
602
+ echo "INFO: Searching repository for VERSION '$VERSION'"
603
+ echo "INFO: $search_command"
604
+ if [ -z "$pkg_version" ]; then
605
+ echo
606
+ echo "ERROR: '$VERSION' not found amongst $pkg_manager list results"
607
+ echo
608
+ exit 1
609
+ fi
610
+ if version_gte "18.09"; then
611
+ # older versions don't support a cli package
612
+ search_command="$pkg_manager list --showduplicates docker-ce-cli | grep '$pkg_pattern' | tail -1 | awk '{print \$2}'"
613
+ cli_pkg_version="$($sh_c "$search_command" | cut -d':' -f 2)"
614
+ fi
615
+ # Cut out the epoch and prefix with a '-'
616
+ pkg_version="-$(echo "$pkg_version" | cut -d':' -f 2)"
617
+ fi
618
+ fi
619
+ (
620
+ pkgs="docker-ce$pkg_version"
621
+ if version_gte "18.09"; then
622
+ # older versions didn't ship the cli and containerd as separate packages
623
+ if [ -n "$cli_pkg_version" ]; then
624
+ pkgs="$pkgs docker-ce-cli-$cli_pkg_version containerd.io"
625
+ else
626
+ pkgs="$pkgs docker-ce-cli containerd.io"
627
+ fi
628
+ fi
629
+ if version_gte "20.10"; then
630
+ pkgs="$pkgs docker-compose-plugin docker-ce-rootless-extras$pkg_version"
631
+ fi
632
+ if version_gte "23.0"; then
633
+ pkgs="$pkgs docker-buildx-plugin"
634
+ fi
635
+ if ! is_dry_run; then
636
+ set -x
637
+ fi
638
+ $sh_c "$pkg_manager install -y -q $pkgs"
639
+ )
640
+ echo_docker_as_nonroot
641
+ exit 0
642
+ ;;
643
+ sles)
644
+ if [ "$(uname -m)" != "s390x" ]; then
645
+ echo "Packages for SLES are currently only available for s390x"
646
+ exit 1
647
+ fi
648
+ if [ "$dist_version" = "15.3" ]; then
649
+ sles_version="SLE_15_SP3"
650
+ else
651
+ sles_minor_version="${dist_version##*.}"
652
+ sles_version="15.$sles_minor_version"
653
+ fi
654
+ repo_file_url="$DOWNLOAD_URL/linux/$lsb_dist/$REPO_FILE"
655
+ pre_reqs="ca-certificates curl libseccomp2 awk"
656
+ (
657
+ if ! is_dry_run; then
658
+ set -x
659
+ fi
660
+ $sh_c "zypper install -y $pre_reqs"
661
+ $sh_c "zypper addrepo $repo_file_url"
662
+ if ! is_dry_run; then
663
+ cat >&2 <<-'EOF'
664
+ WARNING!!
665
+ openSUSE repository (https://download.opensuse.org/repositories/security:SELinux) will be enabled now.
666
+ Do you wish to continue?
667
+ You may press Ctrl+C now to abort this script.
668
+ EOF
669
+ ( set -x; sleep 30 )
670
+ fi
671
+ opensuse_repo="https://download.opensuse.org/repositories/security:SELinux/$sles_version/security:SELinux.repo"
672
+ $sh_c "zypper addrepo $opensuse_repo"
673
+ $sh_c "zypper --gpg-auto-import-keys refresh"
674
+ $sh_c "zypper lr -d"
675
+ )
676
+ pkg_version=""
677
+ if [ -n "$VERSION" ]; then
678
+ if is_dry_run; then
679
+ echo "# WARNING: VERSION pinning is not supported in DRY_RUN"
680
+ else
681
+ pkg_pattern="$(echo "$VERSION" | sed 's/-ce-/\\\\.ce.*/g' | sed 's/-/.*/g')"
682
+ search_command="zypper search -s --match-exact 'docker-ce' | grep '$pkg_pattern' | tail -1 | awk '{print \$6}'"
683
+ pkg_version="$($sh_c "$search_command")"
684
+ echo "INFO: Searching repository for VERSION '$VERSION'"
685
+ echo "INFO: $search_command"
686
+ if [ -z "$pkg_version" ]; then
687
+ echo
688
+ echo "ERROR: '$VERSION' not found amongst zypper list results"
689
+ echo
690
+ exit 1
691
+ fi
692
+ search_command="zypper search -s --match-exact 'docker-ce-cli' | grep '$pkg_pattern' | tail -1 | awk '{print \$6}'"
693
+ # It's okay for cli_pkg_version to be blank, since older versions don't support a cli package
694
+ cli_pkg_version="$($sh_c "$search_command")"
695
+ pkg_version="-$pkg_version"
696
+ fi
697
+ fi
698
+ (
699
+ pkgs="docker-ce$pkg_version"
700
+ if version_gte "18.09"; then
701
+ if [ -n "$cli_pkg_version" ]; then
702
+ # older versions didn't ship the cli and containerd as separate packages
703
+ pkgs="$pkgs docker-ce-cli-$cli_pkg_version containerd.io"
704
+ else
705
+ pkgs="$pkgs docker-ce-cli containerd.io"
706
+ fi
707
+ fi
708
+ if version_gte "20.10"; then
709
+ pkgs="$pkgs docker-compose-plugin docker-ce-rootless-extras$pkg_version"
710
+ fi
711
+ if version_gte "23.0"; then
712
+ pkgs="$pkgs docker-buildx-plugin"
713
+ fi
714
+ if ! is_dry_run; then
715
+ set -x
716
+ fi
717
+ $sh_c "zypper -q install -y $pkgs"
718
+ )
719
+ echo_docker_as_nonroot
720
+ exit 0
721
+ ;;
722
+ *)
723
+ if [ -z "$lsb_dist" ]; then
724
+ if is_darwin; then
725
+ echo
726
+ echo "ERROR: Unsupported operating system 'macOS'"
727
+ echo "Please get Docker Desktop from https://www.docker.com/products/docker-desktop"
728
+ echo
729
+ exit 1
730
+ fi
731
+ fi
732
+ echo
733
+ echo "ERROR: Unsupported distribution '$lsb_dist'"
734
+ echo
735
+ exit 1
736
+ ;;
737
+ esac
738
+ exit 1
739
+ }
740
+
741
+ # wrapped up in a function so that we have some protection against only getting
742
+ # half the file during "curl | sh"
743
+ do_install