Spaces:
Runtime error
Runtime error
created global cleaned file
Browse files- .gitignore +2 -1
- DisentanglementBase.py +482 -0
- backend/color_annotations.py +97 -0
- backend/disentangle_concepts.py +221 -13
- backend/networks_stylegan3.py +515 -0
- data/stylegan3.webp +3 -0
- data/textile_annotated_files/final_sim_seeds0000-100000.csv +3 -0
- data/textile_annotated_files/hsv_info.csv +3 -0
- data/textile_annotated_files/seeds0000-100000.pkl +3 -0
- data/textile_annotated_files/seeds0000-100000_S.pkl +3 -0
- data/textile_annotated_files/top_three_colours.csv +3 -0
- data/textile_model_files/network-snapshot-005000.pkl +3 -0
- pages/{4_Vase_Qualities_Comparison.py → 4_Vase_Qualities_Comparison copy.py} +0 -0
- pages/5_Textiles_Disentanglement.py +178 -0
.gitignore
CHANGED
@@ -184,4 +184,5 @@ dmypy.json
|
|
184 |
cython_debug/
|
185 |
|
186 |
data/images/
|
187 |
-
tmp/
|
|
|
|
184 |
cython_debug/
|
185 |
|
186 |
data/images/
|
187 |
+
tmp/
|
188 |
+
figures/
|
DisentanglementBase.py
ADDED
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
|
4 |
+
from sklearn.svm import SVC
|
5 |
+
from sklearn.decomposition import PCA
|
6 |
+
from sklearn.linear_model import LogisticRegression
|
7 |
+
from sklearn.model_selection import train_test_split
|
8 |
+
|
9 |
+
from tqdm import tqdm
|
10 |
+
import random
|
11 |
+
from os.path import join
|
12 |
+
import os
|
13 |
+
import pickle
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
import matplotlib.pyplot as plt
|
18 |
+
import PIL
|
19 |
+
from PIL import Image, ImageColor
|
20 |
+
|
21 |
+
import sys
|
22 |
+
sys.path.append('backend')
|
23 |
+
from color_annotations import extract_color
|
24 |
+
from networks_stylegan3 import *
|
25 |
+
sys.path.append('.')
|
26 |
+
|
27 |
+
import dnnlib
|
28 |
+
import legacy
|
29 |
+
|
30 |
+
class DisentanglementBase:
|
31 |
+
def __init__(self, repo_folder, model, annotations, df, space, colors_list, compute_s):
|
32 |
+
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
33 |
+
print('Using device', self.device)
|
34 |
+
self.repo_folder = repo_folder
|
35 |
+
self.model = model.to(self.device)
|
36 |
+
self.annotations = annotations
|
37 |
+
self.df = df
|
38 |
+
self.space = space
|
39 |
+
|
40 |
+
self.layers = ['input', 'L0_36_512', 'L1_36_512', 'L2_36_512', 'L3_52_512',
|
41 |
+
'L4_52_512', 'L5_84_512', 'L6_84_512', 'L7_148_512', 'L8_148_512',
|
42 |
+
'L9_148_362', 'L10_276_256', 'L11_276_181', 'L12_276_128',
|
43 |
+
'L13_256_128', 'L14_256_3']
|
44 |
+
self.layers_shapes = [4, 512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 362, 256, 181, 128, 128]
|
45 |
+
self.decoding_layers = 16
|
46 |
+
self.colors_list = colors_list
|
47 |
+
|
48 |
+
self.to_hsv()
|
49 |
+
if compute_s:
|
50 |
+
self.get_s_space()
|
51 |
+
|
52 |
+
|
53 |
+
def to_hsv(self):
|
54 |
+
"""
|
55 |
+
The tohsv function takes the top 3 colors of each image and converts them to HSV values.
|
56 |
+
It then adds these values as new columns in the dataframe.
|
57 |
+
|
58 |
+
:param self: Allow the function to access the dataframe
|
59 |
+
:return: The dataframe with the new columns added
|
60 |
+
:doc-author: Trelent
|
61 |
+
"""
|
62 |
+
print('Adding HSV encoding')
|
63 |
+
self.df['H1'] = self.df['top1col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[0])
|
64 |
+
self.df['H2'] = self.df['top2col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[0])
|
65 |
+
self.df['H3'] = self.df['top3col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[0])
|
66 |
+
|
67 |
+
self.df['S1'] = self.df['top1col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[1])
|
68 |
+
self.df['S2'] = self.df['top2col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[1])
|
69 |
+
self.df['S3'] = self.df['top3col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[1])
|
70 |
+
|
71 |
+
self.df['V1'] = self.df['top1col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[2])
|
72 |
+
self.df['V2'] = self.df['top2col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[2])
|
73 |
+
self.df['V3'] = self.df['top3col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[2])
|
74 |
+
|
75 |
+
def get_s_space(self):
|
76 |
+
"""
|
77 |
+
The get_s_space function takes the w_vectors from the annotations dictionary and uses them to generate s_vectors.
|
78 |
+
The s_space is a space of vectors that are generated by passing each w vector through each layer of the model.
|
79 |
+
This allows us to see how much information about a particular class is contained in different layers.
|
80 |
+
|
81 |
+
:param self: Bind the method to a class
|
82 |
+
:return: A list of lists of s vectors
|
83 |
+
:doc-author: Trelent
|
84 |
+
"""
|
85 |
+
print('Getting S space from W')
|
86 |
+
ss = []
|
87 |
+
for w in tqdm(self.annotations['w_vectors']):
|
88 |
+
w_torch = torch.from_numpy(w).to(self.device)
|
89 |
+
W = w_torch.expand((16, -1)).unsqueeze(0)
|
90 |
+
s = []
|
91 |
+
for i,layer in enumerate(self.layers):
|
92 |
+
s.append(getattr(self.model.synthesis, layer).affine(W[0, i].unsqueeze(0)).numpy())
|
93 |
+
|
94 |
+
ss.append(s)
|
95 |
+
self.annotations['s_vectors'] = ss
|
96 |
+
annotations_file = join(self.repo_folder, 'data/textile_annotated_files/seeds0000-100000_S.pkl')
|
97 |
+
print('Storing s for future use here:', annotations_file)
|
98 |
+
with open(annotations_file, 'wb') as f:
|
99 |
+
pickle.dump(self.annotations, f)
|
100 |
+
|
101 |
+
def get_encoded_latent(self):
|
102 |
+
# ... (existing code for getX)
|
103 |
+
if self.space.lower() == 'w':
|
104 |
+
X = np.array(self.annotations['w_vectors']).reshape((len(self.annotations['w_vectors']), 512))
|
105 |
+
elif self.space.lower() == 'z':
|
106 |
+
X = np.array(self.annotations['z_vectors']).reshape((len(self.annotations['z_vectors']), 512))
|
107 |
+
elif self.space.lower() == 's':
|
108 |
+
concat_v = []
|
109 |
+
for i in range(len(self.annotations['w_vectors'])):
|
110 |
+
concat_v.append(np.concatenate(self.annotations['s_vectors'][i], axis=1))
|
111 |
+
X = np.array(concat_v)
|
112 |
+
X = X[:, 0, :]
|
113 |
+
else:
|
114 |
+
Exception("Sorry, option not available, select among Z, W, S")
|
115 |
+
|
116 |
+
print('Shape embedding:', X.shape)
|
117 |
+
return X
|
118 |
+
|
119 |
+
def get_train_val(self, var='H1', cat=True):
|
120 |
+
X = self.get_encoded_latent()
|
121 |
+
y = np.array(self.df[var].values)
|
122 |
+
if cat:
|
123 |
+
y_cat = pd.cut(y,
|
124 |
+
bins=[x*256/12 if x<12 else 256 for x in range(13)],
|
125 |
+
labels=self.colors_list
|
126 |
+
).fillna('Warm Pink Red')
|
127 |
+
x_train, x_val, y_train, y_val = train_test_split(X, y_cat, test_size=0.2)
|
128 |
+
else:
|
129 |
+
x_train, x_val, y_train, y_val = train_test_split(X, y, test_size=0.2)
|
130 |
+
return x_train, x_val, y_train, y_val
|
131 |
+
|
132 |
+
def InterFaceGAN_separation_vector(self, method='LR', C=0.1):
|
133 |
+
"""
|
134 |
+
Method from InterfaceGAN
|
135 |
+
The get_separation_space function takes in a type_bin, annotations, and df.
|
136 |
+
It then samples 100 of the most representative abstracts for that type_bin and 100 of the least representative abstracts for that type_bin.
|
137 |
+
It then trains an SVM or logistic regression model on these 200 samples to find a separation space between them.
|
138 |
+
The function returns this separation space as well as how many nodes are important in this separation space.
|
139 |
+
|
140 |
+
:param type_bin: Select the type of abstracts to be used for training
|
141 |
+
:param annotations: Access the z_vectors
|
142 |
+
:param df: Get the abstracts that are used for training
|
143 |
+
:param samples: Determine how many samples to take from the top and bottom of the distribution
|
144 |
+
:param method: Specify the classifier to use
|
145 |
+
:param C: Control the regularization strength
|
146 |
+
:return: The weights of the linear classifier
|
147 |
+
:doc-author: Trelent
|
148 |
+
"""
|
149 |
+
x_train, x_val, y_train, y_val = self.get_train_val()
|
150 |
+
|
151 |
+
if method == 'SVM':
|
152 |
+
svc = SVC(gamma='auto', kernel='linear', random_state=0, C=C)
|
153 |
+
svc.fit(x_train, y_train)
|
154 |
+
print('Val performance SVM', np.round(svc.score(x_val, y_val), 2))
|
155 |
+
return svc.coef_ / np.linalg.norm(clf.coef_)
|
156 |
+
elif method == 'LR':
|
157 |
+
clf = LogisticRegression(random_state=0, C=C)
|
158 |
+
clf.fit(x_train, y_train)
|
159 |
+
print('Val performance logistic regression', np.round(clf.score(x_val, y_val), 2))
|
160 |
+
return clf.coef_ / np.linalg.norm(clf.coef_)
|
161 |
+
|
162 |
+
def get_original_position_latent(self, positive_idxs, negative_idxs):
|
163 |
+
# ... (existing code for get_original_pos)
|
164 |
+
separation_vectors = []
|
165 |
+
for i in range(len(self.colors_list)):
|
166 |
+
if self.space.lower() == 's':
|
167 |
+
current_idx = 0
|
168 |
+
vectors = []
|
169 |
+
for j, (leng, layer) in enumerate(zip(self.layers_shapes, self.layers)):
|
170 |
+
arr = np.zeros(leng)
|
171 |
+
for positive_idx in positive_idxs[i]:
|
172 |
+
if positive_idx >= current_idx and positive_idx < current_idx + leng:
|
173 |
+
arr[positive_idx - current_idx] = 1
|
174 |
+
for negative_idx in negative_idxs[i]:
|
175 |
+
if negative_idx >= current_idx and negative_idx < current_idx + leng:
|
176 |
+
arr[negative_idx - current_idx] = 1
|
177 |
+
arr = arr / (np.linalg.norm(arr) + 0.000001)
|
178 |
+
vectors.append(arr)
|
179 |
+
current_idx += leng
|
180 |
+
elif self.space.lower() == 'z' or self.space.lower() == 'w':
|
181 |
+
vectors = np.zeros(512)
|
182 |
+
vectors[positive_idxs[i]] = 1
|
183 |
+
vectors[negative_idxs[i]] = -1
|
184 |
+
vectors = vectors / (np.linalg.norm(vectors) + 0.000001)
|
185 |
+
else:
|
186 |
+
raise Exception("""This space is not allowed in this function,
|
187 |
+
select among Z, W, S""")
|
188 |
+
separation_vectors.append(vectors)
|
189 |
+
|
190 |
+
return separation_vectors
|
191 |
+
|
192 |
+
def StyleSpace_separation_vector(self, sign=True, num_factors=20, cutout=0.25):
|
193 |
+
""" Formula from StyleSpace Analysis """
|
194 |
+
x_train, x_val, y_train, y_val = self.get_train_val()
|
195 |
+
|
196 |
+
positive_idxs = []
|
197 |
+
negative_idxs = []
|
198 |
+
for color in self.colors_list:
|
199 |
+
x_col = x_train[np.where(y_train == color)]
|
200 |
+
mp = np.mean(x_train, axis=0)
|
201 |
+
sp = np.std(x_train, axis=0)
|
202 |
+
de = (x_col - mp) / sp
|
203 |
+
meu = np.mean(de, axis=0)
|
204 |
+
seu = np.std(de, axis=0)
|
205 |
+
if sign:
|
206 |
+
thetau = meu / seu
|
207 |
+
positive_idx = np.argsort(thetau)[-num_factors//2:]
|
208 |
+
negative_idx = np.argsort(thetau)[:num_factors//2]
|
209 |
+
|
210 |
+
else:
|
211 |
+
thetau = np.abs(meu) / seu
|
212 |
+
positive_idx = np.argsort(thetau)[-num_factors:]
|
213 |
+
negative_idx = []
|
214 |
+
|
215 |
+
|
216 |
+
if cutout:
|
217 |
+
beyond_cutout = np.where(np.abs(thetau) > cutout)
|
218 |
+
positive_idx = np.intersect1d(positive_idx, beyond_cutout)
|
219 |
+
negative_idx = np.intersect1d(negative_idx, beyond_cutout)
|
220 |
+
|
221 |
+
if len(positive_idx) == 0 and len(negative_idx) == 0:
|
222 |
+
print('No values found above the current cutout', cutout, 'for color', color, '.\n Disentangled vector will be all zeros.' )
|
223 |
+
|
224 |
+
positive_idxs.append(positive_idx)
|
225 |
+
negative_idxs.append(negative_idx)
|
226 |
+
|
227 |
+
separation_vectors = self.get_original_position_latent(positive_idxs, negative_idxs)
|
228 |
+
return separation_vectors
|
229 |
+
|
230 |
+
def GANSpace_separation_vectors(self, num_components):
|
231 |
+
x_train, x_val, y_train, y_val = self.get_train_val()
|
232 |
+
if self.space.lower() == 'w':
|
233 |
+
pca = PCA(n_components=num_components)
|
234 |
+
|
235 |
+
dims_pca = pca.fit_transform(x_train.T)
|
236 |
+
dims_pca /= np.linalg.norm(dims_pca, axis=0)
|
237 |
+
|
238 |
+
return dims_pca
|
239 |
+
|
240 |
+
else:
|
241 |
+
raise("""This space is not allowed in this function,
|
242 |
+
only W""")
|
243 |
+
|
244 |
+
def generate_images(self, seed, separation_vector=None, lambd=0):
|
245 |
+
"""
|
246 |
+
The generate_original_image function takes in a latent vector and the model,
|
247 |
+
and returns an image generated from that latent vector.
|
248 |
+
|
249 |
+
|
250 |
+
:param z: Generate the image
|
251 |
+
:param model: Generate the image
|
252 |
+
:return: A pil image
|
253 |
+
:doc-author: Trelent
|
254 |
+
"""
|
255 |
+
G = self.model.to(self.device) # type: ignore
|
256 |
+
# Labels.
|
257 |
+
label = torch.zeros([1, G.c_dim], device=self.device)
|
258 |
+
if self.space.lower() == 'z':
|
259 |
+
vec = self.annotations['z_vectors'][seed]
|
260 |
+
Z = torch.from_numpy(vec.copy()).to(self.device)
|
261 |
+
if separation_vector is not None:
|
262 |
+
change = torch.from_numpy(separation_vector.copy()).unsqueeze(0).to(self.device)
|
263 |
+
Z = torch.add(Z, change, alpha=lambd)
|
264 |
+
img = G(Z, label, truncation_psi=1, noise_mode='const')
|
265 |
+
elif self.space.lower() == 'w':
|
266 |
+
vec = self.annotations['w_vectors'][seed]
|
267 |
+
W = torch.from_numpy(np.repeat(vec, self.decoding_layers, axis=0)
|
268 |
+
.reshape(1, self.decoding_layers, vec.shape[1]).copy()).to(self.device)
|
269 |
+
if separation_vector is not None:
|
270 |
+
change = torch.from_numpy(separation_vector.copy()).unsqueeze(0).to(self.device)
|
271 |
+
W = torch.add(W, change, alpha=lambd)
|
272 |
+
img = G.synthesis(W, noise_mode='const')
|
273 |
+
else:
|
274 |
+
raise Exception("""This space is not allowed in this function,
|
275 |
+
select either W or Z or use generate_flexible_images""")
|
276 |
+
|
277 |
+
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
278 |
+
return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
|
279 |
+
|
280 |
+
def forward_from_style(self, x, styles, layer):
|
281 |
+
dtype = torch.float16 if (getattr(self.model.synthesis, layer).use_fp16 and self.device=='cuda') else torch.float32
|
282 |
+
|
283 |
+
if getattr(self.model.synthesis, layer).is_torgb:
|
284 |
+
weight_gain = 1 / np.sqrt(getattr(self.model.synthesis, layer).in_channels * (getattr(self.model.synthesis, layer).conv_kernel ** 2))
|
285 |
+
styles = styles * weight_gain
|
286 |
+
|
287 |
+
input_gain = getattr(self.model.synthesis, layer).magnitude_ema.rsqrt().to(dtype)
|
288 |
+
|
289 |
+
# Execute modulated conv2d.
|
290 |
+
x = modulated_conv2d(x=x.to(dtype), w=getattr(self.model.synthesis, layer).weight.to(dtype), s=styles.to(dtype),
|
291 |
+
padding=getattr(self.model.synthesis, layer).conv_kernel-1,
|
292 |
+
demodulate=(not getattr(self.model.synthesis, layer).is_torgb),
|
293 |
+
input_gain=input_gain.to(dtype))
|
294 |
+
|
295 |
+
# Execute bias, filtered leaky ReLU, and clamping.
|
296 |
+
gain = 1 if getattr(self.model.synthesis, layer).is_torgb else np.sqrt(2)
|
297 |
+
slope = 1 if getattr(self.model.synthesis, layer).is_torgb else 0.2
|
298 |
+
|
299 |
+
x = filtered_lrelu.filtered_lrelu(x=x, fu=getattr(self.model.synthesis, layer).up_filter, fd=getattr(self.model.synthesis, layer).down_filter,
|
300 |
+
b=getattr(self.model.synthesis, layer).bias.to(x.dtype),
|
301 |
+
up=getattr(self.model.synthesis, layer).up_factor, down=getattr(self.model.synthesis, layer).down_factor,
|
302 |
+
padding=getattr(self.model.synthesis, layer).padding,
|
303 |
+
gain=gain, slope=slope, clamp=getattr(self.model.synthesis, layer).conv_clamp)
|
304 |
+
return x
|
305 |
+
|
306 |
+
def generate_flexible_images(self, seed, separation_vector=None, lambd=0):
|
307 |
+
if self.space.lower() != 's':
|
308 |
+
raise Exception("""This space is not allowed in this function,
|
309 |
+
select S or use generate_images""")
|
310 |
+
|
311 |
+
vec = self.annotations['w_vectors'][seed]
|
312 |
+
w_torch = torch.from_numpy(vec).to(self.device)
|
313 |
+
W = w_torch.expand((self.decoding_layers, -1)).unsqueeze(0)
|
314 |
+
x = self.model.synthesis.input(W[0,0].unsqueeze(0))
|
315 |
+
for i, layer in enumerate(self.layers[1:]):
|
316 |
+
style = getattr(self.model.synthesis, layer).affine(W[0, i].unsqueeze(0))
|
317 |
+
if separation_vector is not None:
|
318 |
+
change = torch.from_numpy(separation_vector[i+1].copy()).unsqueeze(0).to(self.device)
|
319 |
+
style = torch.add(style, change, alpha=lambd)
|
320 |
+
x = self.forward_from_style(x, style, layer)
|
321 |
+
|
322 |
+
if self.model.synthesis.output_scale != 1:
|
323 |
+
x = x * self.model.synthesis.output_scale
|
324 |
+
|
325 |
+
img = (x.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
326 |
+
img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
|
327 |
+
|
328 |
+
return img
|
329 |
+
|
330 |
+
def generate_changes(self, seed, separation_vector, min_epsilon=-3, max_epsilon=3, count=5, savefig=True, feature=None, method=None):
|
331 |
+
"""
|
332 |
+
The regenerate_images function takes a model, z, and decision_boundary as input. It then
|
333 |
+
constructs an inverse rotation/translation matrix and passes it to the generator. The generator
|
334 |
+
expects this matrix as an inverse to avoid potentially failing numerical operations in the network.
|
335 |
+
The function then generates images using G(z_0, label) where z_0 is a linear combination of z and the decision boundary.
|
336 |
+
|
337 |
+
:param model: Pass in the model to be used for image generation
|
338 |
+
:param z: Generate the starting point of the line
|
339 |
+
:param decision_boundary: Generate images along the direction of the decision boundary
|
340 |
+
:param min_epsilon: Set the minimum value of lambda
|
341 |
+
:param max_epsilon: Set the maximum distance from the original image to generate
|
342 |
+
:param count: Determine the number of images that are generated
|
343 |
+
:return: A list of images and a list of lambdas
|
344 |
+
:doc-author: Trelent
|
345 |
+
"""
|
346 |
+
|
347 |
+
os.makedirs(join(self.repo_folder, 'figures'), exist_ok=True)
|
348 |
+
lambdas = np.linspace(min_epsilon, max_epsilon, count)
|
349 |
+
images = []
|
350 |
+
# Generate images.
|
351 |
+
for _, lambd in enumerate(tqdm(lambdas)):
|
352 |
+
if self.space.lower() == 's':
|
353 |
+
images.append(self.generate_flexible_images(seed, separation_vector=separation_vector, lambd=lambd))
|
354 |
+
elif self.space.lower() in ['z', 'w']:
|
355 |
+
images.append(self.generate_images(seed, separation_vector=separation_vector, lambd=lambd))
|
356 |
+
|
357 |
+
if savefig:
|
358 |
+
print('Generating image for color', feature)
|
359 |
+
fig, axs = plt.subplots(1, len(images), figsize=(90,20))
|
360 |
+
title = 'Disentanglement method: '+ method + ', on feature: ' + feature + ' on space: ' + self.space + ', image seed: ' + str(seed)
|
361 |
+
name = '_'.join([method, feature, self.space, str(seed), str(lambdas[-1])])
|
362 |
+
fig.suptitle(title, fontsize=20)
|
363 |
+
|
364 |
+
for i, (image, lambd) in enumerate(zip(images, lambdas)):
|
365 |
+
axs[i].imshow(image)
|
366 |
+
axs[i].set_title(np.round(lambd, 2))
|
367 |
+
plt.tight_layout()
|
368 |
+
plt.savefig(join(self.repo_folder, 'figures', name+'.jpg'))
|
369 |
+
return images, lambdas
|
370 |
+
|
371 |
+
def get_verification_score(self, separation_vector, feature_id, samples=10, lambd=1, savefig=False, feature=None, method=None):
|
372 |
+
items = random.sample(range(100000), samples)
|
373 |
+
hue_low = feature_id * 256 / 12
|
374 |
+
hue_high = (feature_id + 1) * 256 / 12
|
375 |
+
|
376 |
+
matches = 0
|
377 |
+
|
378 |
+
for seed in tqdm(items):
|
379 |
+
images, lambdas = self.generate_changes(seed, separation_vector, min_epsilon=-lambd, max_epsilon=lambd, count=3, savefig=savefig, feature=feature, method=method)
|
380 |
+
colors_negative = extract_color(images[0], 5, 1, None)
|
381 |
+
h0, s0, v0 = ImageColor.getcolor(colors_negative[0], 'HSV')
|
382 |
+
|
383 |
+
colors_orig = extract_color(images[1], 5, 1, None)
|
384 |
+
h1, s1, v1 = ImageColor.getcolor(colors_orig[0], 'HSV')
|
385 |
+
|
386 |
+
colors_positive = extract_color(images[2], 5, 1, None)
|
387 |
+
h2, s2, v2 = ImageColor.getcolor(colors_positive[0], 'HSV')
|
388 |
+
|
389 |
+
if h1 > hue_low and h1 < hue_high:
|
390 |
+
samples -= 1
|
391 |
+
else:
|
392 |
+
if (h0 > hue_low and h0 < hue_high) or (h2 > hue_low and h2 < hue_high):
|
393 |
+
matches += 1
|
394 |
+
|
395 |
+
return np.round(matches / samples, 2)
|
396 |
+
|
397 |
+
|
398 |
+
def main():
|
399 |
+
repo_folder = '.'
|
400 |
+
annotations_file = join(repo_folder, 'data/textile_annotated_files/seeds0000-100000_S.pkl')
|
401 |
+
with open(annotations_file, 'rb') as f:
|
402 |
+
annotations = pickle.load(f)
|
403 |
+
|
404 |
+
df_file = join(repo_folder, 'data/textile_annotated_files/top_three_colours.csv')
|
405 |
+
df = pd.read_csv(df_file).fillna('#000000')
|
406 |
+
|
407 |
+
model_file = join(repo_folder, 'data/textile_model_files/network-snapshot-005000.pkl')
|
408 |
+
with dnnlib.util.open_url(model_file) as f:
|
409 |
+
model = legacy.load_network_pkl(f)['G_ema'] # type: ignore
|
410 |
+
|
411 |
+
colors_list = ['Warm Pink Red', 'Red Orange', 'Orange Yellow', 'Gold Yellow', 'Chartreuse Green',
|
412 |
+
'Kelly Green', 'Green Blue Seafoam', 'Blue Green Cyan',
|
413 |
+
'Warm Blue', 'Indigo Blue Purple', 'Purple Magenta', 'Magenta Pink']
|
414 |
+
|
415 |
+
scores = []
|
416 |
+
kwargs = {'CL method':['LR', 'SVM'], 'C':[0.1, 1], 'sign':[True, False], 'num_factors':[1, 10, 20, 50], 'cutout': [None, 0.2], 'max_lambda':[6, 10, 1], 'samples':50, 'lambda_verif':[1, 3, 6]}
|
417 |
+
|
418 |
+
for space in ['w', 'z', 's']:
|
419 |
+
print('Launching experiment with space:', space)
|
420 |
+
disentanglemnet_exp = DisentanglementBase(repo_folder, model, annotations, df, space=space, colors_list=colors_list, compute_s=False)
|
421 |
+
|
422 |
+
for method in ['StyleSpace', 'InterFaceGAN', 'GANSpace']:
|
423 |
+
if space != 's' and method == 'InterFaceGAN':
|
424 |
+
print('Now obtaining separation vector for using InterfaceGAN')
|
425 |
+
for met in kwargs['CL method']:
|
426 |
+
for c in kwargs['C']:
|
427 |
+
separation_vectors = disentanglemnet_exp.InterFaceGAN_separation_vector(method=met, C=c)
|
428 |
+
for i, color in enumerate(colors_list):
|
429 |
+
print('Generating images with variations')
|
430 |
+
seed = random.randint(0,100000)
|
431 |
+
for eps in kwargs['max_lambda']:
|
432 |
+
disentanglemnet_exp.generate_changes(seed, separation_vectors[i], min_epsilon=-eps, max_epsilon=eps, savefig=True, feature=color, method=method)
|
433 |
+
|
434 |
+
print('Finally obtaining verification score')
|
435 |
+
for verif in kwargs['lambda_verif']:
|
436 |
+
score = disentanglemnet_exp.get_verification_score(separation_vectors[i], i, samples=kwargs['samples'], lambd=verif, savefig=True, feature=color, method=method)
|
437 |
+
print('Score for method', method, 'on space', space, 'for color', color, ':', score)
|
438 |
+
|
439 |
+
scores.append([space, method, color, score, 'classification method:' + met + ', regularization: ' + str(c) + ', verification lambda:' + str(verif)])
|
440 |
+
|
441 |
+
elif method == 'StyleSpace':
|
442 |
+
print('Now obtaining separation vector for using StyleSpace')
|
443 |
+
for sign in kwargs['sign']:
|
444 |
+
for num_factors in kwargs['num_factors']:
|
445 |
+
for cutout in kwargs['cutout']:
|
446 |
+
separation_vectors = disentanglemnet_exp.StyleSpace_separation_vector(sign=sign, num_factors=num_factors, cutout=cutout)
|
447 |
+
for i, color in enumerate(colors_list):
|
448 |
+
print('Generating images with variations')
|
449 |
+
seed = random.randint(0,100000)
|
450 |
+
for eps in kwargs['max_lambda']:
|
451 |
+
disentanglemnet_exp.generate_changes(seed, separation_vectors[i], min_epsilon=-eps, max_epsilon=eps, savefig=True, feature=color, method=method)
|
452 |
+
|
453 |
+
print('Finally obtaining verification score')
|
454 |
+
for verif in kwargs['lambda_verif']:
|
455 |
+
score = disentanglemnet_exp.get_verification_score(separation_vectors[i], i, samples=kwargs['samples'], lambd=verif, savefig=True, feature=color, method=method)
|
456 |
+
print('Score for method', method, 'on space', space, 'for color', color, ':', score)
|
457 |
+
|
458 |
+
scores.append([space, method, color, score, 'using sign:' + str(sign) + ', number of factors: ' + str(num_factors) + ', using cutout: ' + str(cutout) + ', verification lambda:' + str(verif)])
|
459 |
+
|
460 |
+
if space == 'w' and method == 'GANSpace':
|
461 |
+
print('Now obtaining separation vector for using GANSpace')
|
462 |
+
separation_vectors = disentanglemnet_exp.GANSpace_separation_vectors(100)
|
463 |
+
for i in range(100):
|
464 |
+
print('Generating images with variations')
|
465 |
+
seed = random.randint(0,100000)
|
466 |
+
for eps in kwargs['max_lambda']:
|
467 |
+
disentanglemnet_exp.generate_changes(seed, separation_vectors[i], min_epsilon=-eps, max_epsilon=eps, savefig=True, feature=color, method=method)
|
468 |
+
|
469 |
+
score = None
|
470 |
+
scores.append([space, method, color, score, '100'])
|
471 |
+
else:
|
472 |
+
print('Skipping', method, 'on space', space)
|
473 |
+
continue
|
474 |
+
|
475 |
+
|
476 |
+
|
477 |
+
score_df = pd.DataFrame(scores, columns=['space', 'method', 'color', 'score', 'kwargs'])
|
478 |
+
print(score_df)
|
479 |
+
score_df.to_csv(join(repo_folder, 'data/scores.csv'))
|
480 |
+
|
481 |
+
if __name__ == "__main__":
|
482 |
+
main()
|
backend/color_annotations.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
"""Extract color features from the generated textile images."""
|
4 |
+
|
5 |
+
import os
|
6 |
+
#os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
7 |
+
from tqdm import tqdm
|
8 |
+
#from transformers import pipeline
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import time
|
12 |
+
|
13 |
+
import click
|
14 |
+
from PIL import Image
|
15 |
+
import math
|
16 |
+
import pickle
|
17 |
+
from glob import glob
|
18 |
+
|
19 |
+
import matplotlib.pyplot as plt
|
20 |
+
import matplotlib.patches as patches
|
21 |
+
import matplotlib.image as mpimg
|
22 |
+
import cv2
|
23 |
+
import extcolors
|
24 |
+
|
25 |
+
from colormap import rgb2hex
|
26 |
+
from PIL import Image
|
27 |
+
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
|
28 |
+
|
29 |
+
def color_to_df(input):
|
30 |
+
colors_pre_list = str(input).replace('([(','').split(', (')[0:-1]
|
31 |
+
df_rgb = [i.split('), ')[0] + ')' for i in colors_pre_list]
|
32 |
+
df_percent = [i.split('), ')[1].replace(')','') for i in colors_pre_list]
|
33 |
+
|
34 |
+
#convert RGB to HEX code
|
35 |
+
df_color_up = [rgb2hex(int(i.split(", ")[0].replace("(","")),
|
36 |
+
int(i.split(", ")[1]),
|
37 |
+
int(i.split(", ")[2].replace(")",""))) for i in df_rgb]
|
38 |
+
|
39 |
+
df = pd.DataFrame(zip(df_color_up, df_percent), columns = ['c_code','occurence'])
|
40 |
+
return df
|
41 |
+
|
42 |
+
def extract_color(input_image, tolerance, zoom, outpath, save=None):
|
43 |
+
colors_x = extcolors.extract_from_image(input_image, tolerance = tolerance, limit = 13)
|
44 |
+
df_color = color_to_df(colors_x)
|
45 |
+
|
46 |
+
#annotate text
|
47 |
+
list_color = list(df_color['c_code'])
|
48 |
+
list_precent = [int(i) for i in list(df_color['occurence'])]
|
49 |
+
text_c = [c + ' ' + str(round(p*100/sum(list_precent),1)) +'%' for c, p in zip(list_color, list_precent)]
|
50 |
+
colors = list(df_color['c_code'])
|
51 |
+
if '#000000' in colors:
|
52 |
+
colors.remove('#000000')
|
53 |
+
return colors[:3]
|
54 |
+
|
55 |
+
|
56 |
+
@click.command()
|
57 |
+
@click.option('--genimages_dir', help='Where the output images are saved', type=str, required=True, metavar='DIR')
|
58 |
+
|
59 |
+
def annotate_textile_images(
|
60 |
+
genimages_dir: str,
|
61 |
+
|
62 |
+
):
|
63 |
+
"""Produce annotations for the generated images.
|
64 |
+
\b
|
65 |
+
#
|
66 |
+
python annotate_textiles.py --genimages_dir /home/ludosc/data/stylegan-10000-textile-upscale
|
67 |
+
"""
|
68 |
+
colours = []
|
69 |
+
pickle_files = glob(genimages_dir + '/imgs0000*.pkl')
|
70 |
+
for pickle_file in pickle_files:
|
71 |
+
print('Using pickle file: ', pickle_file)
|
72 |
+
with open(pickle_file, 'rb') as f:
|
73 |
+
info = pickle.load(f)
|
74 |
+
|
75 |
+
listlen = len(info['fname'])
|
76 |
+
os.makedirs('/data/ludosc/colour_palettes/', exist_ok=True)
|
77 |
+
for i,im in enumerate(tqdm(info['fname'])):
|
78 |
+
try:
|
79 |
+
top_cols = exact_color(im, 12, 5, '/data/ludosc/colour_palettes/' + im.split('/')[-1])
|
80 |
+
colours.append([im]+top_cols)
|
81 |
+
except Exception as e:
|
82 |
+
print(e)
|
83 |
+
if i % 1000 == 0:
|
84 |
+
df = pd.DataFrame(colours, columns=['fname', 'top1col', 'top2col', 'top3col'])
|
85 |
+
print(df.head())
|
86 |
+
df.to_csv(genimages_dir + f'/top_three_colours.csv', index=False)
|
87 |
+
|
88 |
+
df = pd.DataFrame(colours, columns=['fname', 'top1col', 'top2col', 'top3col'])
|
89 |
+
print(df.head())
|
90 |
+
df.to_csv(genimages_dir + f'/final_sim_{os.path.basename(pickle_file.split(".")[0])}.csv', index=False)
|
91 |
+
|
92 |
+
#----------------------------------------------------------------------------
|
93 |
+
|
94 |
+
if __name__ == "__main__":
|
95 |
+
annotate_textile_images() # pylint: disable=no-value-for-parameter
|
96 |
+
|
97 |
+
#----------------------------------------------------------------------------
|
backend/disentangle_concepts.py
CHANGED
@@ -5,6 +5,13 @@ from sklearn.model_selection import train_test_split
|
|
5 |
import torch
|
6 |
from umap import UMAP
|
7 |
import PIL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
def get_separation_space(type_bin, annotations, df, samples=200, method='LR', C=0.1, latent_space='Z'):
|
10 |
"""
|
@@ -65,7 +72,7 @@ def get_separation_space(type_bin, annotations, df, samples=200, method='LR', C=
|
|
65 |
return clf.coef_ / np.linalg.norm(clf.coef_), imp_features, imp_nodes, np.round(clf.score(x_val, y_val),2)
|
66 |
|
67 |
|
68 |
-
def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=5, latent_space='Z', layers=None):
|
69 |
"""
|
70 |
The regenerate_images function takes a model, z, and decision_boundary as input. It then
|
71 |
constructs an inverse rotation/translation matrix and passes it to the generator. The generator
|
@@ -92,19 +99,18 @@ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3
|
|
92 |
z = torch.from_numpy(z.copy()).to(device)
|
93 |
decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
|
94 |
|
|
|
95 |
lambdas = np.linspace(min_epsilon, max_epsilon, count)
|
96 |
images = []
|
97 |
# Generate images.
|
98 |
-
for _, lambda_ in enumerate(lambdas):
|
99 |
z_0 = z + lambda_ * decision_boundary
|
100 |
if latent_space == 'Z':
|
101 |
W_0 = G.mapping(z_0, label, truncation_psi=1).to(torch.float32)
|
102 |
W = G.mapping(z, label, truncation_psi=1).to(torch.float32)
|
103 |
-
print(W.dtype)
|
104 |
else:
|
105 |
-
W_0 = z_0.expand((
|
106 |
-
W = z.expand((
|
107 |
-
print(W.dtype)
|
108 |
|
109 |
if layers:
|
110 |
W_f = torch.empty_like(W).copy_(W).to(torch.float32)
|
@@ -117,14 +123,14 @@ def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3
|
|
117 |
images.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
|
118 |
|
119 |
return images, lambdas
|
120 |
-
|
121 |
-
|
122 |
def generate_joint_effect(model, z, decision_boundaries, min_epsilon=-3, max_epsilon=3, count=5, latent_space='Z'):
|
123 |
decision_boundary_joint = np.sum(decision_boundaries, axis=0)
|
124 |
print(decision_boundary_joint.shape)
|
125 |
return regenerate_images(model, z, decision_boundary_joint, min_epsilon=min_epsilon, max_epsilon=max_epsilon, count=count, latent_space=latent_space)
|
126 |
|
127 |
-
def generate_original_image(z, model, latent_space='Z'):
|
128 |
"""
|
129 |
The generate_original_image function takes in a latent vector and the model,
|
130 |
and returns an image generated from that latent vector.
|
@@ -135,6 +141,8 @@ def generate_original_image(z, model, latent_space='Z'):
|
|
135 |
:return: A pil image
|
136 |
:doc-author: Trelent
|
137 |
"""
|
|
|
|
|
138 |
device = torch.device('cpu')
|
139 |
G = model.to(device) # type: ignore
|
140 |
# Labels.
|
@@ -143,10 +151,10 @@ def generate_original_image(z, model, latent_space='Z'):
|
|
143 |
z = torch.from_numpy(z.copy()).to(device)
|
144 |
img = G(z, label, truncation_psi=1, noise_mode='const')
|
145 |
else:
|
146 |
-
W = torch.from_numpy(np.repeat(z,
|
147 |
print(W.shape)
|
148 |
img = G.synthesis(W, noise_mode='const')
|
149 |
-
|
150 |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
151 |
return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
|
152 |
|
@@ -188,8 +196,36 @@ def get_concepts_vectors(concepts, annotations, df, samples=100, method='LR', C=
|
|
188 |
return vectors, nodes_in_common, performances
|
189 |
|
190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
|
192 |
-
def
|
193 |
import open_clip
|
194 |
import os
|
195 |
import random
|
@@ -243,5 +279,177 @@ def get_verification_score(concept, decision_boundary, model, annotations, sampl
|
|
243 |
return np.round(np.mean(np.array(changes)), 4)
|
244 |
|
245 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
246 |
|
247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import torch
|
6 |
from umap import UMAP
|
7 |
import PIL
|
8 |
+
from tqdm import tqdm
|
9 |
+
import random
|
10 |
+
from PIL import Image, ImageColor
|
11 |
+
|
12 |
+
from .color_annotations import extract_color
|
13 |
+
|
14 |
+
|
15 |
|
16 |
def get_separation_space(type_bin, annotations, df, samples=200, method='LR', C=0.1, latent_space='Z'):
|
17 |
"""
|
|
|
72 |
return clf.coef_ / np.linalg.norm(clf.coef_), imp_features, imp_nodes, np.round(clf.score(x_val, y_val),2)
|
73 |
|
74 |
|
75 |
+
def regenerate_images(model, z, decision_boundary, min_epsilon=-3, max_epsilon=3, count=5, latent_space='Z', layers=None, number=3):
|
76 |
"""
|
77 |
The regenerate_images function takes a model, z, and decision_boundary as input. It then
|
78 |
constructs an inverse rotation/translation matrix and passes it to the generator. The generator
|
|
|
99 |
z = torch.from_numpy(z.copy()).to(device)
|
100 |
decision_boundary = torch.from_numpy(decision_boundary.copy()).to(device)
|
101 |
|
102 |
+
repetitions = 16 if number == 3 else 14
|
103 |
lambdas = np.linspace(min_epsilon, max_epsilon, count)
|
104 |
images = []
|
105 |
# Generate images.
|
106 |
+
for _, lambda_ in enumerate(tqdm(lambdas)):
|
107 |
z_0 = z + lambda_ * decision_boundary
|
108 |
if latent_space == 'Z':
|
109 |
W_0 = G.mapping(z_0, label, truncation_psi=1).to(torch.float32)
|
110 |
W = G.mapping(z, label, truncation_psi=1).to(torch.float32)
|
|
|
111 |
else:
|
112 |
+
W_0 = z_0.expand((repetitions, -1)).unsqueeze(0).to(torch.float32)
|
113 |
+
W = z.expand((repetitions, -1)).unsqueeze(0).to(torch.float32)
|
|
|
114 |
|
115 |
if layers:
|
116 |
W_f = torch.empty_like(W).copy_(W).to(torch.float32)
|
|
|
123 |
images.append(PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB'))
|
124 |
|
125 |
return images, lambdas
|
126 |
+
|
127 |
+
|
128 |
def generate_joint_effect(model, z, decision_boundaries, min_epsilon=-3, max_epsilon=3, count=5, latent_space='Z'):
|
129 |
decision_boundary_joint = np.sum(decision_boundaries, axis=0)
|
130 |
print(decision_boundary_joint.shape)
|
131 |
return regenerate_images(model, z, decision_boundary_joint, min_epsilon=min_epsilon, max_epsilon=max_epsilon, count=count, latent_space=latent_space)
|
132 |
|
133 |
+
def generate_original_image(z, model, latent_space='Z', number=3):
|
134 |
"""
|
135 |
The generate_original_image function takes in a latent vector and the model,
|
136 |
and returns an image generated from that latent vector.
|
|
|
141 |
:return: A pil image
|
142 |
:doc-author: Trelent
|
143 |
"""
|
144 |
+
repetitions = 16 if number == 3 else 14
|
145 |
+
|
146 |
device = torch.device('cpu')
|
147 |
G = model.to(device) # type: ignore
|
148 |
# Labels.
|
|
|
151 |
z = torch.from_numpy(z.copy()).to(device)
|
152 |
img = G(z, label, truncation_psi=1, noise_mode='const')
|
153 |
else:
|
154 |
+
W = torch.from_numpy(np.repeat(z, repetitions, axis=0).reshape(1, repetitions, z.shape[1]).copy()).to(device)
|
155 |
print(W.shape)
|
156 |
img = G.synthesis(W, noise_mode='const')
|
157 |
+
|
158 |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
159 |
return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
|
160 |
|
|
|
196 |
return vectors, nodes_in_common, performances
|
197 |
|
198 |
|
199 |
+
def get_verification_score(color_id, decision_boundary, model, annotations, samples=100, latent_space='W'):
|
200 |
+
listlen = len(annotations['fname'])
|
201 |
+
items = random.sample(range(listlen), samples)
|
202 |
+
hue_low = color_id * 256 / 12
|
203 |
+
hue_high = (color_id + 1) * 256 / 12
|
204 |
+
hue_mean = (hue_low + hue_high) / 2
|
205 |
+
print(int(hue_low), int(hue_high), int(hue_mean))
|
206 |
+
distances = []
|
207 |
+
distances_orig = []
|
208 |
+
for iterator in tqdm(items):
|
209 |
+
if latent_space == 'Z':
|
210 |
+
z = annotations['z_vectors'][iterator]
|
211 |
+
else:
|
212 |
+
z = annotations['w_vectors'][iterator]
|
213 |
+
|
214 |
+
images, lambdas = regenerate_images(model, z, decision_boundary, min_epsilon=0, max_epsilon=1, count=2, latent_space=latent_space)
|
215 |
+
colors_orig = extract_color(images[0], 5, 1, None)
|
216 |
+
h_old, s_old, v_old = ImageColor.getcolor(colors_orig[0], 'HSV')
|
217 |
+
colors_new = extract_color(images[1], 5, 1, None)
|
218 |
+
h_new, s_new, v_new = ImageColor.getcolor(colors_new[0], 'HSV')
|
219 |
+
print(h_old, h_new)
|
220 |
+
distance = np.abs(hue_mean - h_new)
|
221 |
+
distances.append(distance)
|
222 |
+
distance_orig = np.abs(hue_mean - h_old)
|
223 |
+
distances_orig.append(distance_orig)
|
224 |
+
|
225 |
+
return np.round(np.mean(np.array(distances)), 4), np.round(np.mean(np.array(distances_orig)), 4)
|
226 |
+
|
227 |
|
228 |
+
def get_verification_score_clip(concept, decision_boundary, model, annotations, samples=100, latent_space='Z'):
|
229 |
import open_clip
|
230 |
import os
|
231 |
import random
|
|
|
279 |
return np.round(np.mean(np.array(changes)), 4)
|
280 |
|
281 |
|
282 |
+
|
283 |
+
def tohsv(df):
|
284 |
+
df['H1'] = df['top1col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[0])
|
285 |
+
df['H2'] = df['top2col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[0])
|
286 |
+
df['H3'] = df['top3col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[0])
|
287 |
+
|
288 |
+
df['S1'] = df['top1col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[1])
|
289 |
+
df['S2'] = df['top2col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[1])
|
290 |
+
df['S3'] = df['top3col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[1])
|
291 |
+
|
292 |
+
df['V1'] = df['top1col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[2])
|
293 |
+
df['V2'] = df['top2col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[2])
|
294 |
+
df['V3'] = df['top3col'].map(lambda x: ImageColor.getcolor(x, 'HSV')[2])
|
295 |
+
return df
|
296 |
|
297 |
+
|
298 |
+
def rest_from_style(x, styles, layer):
|
299 |
+
dtype = torch.float16 if (getattr(model.synthesis, layer).use_fp16 and device=='cuda') else torch.float32
|
300 |
+
if getattr(model.synthesis, layer).is_torgb:
|
301 |
+
print(layer, getattr(model.synthesis, layer).is_torgb)
|
302 |
+
weight_gain = 1 / np.sqrt(getattr(model.synthesis, layer).in_channels * (getattr(model.synthesis, layer).conv_kernel ** 2))
|
303 |
+
styles = styles * weight_gain
|
304 |
+
input_gain = getattr(model.synthesis, layer).magnitude_ema.rsqrt().to(dtype)
|
305 |
+
# Execute modulated conv2d.
|
306 |
+
x = modulated_conv2d(x=x.to(dtype), w=getattr(model.synthesis, layer).weight.to(dtype), s=styles.to(dtype),
|
307 |
+
padding=getattr(model.synthesis, layer).conv_kernel-1, demodulate=(not getattr(model.synthesis, layer).is_torgb), input_gain=input_gain.to(dtype))
|
308 |
+
# Execute bias, filtered leaky ReLU, and clamping.
|
309 |
+
gain = 1 if getattr(model.synthesis, layer).is_torgb else np.sqrt(2)
|
310 |
+
slope = 1 if getattr(model.synthesis, layer).is_torgb else 0.2
|
311 |
+
x = filtered_lrelu.filtered_lrelu(x=x, fu=getattr(model.synthesis, layer).up_filter, fd=getattr(model.synthesis, layer).down_filter,
|
312 |
+
b=getattr(model.synthesis, layer).bias.to(x.dtype),
|
313 |
+
up=getattr(model.synthesis, layer).up_factor, down=getattr(model.synthesis, layer).down_factor,
|
314 |
+
padding=getattr(model.synthesis, layer).padding,
|
315 |
+
gain=gain, slope=slope, clamp=getattr(model.synthesis, layer).conv_clamp)
|
316 |
+
return x
|
317 |
+
|
318 |
+
|
319 |
+
def getS(w):
|
320 |
+
w_torch = torch.from_numpy(w).to('cpu')
|
321 |
+
W = w_torch.expand((16, -1)).unsqueeze(0)
|
322 |
+
s = []
|
323 |
+
s.append(model.synthesis.input.affine(W[0, 0].unsqueeze(0)).numpy())
|
324 |
+
s.append(model.synthesis.L0_36_512.affine(W[0, 1].unsqueeze(0)).numpy())
|
325 |
+
s.append(model.synthesis.L1_36_512.affine(W[0, 2].unsqueeze(0)).numpy())
|
326 |
+
s.append(model.synthesis.L2_36_512.affine(W[0, 3].unsqueeze(0)).numpy())
|
327 |
+
s.append(model.synthesis.L3_52_512.affine(W[0, 4].unsqueeze(0)).numpy())
|
328 |
+
s.append(model.synthesis.L4_52_512.affine(W[0, 5].unsqueeze(0)).numpy())
|
329 |
+
s.append(model.synthesis.L5_84_512.affine(W[0, 6].unsqueeze(0)).numpy())
|
330 |
+
s.append(model.synthesis.L6_84_512.affine(W[0, 7].unsqueeze(0)).numpy())
|
331 |
+
s.append(model.synthesis.L7_148_512.affine(W[0, 8].unsqueeze(0)).numpy())
|
332 |
+
s.append(model.synthesis.L8_148_512.affine(W[0, 9].unsqueeze(0)).numpy())
|
333 |
+
s.append(model.synthesis.L9_148_362.affine(W[0, 10].unsqueeze(0)).numpy())
|
334 |
+
s.append(model.synthesis.L10_276_256.affine(W[0, 11].unsqueeze(0)).numpy())
|
335 |
+
s.append(model.synthesis.L11_276_181.affine(W[0, 12].unsqueeze(0)).numpy())
|
336 |
+
s.append(model.synthesis.L12_276_128.affine(W[0, 13].unsqueeze(0)).numpy())
|
337 |
+
s.append(model.synthesis.L13_256_128.affine(W[0, 14].unsqueeze(0)).numpy())
|
338 |
+
s.append(model.synthesis.L14_256_3.affine(W[0, 15].unsqueeze(0)).numpy())
|
339 |
+
return s
|
340 |
+
|
341 |
+
def detect_attribute_specific_channels(positives, all, sign=False):
|
342 |
+
""" Formula from StyleSpace Analysis """
|
343 |
+
mp = np.mean(all, axis=0)
|
344 |
+
sp = np.std(all, axis=0)
|
345 |
+
de = (positives - mp) / sp
|
346 |
+
meu = np.mean(de, axis=0)
|
347 |
+
seu = np.std(de, axis=0)
|
348 |
+
if sign:
|
349 |
+
thetau = meu / seu
|
350 |
+
else:
|
351 |
+
thetau = np.abs(meu) / seu
|
352 |
+
return thetau
|
353 |
+
|
354 |
+
def all_variance_based_disentanglements(labels, x, y, k=10, sign=False, cutout=0.28):
|
355 |
+
seps = []
|
356 |
+
sorted_vals = []
|
357 |
+
for lbl in labels:
|
358 |
+
positives = x[np.where(y == lbl)]
|
359 |
+
variations = detect_attribute_specific_channels(positives, x, sign=sign)
|
360 |
+
if sign:
|
361 |
+
argsorted_vars_pos = np.argsort(variations)[-k//2:]
|
362 |
+
# print(argsorted_vars_pos)
|
363 |
+
argsorted_vars_neg = np.argsort(variations)[:k//2]
|
364 |
+
if cutout:
|
365 |
+
beyond_cutout = np.where(np.abs(variations) > cutout)
|
366 |
+
# print(beyond_cutout)
|
367 |
+
argsorted_vars_pos_int = np.intersect1d(argsorted_vars_pos, beyond_cutout)
|
368 |
+
argsorted_vars_neg_int = np.intersect1d(argsorted_vars_neg, beyond_cutout)
|
369 |
+
# print(argsorted_vars_pos)
|
370 |
+
if len(argsorted_vars_neg_int) > 0:
|
371 |
+
argsorted_vars_neg = np.array(argsorted_vars_neg_int)
|
372 |
+
if len(argsorted_vars_pos_int) > 0:
|
373 |
+
argsorted_vars_pos = np.array(argsorted_vars_pos_int)
|
374 |
+
|
375 |
+
|
376 |
+
else:
|
377 |
+
argsorted_vars = np.argsort(variations)[-k:]
|
378 |
+
|
379 |
+
|
380 |
+
sorted_vals.append(np.sort(variations))
|
381 |
+
separation_vector_onehot /= np.linalg.norm(separation_vector_onehot)
|
382 |
+
seps.append(separation_vector_onehot)
|
383 |
+
return seps, sorted_vals
|
384 |
+
|
385 |
+
def generate_flexible_images(w, change_vectors, lambdas=1, device='cpu'):
|
386 |
+
w_torch = torch.from_numpy(w).to('cpu')
|
387 |
+
if len(change_vectors) != 17:
|
388 |
+
w_torch = w_torch + lambdas * change_vectors[0]
|
389 |
+
W = w_torch.expand((16, -1)).unsqueeze(0)
|
390 |
+
|
391 |
+
x = model.synthesis.input(W[0,0].unsqueeze(0))
|
392 |
+
for i, layer in enumerate(layers):
|
393 |
+
if i < 2:
|
394 |
+
continue
|
395 |
+
style = getattr(model.synthesis, layer).affine(W[0, i-1].unsqueeze(0))
|
396 |
+
if len(change_vectors) != 17:
|
397 |
+
change = torch.from_numpy(change_vectors[i].copy()).unsqueeze(0).to(device)
|
398 |
+
style = torch.add(style, change, alpha=lambdas)
|
399 |
+
x = rest_from_style(x, style, layer)
|
400 |
+
|
401 |
+
if model.synthesis.output_scale != 1:
|
402 |
+
x = x * model.synthesis.output_scale
|
403 |
+
|
404 |
+
img = (x.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
405 |
+
img = PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')
|
406 |
+
|
407 |
+
return img
|
408 |
+
|
409 |
+
def get_original_pos(top_positions, bottom_positions=None, space='s', sign=True,
|
410 |
+
shapes=[[512, 4, 512, 512, 512, 512, 512, 512, 512,
|
411 |
+
512, 512, 512, 362, 256, 181, 128, 128]],
|
412 |
+
layers=['w', 'input', 'L0_36_512', 'L1_36_512', 'L2_36_512', 'L3_52_512',
|
413 |
+
'L4_52_512', 'L5_84_512', 'L6_84_512', 'L7_148_512', 'L8_148_512',
|
414 |
+
'L9_148_362', 'L10_276_256', 'L11_276_181', 'L12_276_128',
|
415 |
+
'L13_256_128', 'L14_256_3'], ):
|
416 |
+
if space == 's':
|
417 |
+
current_idx = 0
|
418 |
+
vectors = []
|
419 |
+
for i, (leng, layer) in enumerate(zip(shapes, layers)):
|
420 |
+
arr = np.zeros(leng)
|
421 |
+
for top_position in top_positions:
|
422 |
+
if top_position >= current_idx and top_position < current_idx + leng:
|
423 |
+
arr[top_position - current_idx] = 1
|
424 |
+
for bottom_position in bottom_positions:
|
425 |
+
if sign:
|
426 |
+
if bottom_position >= current_idx and bottom_position < current_idx + leng:
|
427 |
+
arr[bottom_position - current_idx] = 1
|
428 |
+
arr = arr / (np.linalg.norm(arr) + 0.000001)
|
429 |
+
vectors.append(arr)
|
430 |
+
current_idx += leng
|
431 |
+
else:
|
432 |
+
if sign:
|
433 |
+
vectors = np.zeros(512)
|
434 |
+
vectors[top_positions] = 1
|
435 |
+
vectors[bottom_positions] = -1
|
436 |
+
else:
|
437 |
+
vectors = np.zeros(512)
|
438 |
+
vectors[top_positions] = 1
|
439 |
+
return vectors
|
440 |
+
|
441 |
+
def getX(annotations, space='s'):
|
442 |
+
if space == 'x':
|
443 |
+
X = np.array(annotations['w_vectors']).reshape((len(annotations['w_vectors']), 512))
|
444 |
+
elif space == 's':
|
445 |
+
concat_v = []
|
446 |
+
for i in range(len(annotations['w_vectors'])):
|
447 |
+
concat_v.append(np.concatenate([annotations['w_vectors'][i]] + annotations['s_vectors'][i], axis=1))
|
448 |
+
|
449 |
+
X = np.array(concat_v)
|
450 |
+
X = X[:, 0, :]
|
451 |
+
print(X.shape)
|
452 |
+
|
453 |
+
return X
|
454 |
+
|
455 |
+
|
backend/networks_stylegan3.py
ADDED
@@ -0,0 +1,515 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION is strictly prohibited.
|
8 |
+
|
9 |
+
"""Generator architecture from the paper
|
10 |
+
"Alias-Free Generative Adversarial Networks"."""
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import scipy.signal
|
14 |
+
import scipy.optimize
|
15 |
+
import torch
|
16 |
+
from torch_utils import misc
|
17 |
+
from torch_utils import persistence
|
18 |
+
from torch_utils.ops import conv2d_gradfix
|
19 |
+
from torch_utils.ops import filtered_lrelu
|
20 |
+
from torch_utils.ops import bias_act
|
21 |
+
|
22 |
+
#----------------------------------------------------------------------------
|
23 |
+
|
24 |
+
@misc.profiled_function
|
25 |
+
def modulated_conv2d(
|
26 |
+
x, # Input tensor: [batch_size, in_channels, in_height, in_width]
|
27 |
+
w, # Weight tensor: [out_channels, in_channels, kernel_height, kernel_width]
|
28 |
+
s, # Style tensor: [batch_size, in_channels]
|
29 |
+
demodulate = True, # Apply weight demodulation?
|
30 |
+
padding = 0, # Padding: int or [padH, padW]
|
31 |
+
input_gain = None, # Optional scale factors for the input channels: [], [in_channels], or [batch_size, in_channels]
|
32 |
+
):
|
33 |
+
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
|
34 |
+
batch_size = int(x.shape[0])
|
35 |
+
out_channels, in_channels, kh, kw = w.shape
|
36 |
+
misc.assert_shape(w, [out_channels, in_channels, kh, kw]) # [OIkk]
|
37 |
+
misc.assert_shape(x, [batch_size, in_channels, None, None]) # [NIHW]
|
38 |
+
misc.assert_shape(s, [batch_size, in_channels]) # [NI]
|
39 |
+
|
40 |
+
# Pre-normalize inputs.
|
41 |
+
if demodulate:
|
42 |
+
w = w * w.square().mean([1,2,3], keepdim=True).rsqrt()
|
43 |
+
s = s * s.square().mean().rsqrt()
|
44 |
+
|
45 |
+
# Modulate weights.
|
46 |
+
w = w.unsqueeze(0) # [NOIkk]
|
47 |
+
w = w * s.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
|
48 |
+
|
49 |
+
# Demodulate weights.
|
50 |
+
if demodulate:
|
51 |
+
dcoefs = (w.square().sum(dim=[2,3,4]) + 1e-8).rsqrt() # [NO]
|
52 |
+
w = w * dcoefs.unsqueeze(2).unsqueeze(3).unsqueeze(4) # [NOIkk]
|
53 |
+
|
54 |
+
# Apply input scaling.
|
55 |
+
if input_gain is not None:
|
56 |
+
input_gain = input_gain.expand(batch_size, in_channels) # [NI]
|
57 |
+
w = w * input_gain.unsqueeze(1).unsqueeze(3).unsqueeze(4) # [NOIkk]
|
58 |
+
|
59 |
+
# Execute as one fused op using grouped convolution.
|
60 |
+
x = x.reshape(1, -1, *x.shape[2:])
|
61 |
+
w = w.reshape(-1, in_channels, kh, kw)
|
62 |
+
x = conv2d_gradfix.conv2d(input=x, weight=w.to(x.dtype), padding=padding, groups=batch_size)
|
63 |
+
x = x.reshape(batch_size, -1, *x.shape[2:])
|
64 |
+
return x
|
65 |
+
|
66 |
+
#----------------------------------------------------------------------------
|
67 |
+
|
68 |
+
@persistence.persistent_class
|
69 |
+
class FullyConnectedLayer(torch.nn.Module):
|
70 |
+
def __init__(self,
|
71 |
+
in_features, # Number of input features.
|
72 |
+
out_features, # Number of output features.
|
73 |
+
activation = 'linear', # Activation function: 'relu', 'lrelu', etc.
|
74 |
+
bias = True, # Apply additive bias before the activation function?
|
75 |
+
lr_multiplier = 1, # Learning rate multiplier.
|
76 |
+
weight_init = 1, # Initial standard deviation of the weight tensor.
|
77 |
+
bias_init = 0, # Initial value of the additive bias.
|
78 |
+
):
|
79 |
+
super().__init__()
|
80 |
+
self.in_features = in_features
|
81 |
+
self.out_features = out_features
|
82 |
+
self.activation = activation
|
83 |
+
self.weight = torch.nn.Parameter(torch.randn([out_features, in_features]) * (weight_init / lr_multiplier))
|
84 |
+
bias_init = np.broadcast_to(np.asarray(bias_init, dtype=np.float32), [out_features])
|
85 |
+
self.bias = torch.nn.Parameter(torch.from_numpy(bias_init / lr_multiplier)) if bias else None
|
86 |
+
self.weight_gain = lr_multiplier / np.sqrt(in_features)
|
87 |
+
self.bias_gain = lr_multiplier
|
88 |
+
|
89 |
+
def forward(self, x):
|
90 |
+
w = self.weight.to(x.dtype) * self.weight_gain
|
91 |
+
b = self.bias
|
92 |
+
if b is not None:
|
93 |
+
b = b.to(x.dtype)
|
94 |
+
if self.bias_gain != 1:
|
95 |
+
b = b * self.bias_gain
|
96 |
+
if self.activation == 'linear' and b is not None:
|
97 |
+
x = torch.addmm(b.unsqueeze(0), x, w.t())
|
98 |
+
else:
|
99 |
+
x = x.matmul(w.t())
|
100 |
+
x = bias_act.bias_act(x, b, act=self.activation)
|
101 |
+
return x
|
102 |
+
|
103 |
+
def extra_repr(self):
|
104 |
+
return f'in_features={self.in_features:d}, out_features={self.out_features:d}, activation={self.activation:s}'
|
105 |
+
|
106 |
+
#----------------------------------------------------------------------------
|
107 |
+
|
108 |
+
@persistence.persistent_class
|
109 |
+
class MappingNetwork(torch.nn.Module):
|
110 |
+
def __init__(self,
|
111 |
+
z_dim, # Input latent (Z) dimensionality.
|
112 |
+
c_dim, # Conditioning label (C) dimensionality, 0 = no labels.
|
113 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
114 |
+
num_ws, # Number of intermediate latents to output.
|
115 |
+
num_layers = 2, # Number of mapping layers.
|
116 |
+
lr_multiplier = 0.01, # Learning rate multiplier for the mapping layers.
|
117 |
+
w_avg_beta = 0.998, # Decay for tracking the moving average of W during training.
|
118 |
+
):
|
119 |
+
super().__init__()
|
120 |
+
self.z_dim = z_dim
|
121 |
+
self.c_dim = c_dim
|
122 |
+
self.w_dim = w_dim
|
123 |
+
self.num_ws = num_ws
|
124 |
+
self.num_layers = num_layers
|
125 |
+
self.w_avg_beta = w_avg_beta
|
126 |
+
|
127 |
+
# Construct layers.
|
128 |
+
self.embed = FullyConnectedLayer(self.c_dim, self.w_dim) if self.c_dim > 0 else None
|
129 |
+
features = [self.z_dim + (self.w_dim if self.c_dim > 0 else 0)] + [self.w_dim] * self.num_layers
|
130 |
+
for idx, in_features, out_features in zip(range(num_layers), features[:-1], features[1:]):
|
131 |
+
layer = FullyConnectedLayer(in_features, out_features, activation='lrelu', lr_multiplier=lr_multiplier)
|
132 |
+
setattr(self, f'fc{idx}', layer)
|
133 |
+
self.register_buffer('w_avg', torch.zeros([w_dim]))
|
134 |
+
|
135 |
+
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
|
136 |
+
misc.assert_shape(z, [None, self.z_dim])
|
137 |
+
if truncation_cutoff is None:
|
138 |
+
truncation_cutoff = self.num_ws
|
139 |
+
|
140 |
+
# Embed, normalize, and concatenate inputs.
|
141 |
+
x = z.to(torch.float32)
|
142 |
+
x = x * (x.square().mean(1, keepdim=True) + 1e-8).rsqrt()
|
143 |
+
if self.c_dim > 0:
|
144 |
+
misc.assert_shape(c, [None, self.c_dim])
|
145 |
+
y = self.embed(c.to(torch.float32))
|
146 |
+
y = y * (y.square().mean(1, keepdim=True) + 1e-8).rsqrt()
|
147 |
+
x = torch.cat([x, y], dim=1) if x is not None else y
|
148 |
+
|
149 |
+
# Execute layers.
|
150 |
+
for idx in range(self.num_layers):
|
151 |
+
x = getattr(self, f'fc{idx}')(x)
|
152 |
+
|
153 |
+
# Update moving average of W.
|
154 |
+
if update_emas:
|
155 |
+
self.w_avg.copy_(x.detach().mean(dim=0).lerp(self.w_avg, self.w_avg_beta))
|
156 |
+
|
157 |
+
# Broadcast and apply truncation.
|
158 |
+
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
|
159 |
+
if truncation_psi != 1:
|
160 |
+
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
|
161 |
+
return x
|
162 |
+
|
163 |
+
def extra_repr(self):
|
164 |
+
return f'z_dim={self.z_dim:d}, c_dim={self.c_dim:d}, w_dim={self.w_dim:d}, num_ws={self.num_ws:d}'
|
165 |
+
|
166 |
+
#----------------------------------------------------------------------------
|
167 |
+
|
168 |
+
@persistence.persistent_class
|
169 |
+
class SynthesisInput(torch.nn.Module):
|
170 |
+
def __init__(self,
|
171 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
172 |
+
channels, # Number of output channels.
|
173 |
+
size, # Output spatial size: int or [width, height].
|
174 |
+
sampling_rate, # Output sampling rate.
|
175 |
+
bandwidth, # Output bandwidth.
|
176 |
+
):
|
177 |
+
super().__init__()
|
178 |
+
self.w_dim = w_dim
|
179 |
+
self.channels = channels
|
180 |
+
self.size = np.broadcast_to(np.asarray(size), [2])
|
181 |
+
self.sampling_rate = sampling_rate
|
182 |
+
self.bandwidth = bandwidth
|
183 |
+
|
184 |
+
# Draw random frequencies from uniform 2D disc.
|
185 |
+
freqs = torch.randn([self.channels, 2])
|
186 |
+
radii = freqs.square().sum(dim=1, keepdim=True).sqrt()
|
187 |
+
freqs /= radii * radii.square().exp().pow(0.25)
|
188 |
+
freqs *= bandwidth
|
189 |
+
phases = torch.rand([self.channels]) - 0.5
|
190 |
+
|
191 |
+
# Setup parameters and buffers.
|
192 |
+
self.weight = torch.nn.Parameter(torch.randn([self.channels, self.channels]))
|
193 |
+
self.affine = FullyConnectedLayer(w_dim, 4, weight_init=0, bias_init=[1,0,0,0])
|
194 |
+
self.register_buffer('transform', torch.eye(3, 3)) # User-specified inverse transform wrt. resulting image.
|
195 |
+
self.register_buffer('freqs', freqs)
|
196 |
+
self.register_buffer('phases', phases)
|
197 |
+
|
198 |
+
def forward(self, w):
|
199 |
+
# Introduce batch dimension.
|
200 |
+
transforms = self.transform.unsqueeze(0) # [batch, row, col]
|
201 |
+
freqs = self.freqs.unsqueeze(0) # [batch, channel, xy]
|
202 |
+
phases = self.phases.unsqueeze(0) # [batch, channel]
|
203 |
+
|
204 |
+
# Apply learned transformation.
|
205 |
+
t = self.affine(w) # t = (r_c, r_s, t_x, t_y)
|
206 |
+
t = t / t[:, :2].norm(dim=1, keepdim=True) # t' = (r'_c, r'_s, t'_x, t'_y)
|
207 |
+
m_r = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse rotation wrt. resulting image.
|
208 |
+
m_r[:, 0, 0] = t[:, 0] # r'_c
|
209 |
+
m_r[:, 0, 1] = -t[:, 1] # r'_s
|
210 |
+
m_r[:, 1, 0] = t[:, 1] # r'_s
|
211 |
+
m_r[:, 1, 1] = t[:, 0] # r'_c
|
212 |
+
m_t = torch.eye(3, device=w.device).unsqueeze(0).repeat([w.shape[0], 1, 1]) # Inverse translation wrt. resulting image.
|
213 |
+
m_t[:, 0, 2] = -t[:, 2] # t'_x
|
214 |
+
m_t[:, 1, 2] = -t[:, 3] # t'_y
|
215 |
+
transforms = m_r @ m_t @ transforms # First rotate resulting image, then translate, and finally apply user-specified transform.
|
216 |
+
|
217 |
+
# Transform frequencies.
|
218 |
+
phases = phases + (freqs @ transforms[:, :2, 2:]).squeeze(2)
|
219 |
+
freqs = freqs @ transforms[:, :2, :2]
|
220 |
+
|
221 |
+
# Dampen out-of-band frequencies that may occur due to the user-specified transform.
|
222 |
+
amplitudes = (1 - (freqs.norm(dim=2) - self.bandwidth) / (self.sampling_rate / 2 - self.bandwidth)).clamp(0, 1)
|
223 |
+
|
224 |
+
# Construct sampling grid.
|
225 |
+
theta = torch.eye(2, 3, device=w.device)
|
226 |
+
theta[0, 0] = 0.5 * self.size[0] / self.sampling_rate
|
227 |
+
theta[1, 1] = 0.5 * self.size[1] / self.sampling_rate
|
228 |
+
grids = torch.nn.functional.affine_grid(theta.unsqueeze(0), [1, 1, self.size[1], self.size[0]], align_corners=False)
|
229 |
+
|
230 |
+
# Compute Fourier features.
|
231 |
+
x = (grids.unsqueeze(3) @ freqs.permute(0, 2, 1).unsqueeze(1).unsqueeze(2)).squeeze(3) # [batch, height, width, channel]
|
232 |
+
x = x + phases.unsqueeze(1).unsqueeze(2)
|
233 |
+
x = torch.sin(x * (np.pi * 2))
|
234 |
+
x = x * amplitudes.unsqueeze(1).unsqueeze(2)
|
235 |
+
|
236 |
+
# Apply trainable mapping.
|
237 |
+
weight = self.weight / np.sqrt(self.channels)
|
238 |
+
x = x @ weight.t()
|
239 |
+
|
240 |
+
# Ensure correct shape.
|
241 |
+
x = x.permute(0, 3, 1, 2) # [batch, channel, height, width]
|
242 |
+
misc.assert_shape(x, [w.shape[0], self.channels, int(self.size[1]), int(self.size[0])])
|
243 |
+
return x
|
244 |
+
|
245 |
+
def extra_repr(self):
|
246 |
+
return '\n'.join([
|
247 |
+
f'w_dim={self.w_dim:d}, channels={self.channels:d}, size={list(self.size)},',
|
248 |
+
f'sampling_rate={self.sampling_rate:g}, bandwidth={self.bandwidth:g}'])
|
249 |
+
|
250 |
+
#----------------------------------------------------------------------------
|
251 |
+
|
252 |
+
@persistence.persistent_class
|
253 |
+
class SynthesisLayer(torch.nn.Module):
|
254 |
+
def __init__(self,
|
255 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
256 |
+
is_torgb, # Is this the final ToRGB layer?
|
257 |
+
is_critically_sampled, # Does this layer use critical sampling?
|
258 |
+
use_fp16, # Does this layer use FP16?
|
259 |
+
|
260 |
+
# Input & output specifications.
|
261 |
+
in_channels, # Number of input channels.
|
262 |
+
out_channels, # Number of output channels.
|
263 |
+
in_size, # Input spatial size: int or [width, height].
|
264 |
+
out_size, # Output spatial size: int or [width, height].
|
265 |
+
in_sampling_rate, # Input sampling rate (s).
|
266 |
+
out_sampling_rate, # Output sampling rate (s).
|
267 |
+
in_cutoff, # Input cutoff frequency (f_c).
|
268 |
+
out_cutoff, # Output cutoff frequency (f_c).
|
269 |
+
in_half_width, # Input transition band half-width (f_h).
|
270 |
+
out_half_width, # Output Transition band half-width (f_h).
|
271 |
+
|
272 |
+
# Hyperparameters.
|
273 |
+
conv_kernel = 3, # Convolution kernel size. Ignored for final the ToRGB layer.
|
274 |
+
filter_size = 6, # Low-pass filter size relative to the lower resolution when up/downsampling.
|
275 |
+
lrelu_upsampling = 2, # Relative sampling rate for leaky ReLU. Ignored for final the ToRGB layer.
|
276 |
+
use_radial_filters = False, # Use radially symmetric downsampling filter? Ignored for critically sampled layers.
|
277 |
+
conv_clamp = 256, # Clamp the output to [-X, +X], None = disable clamping.
|
278 |
+
magnitude_ema_beta = 0.999, # Decay rate for the moving average of input magnitudes.
|
279 |
+
):
|
280 |
+
super().__init__()
|
281 |
+
self.w_dim = w_dim
|
282 |
+
self.is_torgb = is_torgb
|
283 |
+
self.is_critically_sampled = is_critically_sampled
|
284 |
+
self.use_fp16 = use_fp16
|
285 |
+
self.in_channels = in_channels
|
286 |
+
self.out_channels = out_channels
|
287 |
+
self.in_size = np.broadcast_to(np.asarray(in_size), [2])
|
288 |
+
self.out_size = np.broadcast_to(np.asarray(out_size), [2])
|
289 |
+
self.in_sampling_rate = in_sampling_rate
|
290 |
+
self.out_sampling_rate = out_sampling_rate
|
291 |
+
self.tmp_sampling_rate = max(in_sampling_rate, out_sampling_rate) * (1 if is_torgb else lrelu_upsampling)
|
292 |
+
self.in_cutoff = in_cutoff
|
293 |
+
self.out_cutoff = out_cutoff
|
294 |
+
self.in_half_width = in_half_width
|
295 |
+
self.out_half_width = out_half_width
|
296 |
+
self.conv_kernel = 1 if is_torgb else conv_kernel
|
297 |
+
self.conv_clamp = conv_clamp
|
298 |
+
self.magnitude_ema_beta = magnitude_ema_beta
|
299 |
+
|
300 |
+
# Setup parameters and buffers.
|
301 |
+
self.affine = FullyConnectedLayer(self.w_dim, self.in_channels, bias_init=1)
|
302 |
+
self.weight = torch.nn.Parameter(torch.randn([self.out_channels, self.in_channels, self.conv_kernel, self.conv_kernel]))
|
303 |
+
self.bias = torch.nn.Parameter(torch.zeros([self.out_channels]))
|
304 |
+
self.register_buffer('magnitude_ema', torch.ones([]))
|
305 |
+
|
306 |
+
# Design upsampling filter.
|
307 |
+
self.up_factor = int(np.rint(self.tmp_sampling_rate / self.in_sampling_rate))
|
308 |
+
assert self.in_sampling_rate * self.up_factor == self.tmp_sampling_rate
|
309 |
+
self.up_taps = filter_size * self.up_factor if self.up_factor > 1 and not self.is_torgb else 1
|
310 |
+
self.register_buffer('up_filter', self.design_lowpass_filter(
|
311 |
+
numtaps=self.up_taps, cutoff=self.in_cutoff, width=self.in_half_width*2, fs=self.tmp_sampling_rate))
|
312 |
+
|
313 |
+
# Design downsampling filter.
|
314 |
+
self.down_factor = int(np.rint(self.tmp_sampling_rate / self.out_sampling_rate))
|
315 |
+
assert self.out_sampling_rate * self.down_factor == self.tmp_sampling_rate
|
316 |
+
self.down_taps = filter_size * self.down_factor if self.down_factor > 1 and not self.is_torgb else 1
|
317 |
+
self.down_radial = use_radial_filters and not self.is_critically_sampled
|
318 |
+
self.register_buffer('down_filter', self.design_lowpass_filter(
|
319 |
+
numtaps=self.down_taps, cutoff=self.out_cutoff, width=self.out_half_width*2, fs=self.tmp_sampling_rate, radial=self.down_radial))
|
320 |
+
|
321 |
+
# Compute padding.
|
322 |
+
pad_total = (self.out_size - 1) * self.down_factor + 1 # Desired output size before downsampling.
|
323 |
+
pad_total -= (self.in_size + self.conv_kernel - 1) * self.up_factor # Input size after upsampling.
|
324 |
+
pad_total += self.up_taps + self.down_taps - 2 # Size reduction caused by the filters.
|
325 |
+
pad_lo = (pad_total + self.up_factor) // 2 # Shift sample locations according to the symmetric interpretation (Appendix C.3).
|
326 |
+
pad_hi = pad_total - pad_lo
|
327 |
+
self.padding = [int(pad_lo[0]), int(pad_hi[0]), int(pad_lo[1]), int(pad_hi[1])]
|
328 |
+
|
329 |
+
def forward(self, x, w, noise_mode='random', force_fp32=False, update_emas=False):
|
330 |
+
assert noise_mode in ['random', 'const', 'none'] # unused
|
331 |
+
misc.assert_shape(x, [None, self.in_channels, int(self.in_size[1]), int(self.in_size[0])])
|
332 |
+
misc.assert_shape(w, [x.shape[0], self.w_dim])
|
333 |
+
|
334 |
+
# Track input magnitude.
|
335 |
+
if update_emas:
|
336 |
+
with torch.autograd.profiler.record_function('update_magnitude_ema'):
|
337 |
+
magnitude_cur = x.detach().to(torch.float32).square().mean()
|
338 |
+
self.magnitude_ema.copy_(magnitude_cur.lerp(self.magnitude_ema, self.magnitude_ema_beta))
|
339 |
+
input_gain = self.magnitude_ema.rsqrt()
|
340 |
+
|
341 |
+
# Execute affine layer.
|
342 |
+
styles = self.affine(w)
|
343 |
+
if self.is_torgb:
|
344 |
+
weight_gain = 1 / np.sqrt(self.in_channels * (self.conv_kernel ** 2))
|
345 |
+
styles = styles * weight_gain
|
346 |
+
|
347 |
+
# Execute modulated conv2d.
|
348 |
+
dtype = torch.float16 if (self.use_fp16 and not force_fp32 and x.device.type == 'cuda') else torch.float32
|
349 |
+
x = modulated_conv2d(x=x.to(dtype), w=self.weight, s=styles,
|
350 |
+
padding=self.conv_kernel-1, demodulate=(not self.is_torgb), input_gain=input_gain)
|
351 |
+
|
352 |
+
# Execute bias, filtered leaky ReLU, and clamping.
|
353 |
+
gain = 1 if self.is_torgb else np.sqrt(2)
|
354 |
+
slope = 1 if self.is_torgb else 0.2
|
355 |
+
x = filtered_lrelu.filtered_lrelu(x=x, fu=self.up_filter, fd=self.down_filter, b=self.bias.to(x.dtype),
|
356 |
+
up=self.up_factor, down=self.down_factor, padding=self.padding, gain=gain, slope=slope, clamp=self.conv_clamp)
|
357 |
+
|
358 |
+
# Ensure correct shape and dtype.
|
359 |
+
misc.assert_shape(x, [None, self.out_channels, int(self.out_size[1]), int(self.out_size[0])])
|
360 |
+
assert x.dtype == dtype
|
361 |
+
return x
|
362 |
+
|
363 |
+
@staticmethod
|
364 |
+
def design_lowpass_filter(numtaps, cutoff, width, fs, radial=False):
|
365 |
+
assert numtaps >= 1
|
366 |
+
|
367 |
+
# Identity filter.
|
368 |
+
if numtaps == 1:
|
369 |
+
return None
|
370 |
+
|
371 |
+
# Separable Kaiser low-pass filter.
|
372 |
+
if not radial:
|
373 |
+
f = scipy.signal.firwin(numtaps=numtaps, cutoff=cutoff, width=width, fs=fs)
|
374 |
+
return torch.as_tensor(f, dtype=torch.float32)
|
375 |
+
|
376 |
+
# Radially symmetric jinc-based filter.
|
377 |
+
x = (np.arange(numtaps) - (numtaps - 1) / 2) / fs
|
378 |
+
r = np.hypot(*np.meshgrid(x, x))
|
379 |
+
f = scipy.special.j1(2 * cutoff * (np.pi * r)) / (np.pi * r)
|
380 |
+
beta = scipy.signal.kaiser_beta(scipy.signal.kaiser_atten(numtaps, width / (fs / 2)))
|
381 |
+
w = np.kaiser(numtaps, beta)
|
382 |
+
f *= np.outer(w, w)
|
383 |
+
f /= np.sum(f)
|
384 |
+
return torch.as_tensor(f, dtype=torch.float32)
|
385 |
+
|
386 |
+
def extra_repr(self):
|
387 |
+
return '\n'.join([
|
388 |
+
f'w_dim={self.w_dim:d}, is_torgb={self.is_torgb},',
|
389 |
+
f'is_critically_sampled={self.is_critically_sampled}, use_fp16={self.use_fp16},',
|
390 |
+
f'in_sampling_rate={self.in_sampling_rate:g}, out_sampling_rate={self.out_sampling_rate:g},',
|
391 |
+
f'in_cutoff={self.in_cutoff:g}, out_cutoff={self.out_cutoff:g},',
|
392 |
+
f'in_half_width={self.in_half_width:g}, out_half_width={self.out_half_width:g},',
|
393 |
+
f'in_size={list(self.in_size)}, out_size={list(self.out_size)},',
|
394 |
+
f'in_channels={self.in_channels:d}, out_channels={self.out_channels:d}'])
|
395 |
+
|
396 |
+
#----------------------------------------------------------------------------
|
397 |
+
|
398 |
+
@persistence.persistent_class
|
399 |
+
class SynthesisNetwork(torch.nn.Module):
|
400 |
+
def __init__(self,
|
401 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
402 |
+
img_resolution, # Output image resolution.
|
403 |
+
img_channels, # Number of color channels.
|
404 |
+
channel_base = 32768, # Overall multiplier for the number of channels.
|
405 |
+
channel_max = 512, # Maximum number of channels in any layer.
|
406 |
+
num_layers = 14, # Total number of layers, excluding Fourier features and ToRGB.
|
407 |
+
num_critical = 2, # Number of critically sampled layers at the end.
|
408 |
+
first_cutoff = 2, # Cutoff frequency of the first layer (f_{c,0}).
|
409 |
+
first_stopband = 2**2.1, # Minimum stopband of the first layer (f_{t,0}).
|
410 |
+
last_stopband_rel = 2**0.3, # Minimum stopband of the last layer, expressed relative to the cutoff.
|
411 |
+
margin_size = 10, # Number of additional pixels outside the image.
|
412 |
+
output_scale = 0.25, # Scale factor for the output image.
|
413 |
+
num_fp16_res = 4, # Use FP16 for the N highest resolutions.
|
414 |
+
**layer_kwargs, # Arguments for SynthesisLayer.
|
415 |
+
):
|
416 |
+
super().__init__()
|
417 |
+
self.w_dim = w_dim
|
418 |
+
self.num_ws = num_layers + 2
|
419 |
+
self.img_resolution = img_resolution
|
420 |
+
self.img_channels = img_channels
|
421 |
+
self.num_layers = num_layers
|
422 |
+
self.num_critical = num_critical
|
423 |
+
self.margin_size = margin_size
|
424 |
+
self.output_scale = output_scale
|
425 |
+
self.num_fp16_res = num_fp16_res
|
426 |
+
|
427 |
+
# Geometric progression of layer cutoffs and min. stopbands.
|
428 |
+
last_cutoff = self.img_resolution / 2 # f_{c,N}
|
429 |
+
last_stopband = last_cutoff * last_stopband_rel # f_{t,N}
|
430 |
+
exponents = np.minimum(np.arange(self.num_layers + 1) / (self.num_layers - self.num_critical), 1)
|
431 |
+
cutoffs = first_cutoff * (last_cutoff / first_cutoff) ** exponents # f_c[i]
|
432 |
+
stopbands = first_stopband * (last_stopband / first_stopband) ** exponents # f_t[i]
|
433 |
+
|
434 |
+
# Compute remaining layer parameters.
|
435 |
+
sampling_rates = np.exp2(np.ceil(np.log2(np.minimum(stopbands * 2, self.img_resolution)))) # s[i]
|
436 |
+
half_widths = np.maximum(stopbands, sampling_rates / 2) - cutoffs # f_h[i]
|
437 |
+
sizes = sampling_rates + self.margin_size * 2
|
438 |
+
sizes[-2:] = self.img_resolution
|
439 |
+
channels = np.rint(np.minimum((channel_base / 2) / cutoffs, channel_max))
|
440 |
+
channels[-1] = self.img_channels
|
441 |
+
|
442 |
+
# Construct layers.
|
443 |
+
self.input = SynthesisInput(
|
444 |
+
w_dim=self.w_dim, channels=int(channels[0]), size=int(sizes[0]),
|
445 |
+
sampling_rate=sampling_rates[0], bandwidth=cutoffs[0])
|
446 |
+
self.layer_names = []
|
447 |
+
for idx in range(self.num_layers + 1):
|
448 |
+
prev = max(idx - 1, 0)
|
449 |
+
is_torgb = (idx == self.num_layers)
|
450 |
+
is_critically_sampled = (idx >= self.num_layers - self.num_critical)
|
451 |
+
use_fp16 = (sampling_rates[idx] * (2 ** self.num_fp16_res) > self.img_resolution)
|
452 |
+
layer = SynthesisLayer(
|
453 |
+
w_dim=self.w_dim, is_torgb=is_torgb, is_critically_sampled=is_critically_sampled, use_fp16=use_fp16,
|
454 |
+
in_channels=int(channels[prev]), out_channels= int(channels[idx]),
|
455 |
+
in_size=int(sizes[prev]), out_size=int(sizes[idx]),
|
456 |
+
in_sampling_rate=int(sampling_rates[prev]), out_sampling_rate=int(sampling_rates[idx]),
|
457 |
+
in_cutoff=cutoffs[prev], out_cutoff=cutoffs[idx],
|
458 |
+
in_half_width=half_widths[prev], out_half_width=half_widths[idx],
|
459 |
+
**layer_kwargs)
|
460 |
+
name = f'L{idx}_{layer.out_size[0]}_{layer.out_channels}'
|
461 |
+
setattr(self, name, layer)
|
462 |
+
self.layer_names.append(name)
|
463 |
+
|
464 |
+
def forward(self, ws, **layer_kwargs):
|
465 |
+
misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
|
466 |
+
ws = ws.to(torch.float32).unbind(dim=1)
|
467 |
+
|
468 |
+
# Execute layers.
|
469 |
+
x = self.input(ws[0])
|
470 |
+
for name, w in zip(self.layer_names, ws[1:]):
|
471 |
+
x = getattr(self, name)(x, w, **layer_kwargs)
|
472 |
+
if self.output_scale != 1:
|
473 |
+
x = x * self.output_scale
|
474 |
+
|
475 |
+
# Ensure correct shape and dtype.
|
476 |
+
misc.assert_shape(x, [None, self.img_channels, self.img_resolution, self.img_resolution])
|
477 |
+
x = x.to(torch.float32)
|
478 |
+
return x
|
479 |
+
|
480 |
+
def extra_repr(self):
|
481 |
+
return '\n'.join([
|
482 |
+
f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
|
483 |
+
f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
|
484 |
+
f'num_layers={self.num_layers:d}, num_critical={self.num_critical:d},',
|
485 |
+
f'margin_size={self.margin_size:d}, num_fp16_res={self.num_fp16_res:d}'])
|
486 |
+
|
487 |
+
#----------------------------------------------------------------------------
|
488 |
+
|
489 |
+
@persistence.persistent_class
|
490 |
+
class Generator(torch.nn.Module):
|
491 |
+
def __init__(self,
|
492 |
+
z_dim, # Input latent (Z) dimensionality.
|
493 |
+
c_dim, # Conditioning label (C) dimensionality.
|
494 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
495 |
+
img_resolution, # Output resolution.
|
496 |
+
img_channels, # Number of output color channels.
|
497 |
+
mapping_kwargs = {}, # Arguments for MappingNetwork.
|
498 |
+
**synthesis_kwargs, # Arguments for SynthesisNetwork.
|
499 |
+
):
|
500 |
+
super().__init__()
|
501 |
+
self.z_dim = z_dim
|
502 |
+
self.c_dim = c_dim
|
503 |
+
self.w_dim = w_dim
|
504 |
+
self.img_resolution = img_resolution
|
505 |
+
self.img_channels = img_channels
|
506 |
+
self.synthesis = SynthesisNetwork(w_dim=w_dim, img_resolution=img_resolution, img_channels=img_channels, **synthesis_kwargs)
|
507 |
+
self.num_ws = self.synthesis.num_ws
|
508 |
+
self.mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs)
|
509 |
+
|
510 |
+
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
|
511 |
+
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
|
512 |
+
img = self.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
|
513 |
+
return img
|
514 |
+
|
515 |
+
#----------------------------------------------------------------------------
|
data/stylegan3.webp
ADDED
Git LFS Details
|
data/textile_annotated_files/final_sim_seeds0000-100000.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2dde6f1168825424eed5aa328c6442873df35637e74786f9c9af956c9c0a97ed
|
3 |
+
size 7886477
|
data/textile_annotated_files/hsv_info.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:36016190017d38cde71f35c267df9f6b0ab40d74ce17e022195e96d32d2f2f71
|
3 |
+
size 1112635
|
data/textile_annotated_files/seeds0000-100000.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2dd321307f3d332193630f823a3e0db79d533156dfbc0d446eab0d5c212b1360
|
3 |
+
size 630151183
|
data/textile_annotated_files/seeds0000-100000_S.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8763836ea1142f6f2e3d36b7fe92bcf9a4549e9ef8e0a83a02b4772d64e95d54
|
3 |
+
size 3178623075
|
data/textile_annotated_files/top_three_colours.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2dde6f1168825424eed5aa328c6442873df35637e74786f9c9af956c9c0a97ed
|
3 |
+
size 7886477
|
data/textile_model_files/network-snapshot-005000.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:717bdd11871c0383d6e28c54b2d61cd485ef236dd1de34d3194323c843b11b62
|
3 |
+
size 343479704
|
pages/{4_Vase_Qualities_Comparison.py → 4_Vase_Qualities_Comparison copy.py}
RENAMED
File without changes
|
pages/5_Textiles_Disentanglement.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pickle
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from matplotlib.backends.backend_agg import RendererAgg
|
9 |
+
|
10 |
+
from backend.disentangle_concepts import *
|
11 |
+
import torch_utils
|
12 |
+
import dnnlib
|
13 |
+
import legacy
|
14 |
+
|
15 |
+
_lock = RendererAgg.lock
|
16 |
+
|
17 |
+
|
18 |
+
st.set_page_config(layout='wide')
|
19 |
+
BACKGROUND_COLOR = '#bcd0e7'
|
20 |
+
SECONDARY_COLOR = '#bce7db'
|
21 |
+
|
22 |
+
|
23 |
+
st.title('Disentanglement studies on the Textile Dataset')
|
24 |
+
st.markdown(
|
25 |
+
"""
|
26 |
+
This is a demo of the Disentanglement studies on the [Oxford Vases Dataset](https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/).
|
27 |
+
""",
|
28 |
+
unsafe_allow_html=False,)
|
29 |
+
|
30 |
+
annotations_file = './data/vase_annotated_files/seeds0000-20000.pkl'
|
31 |
+
with open(annotations_file, 'rb') as f:
|
32 |
+
annotations = pickle.load(f)
|
33 |
+
|
34 |
+
|
35 |
+
if 'image_id' not in st.session_state:
|
36 |
+
st.session_state.image_id = 0
|
37 |
+
if 'concept_ids' not in st.session_state:
|
38 |
+
st.session_state.concept_ids =['AMPHORA']
|
39 |
+
if 'space_id' not in st.session_state:
|
40 |
+
st.session_state.space_id = 'W'
|
41 |
+
|
42 |
+
# def on_change_random_input():
|
43 |
+
# st.session_state.image_id = st.session_state.image_id
|
44 |
+
|
45 |
+
# ----------------------------- INPUT ----------------------------------
|
46 |
+
st.header('Input')
|
47 |
+
input_col_1, input_col_2, input_col_3 = st.columns(3)
|
48 |
+
# --------------------------- INPUT column 1 ---------------------------
|
49 |
+
with input_col_1:
|
50 |
+
with st.form('text_form'):
|
51 |
+
|
52 |
+
# image_id = st.number_input('Image ID: ', format='%d', step=1)
|
53 |
+
st.write('**Choose two options to disentangle**')
|
54 |
+
type_col = st.selectbox('Concept category:', tuple(['Provenance', 'Shape Name', 'Fabric', 'Technique']))
|
55 |
+
|
56 |
+
ann_df = pd.read_csv(f'./data/vase_annotated_files/sim_{type_col}_seeds0000-20000.csv')
|
57 |
+
labels = list(ann_df.columns)
|
58 |
+
labels.remove('ID')
|
59 |
+
labels.remove('Unnamed: 0')
|
60 |
+
|
61 |
+
concept_ids = st.multiselect('Concepts:', tuple(labels), max_selections=2, default=[labels[2], labels[3]])
|
62 |
+
|
63 |
+
st.write('**Choose a latent space to disentangle**')
|
64 |
+
space_id = st.selectbox('Space:', tuple(['W', 'Z']))
|
65 |
+
|
66 |
+
choose_text_button = st.form_submit_button('Choose the defined concept and space to disentangle')
|
67 |
+
|
68 |
+
if choose_text_button:
|
69 |
+
concept_ids = list(concept_ids)
|
70 |
+
st.session_state.concept_ids = concept_ids
|
71 |
+
space_id = str(space_id)
|
72 |
+
st.session_state.space_id = space_id
|
73 |
+
# st.write(image_id, st.session_state.image_id)
|
74 |
+
|
75 |
+
# ---------------------------- SET UP OUTPUT ------------------------------
|
76 |
+
epsilon_container = st.empty()
|
77 |
+
st.header('Output')
|
78 |
+
st.subheader('Concept vector')
|
79 |
+
|
80 |
+
# perform attack container
|
81 |
+
# header_col_1, header_col_2, header_col_3, header_col_4, header_col_5 = st.columns([1,1,1,1,1])
|
82 |
+
# output_col_1, output_col_2, output_col_3, output_col_4, output_col_5 = st.columns([1,1,1,1,1])
|
83 |
+
header_col_1, header_col_2 = st.columns([5,1])
|
84 |
+
output_col_1, output_col_2 = st.columns([5,1])
|
85 |
+
|
86 |
+
st.subheader('Derivations along the concept vector')
|
87 |
+
|
88 |
+
# prediction error container
|
89 |
+
error_container = st.empty()
|
90 |
+
smoothgrad_header_container = st.empty()
|
91 |
+
|
92 |
+
# smoothgrad container
|
93 |
+
smooth_head_1, smooth_head_2, smooth_head_3, smooth_head_4, smooth_head_5 = st.columns([1,1,1,1,1])
|
94 |
+
smoothgrad_col_1, smoothgrad_col_2, smoothgrad_col_3, smoothgrad_col_4, smoothgrad_col_5 = st.columns([1,1,1,1,1])
|
95 |
+
|
96 |
+
# ---------------------------- DISPLAY COL 1 ROW 1 ------------------------------
|
97 |
+
with output_col_1:
|
98 |
+
separation_vector, number_important_features, imp_nodes, performance = get_separation_space(concept_ids, annotations, ann_df, latent_space=st.session_state.space_id, samples=150)
|
99 |
+
# st.write(f'Class ID {input_id} - {input_label}: {pred_prob*100:.3f}% confidence')
|
100 |
+
st.write('Concept vector', separation_vector)
|
101 |
+
header_col_1.write(f'Concept {st.session_state.concept_ids} - Space {st.session_state.space_id} - Number of relevant nodes: {number_important_features} - Val classification performance: {performance}')# - Nodes {",".join(list(imp_nodes))}')
|
102 |
+
|
103 |
+
# ----------------------------- INPUT column 2 & 3 ----------------------------
|
104 |
+
with input_col_2:
|
105 |
+
with st.form('image_form'):
|
106 |
+
|
107 |
+
# image_id = st.number_input('Image ID: ', format='%d', step=1)
|
108 |
+
st.write('**Choose or generate a random image to test the disentanglement**')
|
109 |
+
chosen_image_id_input = st.empty()
|
110 |
+
image_id = chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
111 |
+
|
112 |
+
choose_image_button = st.form_submit_button('Choose the defined image')
|
113 |
+
random_id = st.form_submit_button('Generate a random image')
|
114 |
+
|
115 |
+
if random_id:
|
116 |
+
image_id = random.randint(0, 20000)
|
117 |
+
st.session_state.image_id = image_id
|
118 |
+
chosen_image_id_input.number_input('Image ID:', format='%d', step=1, value=st.session_state.image_id)
|
119 |
+
|
120 |
+
if choose_image_button:
|
121 |
+
image_id = int(image_id)
|
122 |
+
st.session_state.image_id = int(image_id)
|
123 |
+
# st.write(image_id, st.session_state.image_id)
|
124 |
+
|
125 |
+
with input_col_3:
|
126 |
+
with st.form('Variate along the disentangled concept'):
|
127 |
+
st.write('**Set range of change**')
|
128 |
+
chosen_epsilon_input = st.empty()
|
129 |
+
epsilon = chosen_epsilon_input.number_input('Lambda:', min_value=1, step=1)
|
130 |
+
epsilon_button = st.form_submit_button('Choose the defined lambda')
|
131 |
+
st.write('**Select hierarchical levels to manipulate**')
|
132 |
+
layers = st.multiselect('Layers:', tuple(range(14)))
|
133 |
+
if len(layers) == 0:
|
134 |
+
layers = None
|
135 |
+
print(layers)
|
136 |
+
layers_button = st.form_submit_button('Choose the defined layers')
|
137 |
+
|
138 |
+
|
139 |
+
# ---------------------------- DISPLAY COL 2 ROW 1 ------------------------------
|
140 |
+
|
141 |
+
#model = torch.load('./data/model_files/pytorch_model.bin', map_location=torch.device('cpu'))
|
142 |
+
with dnnlib.util.open_url('./data/vase_model_files/network-snapshot-003800.pkl') as f:
|
143 |
+
model = legacy.load_network_pkl(f)['G_ema'].to('cpu') # type: ignore
|
144 |
+
|
145 |
+
if st.session_state.space_id == 'Z':
|
146 |
+
original_image_vec = annotations['z_vectors'][st.session_state.image_id]
|
147 |
+
else:
|
148 |
+
original_image_vec = annotations['w_vectors'][st.session_state.image_id]
|
149 |
+
|
150 |
+
img = generate_original_image(original_image_vec, model, latent_space=st.session_state.space_id)
|
151 |
+
|
152 |
+
top_pred = ann_df.loc[st.session_state.image_id, labels].astype(float).idxmax()
|
153 |
+
# input_image = original_image_dict['image']
|
154 |
+
# input_label = original_image_dict['label']
|
155 |
+
# input_id = original_image_dict['id']
|
156 |
+
|
157 |
+
with smoothgrad_col_3:
|
158 |
+
st.image(img)
|
159 |
+
smooth_head_3.write(f'Base image, predicted as {top_pred}')
|
160 |
+
|
161 |
+
|
162 |
+
images, lambdas = regenerate_images(model, original_image_vec, separation_vector, min_epsilon=-(int(epsilon)), max_epsilon=int(epsilon), latent_space=st.session_state.space_id, layers=layers)
|
163 |
+
|
164 |
+
with smoothgrad_col_1:
|
165 |
+
st.image(images[0])
|
166 |
+
smooth_head_1.write(f'Change of {np.round(lambdas[0], 2)}')
|
167 |
+
|
168 |
+
with smoothgrad_col_2:
|
169 |
+
st.image(images[1])
|
170 |
+
smooth_head_2.write(f'Change of {np.round(lambdas[1], 2)}')
|
171 |
+
|
172 |
+
with smoothgrad_col_4:
|
173 |
+
st.image(images[3])
|
174 |
+
smooth_head_4.write(f'Change of {np.round(lambdas[3], 2)}')
|
175 |
+
|
176 |
+
with smoothgrad_col_5:
|
177 |
+
st.image(images[4])
|
178 |
+
smooth_head_5.write(f'Change of {np.round(lambdas[4], 2)}')
|