ludusc commited on
Commit
973a4da
1 Parent(s): aeb8017

created global cleaned file

Browse files
.gitignore CHANGED
@@ -184,4 +184,5 @@ dmypy.json
184
  cython_debug/
185
 
186
  data/images/
187
- tmp/
 
 
184
  cython_debug/
185
 
186
  data/images/
187
+ tmp/
188
+ figures/
DisentanglementBase.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+
4
+ from sklearn.svm import SVC
5
+ from sklearn.decomposition import PCA
6
+ from sklearn.linear_model import LogisticRegression
7
+ from sklearn.model_selection import train_test_split
8
+
9
+ from tqdm import tqdm
10
+ import random
11
+ from os.path import join
12
+ import os
13
+ import pickle
14
+
15
+ import torch
16
+
17
+ import matplotlib.pyplot as plt
18
+ import PIL
19
+ from PIL import Image, ImageColor
20
+
21
+ import sys
22
+ sys.path.append('backend')
23
+ from color_annotations import extract_color
24
+ from networks_stylegan3 import *
25
+ sys.path.append('.')
26
+
27
+ import dnnlib
28
+ import legacy
29
+
30
+ class DisentanglementBase:
31
+ def __init__(self, repo_folder, model, annotations, df, space, colors_list, compute_s):
32
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
33
+ print('Using device', self.device)
34
+ self.repo_folder = repo_folder
35
+ self.model = model.to(self.device)
36
+ self.annotations = annotations
37
+ self.df = df
38
+ self.space = space
39
+
40
+ self.layers = ['input', 'L0_36_512', 'L1_36_512', 'L2_36_512', 'L3_52_512',
41
+ 'L4_52_512', 'L5_84_512', 'L6_84_512', 'L7_148_512', 'L8_148_512',
42
+ 'L9_148_362', 'L10_276_256', 'L11_276_181', 'L12_276_128',
43
+ 'L13_256_128', 'L14_256_3']
44
+ self.layers_shapes = [4, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 362, 256, 181, 128, 128]
45
+ self.decoding_layers = 16
46
+ self.colors_list = colors_list
47
+
48
+ self.to_hsv()
49
+ if compute_s:
50
+ self.get_s_space()
51
+
52
+
53
+ def to_hsv(self):
54
+ """
55
+ The tohsv function takes the top 3 colors of each image and converts them to HSV values.
56
+ It then adds these values as new columns in the dataframe.
57
+
58
+ :param self: Allow the function to access the dataframe
59
+ :return: The dataframe with the new columns added
60
+ :doc-author: Trelent
61
+ """
62
+ print('Adding HSV encoding')
63
+ self.df['H1'] = self.df['top1col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[0])
64
+ self.df['H2'] = self.df['top2col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[0])
65
+ self.df['H3'] = self.df['top3col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[0])
66
+
67
+ self.df['S1'] = self.df['top1col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[1])
68
+ self.df['S2'] = self.df['top2col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[1])
69
+ self.df['S3'] = self.df['top3col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[1])
70
+
71
+ self.df['V1'] = self.df['top1col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[2])
72
+ self.df['V2'] = self.df['top2col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[2])
73
+ self.df['V3'] = self.df['top3col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[2])
74
+
75
+ def get_s_space(self):
76
+ """
77
+ The get_s_space function takes the w_vectors from the annotations dictionary and uses them to generate s_vectors.
78
+ The s_space is a space of vectors that are generated by passing each w vector through each layer of the model.
79
+ This allows us to see how much information about a particular class is contained in different layers.
80
+
81
+ :param self: Bind the method to a class
82
+ :return: A list of lists of s vectors
83
+ :doc-author: Trelent
84
+ """
85
+ print('Getting S space from W')
86
+ ss = []
87
+ for w in tqdm(self.annotations['w_vectors']):
88
+ w_torch = torch.from_numpy(w).to(self.device)
89
+ W = w_torch.expand((16, -1)).unsqueeze(0)
90
+ s = []
91
+ for i,layer in enumerate(self.layers):
92
+ s.append(getattr(self.model.synthesis, layer).affine(W[0, i].unsqueeze(0)).numpy())
93
+
94
+ ss.append(s)
95
+ self.annotations['s_vectors'] = ss
96
+ annotations_file = join(self.repo_folder, 'data/textile_annotated_files/seeds0000-100000_S.pkl')
97
+ print('Storing s for future use here:', annotations_file)
98
+ with open(annotations_file, 'wb') as f:
99
+ pickle.dump(self.annotations, f)
100
+
101
+ def get_encoded_latent(self):
102
+ # ... (existing code for getX)
103
+ if self.space.lower() == 'w':
104
+ X = np.array(self.annotations['w_vectors']).reshape((len(self.annotations['w_vectors']), 512))
105
+ elif self.space.lower() == 'z':
106
+ X = np.array(self.annotations['z_vectors']).reshape((len(self.annotations['z_vectors']), 512))
107
+ elif self.space.lower() == 's':
108
+ concat_v = []
109
+ for i in range(len(self.annotations['w_vectors'])):
110
+ concat_v.append(np.concatenate(self.annotations['s_vectors'][i], axis=1))
111
+ X = np.array(concat_v)
112
+ X = X[:, 0, :]
113
+ else:
114
+ Exception("Sorry, option not available, select among Z, W, S")
115
+
116
+ print('Shape embedding:', X.shape)
117
+ return X
118
+
119
+ def get_train_val(self, var='H1', cat=True):
120
+ X = self.get_encoded_latent()
121
+ y = np.array(self.df[var].values)
122
+ if cat:
123
+ y_cat = pd.cut(y,
124
+ bins=[x*256/12 if x<12 else 256 for x in range(13)],
125
+ labels=self.colors_list
126
+ ).fillna('Warm Pink Red')
127
+ x_train, x_val, y_train, y_val = train_test_split(X, y_cat, test_size=0.2)
128
+ else:
129
+ x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
130
+ return x_train, x_val, y_train, y_val
131
+
132
+ def InterFaceGAN_separation_vector(self, method='LR', C=0.1):
133
+ """
134
+ Method from InterfaceGAN
135
+ The get_separation_space function takes in a type_bin, annotations, and df.
136
+ It then samples 100 of the most representative abstracts for that type_bin and 100 of the least representative abstracts for that type_bin.
137
+ It then trains an SVM or logistic regression model on these 200 samples to find a separation space between them.
138
+ The function returns this separation space as well as how many nodes are important in this separation space.
139
+
140
+ :param type_bin: Select the type of abstracts to be used for training
141
+ :param annotations: Access the z_vectors
142
+ :param df: Get the abstracts that are used for training
143
+ :param samples: Determine how many samples to take from the top and bottom of the distribution
144
+ :param method: Specify the classifier to use
145
+ :param C: Control the regularization strength
146
+ :return: The weights of the linear classifier
147
+ :doc-author: Trelent
148
+ """
149
+ x_train, x_val, y_train, y_val = self.get_train_val()
150
+
151
+ if method == 'SVM':
152
+ svc = SVC(gamma='auto', kernel='linear', random_state=0, C=C)
153
+ svc.fit(x_train, y_train)
154
+ print('Val performance SVM', np.round(svc.score(x_val, y_val), 2))
155
+ return svc.coef_ / np.linalg.norm(clf.coef_)
156
+ elif method == 'LR':
157
+ clf = LogisticRegression(random_state=0, C=C)
158
+ clf.fit(x_train, y_train)
159
+ print('Val performance logistic regression', np.round(clf.score(x_val, y_val), 2))
160
+ return clf.coef_ / np.linalg.norm(clf.coef_)
161
+
162
+ def get_original_position_latent(self, positive_idxs, negative_idxs):
163
+ # ... (existing code for get_original_pos)
164
+ separation_vectors = []
165
+ for i in range(len(self.colors_list)):
166
+ if self.space.lower() == 's':
167
+ current_idx = 0
168
+ vectors = []
169
+ for j, (leng, layer) in enumerate(zip(self.layers_shapes, self.layers)):
170
+ arr = np.zeros(leng)
171
+ for positive_idx in positive_idxs[i]:
172
+ if positive_idx >= current_idx and positive_idx < current_idx + leng:
173
+ arr[positive_idx - current_idx] = 1
174
+ for negative_idx in negative_idxs[i]:
175
+ if negative_idx >= current_idx and negative_idx < current_idx + leng:
176
+ arr[negative_idx - current_idx] = 1
177
+ arr = arr / (np.linalg.norm(arr) + 0.000001)
178
+ vectors.append(arr)
179
+ current_idx += leng
180
+ elif self.space.lower() == 'z' or self.space.lower() == 'w':
181
+ vectors = np.zeros(512)
182
+ vectors[positive_idxs[i]] = 1
183
+ vectors[negative_idxs[i]] = -1
184
+ vectors = vectors / (np.linalg.norm(vectors) + 0.000001)
185
+ else:
186
+ raise Exception("""This space is not allowed in this function,
187
+ select among Z, W, S""")
188
+ separation_vectors.append(vectors)
189
+
190
+ return separation_vectors
191
+
192
+ def StyleSpace_separation_vector(self, sign=True, num_factors=20, cutout=0.25):
193
+ """ Formula from StyleSpace Analysis """
194
+ x_train, x_val, y_train, y_val = self.get_train_val()
195
+
196
+ positive_idxs = []
197
+ negative_idxs = []
198
+ for color in self.colors_list:
199
+ x_col = x_train[np.where(y_train == color)]
200
+ mp = np.mean(x_train, axis=0)
201
+ sp = np.std(x_train, axis=0)
202
+ de = (x_col - mp) / sp
203
+ meu = np.mean(de, axis=0)
204
+ seu = np.std(de, axis=0)
205
+ if sign:
206
+ thetau = meu / seu
207
+ positive_idx = np.argsort(thetau)[-num_factors//2:]
208
+ negative_idx = np.argsort(thetau)[:num_factors//2]
209
+
210
+ else:
211
+ thetau = np.abs(meu) / seu
212
+ positive_idx = np.argsort(thetau)[-num_factors:]
213
+ negative_idx = []
214
+
215
+
216
+ if cutout:
217
+ beyond_cutout = np.where(np.abs(thetau) > cutout)
218
+ positive_idx = np.intersect1d(positive_idx, beyond_cutout)
219
+ negative_idx = np.intersect1d(negative_idx, beyond_cutout)
220
+
221
+ if len(positive_idx) == 0 and len(negative_idx) == 0:
222
+ print('No values found above the current cutout', cutout, 'for color', color, '.\n Disentangled vector will be all zeros.' )
223
+
224
+ positive_idxs.append(positive_idx)
225
+ negative_idxs.append(negative_idx)
226
+
227
+ separation_vectors = self.get_original_position_latent(positive_idxs, negative_idxs)
228
+ return separation_vectors
229
+
230
+ def GANSpace_separation_vectors(self, num_components):
231
+ x_train, x_val, y_train, y_val = self.get_train_val()
232
+ if self.space.lower() == 'w':
233
+ pca = PCA(n_components=num_components)
234
+
235
+ dims_pca = pca.fit_transform(x_train.T)
236
+ dims_pca /= np.linalg.norm(dims_pca, axis=0)
237
+
238
+ return dims_pca
239
+
240
+ else:
241
+ raise("""This space is not allowed in this function,
242
+ only W""")
243
+
244
+ def generate_images(self, seed, separation_vector=None, lambd=0):
245
+ """
246
+ The generate_original_image function takes in a latent vector and the model,
247
+ and returns an image generated from that latent vector.
248
+
249
+
250
+ :param z: Generate the image
251
+ :param model: Generate the image
252
+ :return: A pil image
253
+ :doc-author: Trelent
254
+ """
255
+ G = self.model.to(self.device) # type: ignore
256
+ # Labels.
257
+ label = torch.zeros([1, G.c_dim], device=self.device)
258
+ if self.space.lower() == 'z':
259
+ vec = self.annotations['z_vectors'][seed]
260
+ Z = torch.from_numpy(vec.copy()).to(self.device)
261
+ if separation_vector is not None:
262
+ change = torch.from_numpy(separation_vector.copy()).unsqueeze(0).to(self.device)
263
+ Z = torch.add(Z, change, alpha=lambd)
264
+ img = G(Z, label, truncation_psi=1, noise_mode='const')
265
+ elif self.space.lower() == 'w':
266
+ vec = self.annotations['w_vectors'][seed]
267
+ W = torch.from_numpy(np.repeat(vec, self.decoding_layers, axis=0)
268
+ .reshape(1, self.decoding_layers, vec.shape[1]).copy()).to(self.device)
269
+ if separation_vector is not None:
270
+ change = torch.from_numpy(separation_vector.copy()).unsqueeze(0).to(self.device)
271
+ W = torch.add(W, change, alpha=lambd)
272
+ img = G.synthesis(W, noise_mode='const')
273
+ else:
274
+ raise Exception("""This space is not allowed in this function,
275
+ select either W or Z or use generate_flexible_images""")
276
+
277
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
278
+ return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
279
+
280
+ def forward_from_style(self, x, styles, layer):
281
+ dtype = torch.float16 if (getattr(self.model.synthesis, layer).use_fp16 and self.device=='cuda') else torch.float32
282
+
283
+ if getattr(self.model.synthesis, layer).is_torgb:
284
+ weight_gain = 1 / np.sqrt(getattr(self.model.synthesis, layer).in_channels * (getattr(self.model.synthesis, layer).conv_kernel ** 2))
285
+ styles = styles * weight_gain
286
+
287
+ input_gain = getattr(self.model.synthesis, layer).magnitude_ema.rsqrt().to(dtype)
288
+
289
+ # Execute modulated conv2d.
290
+ x = modulated_conv2d(x=x.to(dtype), w=getattr(self.model.synthesis, layer).weight.to(dtype), s=styles.to(dtype),
291
+ padding=getattr(self.model.synthesis, layer).conv_kernel-1,
292
+ demodulate=(not getattr(self.model.synthesis, layer).is_torgb),
293
+ input_gain=input_gain.to(dtype))
294
+
295
+ # Execute bias, filtered leaky ReLU, and clamping.
296
+ gain = 1 if getattr(self.model.synthesis, layer).is_torgb else np.sqrt(2)
297
+ slope = 1 if getattr(self.model.synthesis, layer).is_torgb else 0.2
298
+
299
+ x = filtered_lrelu.filtered_lrelu(x=x, fu=getattr(self.model.synthesis, layer).up_filter, fd=getattr(self.model.synthesis, layer).down_filter,
300
+ b=getattr(self.model.synthesis, layer).bias.to(x.dtype),
301
+ up=getattr(self.model.synthesis, layer).up_factor, down=getattr(self.model.synthesis, layer).down_factor,
302
+ padding=getattr(self.model.synthesis, layer).padding,
303
+ gain=gain, slope=slope, clamp=getattr(self.model.synthesis, layer).conv_clamp)
304
+ return x
305
+
306
+ def generate_flexible_images(self, seed, separation_vector=None, lambd=0):
307
+ if self.space.lower() != 's':
308
+ raise Exception("""This space is not allowed in this function,
309
+ select S or use generate_images""")
310
+
311
+ vec = self.annotations['w_vectors'][seed]
312
+ w_torch = torch.from_numpy(vec).to(self.device)
313
+ W = w_torch.expand((self.decoding_layers, -1)).unsqueeze(0)
314
+ x = self.model.synthesis.input(W[0,0].unsqueeze(0))
315
+ for i, layer in enumerate(self.layers[1:]):
316
+ style = getattr(self.model.synthesis, layer).affine(W[0, i].unsqueeze(0))
317
+ if separation_vector is not None:
318
+ change = torch.from_numpy(separation_vector[i+1].copy()).unsqueeze(0).to(self.device)
319
+ style = torch.add(style, change, alpha=lambd)
320
+ x = self.forward_from_style(x, style, layer)
321
+
322
+ if self.model.synthesis.output_scale != 1:
323
+ x = x * self.model.synthesis.output_scale
324
+
325
+ img = (x.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
326
+ img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
327
+
328
+ return img
329
+
330
+ def generate_changes(self, seed, separation_vector, min_epsilon=-3, max_epsilon=3, count=5, savefig=True, feature=None, method=None):
331
+ """
332
+ The regenerate_images function takes a model, z, and decision_boundary as input. It then
333
+ constructs an inverse rotation/translation matrix and passes it to the generator. The generator
334
+ expects this matrix as an inverse to avoid potentially failing numerical operations in the network.
335
+ The function then generates images using G(z_0, label) where z_0 is a linear combination of z and the decision boundary.
336
+
337
+ :param model: Pass in the model to be used for image generation
338
+ :param z: Generate the starting point of the line
339
+ :param decision_boundary: Generate images along the direction of the decision boundary
340
+ :param min_epsilon: Set the minimum value of lambda
341
+ :param max_epsilon: Set the maximum distance from the original image to generate
342
+ :param count: Determine the number of images that are generated
343
+ :return: A list of images and a list of lambdas
344
+ :doc-author: Trelent
345
+ """
346
+
347
+ os.makedirs(join(self.repo_folder, 'figures'), exist_ok=True)
348
+ lambdas = np.linspace(min_epsilon, max_epsilon, count)
349
+ images = []
350
+ # Generate images.
351
+ for _, lambd in enumerate(tqdm(lambdas)):
352
+ if self.space.lower() == 's':
353
+ images.append(self.generate_flexible_images(seed, separation_vector=separation_vector, lambd=lambd))
354
+ elif self.space.lower() in ['z', 'w']:
355
+ images.append(self.generate_images(seed, separation_vector=separation_vector, lambd=lambd))
356
+
357
+ if savefig:
358
+ print('Generating image for color', feature)
359
+ fig, axs = plt.subplots(1, len(images), figsize=(90,20))
360
+ title = 'Disentanglement method: '+ method + ', on feature: ' + feature + ' on space: ' + self.space + ', image seed: ' + str(seed)
361
+ name = '_'.join([method, feature, self.space, str(seed), str(lambdas[-1])])
362
+ fig.suptitle(title, fontsize=20)
363
+
364
+ for i, (image, lambd) in enumerate(zip(images, lambdas)):
365
+ axs[i].imshow(image)
366
+ axs[i].set_title(np.round(lambd, 2))
367
+ plt.tight_layout()
368
+ plt.savefig(join(self.repo_folder, 'figures', name+'.jpg'))
369
+ return images, lambdas
370
+
371
+ def get_verification_score(self, separation_vector, feature_id, samples=10, lambd=1, savefig=False, feature=None, method=None):
372
+ items = random.sample(range(100000), samples)
373
+ hue_low = feature_id * 256 / 12
374
+ hue_high = (feature_id + 1) * 256 / 12
375
+
376
+ matches = 0
377
+
378
+ for seed in tqdm(items):
379
+ images, lambdas = self.generate_changes(seed, separation_vector, min_epsilon=-lambd, max_epsilon=lambd, count=3, savefig=savefig, feature=feature, method=method)
380
+ colors_negative = extract_color(images[0], 5, 1, None)
381
+ h0, s0, v0 = ImageColor.getcolor(colors_negative[0], 'HSV')
382
+
383
+ colors_orig = extract_color(images[1], 5, 1, None)
384
+ h1, s1, v1 = ImageColor.getcolor(colors_orig[0], 'HSV')
385
+
386
+ colors_positive = extract_color(images[2], 5, 1, None)
387
+ h2, s2, v2 = ImageColor.getcolor(colors_positive[0], 'HSV')
388
+
389
+ if h1 > hue_low and h1 < hue_high:
390
+ samples -= 1
391
+ else:
392
+ if (h0 > hue_low and h0 < hue_high) or (h2 > hue_low and h2 < hue_high):
393
+ matches += 1
394
+
395
+ return np.round(matches / samples, 2)
396
+
397
+
398
+ def main():
399
+ repo_folder = '.'
400
+ annotations_file = join(repo_folder, 'data/textile_annotated_files/seeds0000-100000_S.pkl')
401
+ with open(annotations_file, 'rb') as f:
402
+ annotations = pickle.load(f)
403
+
404
+ df_file = join(repo_folder, 'data/textile_annotated_files/top_three_colours.csv')
405
+ df = pd.read_csv(df_file).fillna('#000000')
406
+
407
+ model_file = join(repo_folder, 'data/textile_model_files/network-snapshot-005000.pkl')
408
+ with dnnlib.util.open_url(model_file) as f:
409
+ model = legacy.load_network_pkl(f)['G_ema'] # type: ignore
410
+
411
+ colors_list = ['Warm Pink Red', 'Red Orange', 'Orange Yellow', 'Gold Yellow', 'Chartreuse Green',
412
+ 'Kelly Green', 'Green Blue Seafoam', 'Blue Green Cyan',
413
+ 'Warm Blue', 'Indigo Blue Purple', 'Purple Magenta', 'Magenta Pink']
414
+
415
+ scores = []
416
+ kwargs = {'CL method':['LR', 'SVM'], 'C':[0.1, 1], 'sign':[True, False], 'num_factors':[1, 10, 20, 50], 'cutout': [None, 0.2], 'max_lambda':[6, 10, 1], 'samples':50, 'lambda_verif':[1, 3, 6]}
417
+
418
+ for space in ['w', 'z', 's']:
419
+ print('Launching experiment with space:', space)
420
+ disentanglemnet_exp = DisentanglementBase(repo_folder, model, annotations, df, space=space, colors_list=colors_list, compute_s=False)
421
+
422
+ for method in ['StyleSpace', 'InterFaceGAN', 'GANSpace']:
423
+ if space != 's' and method == 'InterFaceGAN':
424
+ print('Now obtaining separation vector for using InterfaceGAN')
425
+ for met in kwargs['CL method']:
426
+ for c in kwargs['C']:
427
+ separation_vectors = disentanglemnet_exp.InterFaceGAN_separation_vector(method=met, C=c)
428
+ for i, color in enumerate(colors_list):
429
+ print('Generating images with variations')
430
+ seed = random.randint(0,100000)
431
+ for eps in kwargs['max_lambda']:
432
+ disentanglemnet_exp.generate_changes(seed, separation_vectors[i], min_epsilon=-eps, max_epsilon=eps, savefig=True, feature=color, method=method)
433
+
434
+ print('Finally obtaining verification score')
435
+ for verif in kwargs['lambda_verif']:
436
+ score = disentanglemnet_exp.get_verification_score(separation_vectors[i], i, samples=kwargs['samples'], lambd=verif, savefig=True, feature=color, method=method)
437
+ print('Score for method', method, 'on space', space, 'for color', color, ':', score)
438
+
439
+ scores.append([space, method, color, score, 'classification method:' + met + ', regularization: ' + str(c) + ', verification lambda:' + str(verif)])
440
+
441
+ elif method == 'StyleSpace':
442
+ print('Now obtaining separation vector for using StyleSpace')
443
+ for sign in kwargs['sign']:
444
+ for num_factors in kwargs['num_factors']:
445
+ for cutout in kwargs['cutout']:
446
+ separation_vectors = disentanglemnet_exp.StyleSpace_separation_vector(sign=sign, num_factors=num_factors, cutout=cutout)
447
+ for i, color in enumerate(colors_list):
448
+ print('Generating images with variations')
449
+ seed = random.randint(0,100000)
450
+ for eps in kwargs['max_lambda']:
451
+ disentanglemnet_exp.generate_changes(seed, separation_vectors[i], min_epsilon=-eps, max_epsilon=eps, savefig=True, feature=color, method=method)
452
+
453
+ print('Finally obtaining verification score')
454
+ for verif in kwargs['lambda_verif']:
455
+ score = disentanglemnet_exp.get_verification_score(separation_vectors[i], i, samples=kwargs['samples'], lambd=verif, savefig=True, feature=color, method=method)
456
+ print('Score for method', method, 'on space', space, 'for color', color, ':', score)
457
+
458
+ scores.append([space, method, color, score, 'using sign:' + str(sign) + ', number of factors: ' + str(num_factors) + ', using cutout: ' + str(cutout) + ', verification lambda:' + str(verif)])
459
+
460
+ if space == 'w' and method == 'GANSpace':
461
+ print('Now obtaining separation vector for using GANSpace')
462
+ separation_vectors = disentanglemnet_exp.GANSpace_separation_vectors(100)
463
+ for i in range(100):
464
+ print('Generating images with variations')
465
+ seed = random.randint(0,100000)
466
+ for eps in kwargs['max_lambda']:
467
+ disentanglemnet_exp.generate_changes(seed, separation_vectors[i], min_epsilon=-eps, max_epsilon=eps, savefig=True, feature=color, method=method)
468
+
469
+ score = None
470
+ scores.append([space, method, color, score, '100'])
471
+ else:
472
+ print('Skipping', method, 'on space', space)
473
+ continue
474
+
475
+
476
+
477
+ score_df = pd.DataFrame(scores, columns=['space', 'method', 'color', 'score', 'kwargs'])
478
+ print(score_df)
479
+ score_df.to_csv(join(repo_folder, 'data/scores.csv'))
480
+
481
+ if __name__ == "__main__":
482
+ main()
backend/color_annotations.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ """Extract color features from the generated textile images."""
4
+
5
+ import os
6
+ #os.environ["TOKENIZERS_PARALLELISM"] = "false"
7
+ from tqdm import tqdm
8
+ #from transformers import pipeline
9
+ import numpy as np
10
+ import pandas as pd
11
+ import time
12
+
13
+ import click
14
+ from PIL import Image
15
+ import math
16
+ import pickle
17
+ from glob import glob
18
+
19
+ import matplotlib.pyplot as plt
20
+ import matplotlib.patches as patches
21
+ import matplotlib.image as mpimg
22
+ import cv2
23
+ import extcolors
24
+
25
+ from colormap import rgb2hex
26
+ from PIL import Image
27
+ from matplotlib.offsetbox import OffsetImage, AnnotationBbox
28
+
29
+ def color_to_df(input):
30
+ colors_pre_list = str(input).replace('([(','').split(', (')[0:-1]
31
+ df_rgb = [i.split('), ')[0] + ')' for i in colors_pre_list]
32
+ df_percent = [i.split('), ')[1].replace(')','') for i in colors_pre_list]
33
+
34
+ #convert RGB to HEX code
35
+ df_color_up = [rgb2hex(int(i.split(", ")[0].replace("(","")),
36
+ int(i.split(", ")[1]),
37
+ int(i.split(", ")[2].replace(")",""))) for i in df_rgb]
38
+
39
+ df = pd.DataFrame(zip(df_color_up, df_percent), columns = ['c_code','occurence'])
40
+ return df
41
+
42
+ def extract_color(input_image, tolerance, zoom, outpath, save=None):
43
+ colors_x = extcolors.extract_from_image(input_image, tolerance = tolerance, limit = 13)
44
+ df_color = color_to_df(colors_x)
45
+
46
+ #annotate text
47
+ list_color = list(df_color['c_code'])
48
+ list_precent = [int(i) for i in list(df_color['occurence'])]
49
+ text_c = [c + ' ' + str(round(p*100/sum(list_precent),1)) +'%' for c, p in zip(list_color, list_precent)]
50
+ colors = list(df_color['c_code'])
51
+ if '#000000' in colors:
52
+ colors.remove('#000000')
53
+ return colors[:3]
54
+
55
+
56
+ @click.command()
57
+ @click.option('--genimages_dir', help='Where the output images are saved', type=str, required=True, metavar='DIR')
58
+
59
+ def annotate_textile_images(
60
+ genimages_dir: str,
61
+
62
+ ):
63
+ """Produce annotations for the generated images.
64
+ \b
65
+ #
66
+ python annotate_textiles.py --genimages_dir /home/ludosc/data/stylegan-10000-textile-upscale
67
+ """
68
+ colours = []
69
+ pickle_files = glob(genimages_dir + '/imgs0000*.pkl')
70
+ for pickle_file in pickle_files:
71
+ print('Using pickle file: ', pickle_file)
72
+ with open(pickle_file, 'rb') as f:
73
+ info = pickle.load(f)
74
+
75
+ listlen = len(info['fname'])
76
+ os.makedirs('/data/ludosc/colour_palettes/', exist_ok=True)
77
+ for i,im in enumerate(tqdm(info['fname'])):
78
+ try:
79
+ top_cols = exact_color(im, 12, 5, '/data/ludosc/colour_palettes/' + im.split('/')[-1])
80
+ colours.append([im]+top_cols)
81
+ except Exception as e:
82
+ print(e)
83
+ if i % 1000 == 0:
84
+ df = pd.DataFrame(colours, columns=['fname', 'top1col', 'top2col', 'top3col'])
85
+ print(df.head())
86
+ df.to_csv(genimages_dir + f'/top_three_colours.csv', index=False)
87
+
88
+ df = pd.DataFrame(colours, columns=['fname', 'top1col', 'top2col', 'top3col'])
89
+ print(df.head())
90
+ df.to_csv(genimages_dir + f'/final_sim_{os.path.basename(pickle_file.split(".")[0])}.csv', index=False)
91
+
92
+ #----------------------------------------------------------------------------
93
+
94
+ if __name__ == "__main__":
95
+ annotate_textile_images() # pylint: disable=no-value-for-parameter
96
+
97
+ #----------------------------------------------------------------------------
backend/disentangle_concepts.py CHANGED
@@ -5,6 +5,13 @@ from sklearn.model_selection import train_test_split
5
  import torch
6
  from umap import UMAP
7
  import PIL
 
 
 
 
 
 
 
8
 
9
  def get_separation_space(type_bin, annotations, df, samples=200, method='LR', C=0.1, latent_space='Z'):
10
  """
@@ -65,7 +72,7 @@ def get_separation_space(type_bin, annotations, df, samples=200, method='LR', C=
65
  return clf.coef_ / np.linalg.norm(clf.coef_), imp_features, imp_nodes, np.round(clf.score(x_val, y_val),2)
66
 
67
 
68
- def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=5, latent_space='Z', layers=None):
69
  """
70
  The regenerate_images function takes a model, z, and decision_boundary as input. It then
71
  constructs an inverse rotation/translation matrix and passes it to the generator. The generator
@@ -92,19 +99,18 @@ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3
92
  z = torch.from_numpy(z.copy()).to(device)
93
  decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
94
 
 
95
  lambdas = np.linspace(min_epsilon, max_epsilon, count)
96
  images = []
97
  # Generate images.
98
- for _, lambda_ in enumerate(lambdas):
99
  z_0 = z + lambda_ * decision_boundary
100
  if latent_space == 'Z':
101
  W_0 = G.mapping(z_0, label, truncation_psi=1).to(torch.float32)
102
  W = G.mapping(z, label, truncation_psi=1).to(torch.float32)
103
- print(W.dtype)
104
  else:
105
- W_0 = z_0.expand((14, -1)).unsqueeze(0).to(torch.float32)
106
- W = z.expand((14, -1)).unsqueeze(0).to(torch.float32)
107
- print(W.dtype)
108
 
109
  if layers:
110
  W_f = torch.empty_like(W).copy_(W).to(torch.float32)
@@ -117,14 +123,14 @@ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3
117
  images.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
118
 
119
  return images, lambdas
120
-
121
-
122
  def generate_joint_effect(model, z, decision_boundaries, min_epsilon=-3, max_epsilon=3, count=5, latent_space='Z'):
123
  decision_boundary_joint = np.sum(decision_boundaries, axis=0)
124
  print(decision_boundary_joint.shape)
125
  return regenerate_images(model, z, decision_boundary_joint, min_epsilon=min_epsilon, max_epsilon=max_epsilon, count=count, latent_space=latent_space)
126
 
127
- def generate_original_image(z, model, latent_space='Z'):
128
  """
129
  The generate_original_image function takes in a latent vector and the model,
130
  and returns an image generated from that latent vector.
@@ -135,6 +141,8 @@ def generate_original_image(z, model, latent_space='Z'):
135
  :return: A pil image
136
  :doc-author: Trelent
137
  """
 
 
138
  device = torch.device('cpu')
139
  G = model.to(device) # type: ignore
140
  # Labels.
@@ -143,10 +151,10 @@ def generate_original_image(z, model, latent_space='Z'):
143
  z = torch.from_numpy(z.copy()).to(device)
144
  img = G(z, label, truncation_psi=1, noise_mode='const')
145
  else:
146
- W = torch.from_numpy(np.repeat(z, 14, axis=0).reshape(1, 14, z.shape[1]).copy()).to(device)
147
  print(W.shape)
148
  img = G.synthesis(W, noise_mode='const')
149
-
150
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
151
  return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
152
 
@@ -188,8 +196,36 @@ def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=
188
  return vectors, nodes_in_common, performances
189
 
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
- def get_verification_score(concept, decision_boundary, model, annotations, samples=100, latent_space='Z'):
193
  import open_clip
194
  import os
195
  import random
@@ -243,5 +279,177 @@ def get_verification_score(concept, decision_boundary, model, annotations, sampl
243
  return np.round(np.mean(np.array(changes)), 4)
244
 
245
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
246
 
247
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import torch
6
  from umap import UMAP
7
  import PIL
8
+ from tqdm import tqdm
9
+ import random
10
+ from PIL import Image, ImageColor
11
+
12
+ from .color_annotations import extract_color
13
+
14
+
15
 
16
  def get_separation_space(type_bin, annotations, df, samples=200, method='LR', C=0.1, latent_space='Z'):
17
  """
 
72
  return clf.coef_ / np.linalg.norm(clf.coef_), imp_features, imp_nodes, np.round(clf.score(x_val, y_val),2)
73
 
74
 
75
+ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=5, latent_space='Z', layers=None, number=3):
76
  """
77
  The regenerate_images function takes a model, z, and decision_boundary as input. It then
78
  constructs an inverse rotation/translation matrix and passes it to the generator. The generator
 
99
  z = torch.from_numpy(z.copy()).to(device)
100
  decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
101
 
102
+ repetitions = 16 if number == 3 else 14
103
  lambdas = np.linspace(min_epsilon, max_epsilon, count)
104
  images = []
105
  # Generate images.
106
+ for _, lambda_ in enumerate(tqdm(lambdas)):
107
  z_0 = z + lambda_ * decision_boundary
108
  if latent_space == 'Z':
109
  W_0 = G.mapping(z_0, label, truncation_psi=1).to(torch.float32)
110
  W = G.mapping(z, label, truncation_psi=1).to(torch.float32)
 
111
  else:
112
+ W_0 = z_0.expand((repetitions, -1)).unsqueeze(0).to(torch.float32)
113
+ W = z.expand((repetitions, -1)).unsqueeze(0).to(torch.float32)
 
114
 
115
  if layers:
116
  W_f = torch.empty_like(W).copy_(W).to(torch.float32)
 
123
  images.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
124
 
125
  return images, lambdas
126
+
127
+
128
  def generate_joint_effect(model, z, decision_boundaries, min_epsilon=-3, max_epsilon=3, count=5, latent_space='Z'):
129
  decision_boundary_joint = np.sum(decision_boundaries, axis=0)
130
  print(decision_boundary_joint.shape)
131
  return regenerate_images(model, z, decision_boundary_joint, min_epsilon=min_epsilon, max_epsilon=max_epsilon, count=count, latent_space=latent_space)
132
 
133
+ def generate_original_image(z, model, latent_space='Z', number=3):
134
  """
135
  The generate_original_image function takes in a latent vector and the model,
136
  and returns an image generated from that latent vector.
 
141
  :return: A pil image
142
  :doc-author: Trelent
143
  """
144
+ repetitions = 16 if number == 3 else 14
145
+
146
  device = torch.device('cpu')
147
  G = model.to(device) # type: ignore
148
  # Labels.
 
151
  z = torch.from_numpy(z.copy()).to(device)
152
  img = G(z, label, truncation_psi=1, noise_mode='const')
153
  else:
154
+ W = torch.from_numpy(np.repeat(z, repetitions, axis=0).reshape(1, repetitions, z.shape[1]).copy()).to(device)
155
  print(W.shape)
156
  img = G.synthesis(W, noise_mode='const')
157
+
158
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
159
  return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
160
 
 
196
  return vectors, nodes_in_common, performances
197
 
198
 
199
+ def get_verification_score(color_id, decision_boundary, model, annotations, samples=100, latent_space='W'):
200
+ listlen = len(annotations['fname'])
201
+ items = random.sample(range(listlen), samples)
202
+ hue_low = color_id * 256 / 12
203
+ hue_high = (color_id + 1) * 256 / 12
204
+ hue_mean = (hue_low + hue_high) / 2
205
+ print(int(hue_low), int(hue_high), int(hue_mean))
206
+ distances = []
207
+ distances_orig = []
208
+ for iterator in tqdm(items):
209
+ if latent_space == 'Z':
210
+ z = annotations['z_vectors'][iterator]
211
+ else:
212
+ z = annotations['w_vectors'][iterator]
213
+
214
+ images, lambdas = regenerate_images(model, z, decision_boundary, min_epsilon=0, max_epsilon=1, count=2, latent_space=latent_space)
215
+ colors_orig = extract_color(images[0], 5, 1, None)
216
+ h_old, s_old, v_old = ImageColor.getcolor(colors_orig[0], 'HSV')
217
+ colors_new = extract_color(images[1], 5, 1, None)
218
+ h_new, s_new, v_new = ImageColor.getcolor(colors_new[0], 'HSV')
219
+ print(h_old, h_new)
220
+ distance = np.abs(hue_mean - h_new)
221
+ distances.append(distance)
222
+ distance_orig = np.abs(hue_mean - h_old)
223
+ distances_orig.append(distance_orig)
224
+
225
+ return np.round(np.mean(np.array(distances)), 4), np.round(np.mean(np.array(distances_orig)), 4)
226
+
227
 
228
+ def get_verification_score_clip(concept, decision_boundary, model, annotations, samples=100, latent_space='Z'):
229
  import open_clip
230
  import os
231
  import random
 
279
  return np.round(np.mean(np.array(changes)), 4)
280
 
281
 
282
+
283
+ def tohsv(df):
284
+ df['H1'] = df['top1col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[0])
285
+ df['H2'] = df['top2col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[0])
286
+ df['H3'] = df['top3col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[0])
287
+
288
+ df['S1'] = df['top1col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[1])
289
+ df['S2'] = df['top2col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[1])
290
+ df['S3'] = df['top3col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[1])
291
+
292
+ df['V1'] = df['top1col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[2])
293
+ df['V2'] = df['top2col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[2])
294
+ df['V3'] = df['top3col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[2])
295
+ return df
296
 
297
+
298
+ def rest_from_style(x, styles, layer):
299
+ dtype = torch.float16 if (getattr(model.synthesis, layer).use_fp16 and device=='cuda') else torch.float32
300
+ if getattr(model.synthesis, layer).is_torgb:
301
+ print(layer, getattr(model.synthesis, layer).is_torgb)
302
+ weight_gain = 1 / np.sqrt(getattr(model.synthesis, layer).in_channels * (getattr(model.synthesis, layer).conv_kernel ** 2))
303
+ styles = styles * weight_gain
304
+ input_gain = getattr(model.synthesis, layer).magnitude_ema.rsqrt().to(dtype)
305
+ # Execute modulated conv2d.
306
+ x = modulated_conv2d(x=x.to(dtype), w=getattr(model.synthesis, layer).weight.to(dtype), s=styles.to(dtype),
307
+ padding=getattr(model.synthesis, layer).conv_kernel-1, demodulate=(not getattr(model.synthesis, layer).is_torgb), input_gain=input_gain.to(dtype))
308
+ # Execute bias, filtered leaky ReLU, and clamping.
309
+ gain = 1 if getattr(model.synthesis, layer).is_torgb else np.sqrt(2)
310
+ slope = 1 if getattr(model.synthesis, layer).is_torgb else 0.2
311
+ x = filtered_lrelu.filtered_lrelu(x=x, fu=getattr(model.synthesis, layer).up_filter, fd=getattr(model.synthesis, layer).down_filter,
312
+ b=getattr(model.synthesis, layer).bias.to(x.dtype),
313
+ up=getattr(model.synthesis, layer).up_factor, down=getattr(model.synthesis, layer).down_factor,
314
+ padding=getattr(model.synthesis, layer).padding,
315
+ gain=gain, slope=slope, clamp=getattr(model.synthesis, layer).conv_clamp)
316
+ return x
317
+
318
+
319
+ def getS(w):
320
+ w_torch = torch.from_numpy(w).to('cpu')
321
+ W = w_torch.expand((16, -1)).unsqueeze(0)
322
+ s = []
323
+ s.append(model.synthesis.input.affine(W[0, 0].unsqueeze(0)).numpy())
324
+ s.append(model.synthesis.L0_36_512.affine(W[0, 1].unsqueeze(0)).numpy())
325
+ s.append(model.synthesis.L1_36_512.affine(W[0, 2].unsqueeze(0)).numpy())
326
+ s.append(model.synthesis.L2_36_512.affine(W[0, 3].unsqueeze(0)).numpy())
327
+ s.append(model.synthesis.L3_52_512.affine(W[0, 4].unsqueeze(0)).numpy())
328
+ s.append(model.synthesis.L4_52_512.affine(W[0, 5].unsqueeze(0)).numpy())
329
+ s.append(model.synthesis.L5_84_512.affine(W[0, 6].unsqueeze(0)).numpy())
330
+ s.append(model.synthesis.L6_84_512.affine(W[0, 7].unsqueeze(0)).numpy())
331
+ s.append(model.synthesis.L7_148_512.affine(W[0, 8].unsqueeze(0)).numpy())
332
+ s.append(model.synthesis.L8_148_512.affine(W[0, 9].unsqueeze(0)).numpy())
333
+ s.append(model.synthesis.L9_148_362.affine(W[0, 10].unsqueeze(0)).numpy())
334
+ s.append(model.synthesis.L10_276_256.affine(W[0, 11].unsqueeze(0)).numpy())
335
+ s.append(model.synthesis.L11_276_181.affine(W[0, 12].unsqueeze(0)).numpy())
336
+ s.append(model.synthesis.L12_276_128.affine(W[0, 13].unsqueeze(0)).numpy())
337
+ s.append(model.synthesis.L13_256_128.affine(W[0, 14].unsqueeze(0)).numpy())
338
+ s.append(model.synthesis.L14_256_3.affine(W[0, 15].unsqueeze(0)).numpy())
339
+ return s
340
+
341
+ def detect_attribute_specific_channels(positives, all, sign=False):
342
+ """ Formula from StyleSpace Analysis """
343
+ mp = np.mean(all, axis=0)
344
+ sp = np.std(all, axis=0)
345
+ de = (positives - mp) / sp
346
+ meu = np.mean(de, axis=0)
347
+ seu = np.std(de, axis=0)
348
+ if sign:
349
+ thetau = meu / seu
350
+ else:
351
+ thetau = np.abs(meu) / seu
352
+ return thetau
353
+
354
+ def all_variance_based_disentanglements(labels, x, y, k=10, sign=False, cutout=0.28):
355
+ seps = []
356
+ sorted_vals = []
357
+ for lbl in labels:
358
+ positives = x[np.where(y == lbl)]
359
+ variations = detect_attribute_specific_channels(positives, x, sign=sign)
360
+ if sign:
361
+ argsorted_vars_pos = np.argsort(variations)[-k//2:]
362
+ # print(argsorted_vars_pos)
363
+ argsorted_vars_neg = np.argsort(variations)[:k//2]
364
+ if cutout:
365
+ beyond_cutout = np.where(np.abs(variations) > cutout)
366
+ # print(beyond_cutout)
367
+ argsorted_vars_pos_int = np.intersect1d(argsorted_vars_pos, beyond_cutout)
368
+ argsorted_vars_neg_int = np.intersect1d(argsorted_vars_neg, beyond_cutout)
369
+ # print(argsorted_vars_pos)
370
+ if len(argsorted_vars_neg_int) > 0:
371
+ argsorted_vars_neg = np.array(argsorted_vars_neg_int)
372
+ if len(argsorted_vars_pos_int) > 0:
373
+ argsorted_vars_pos = np.array(argsorted_vars_pos_int)
374
+
375
+
376
+ else:
377
+ argsorted_vars = np.argsort(variations)[-k:]
378
+
379
+
380
+ sorted_vals.append(np.sort(variations))
381
+ separation_vector_onehot /= np.linalg.norm(separation_vector_onehot)
382
+ seps.append(separation_vector_onehot)
383
+ return seps, sorted_vals
384
+
385
+ def generate_flexible_images(w, change_vectors, lambdas=1, device='cpu'):
386
+ w_torch = torch.from_numpy(w).to('cpu')
387
+ if len(change_vectors) != 17:
388
+ w_torch = w_torch + lambdas * change_vectors[0]
389
+ W = w_torch.expand((16, -1)).unsqueeze(0)
390
+
391
+ x = model.synthesis.input(W[0,0].unsqueeze(0))
392
+ for i, layer in enumerate(layers):
393
+ if i < 2:
394
+ continue
395
+ style = getattr(model.synthesis, layer).affine(W[0, i-1].unsqueeze(0))
396
+ if len(change_vectors) != 17:
397
+ change = torch.from_numpy(change_vectors[i].copy()).unsqueeze(0).to(device)
398
+ style = torch.add(style, change, alpha=lambdas)
399
+ x = rest_from_style(x, style, layer)
400
+
401
+ if model.synthesis.output_scale != 1:
402
+ x = x * model.synthesis.output_scale
403
+
404
+ img = (x.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
405
+ img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
406
+
407
+ return img
408
+
409
+ def get_original_pos(top_positions, bottom_positions=None, space='s', sign=True,
410
+ shapes=[[512, 4, 512, 512, 512, 512, 512, 512, 512,
411
+ 512, 512, 512, 362, 256, 181, 128, 128]],
412
+ layers=['w', 'input', 'L0_36_512', 'L1_36_512', 'L2_36_512', 'L3_52_512',
413
+ 'L4_52_512', 'L5_84_512', 'L6_84_512', 'L7_148_512', 'L8_148_512',
414
+ 'L9_148_362', 'L10_276_256', 'L11_276_181', 'L12_276_128',
415
+ 'L13_256_128', 'L14_256_3'], ):
416
+ if space == 's':
417
+ current_idx = 0
418
+ vectors = []
419
+ for i, (leng, layer) in enumerate(zip(shapes, layers)):
420
+ arr = np.zeros(leng)
421
+ for top_position in top_positions:
422
+ if top_position >= current_idx and top_position < current_idx + leng:
423
+ arr[top_position - current_idx] = 1
424
+ for bottom_position in bottom_positions:
425
+ if sign:
426
+ if bottom_position >= current_idx and bottom_position < current_idx + leng:
427
+ arr[bottom_position - current_idx] = 1
428
+ arr = arr / (np.linalg.norm(arr) + 0.000001)
429
+ vectors.append(arr)
430
+ current_idx += leng
431
+ else:
432
+ if sign:
433
+ vectors = np.zeros(512)
434
+ vectors[top_positions] = 1
435
+ vectors[bottom_positions] = -1
436
+ else:
437
+ vectors = np.zeros(512)
438
+ vectors[top_positions] = 1
439
+ return vectors
440
+
441
+ def getX(annotations, space='s'):
442
+ if space == 'x':
443
+ X = np.array(annotations['w_vectors']).reshape((len(annotations['w_vectors']), 512))
444
+ elif space == 's':
445
+ concat_v = []
446
+ for i in range(len(annotations['w_vectors'])):
447
+ concat_v.append(np.concatenate([annotations['w_vectors'][i]] + annotations['s_vectors'][i], axis=1))
448
+
449
+ X = np.array(concat_v)
450
+ X = X[:, 0, :]
451
+ print(X.shape)
452
+
453
+ return X
454
+
455
+
backend/networks_stylegan3.py ADDED
@@ -0,0 +1,515 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Generator architecture from the paper
10
+ "Alias-Free Generative Adversarial Networks"."""
11
+
12
+ import numpy as np
13
+ import scipy.signal
14
+ import scipy.optimize
15
+ import torch
16
+ from torch_utils import misc
17
+ from torch_utils import persistence
18
+ from torch_utils.ops import conv2d_gradfix
19
+ from torch_utils.ops import filtered_lrelu
20
+ from torch_utils.ops import bias_act
21
+
22
+ #----------------------------------------------------------------------------
23
+
24
+ @misc.profiled_function
25
+ def modulated_conv2d(
26
+ x, # Input tensor: [batch_size, in_channels, in_height, in_width]
27
+ w, # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width]
28
+ s, # Style tensor: [batch_size, in_channels]
29
+ demodulate = True, # Apply weight demodulation?
30
+ padding = 0, # Padding: int or [padH, padW]
31
+ input_gain = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels]
32
+ ):
33
+ with misc.suppress_tracer_warnings(): # this value will be treated as a constant
34
+ batch_size = int(x.shape[0])
35
+ out_channels, in_channels, kh, kw = w.shape
36
+ misc.assert_shape(w, [out_channels, in_channels, kh, kw]) # [OIkk]
37
+ misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
38
+ misc.assert_shape(s, [batch_size, in_channels]) # [NI]
39
+
40
+ # Pre-normalize inputs.
41
+ if demodulate:
42
+ w = w * w.square().mean([1,2,3], keepdim=True).rsqrt()
43
+ s = s * s.square().mean().rsqrt()
44
+
45
+ # Modulate weights.
46
+ w = w.unsqueeze(0) # [NOIkk]
47
+ w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
48
+
49
+ # Demodulate weights.
50
+ if demodulate:
51
+ dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
52
+ w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk]
53
+
54
+ # Apply input scaling.
55
+ if input_gain is not None:
56
+ input_gain = input_gain.expand(batch_size, in_channels) # [NI]
57
+ w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
58
+
59
+ # Execute as one fused op using grouped convolution.
60
+ x = x.reshape(1, -1, *x.shape[2:])
61
+ w = w.reshape(-1, in_channels, kh, kw)
62
+ x = conv2d_gradfix.conv2d(input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size)
63
+ x = x.reshape(batch_size, -1, *x.shape[2:])
64
+ return x
65
+
66
+ #----------------------------------------------------------------------------
67
+
68
+ @persistence.persistent_class
69
+ class FullyConnectedLayer(torch.nn.Module):
70
+ def __init__(self,
71
+ in_features, # Number of input features.
72
+ out_features, # Number of output features.
73
+ activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
74
+ bias = True, # Apply additive bias before the activation function?
75
+ lr_multiplier = 1, # Learning rate multiplier.
76
+ weight_init = 1, # Initial standard deviation of the weight tensor.
77
+ bias_init = 0, # Initial value of the additive bias.
78
+ ):
79
+ super().__init__()
80
+ self.in_features = in_features
81
+ self.out_features = out_features
82
+ self.activation = activation
83
+ self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier))
84
+ bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features])
85
+ self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None
86
+ self.weight_gain = lr_multiplier / np.sqrt(in_features)
87
+ self.bias_gain = lr_multiplier
88
+
89
+ def forward(self, x):
90
+ w = self.weight.to(x.dtype) * self.weight_gain
91
+ b = self.bias
92
+ if b is not None:
93
+ b = b.to(x.dtype)
94
+ if self.bias_gain != 1:
95
+ b = b * self.bias_gain
96
+ if self.activation == 'linear' and b is not None:
97
+ x = torch.addmm(b.unsqueeze(0), x, w.t())
98
+ else:
99
+ x = x.matmul(w.t())
100
+ x = bias_act.bias_act(x, b, act=self.activation)
101
+ return x
102
+
103
+ def extra_repr(self):
104
+ return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
105
+
106
+ #----------------------------------------------------------------------------
107
+
108
+ @persistence.persistent_class
109
+ class MappingNetwork(torch.nn.Module):
110
+ def __init__(self,
111
+ z_dim, # Input latent (Z) dimensionality.
112
+ c_dim, # Conditioning label (C) dimensionality, 0 = no labels.
113
+ w_dim, # Intermediate latent (W) dimensionality.
114
+ num_ws, # Number of intermediate latents to output.
115
+ num_layers = 2, # Number of mapping layers.
116
+ lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
117
+ w_avg_beta = 0.998, # Decay for tracking the moving average of W during training.
118
+ ):
119
+ super().__init__()
120
+ self.z_dim = z_dim
121
+ self.c_dim = c_dim
122
+ self.w_dim = w_dim
123
+ self.num_ws = num_ws
124
+ self.num_layers = num_layers
125
+ self.w_avg_beta = w_avg_beta
126
+
127
+ # Construct layers.
128
+ self.embed = FullyConnectedLayer(self.c_dim, self.w_dim) if self.c_dim > 0 else None
129
+ features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers
130
+ for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]):
131
+ layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier)
132
+ setattr(self, f'fc{idx}', layer)
133
+ self.register_buffer('w_avg', torch.zeros([w_dim]))
134
+
135
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
136
+ misc.assert_shape(z, [None, self.z_dim])
137
+ if truncation_cutoff is None:
138
+ truncation_cutoff = self.num_ws
139
+
140
+ # Embed, normalize, and concatenate inputs.
141
+ x = z.to(torch.float32)
142
+ x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()
143
+ if self.c_dim > 0:
144
+ misc.assert_shape(c, [None, self.c_dim])
145
+ y = self.embed(c.to(torch.float32))
146
+ y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt()
147
+ x = torch.cat([x, y], dim=1) if x is not None else y
148
+
149
+ # Execute layers.
150
+ for idx in range(self.num_layers):
151
+ x = getattr(self, f'fc{idx}')(x)
152
+
153
+ # Update moving average of W.
154
+ if update_emas:
155
+ self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
156
+
157
+ # Broadcast and apply truncation.
158
+ x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
159
+ if truncation_psi != 1:
160
+ x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
161
+ return x
162
+
163
+ def extra_repr(self):
164
+ return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'
165
+
166
+ #----------------------------------------------------------------------------
167
+
168
+ @persistence.persistent_class
169
+ class SynthesisInput(torch.nn.Module):
170
+ def __init__(self,
171
+ w_dim, # Intermediate latent (W) dimensionality.
172
+ channels, # Number of output channels.
173
+ size, # Output spatial size: int or [width, height].
174
+ sampling_rate, # Output sampling rate.
175
+ bandwidth, # Output bandwidth.
176
+ ):
177
+ super().__init__()
178
+ self.w_dim = w_dim
179
+ self.channels = channels
180
+ self.size = np.broadcast_to(np.asarray(size), [2])
181
+ self.sampling_rate = sampling_rate
182
+ self.bandwidth = bandwidth
183
+
184
+ # Draw random frequencies from uniform 2D disc.
185
+ freqs = torch.randn([self.channels, 2])
186
+ radii = freqs.square().sum(dim=1, keepdim=True).sqrt()
187
+ freqs /= radii * radii.square().exp().pow(0.25)
188
+ freqs *= bandwidth
189
+ phases = torch.rand([self.channels]) - 0.5
190
+
191
+ # Setup parameters and buffers.
192
+ self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels]))
193
+ self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0])
194
+ self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image.
195
+ self.register_buffer('freqs', freqs)
196
+ self.register_buffer('phases', phases)
197
+
198
+ def forward(self, w):
199
+ # Introduce batch dimension.
200
+ transforms = self.transform.unsqueeze(0) # [batch, row, col]
201
+ freqs = self.freqs.unsqueeze(0) # [batch, channel, xy]
202
+ phases = self.phases.unsqueeze(0) # [batch, channel]
203
+
204
+ # Apply learned transformation.
205
+ t = self.affine(w) # t = (r_c, r_s, t_x, t_y)
206
+ t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y)
207
+ m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image.
208
+ m_r[:, 0, 0] = t[:, 0] # r'_c
209
+ m_r[:, 0, 1] = -t[:, 1] # r'_s
210
+ m_r[:, 1, 0] = t[:, 1] # r'_s
211
+ m_r[:, 1, 1] = t[:, 0] # r'_c
212
+ m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image.
213
+ m_t[:, 0, 2] = -t[:, 2] # t'_x
214
+ m_t[:, 1, 2] = -t[:, 3] # t'_y
215
+ transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform.
216
+
217
+ # Transform frequencies.
218
+ phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2)
219
+ freqs = freqs @ transforms[:, :2, :2]
220
+
221
+ # Dampen out-of-band frequencies that may occur due to the user-specified transform.
222
+ amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1)
223
+
224
+ # Construct sampling grid.
225
+ theta = torch.eye(2, 3, device=w.device)
226
+ theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate
227
+ theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate
228
+ grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False)
229
+
230
+ # Compute Fourier features.
231
+ x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel]
232
+ x = x + phases.unsqueeze(1).unsqueeze(2)
233
+ x = torch.sin(x * (np.pi * 2))
234
+ x = x * amplitudes.unsqueeze(1).unsqueeze(2)
235
+
236
+ # Apply trainable mapping.
237
+ weight = self.weight / np.sqrt(self.channels)
238
+ x = x @ weight.t()
239
+
240
+ # Ensure correct shape.
241
+ x = x.permute(0, 3, 1, 2) # [batch, channel, height, width]
242
+ misc.assert_shape(x, [w.shape[0], self.channels, int(self.size[1]), int(self.size[0])])
243
+ return x
244
+
245
+ def extra_repr(self):
246
+ return '\n'.join([
247
+ f'w_dim={self.w_dim:d}, channels={self.channels:d}, size={list(self.size)},',
248
+ f'sampling_rate={self.sampling_rate:g}, bandwidth={self.bandwidth:g}'])
249
+
250
+ #----------------------------------------------------------------------------
251
+
252
+ @persistence.persistent_class
253
+ class SynthesisLayer(torch.nn.Module):
254
+ def __init__(self,
255
+ w_dim, # Intermediate latent (W) dimensionality.
256
+ is_torgb, # Is this the final ToRGB layer?
257
+ is_critically_sampled, # Does this layer use critical sampling?
258
+ use_fp16, # Does this layer use FP16?
259
+
260
+ # Input & output specifications.
261
+ in_channels, # Number of input channels.
262
+ out_channels, # Number of output channels.
263
+ in_size, # Input spatial size: int or [width, height].
264
+ out_size, # Output spatial size: int or [width, height].
265
+ in_sampling_rate, # Input sampling rate (s).
266
+ out_sampling_rate, # Output sampling rate (s).
267
+ in_cutoff, # Input cutoff frequency (f_c).
268
+ out_cutoff, # Output cutoff frequency (f_c).
269
+ in_half_width, # Input transition band half-width (f_h).
270
+ out_half_width, # Output Transition band half-width (f_h).
271
+
272
+ # Hyperparameters.
273
+ conv_kernel = 3, # Convolution kernel size. Ignored for final the ToRGB layer.
274
+ filter_size = 6, # Low-pass filter size relative to the lower resolution when up/downsampling.
275
+ lrelu_upsampling = 2, # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer.
276
+ use_radial_filters = False, # Use radially symmetric downsampling filter? Ignored for critically sampled layers.
277
+ conv_clamp = 256, # Clamp the output to [-X, +X], None = disable clamping.
278
+ magnitude_ema_beta = 0.999, # Decay rate for the moving average of input magnitudes.
279
+ ):
280
+ super().__init__()
281
+ self.w_dim = w_dim
282
+ self.is_torgb = is_torgb
283
+ self.is_critically_sampled = is_critically_sampled
284
+ self.use_fp16 = use_fp16
285
+ self.in_channels = in_channels
286
+ self.out_channels = out_channels
287
+ self.in_size = np.broadcast_to(np.asarray(in_size), [2])
288
+ self.out_size = np.broadcast_to(np.asarray(out_size), [2])
289
+ self.in_sampling_rate = in_sampling_rate
290
+ self.out_sampling_rate = out_sampling_rate
291
+ self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling)
292
+ self.in_cutoff = in_cutoff
293
+ self.out_cutoff = out_cutoff
294
+ self.in_half_width = in_half_width
295
+ self.out_half_width = out_half_width
296
+ self.conv_kernel = 1 if is_torgb else conv_kernel
297
+ self.conv_clamp = conv_clamp
298
+ self.magnitude_ema_beta = magnitude_ema_beta
299
+
300
+ # Setup parameters and buffers.
301
+ self.affine = FullyConnectedLayer(self.w_dim, self.in_channels, bias_init=1)
302
+ self.weight = torch.nn.Parameter(torch.randn([self.out_channels, self.in_channels, self.conv_kernel, self.conv_kernel]))
303
+ self.bias = torch.nn.Parameter(torch.zeros([self.out_channels]))
304
+ self.register_buffer('magnitude_ema', torch.ones([]))
305
+
306
+ # Design upsampling filter.
307
+ self.up_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate))
308
+ assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate
309
+ self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1
310
+ self.register_buffer('up_filter', self.design_lowpass_filter(
311
+ numtaps=self.up_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate))
312
+
313
+ # Design downsampling filter.
314
+ self.down_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate))
315
+ assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate
316
+ self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1
317
+ self.down_radial = use_radial_filters and not self.is_critically_sampled
318
+ self.register_buffer('down_filter', self.design_lowpass_filter(
319
+ numtaps=self.down_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial))
320
+
321
+ # Compute padding.
322
+ pad_total = (self.out_size - 1) * self.down_factor + 1 # Desired output size before downsampling.
323
+ pad_total -= (self.in_size + self.conv_kernel - 1) * self.up_factor # Input size after upsampling.
324
+ pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters.
325
+ pad_lo = (pad_total + self.up_factor) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3).
326
+ pad_hi = pad_total - pad_lo
327
+ self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])]
328
+
329
+ def forward(self, x, w, noise_mode='random', force_fp32=False, update_emas=False):
330
+ assert noise_mode in ['random', 'const', 'none'] # unused
331
+ misc.assert_shape(x, [None, self.in_channels, int(self.in_size[1]), int(self.in_size[0])])
332
+ misc.assert_shape(w, [x.shape[0], self.w_dim])
333
+
334
+ # Track input magnitude.
335
+ if update_emas:
336
+ with torch.autograd.profiler.record_function('update_magnitude_ema'):
337
+ magnitude_cur = x.detach().to(torch.float32).square().mean()
338
+ self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.magnitude_ema_beta))
339
+ input_gain = self.magnitude_ema.rsqrt()
340
+
341
+ # Execute affine layer.
342
+ styles = self.affine(w)
343
+ if self.is_torgb:
344
+ weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2))
345
+ styles = styles * weight_gain
346
+
347
+ # Execute modulated conv2d.
348
+ dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
349
+ x = modulated_conv2d(x=x.to(dtype), w=self.weight, s=styles,
350
+ padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain)
351
+
352
+ # Execute bias, filtered leaky ReLU, and clamping.
353
+ gain = 1 if self.is_torgb else np.sqrt(2)
354
+ slope = 1 if self.is_torgb else 0.2
355
+ x = filtered_lrelu.filtered_lrelu(x=x, fu=self.up_filter, fd=self.down_filter, b=self.bias.to(x.dtype),
356
+ up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp)
357
+
358
+ # Ensure correct shape and dtype.
359
+ misc.assert_shape(x, [None, self.out_channels, int(self.out_size[1]), int(self.out_size[0])])
360
+ assert x.dtype == dtype
361
+ return x
362
+
363
+ @staticmethod
364
+ def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False):
365
+ assert numtaps >= 1
366
+
367
+ # Identity filter.
368
+ if numtaps == 1:
369
+ return None
370
+
371
+ # Separable Kaiser low-pass filter.
372
+ if not radial:
373
+ f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs)
374
+ return torch.as_tensor(f, dtype=torch.float32)
375
+
376
+ # Radially symmetric jinc-based filter.
377
+ x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs
378
+ r = np.hypot(*np.meshgrid(x, x))
379
+ f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r)
380
+ beta = scipy.signal.kaiser_beta(scipy.signal.kaiser_atten(numtaps, width / (fs / 2)))
381
+ w = np.kaiser(numtaps, beta)
382
+ f *= np.outer(w, w)
383
+ f /= np.sum(f)
384
+ return torch.as_tensor(f, dtype=torch.float32)
385
+
386
+ def extra_repr(self):
387
+ return '\n'.join([
388
+ f'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},',
389
+ f'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},',
390
+ f'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},',
391
+ f'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},',
392
+ f'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},',
393
+ f'in_size={list(self.in_size)}, out_size={list(self.out_size)},',
394
+ f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}'])
395
+
396
+ #----------------------------------------------------------------------------
397
+
398
+ @persistence.persistent_class
399
+ class SynthesisNetwork(torch.nn.Module):
400
+ def __init__(self,
401
+ w_dim, # Intermediate latent (W) dimensionality.
402
+ img_resolution, # Output image resolution.
403
+ img_channels, # Number of color channels.
404
+ channel_base = 32768, # Overall multiplier for the number of channels.
405
+ channel_max = 512, # Maximum number of channels in any layer.
406
+ num_layers = 14, # Total number of layers, excluding Fourier features and ToRGB.
407
+ num_critical = 2, # Number of critically sampled layers at the end.
408
+ first_cutoff = 2, # Cutoff frequency of the first layer (f_{c,0}).
409
+ first_stopband = 2**2.1, # Minimum stopband of the first layer (f_{t,0}).
410
+ last_stopband_rel = 2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff.
411
+ margin_size = 10, # Number of additional pixels outside the image.
412
+ output_scale = 0.25, # Scale factor for the output image.
413
+ num_fp16_res = 4, # Use FP16 for the N highest resolutions.
414
+ **layer_kwargs, # Arguments for SynthesisLayer.
415
+ ):
416
+ super().__init__()
417
+ self.w_dim = w_dim
418
+ self.num_ws = num_layers + 2
419
+ self.img_resolution = img_resolution
420
+ self.img_channels = img_channels
421
+ self.num_layers = num_layers
422
+ self.num_critical = num_critical
423
+ self.margin_size = margin_size
424
+ self.output_scale = output_scale
425
+ self.num_fp16_res = num_fp16_res
426
+
427
+ # Geometric progression of layer cutoffs and min. stopbands.
428
+ last_cutoff = self.img_resolution / 2 # f_{c,N}
429
+ last_stopband = last_cutoff * last_stopband_rel # f_{t,N}
430
+ exponents = np.minimum(np.arange(self.num_layers + 1) / (self.num_layers - self.num_critical), 1)
431
+ cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents # f_c[i]
432
+ stopbands = first_stopband * (last_stopband / first_stopband) ** exponents # f_t[i]
433
+
434
+ # Compute remaining layer parameters.
435
+ sampling_rates = np.exp2(np.ceil(np.log2(np.minimum(stopbands * 2, self.img_resolution)))) # s[i]
436
+ half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs # f_h[i]
437
+ sizes = sampling_rates + self.margin_size * 2
438
+ sizes[-2:] = self.img_resolution
439
+ channels = np.rint(np.minimum((channel_base / 2) / cutoffs, channel_max))
440
+ channels[-1] = self.img_channels
441
+
442
+ # Construct layers.
443
+ self.input = SynthesisInput(
444
+ w_dim=self.w_dim, channels=int(channels[0]), size=int(sizes[0]),
445
+ sampling_rate=sampling_rates[0], bandwidth=cutoffs[0])
446
+ self.layer_names = []
447
+ for idx in range(self.num_layers + 1):
448
+ prev = max(idx - 1, 0)
449
+ is_torgb = (idx == self.num_layers)
450
+ is_critically_sampled = (idx >= self.num_layers - self.num_critical)
451
+ use_fp16 = (sampling_rates[idx] * (2 ** self.num_fp16_res) > self.img_resolution)
452
+ layer = SynthesisLayer(
453
+ w_dim=self.w_dim, is_torgb=is_torgb, is_critically_sampled=is_critically_sampled, use_fp16=use_fp16,
454
+ in_channels=int(channels[prev]), out_channels= int(channels[idx]),
455
+ in_size=int(sizes[prev]), out_size=int(sizes[idx]),
456
+ in_sampling_rate=int(sampling_rates[prev]), out_sampling_rate=int(sampling_rates[idx]),
457
+ in_cutoff=cutoffs[prev], out_cutoff=cutoffs[idx],
458
+ in_half_width=half_widths[prev], out_half_width=half_widths[idx],
459
+ **layer_kwargs)
460
+ name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}'
461
+ setattr(self, name, layer)
462
+ self.layer_names.append(name)
463
+
464
+ def forward(self, ws, **layer_kwargs):
465
+ misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
466
+ ws = ws.to(torch.float32).unbind(dim=1)
467
+
468
+ # Execute layers.
469
+ x = self.input(ws[0])
470
+ for name, w in zip(self.layer_names, ws[1:]):
471
+ x = getattr(self, name)(x, w, **layer_kwargs)
472
+ if self.output_scale != 1:
473
+ x = x * self.output_scale
474
+
475
+ # Ensure correct shape and dtype.
476
+ misc.assert_shape(x, [None, self.img_channels, self.img_resolution, self.img_resolution])
477
+ x = x.to(torch.float32)
478
+ return x
479
+
480
+ def extra_repr(self):
481
+ return '\n'.join([
482
+ f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
483
+ f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
484
+ f'num_layers={self.num_layers:d}, num_critical={self.num_critical:d},',
485
+ f'margin_size={self.margin_size:d}, num_fp16_res={self.num_fp16_res:d}'])
486
+
487
+ #----------------------------------------------------------------------------
488
+
489
+ @persistence.persistent_class
490
+ class Generator(torch.nn.Module):
491
+ def __init__(self,
492
+ z_dim, # Input latent (Z) dimensionality.
493
+ c_dim, # Conditioning label (C) dimensionality.
494
+ w_dim, # Intermediate latent (W) dimensionality.
495
+ img_resolution, # Output resolution.
496
+ img_channels, # Number of output color channels.
497
+ mapping_kwargs = {}, # Arguments for MappingNetwork.
498
+ **synthesis_kwargs, # Arguments for SynthesisNetwork.
499
+ ):
500
+ super().__init__()
501
+ self.z_dim = z_dim
502
+ self.c_dim = c_dim
503
+ self.w_dim = w_dim
504
+ self.img_resolution = img_resolution
505
+ self.img_channels = img_channels
506
+ self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
507
+ self.num_ws = self.synthesis.num_ws
508
+ self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
509
+
510
+ def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
511
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
512
+ img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
513
+ return img
514
+
515
+ #----------------------------------------------------------------------------
data/stylegan3.webp ADDED

Git LFS Details

  • SHA256: f2ac8f58158a27eeef16c18ced280bee758f0fb61b89f42b11b2bc531ea2aa99
  • Pointer size: 130 Bytes
  • Size of remote file: 38.5 kB
data/textile_annotated_files/final_sim_seeds0000-100000.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2dde6f1168825424eed5aa328c6442873df35637e74786f9c9af956c9c0a97ed
3
+ size 7886477
data/textile_annotated_files/hsv_info.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:36016190017d38cde71f35c267df9f6b0ab40d74ce17e022195e96d32d2f2f71
3
+ size 1112635
data/textile_annotated_files/seeds0000-100000.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2dd321307f3d332193630f823a3e0db79d533156dfbc0d446eab0d5c212b1360
3
+ size 630151183
data/textile_annotated_files/seeds0000-100000_S.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8763836ea1142f6f2e3d36b7fe92bcf9a4549e9ef8e0a83a02b4772d64e95d54
3
+ size 3178623075
data/textile_annotated_files/top_three_colours.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2dde6f1168825424eed5aa328c6442873df35637e74786f9c9af956c9c0a97ed
3
+ size 7886477
data/textile_model_files/network-snapshot-005000.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:717bdd11871c0383d6e28c54b2d61cd485ef236dd1de34d3194323c843b11b62
3
+ size 343479704
pages/{4_Vase_Qualities_Comparison.py → 4_Vase_Qualities_Comparison copy.py} RENAMED
File without changes
pages/5_Textiles_Disentanglement.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pickle
3
+ import pandas as pd
4
+ import numpy as np
5
+ import random
6
+ import torch
7
+
8
+ from matplotlib.backends.backend_agg import RendererAgg
9
+
10
+ from backend.disentangle_concepts import *
11
+ import torch_utils
12
+ import dnnlib
13
+ import legacy
14
+
15
+ _lock = RendererAgg.lock
16
+
17
+
18
+ st.set_page_config(layout='wide')
19
+ BACKGROUND_COLOR = '#bcd0e7'
20
+ SECONDARY_COLOR = '#bce7db'
21
+
22
+
23
+ st.title('Disentanglement studies on the Textile Dataset')
24
+ st.markdown(
25
+ """
26
+ This is a demo of the Disentanglement studies on the [Oxford Vases Dataset](https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/).
27
+ """,
28
+ unsafe_allow_html=False,)
29
+
30
+ annotations_file = './data/vase_annotated_files/seeds0000-20000.pkl'
31
+ with open(annotations_file, 'rb') as f:
32
+ annotations = pickle.load(f)
33
+
34
+
35
+ if 'image_id' not in st.session_state:
36
+ st.session_state.image_id = 0
37
+ if 'concept_ids' not in st.session_state:
38
+ st.session_state.concept_ids =['AMPHORA']
39
+ if 'space_id' not in st.session_state:
40
+ st.session_state.space_id = 'W'
41
+
42
+ # def on_change_random_input():
43
+ # st.session_state.image_id = st.session_state.image_id
44
+
45
+ # ----------------------------- INPUT ----------------------------------
46
+ st.header('Input')
47
+ input_col_1, input_col_2, input_col_3 = st.columns(3)
48
+ # --------------------------- INPUT column 1 ---------------------------
49
+ with input_col_1:
50
+ with st.form('text_form'):
51
+
52
+ # image_id = st.number_input('Image ID: ', format='%d', step=1)
53
+ st.write('**Choose two options to disentangle**')
54
+ type_col = st.selectbox('Concept category:', tuple(['Provenance', 'Shape Name', 'Fabric', 'Technique']))
55
+
56
+ ann_df = pd.read_csv(f'./data/vase_annotated_files/sim_{type_col}_seeds0000-20000.csv')
57
+ labels = list(ann_df.columns)
58
+ labels.remove('ID')
59
+ labels.remove('Unnamed: 0')
60
+
61
+ concept_ids = st.multiselect('Concepts:', tuple(labels), max_selections=2, default=[labels[2], labels[3]])
62
+
63
+ st.write('**Choose a latent space to disentangle**')
64
+ space_id = st.selectbox('Space:', tuple(['W', 'Z']))
65
+
66
+ choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
67
+
68
+ if choose_text_button:
69
+ concept_ids = list(concept_ids)
70
+ st.session_state.concept_ids = concept_ids
71
+ space_id = str(space_id)
72
+ st.session_state.space_id = space_id
73
+ # st.write(image_id, st.session_state.image_id)
74
+
75
+ # ---------------------------- SET UP OUTPUT ------------------------------
76
+ epsilon_container = st.empty()
77
+ st.header('Output')
78
+ st.subheader('Concept vector')
79
+
80
+ # perform attack container
81
+ # header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
82
+ # output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
83
+ header_col_1, header_col_2 = st.columns([5,1])
84
+ output_col_1, output_col_2 = st.columns([5,1])
85
+
86
+ st.subheader('Derivations along the concept vector')
87
+
88
+ # prediction error container
89
+ error_container = st.empty()
90
+ smoothgrad_header_container = st.empty()
91
+
92
+ # smoothgrad container
93
+ smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
94
+ smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
95
+
96
+ # ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
97
+ with output_col_1:
98
+ separation_vector, number_important_features, imp_nodes, performance = get_separation_space(concept_ids, annotations, ann_df, latent_space=st.session_state.space_id, samples=150)
99
+ # st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
100
+ st.write('Concept vector', separation_vector)
101
+ header_col_1.write(f'Concept {st.session_state.concept_ids} - Space {st.session_state.space_id} - Number of relevant nodes: {number_important_features} - Val classification performance: {performance}')# - Nodes {",".join(list(imp_nodes))}')
102
+
103
+ # ----------------------------- INPUT column 2 & 3 ----------------------------
104
+ with input_col_2:
105
+ with st.form('image_form'):
106
+
107
+ # image_id = st.number_input('Image ID: ', format='%d', step=1)
108
+ st.write('**Choose or generate a random image to test the disentanglement**')
109
+ chosen_image_id_input = st.empty()
110
+ image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
111
+
112
+ choose_image_button = st.form_submit_button('Choose the defined image')
113
+ random_id = st.form_submit_button('Generate a random image')
114
+
115
+ if random_id:
116
+ image_id = random.randint(0, 20000)
117
+ st.session_state.image_id = image_id
118
+ chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
119
+
120
+ if choose_image_button:
121
+ image_id = int(image_id)
122
+ st.session_state.image_id = int(image_id)
123
+ # st.write(image_id, st.session_state.image_id)
124
+
125
+ with input_col_3:
126
+ with st.form('Variate along the disentangled concept'):
127
+ st.write('**Set range of change**')
128
+ chosen_epsilon_input = st.empty()
129
+ epsilon = chosen_epsilon_input.number_input('Lambda:', min_value=1, step=1)
130
+ epsilon_button = st.form_submit_button('Choose the defined lambda')
131
+ st.write('**Select hierarchical levels to manipulate**')
132
+ layers = st.multiselect('Layers:', tuple(range(14)))
133
+ if len(layers) == 0:
134
+ layers = None
135
+ print(layers)
136
+ layers_button = st.form_submit_button('Choose the defined layers')
137
+
138
+
139
+ # ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
140
+
141
+ #model = torch.load('./data/model_files/pytorch_model.bin', map_location=torch.device('cpu'))
142
+ with dnnlib.util.open_url('./data/vase_model_files/network-snapshot-003800.pkl') as f:
143
+ model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
144
+
145
+ if st.session_state.space_id == 'Z':
146
+ original_image_vec = annotations['z_vectors'][st.session_state.image_id]
147
+ else:
148
+ original_image_vec = annotations['w_vectors'][st.session_state.image_id]
149
+
150
+ img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
151
+
152
+ top_pred = ann_df.loc[st.session_state.image_id, labels].astype(float).idxmax()
153
+ # input_image = original_image_dict['image']
154
+ # input_label = original_image_dict['label']
155
+ # input_id = original_image_dict['id']
156
+
157
+ with smoothgrad_col_3:
158
+ st.image(img)
159
+ smooth_head_3.write(f'Base image, predicted as {top_pred}')
160
+
161
+
162
+ images, lambdas = regenerate_images(model, original_image_vec, separation_vector, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon), latent_space=st.session_state.space_id, layers=layers)
163
+
164
+ with smoothgrad_col_1:
165
+ st.image(images[0])
166
+ smooth_head_1.write(f'Change of {np.round(lambdas[0], 2)}')
167
+
168
+ with smoothgrad_col_2:
169
+ st.image(images[1])
170
+ smooth_head_2.write(f'Change of {np.round(lambdas[1], 2)}')
171
+
172
+ with smoothgrad_col_4:
173
+ st.image(images[3])
174
+ smooth_head_4.write(f'Change of {np.round(lambdas[3], 2)}')
175
+
176
+ with smoothgrad_col_5:
177
+ st.image(images[4])
178
+ smooth_head_5.write(f'Change of {np.round(lambdas[4], 2)}')