ethanNeuralImage commited on
Commit
5238ef9
1 Parent(s): 47689a5

trying to get RIS working

Browse files
app.py CHANGED
@@ -22,6 +22,10 @@ from argparse import Namespace
22
 
23
  from mapper.styleclip_mapper import StyleCLIPMapper
24
 
 
 
 
 
25
  from PIL import Image
26
 
27
  opts_args = ['--no_fine_mapper']
@@ -62,6 +66,10 @@ resize_amount = (256, 256) if hyperstyle_args.resize_outputs else (hyperstyle_ar
62
  im2tensor_transforms = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
63
  direction_calculator = load_direction_calculator(opts)
64
 
 
 
 
 
65
 
66
  with gr.Blocks() as demo:
67
  with gr.Row() as row:
@@ -70,6 +78,8 @@ with gr.Blocks() as demo:
70
  align = gr.Checkbox(True, label='Align Image')
71
  inverter_bools = gr.CheckboxGroup(["Hyperstyle", "E4E"], value=['Hyperstyle'], label='Inverter Choices')
72
  n_hyperstyle_iterations = gr.Number(5, label='Number of Iterations For Hyperstyle', precision=0)
 
 
73
  with gr.Box():
74
  mapper_bool = gr.Checkbox(True, label='Output Mapper Result')
75
  with gr.Box() as mapper_opts:
@@ -82,14 +92,26 @@ with gr.Blocks() as demo:
82
  target_text = gr.Text(value=mapper_descs['afro'], label='Target Text')
83
  alpha = gr.Slider(minimum=-10.0, maximum=10.0, value=4.1, step=0.1, label="Alpha for Global Direction")
84
  beta = gr.Slider(minimum=0.0, maximum=0.30, value=0.15, step=0.01, label="Beta for Global Direction")
 
 
 
 
85
  submit_button = gr.Button("Edit Image")
86
  with gr.Column() as outputs:
87
  with gr.Row() as hyperstyle_images:
 
88
  output_hyperstyle_mapper = gr.Image(type='pil', label="Hyperstyle Mapper")
89
  output_hyperstyle_gd = gr.Image(type='pil', label="Hyperstyle Global Directions", visible=False)
 
 
 
90
  with gr.Row(visible=False) as e4e_images:
 
91
  output_e4e_mapper = gr.Image(type='pil', label="E4E Mapper")
92
  output_e4e_gd = gr.Image(type='pil', label="E4E Global Directions", visible=False)
 
 
 
93
  def n_iter_change(number):
94
  if number < 0:
95
  return 0
@@ -105,7 +127,11 @@ with gr.Blocks() as demo:
105
  e4e_images: gr.update(visible=e4e_bool),
106
  n_hyperstyle_iterations: gr.update(visible=hyperstyle_bool)
107
  }
108
-
 
 
 
 
109
  def mapper_toggles(bool):
110
  return {
111
  mapper_opts: gr.update(visible=bool),
@@ -118,12 +144,20 @@ with gr.Blocks() as demo:
118
  output_hyperstyle_gd: gr.update(visible=bool),
119
  output_e4e_gd: gr.update(visible=bool)
120
  }
 
 
 
 
 
 
121
 
122
  n_hyperstyle_iterations.change(n_iter_change, n_hyperstyle_iterations, n_hyperstyle_iterations)
123
  mapper_choice.change(mapper_change, mapper_choice, [target_text])
124
  inverter_bools.change(inverter_toggles, inverter_bools, [hyperstyle_images, e4e_images, n_hyperstyle_iterations])
 
125
  mapper_bool.change(mapper_toggles, mapper_bool, [mapper_opts, output_hyperstyle_mapper, output_e4e_mapper])
126
  gd_bool.change(gd_toggles, gd_bool, [gd_opts, output_hyperstyle_gd, output_e4e_gd])
 
127
  def map_latent(mapper, inputs, stylespace=False, weight_deltas=None, strength=0.1):
128
  w = inputs.to(device)
129
  with torch.no_grad():
@@ -140,9 +174,10 @@ with gr.Blocks() as demo:
140
  result_batch = (x_hat, w_hat)
141
  return result_batch
142
  def submit(
143
- src, align_img, inverter_bools, n_iterations,
144
  mapper_bool, mapper_choice, mapper_alpha,
145
  gd_bool, neutral_text, target_text, alpha, beta,
 
146
  ):
147
  if device == 'cuda': torch.cuda.empty_cache()
148
  opts.checkpoint_path = mapper_dict[mapper_choice]
@@ -166,9 +201,20 @@ with gr.Blocks() as demo:
166
  opts.target_text = target_text
167
  opts.alpha = alpha
168
  opts.beta = beta
 
 
 
 
 
 
 
169
 
170
  if 'Hyperstyle' in inverter_bools:
171
  hyperstyle_batch, hyperstyle_latents, hyperstyle_deltas, _ = run_inversion(input_img.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
 
 
 
 
172
  if mapper_bool:
173
  mapped_hyperstyle, _ = map_latent(mapper, hyperstyle_latents, stylespace=False, weight_deltas=hyperstyle_deltas, strength=mapper_alpha)
174
  mapped_hyperstyle = tensor2im(mapped_hyperstyle[0])
@@ -181,13 +227,27 @@ with gr.Blocks() as demo:
181
  else:
182
  gd_hyperstyle = None
183
 
184
- hyperstyle_output = [mapped_hyperstyle,gd_hyperstyle]
 
 
 
 
 
 
 
 
 
 
185
  else:
186
- hyperstyle_output = [None, None]
187
  output_imgs.extend(hyperstyle_output)
188
  if 'E4E' in inverter_bools:
189
  e4e_batch, e4e_latents = hyperstyle.w_invert(input_img.unsqueeze(0))
190
  e4e_deltas = None
 
 
 
 
191
  if mapper_bool:
192
  mapped_e4e, _ = map_latent(mapper, e4e_latents, stylespace=False, weight_deltas=e4e_deltas, strength=mapper_alpha)
193
  mapped_e4e = tensor2im(mapped_e4e[0])
@@ -200,19 +260,31 @@ with gr.Blocks() as demo:
200
  else:
201
  gd_e4e = None
202
 
203
- e4e_output = [mapped_e4e, gd_e4e]
 
 
 
 
 
 
 
 
 
 
204
  else:
205
- e4e_output = [None, None]
206
  output_imgs.extend(e4e_output)
207
  return output_imgs
208
  submit_button.click(
209
  submit,
210
  [
211
- source, align, inverter_bools, n_hyperstyle_iterations,
212
  mapper_bool, mapper_choice, mapper_alpha,
213
  gd_bool, neutral_text, target_text, alpha, beta,
 
214
  ],
215
- [output_hyperstyle_mapper, output_hyperstyle_gd, output_e4e_mapper, output_e4e_gd]
 
216
  )
217
 
218
  demo.launch()
 
22
 
23
  from mapper.styleclip_mapper import StyleCLIPMapper
24
 
25
+ import ris.spherical_kmeans as spherical_kmeans
26
+ from ris.blend import blend_latents
27
+ from ris.model import Generator as RIS_Generator
28
+
29
  from PIL import Image
30
 
31
  opts_args = ['--no_fine_mapper']
 
66
  im2tensor_transforms = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
67
  direction_calculator = load_direction_calculator(opts)
68
 
69
+ ris_gen = RIS_Generator(1024, 512, 8, channel_multiplier=2).to(device).eval()
70
+ ris_ckpt = torch.load('./pretrained_models/ris/stylegan2-ffhq-config-f.pt', map_location=lambda storage, loc: storage)
71
+ ris_gen.load_state_dict(ris_ckpt['g_ema'], strict=False)
72
+
73
 
74
  with gr.Blocks() as demo:
75
  with gr.Row() as row:
 
78
  align = gr.Checkbox(True, label='Align Image')
79
  inverter_bools = gr.CheckboxGroup(["Hyperstyle", "E4E"], value=['Hyperstyle'], label='Inverter Choices')
80
  n_hyperstyle_iterations = gr.Number(5, label='Number of Iterations For Hyperstyle', precision=0)
81
+ with gr.Box():
82
+ invert_bool = gr.Checkbox(False, label='Output Inverter Result')
83
  with gr.Box():
84
  mapper_bool = gr.Checkbox(True, label='Output Mapper Result')
85
  with gr.Box() as mapper_opts:
 
92
  target_text = gr.Text(value=mapper_descs['afro'], label='Target Text')
93
  alpha = gr.Slider(minimum=-10.0, maximum=10.0, value=4.1, step=0.1, label="Alpha for Global Direction")
94
  beta = gr.Slider(minimum=0.0, maximum=0.30, value=0.15, step=0.01, label="Beta for Global Direction")
95
+ with gr.Box():
96
+ ris_bool = gr.Checkbox(False, label='Output RIS Result')
97
+ with gr.Box(visible=False) as ris_opts:
98
+ ref_img = gr.Image(label='Refrence Image for Hair', type='filepath')
99
  submit_button = gr.Button("Edit Image")
100
  with gr.Column() as outputs:
101
  with gr.Row() as hyperstyle_images:
102
+ output_hyperstyle_invert = gr.Image(type='pil', label="Hyperstyle Inverted", visible=False)
103
  output_hyperstyle_mapper = gr.Image(type='pil', label="Hyperstyle Mapper")
104
  output_hyperstyle_gd = gr.Image(type='pil', label="Hyperstyle Global Directions", visible=False)
105
+ output_hyperstyle_ris = gr.Image(type='pil', label='Hyperstyle RIS', visible=False)
106
+ with gr.Row() as hyperstyle_metrics:
107
+ output_hypersyle_metrics = gr.Text()
108
  with gr.Row(visible=False) as e4e_images:
109
+ output_e4e_invert = gr.Image(type='pil', label="E4E Inverted", visible=False)
110
  output_e4e_mapper = gr.Image(type='pil', label="E4E Mapper")
111
  output_e4e_gd = gr.Image(type='pil', label="E4E Global Directions", visible=False)
112
+ output_e4e_ris = gr.Image(type='pil', label='E4E RIS', visible=False)
113
+ with gr.Row() as e4e_metrics:
114
+ output_e4e_metrics = gr.Text()
115
  def n_iter_change(number):
116
  if number < 0:
117
  return 0
 
127
  e4e_images: gr.update(visible=e4e_bool),
128
  n_hyperstyle_iterations: gr.update(visible=hyperstyle_bool)
129
  }
130
+ def outp_toggles(bool):
131
+ return {
132
+ output_hyperstyle_invert: gr.update(visible=bool),
133
+ output_e4e_invert: gr.update(visible=bool)
134
+ }
135
  def mapper_toggles(bool):
136
  return {
137
  mapper_opts: gr.update(visible=bool),
 
144
  output_hyperstyle_gd: gr.update(visible=bool),
145
  output_e4e_gd: gr.update(visible=bool)
146
  }
147
+ def ris_toggles(bool):
148
+ return {
149
+ ris_opts: gr.update(visible=bool),
150
+ output_hyperstyle_ris: gr.update(visible=bool),
151
+ output_e4e_ris: gr.update(visible=bool)
152
+ }
153
 
154
  n_hyperstyle_iterations.change(n_iter_change, n_hyperstyle_iterations, n_hyperstyle_iterations)
155
  mapper_choice.change(mapper_change, mapper_choice, [target_text])
156
  inverter_bools.change(inverter_toggles, inverter_bools, [hyperstyle_images, e4e_images, n_hyperstyle_iterations])
157
+ invert_bool.change(outp_toggles, invert_bool, [output_hyperstyle_invert, output_e4e_invert])
158
  mapper_bool.change(mapper_toggles, mapper_bool, [mapper_opts, output_hyperstyle_mapper, output_e4e_mapper])
159
  gd_bool.change(gd_toggles, gd_bool, [gd_opts, output_hyperstyle_gd, output_e4e_gd])
160
+ ris_bool.change(ris_toggles, ris_bool, [ris_opts, output_hyperstyle_ris, output_e4e_ris])
161
  def map_latent(mapper, inputs, stylespace=False, weight_deltas=None, strength=0.1):
162
  w = inputs.to(device)
163
  with torch.no_grad():
 
174
  result_batch = (x_hat, w_hat)
175
  return result_batch
176
  def submit(
177
+ src, align_img, inverter_bools, n_iterations, invert_bool,
178
  mapper_bool, mapper_choice, mapper_alpha,
179
  gd_bool, neutral_text, target_text, alpha, beta,
180
+ ris_bool, ref_img,
181
  ):
182
  if device == 'cuda': torch.cuda.empty_cache()
183
  opts.checkpoint_path = mapper_dict[mapper_choice]
 
201
  opts.target_text = target_text
202
  opts.alpha = alpha
203
  opts.beta = beta
204
+
205
+ if ris_bool:
206
+ if align_img:
207
+ ref_input = align_face(ref_img, predictor)
208
+ else:
209
+ ref_input = Image.open(src).convert('RGB')
210
+ ref_input = im2tensor_transforms(ref_input).to(device)
211
 
212
  if 'Hyperstyle' in inverter_bools:
213
  hyperstyle_batch, hyperstyle_latents, hyperstyle_deltas, _ = run_inversion(input_img.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
214
+ if invert_bool:
215
+ invert_hyperstyle = tensor2im(hyperstyle_batch[0])
216
+ else:
217
+ invert_hyperstyle = None
218
  if mapper_bool:
219
  mapped_hyperstyle, _ = map_latent(mapper, hyperstyle_latents, stylespace=False, weight_deltas=hyperstyle_deltas, strength=mapper_alpha)
220
  mapped_hyperstyle = tensor2im(mapped_hyperstyle[0])
 
227
  else:
228
  gd_hyperstyle = None
229
 
230
+ if ris_bool:
231
+
232
+ ref_hyperstyle_batch, ref_hyperstyle_latents, ref_hyperstyle_deltas, _ = run_inversion(ref_input.unsqueeze(0), hyperstyle, hyperstyle_args, return_intermediate_results=False)
233
+ blend_hyperstyle, blend_hyperstyle_latents = blend_latents(hyperstyle_latents, ref_hyperstyle_batch,
234
+ src_deltas=hyperstyle_deltas, ref_deltas=ref_hyperstyle_deltas,
235
+ generator=ris_gen, device=device)
236
+ ris_hyperstyle = tensor2im(blend_hyperstyle)
237
+ else:
238
+ ris_hyperstyle=None
239
+
240
+ hyperstyle_output = [invert_hyperstyle, mapped_hyperstyle,gd_hyperstyle, ris_hyperstyle]
241
  else:
242
+ hyperstyle_output = [None, None, None, None]
243
  output_imgs.extend(hyperstyle_output)
244
  if 'E4E' in inverter_bools:
245
  e4e_batch, e4e_latents = hyperstyle.w_invert(input_img.unsqueeze(0))
246
  e4e_deltas = None
247
+ if invert_bool:
248
+ invert_e4e = tensor2im(e4e_batch[0])
249
+ else:
250
+ invert_e4e = None
251
  if mapper_bool:
252
  mapped_e4e, _ = map_latent(mapper, e4e_latents, stylespace=False, weight_deltas=e4e_deltas, strength=mapper_alpha)
253
  mapped_e4e = tensor2im(mapped_e4e[0])
 
260
  else:
261
  gd_e4e = None
262
 
263
+ if ris_bool:
264
+ ref_e4e_batch, ref_e4e_latents, = hyperstyle.w_invert(ref_input.unsqueeze(0))
265
+ ref_e4e_deltas= None
266
+ blend_e4e, blend_e4e_latents = blend_latents(e4e_latents, ref_e4e_batch,
267
+ src_deltas=None, ref_deltas=None,
268
+ generator=ris_gen, device=device)
269
+ ris_e4e = tensor2im(blend_e4e)
270
+ else:
271
+ ris_e4e=None
272
+
273
+ e4e_output = [invert_e4e, mapped_e4e, gd_e4e, ris_e4e]
274
  else:
275
+ e4e_output = [None, None, None, None]
276
  output_imgs.extend(e4e_output)
277
  return output_imgs
278
  submit_button.click(
279
  submit,
280
  [
281
+ source, align, inverter_bools, n_hyperstyle_iterations, invert_bool,
282
  mapper_bool, mapper_choice, mapper_alpha,
283
  gd_bool, neutral_text, target_text, alpha, beta,
284
+ ris_bool, ref_img
285
  ],
286
+ [output_hyperstyle_invert, output_hyperstyle_mapper, output_hyperstyle_gd, output_hyperstyle_ris,
287
+ output_e4e_invert, output_e4e_mapper, output_e4e_gd, output_e4e_ris]
288
  )
289
 
290
  demo.launch()
pretrained_models/ris/catalog.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1835e4a20709c43ec1cfd47d17f55473a8dc14fa1b5418880bab04cb6b9f9b26
3
+ size 857089
pretrained_models/ris/stylegan2-ffhq-config-f.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bae494ef77e32a9cd1792a81a3c167692a0e64f6bcd8b06592ff42917e2ed46e
3
+ size 381462551
requirements.txt CHANGED
@@ -6,4 +6,6 @@ numpy
6
  matplotlib
7
  opencv-python
8
  scipy
 
 
9
  git+https://github.com/openai/CLIP.git
 
6
  matplotlib
7
  opencv-python
8
  scipy
9
+ scikit-learn==0.22
10
+
11
  git+https://github.com/openai/CLIP.git
ris/__init__.py ADDED
File without changes
ris/blend.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import imp
2
+ import torch
3
+ import pickle
4
+
5
+ from .util import *
6
+ from .spherical_kmeans import MiniBatchSphericalKMeans as sKmeans
7
+
8
+
9
+
10
+
11
+ truncation = 0.5
12
+ stop_idx = 11
13
+ n_clusters = 18
14
+
15
+ clusterer = pickle.load(open('./pretrained_models/ris/catalog.pkl', 'rb'))
16
+
17
+ labels2idx = {
18
+ 'nose': 0,
19
+ 'eyes': 1,
20
+ 'mouth': 2,
21
+ 'hair': 3,
22
+ 'background': 4,
23
+ 'cheek': 5,
24
+ 'neck': 6,
25
+ 'clothes': 7,
26
+ }
27
+
28
+ labels_map = {
29
+ 0: torch.tensor([7]),
30
+ 1: torch.tensor([1,6]),
31
+ 2: torch.tensor([4]),
32
+ 3: torch.tensor([0,3,5,8,10,15,16]),
33
+ 4: torch.tensor([11,13,14]),
34
+ 5: torch.tensor([9]),
35
+ 6: torch.tensor([17]),
36
+ 7: torch.tensor([2,12]),
37
+ }
38
+
39
+ lables2idx = dict((v,k) for k,v in labels2idx.items())
40
+ n_class = len(lables2idx)
41
+
42
+ segid_map = dict.fromkeys(labels_map[0].tolist(), 0)
43
+ segid_map.update(dict.fromkeys(labels_map[1].tolist(), 1))
44
+ segid_map.update(dict.fromkeys(labels_map[2].tolist(), 2))
45
+ segid_map.update(dict.fromkeys(labels_map[3].tolist(), 3))
46
+ segid_map.update(dict.fromkeys(labels_map[4].tolist(), 4))
47
+ segid_map.update(dict.fromkeys(labels_map[5].tolist(), 5))
48
+ segid_map.update(dict.fromkeys(labels_map[6].tolist(), 6))
49
+ segid_map.update(dict.fromkeys(labels_map[7].tolist(), 7))
50
+
51
+ torch.manual_seed(0)
52
+
53
+
54
+ # compute M given a style code.
55
+ @torch.no_grad()
56
+ def compute_M(w, generator, weights_deltas=None, device='cuda'):
57
+ M = []
58
+
59
+ # get segmentation
60
+ # _, outputs = generator(w, is_cluster=1)
61
+ _, outputs = generator(w, weights_deltas=weights_deltas)
62
+ cluster_layer = outputs[stop_idx][0]
63
+ activation = flatten_act(cluster_layer)
64
+ seg_mask = clusterer.predict(activation)
65
+ b,c,h,w = cluster_layer.size()
66
+
67
+ # create masks for each feature
68
+ all_seg_mask = []
69
+ seg_mask = torch.from_numpy(seg_mask).view(b,1,h,w,1).to(device)
70
+
71
+ for key in range(n_class):
72
+ # combine masks for all indices for a particular segmentation class
73
+ indices = labels_map[key].view(1,1,1,1,-1)
74
+ key_mask = (seg_mask == indices.to(device)).any(-1) #[b,1,h,w]
75
+ all_seg_mask.append(key_mask)
76
+
77
+ all_seg_mask = torch.stack(all_seg_mask, 1)
78
+
79
+ # go through each activation layer and compute M
80
+ for layer_idx in range(len(outputs)):
81
+ layer = outputs[layer_idx][1].to(device)
82
+ b,c,h,w = layer.size()
83
+ layer = F.instance_norm(layer)
84
+ layer = layer.pow(2)
85
+
86
+ # resize the segmentation masks to current activations' resolution
87
+ layer_seg_mask = F.interpolate(all_seg_mask.flatten(0,1).float(), align_corners=False,
88
+ size=(h,w), mode='bilinear').view(b,-1,1,h,w)
89
+
90
+ masked_layer = layer.unsqueeze(1) * layer_seg_mask # [b,k,c,h,w]
91
+ masked_layer = (masked_layer.sum([3,4])/ (h*w))#[b,k,c]
92
+
93
+ M.append(masked_layer.to(device))
94
+
95
+ M = torch.cat(M, -1) #[b, k, c]
96
+
97
+ # softmax to assign each channel to a particular segmentation class
98
+ M = F.softmax(M/.1, 1)
99
+ # simple thresholding
100
+ M = (M>.8).float()
101
+
102
+ # zero out torgb transfers, from https://arxiv.org/abs/2011.12799
103
+ for i in range(n_class):
104
+ part_M = style2list(M[:, i])
105
+ for j in range(len(part_M)):
106
+ if j in rgb_layer_idx:
107
+ part_M[j].zero_()
108
+ part_M = list2style(part_M)
109
+ M[:, i] = part_M
110
+
111
+ return M
112
+
113
+ def blend_latents (source_latent, ref_latent, generator, src_deltas=None, ref_deltas=None, device='cuda'):
114
+ source = generator.get_latent(source_latent[0].unsqueeze(0), truncation=1, is_latent=True)
115
+ ref = generator.get_latent(ref_latent[0].unsqueeze(0), truncation=1, is_latent=True)
116
+ source_M = compute_M(source, generator, weights_deltas=src_deltas, device='cpu')
117
+ ref_M = compute_M(ref, generator, weights_deltas=ref_deltas, device='cpu')
118
+
119
+ blend_deltas = src_deltas
120
+
121
+ max_M = torch.max(source_M.expand_as(ref_M), ref_M)
122
+ max_M = add_pose(max_M, labels2idx)
123
+ idx = labels2idx['hair']
124
+
125
+ part_M = max_M[:, idx].to(device)
126
+ part_M_mask = style2list(part_M)
127
+
128
+ blend = style2list((add_direction(source, ref, part_M, 1.3)))
129
+ blend_out, _ = generator(blend, weights_deltas=blend_deltas)
130
+
131
+ return blend_out, blend
ris/e4e_projection.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ from argparse import Namespace
8
+ from e4e.models.psp import pSp
9
+ from util import *
10
+
11
+
12
+ @ torch.no_grad()
13
+ def projection(img, name, generator, device='cuda'):
14
+ model_path = 'e4e_ffhq_encode.pt'
15
+ ensure_checkpoint_exists(model_path)
16
+ ckpt = torch.load(model_path, map_location='cpu')
17
+ opts = ckpt['opts']
18
+ opts['checkpoint_path'] = model_path
19
+ opts= Namespace(**opts)
20
+ net = pSp(opts, device).eval().to(device)
21
+
22
+ transform = transforms.Compose(
23
+ [
24
+ transforms.Resize(256),
25
+ transforms.CenterCrop(256),
26
+ transforms.ToTensor(),
27
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
28
+ ]
29
+ )
30
+
31
+ img = transform(img).unsqueeze(0).to(device)
32
+ images, w_plus = net(img, randomize_noise=False, return_latents=True)
33
+ result_file = {}
34
+ filename = './inversion_codes/' + name + '.pt'
35
+ result_file['latent'] = w_plus[0]
36
+ torch.save(result_file, filename)
37
+
ris/legacy.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ import click
10
+ import pickle
11
+ import re
12
+ import copy
13
+ import numpy as np
14
+ import torch
15
+ import dnnlib
16
+ from torch_utils import misc
17
+
18
+ #----------------------------------------------------------------------------
19
+
20
+ def load_network_pkl(f, force_fp16=False):
21
+ data = _LegacyUnpickler(f).load()
22
+
23
+ # Legacy TensorFlow pickle => convert.
24
+ if isinstance(data, tuple) and len(data) == 3 and all(isinstance(net, _TFNetworkStub) for net in data):
25
+ tf_G, tf_D, tf_Gs = data
26
+ G = convert_tf_generator(tf_G)
27
+ D = convert_tf_discriminator(tf_D)
28
+ G_ema = convert_tf_generator(tf_Gs)
29
+ data = dict(G=G, D=D, G_ema=G_ema)
30
+
31
+ # Add missing fields.
32
+ if 'training_set_kwargs' not in data:
33
+ data['training_set_kwargs'] = None
34
+ if 'augment_pipe' not in data:
35
+ data['augment_pipe'] = None
36
+
37
+ # Validate contents.
38
+ assert isinstance(data['G'], torch.nn.Module)
39
+ assert isinstance(data['D'], torch.nn.Module)
40
+ assert isinstance(data['G_ema'], torch.nn.Module)
41
+ assert isinstance(data['training_set_kwargs'], (dict, type(None)))
42
+ assert isinstance(data['augment_pipe'], (torch.nn.Module, type(None)))
43
+
44
+ # Force FP16.
45
+ if force_fp16:
46
+ for key in ['G', 'D', 'G_ema']:
47
+ old = data[key]
48
+ kwargs = copy.deepcopy(old.init_kwargs)
49
+ if key.startswith('G'):
50
+ kwargs.synthesis_kwargs = dnnlib.EasyDict(kwargs.get('synthesis_kwargs', {}))
51
+ kwargs.synthesis_kwargs.num_fp16_res = 4
52
+ kwargs.synthesis_kwargs.conv_clamp = 256
53
+ if key.startswith('D'):
54
+ kwargs.num_fp16_res = 4
55
+ kwargs.conv_clamp = 256
56
+ if kwargs != old.init_kwargs:
57
+ new = type(old)(**kwargs).eval().requires_grad_(False)
58
+ misc.copy_params_and_buffers(old, new, require_all=True)
59
+ data[key] = new
60
+ return data
61
+
62
+ #----------------------------------------------------------------------------
63
+
64
+ class _TFNetworkStub(dnnlib.EasyDict):
65
+ pass
66
+
67
+ class _LegacyUnpickler(pickle.Unpickler):
68
+ def find_class(self, module, name):
69
+ if module == 'dnnlib.tflib.network' and name == 'Network':
70
+ return _TFNetworkStub
71
+ return super().find_class(module, name)
72
+
73
+ #----------------------------------------------------------------------------
74
+
75
+ def _collect_tf_params(tf_net):
76
+ # pylint: disable=protected-access
77
+ tf_params = dict()
78
+ def recurse(prefix, tf_net):
79
+ for name, value in tf_net.variables:
80
+ tf_params[prefix + name] = value
81
+ for name, comp in tf_net.components.items():
82
+ recurse(prefix + name + '/', comp)
83
+ recurse('', tf_net)
84
+ return tf_params
85
+
86
+ #----------------------------------------------------------------------------
87
+
88
+ def _populate_module_params(module, *patterns):
89
+ for name, tensor in misc.named_params_and_buffers(module):
90
+ found = False
91
+ value = None
92
+ for pattern, value_fn in zip(patterns[0::2], patterns[1::2]):
93
+ match = re.fullmatch(pattern, name)
94
+ if match:
95
+ found = True
96
+ if value_fn is not None:
97
+ value = value_fn(*match.groups())
98
+ break
99
+ try:
100
+ assert found
101
+ if value is not None:
102
+ tensor.copy_(torch.from_numpy(np.array(value)))
103
+ except:
104
+ print(name, list(tensor.shape))
105
+ raise
106
+
107
+ #----------------------------------------------------------------------------
108
+
109
+ def convert_tf_generator(tf_G):
110
+ if tf_G.version < 4:
111
+ raise ValueError('TensorFlow pickle version too low')
112
+
113
+ # Collect kwargs.
114
+ tf_kwargs = tf_G.static_kwargs
115
+ known_kwargs = set()
116
+ def kwarg(tf_name, default=None, none=None):
117
+ known_kwargs.add(tf_name)
118
+ val = tf_kwargs.get(tf_name, default)
119
+ return val if val is not None else none
120
+
121
+ # Convert kwargs.
122
+ kwargs = dnnlib.EasyDict(
123
+ z_dim = kwarg('latent_size', 512),
124
+ c_dim = kwarg('label_size', 0),
125
+ w_dim = kwarg('dlatent_size', 512),
126
+ img_resolution = kwarg('resolution', 1024),
127
+ img_channels = kwarg('num_channels', 3),
128
+ mapping_kwargs = dnnlib.EasyDict(
129
+ num_layers = kwarg('mapping_layers', 8),
130
+ embed_features = kwarg('label_fmaps', None),
131
+ layer_features = kwarg('mapping_fmaps', None),
132
+ activation = kwarg('mapping_nonlinearity', 'lrelu'),
133
+ lr_multiplier = kwarg('mapping_lrmul', 0.01),
134
+ w_avg_beta = kwarg('w_avg_beta', 0.995, none=1),
135
+ ),
136
+ synthesis_kwargs = dnnlib.EasyDict(
137
+ channel_base = kwarg('fmap_base', 16384) * 2,
138
+ channel_max = kwarg('fmap_max', 512),
139
+ num_fp16_res = kwarg('num_fp16_res', 0),
140
+ conv_clamp = kwarg('conv_clamp', None),
141
+ architecture = kwarg('architecture', 'skip'),
142
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
143
+ use_noise = kwarg('use_noise', True),
144
+ activation = kwarg('nonlinearity', 'lrelu'),
145
+ ),
146
+ )
147
+
148
+ # Check for unknown kwargs.
149
+ kwarg('truncation_psi')
150
+ kwarg('truncation_cutoff')
151
+ kwarg('style_mixing_prob')
152
+ kwarg('structure')
153
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
154
+ if len(unknown_kwargs) > 0:
155
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
156
+
157
+ # Collect params.
158
+ tf_params = _collect_tf_params(tf_G)
159
+ for name, value in list(tf_params.items()):
160
+ match = re.fullmatch(r'ToRGB_lod(\d+)/(.*)', name)
161
+ if match:
162
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
163
+ tf_params[f'{r}x{r}/ToRGB/{match.group(2)}'] = value
164
+ kwargs.synthesis.kwargs.architecture = 'orig'
165
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
166
+
167
+ # Convert params.
168
+ from training import networks
169
+ G = networks.Generator(**kwargs).eval().requires_grad_(False)
170
+ # pylint: disable=unnecessary-lambda
171
+ _populate_module_params(G,
172
+ r'mapping\.w_avg', lambda: tf_params[f'dlatent_avg'],
173
+ r'mapping\.embed\.weight', lambda: tf_params[f'mapping/LabelEmbed/weight'].transpose(),
174
+ r'mapping\.embed\.bias', lambda: tf_params[f'mapping/LabelEmbed/bias'],
175
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'mapping/Dense{i}/weight'].transpose(),
176
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'mapping/Dense{i}/bias'],
177
+ r'synthesis\.b4\.const', lambda: tf_params[f'synthesis/4x4/Const/const'][0],
178
+ r'synthesis\.b4\.conv1\.weight', lambda: tf_params[f'synthesis/4x4/Conv/weight'].transpose(3, 2, 0, 1),
179
+ r'synthesis\.b4\.conv1\.bias', lambda: tf_params[f'synthesis/4x4/Conv/bias'],
180
+ r'synthesis\.b4\.conv1\.noise_const', lambda: tf_params[f'synthesis/noise0'][0, 0],
181
+ r'synthesis\.b4\.conv1\.noise_strength', lambda: tf_params[f'synthesis/4x4/Conv/noise_strength'],
182
+ r'synthesis\.b4\.conv1\.affine\.weight', lambda: tf_params[f'synthesis/4x4/Conv/mod_weight'].transpose(),
183
+ r'synthesis\.b4\.conv1\.affine\.bias', lambda: tf_params[f'synthesis/4x4/Conv/mod_bias'] + 1,
184
+ r'synthesis\.b(\d+)\.conv0\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
185
+ r'synthesis\.b(\d+)\.conv0\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/bias'],
186
+ r'synthesis\.b(\d+)\.conv0\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-5}'][0, 0],
187
+ r'synthesis\.b(\d+)\.conv0\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/noise_strength'],
188
+ r'synthesis\.b(\d+)\.conv0\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_weight'].transpose(),
189
+ r'synthesis\.b(\d+)\.conv0\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv0_up/mod_bias'] + 1,
190
+ r'synthesis\.b(\d+)\.conv1\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/weight'].transpose(3, 2, 0, 1),
191
+ r'synthesis\.b(\d+)\.conv1\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/bias'],
192
+ r'synthesis\.b(\d+)\.conv1\.noise_const', lambda r: tf_params[f'synthesis/noise{int(np.log2(int(r)))*2-4}'][0, 0],
193
+ r'synthesis\.b(\d+)\.conv1\.noise_strength', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/noise_strength'],
194
+ r'synthesis\.b(\d+)\.conv1\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_weight'].transpose(),
195
+ r'synthesis\.b(\d+)\.conv1\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/Conv1/mod_bias'] + 1,
196
+ r'synthesis\.b(\d+)\.torgb\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/weight'].transpose(3, 2, 0, 1),
197
+ r'synthesis\.b(\d+)\.torgb\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/bias'],
198
+ r'synthesis\.b(\d+)\.torgb\.affine\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_weight'].transpose(),
199
+ r'synthesis\.b(\d+)\.torgb\.affine\.bias', lambda r: tf_params[f'synthesis/{r}x{r}/ToRGB/mod_bias'] + 1,
200
+ r'synthesis\.b(\d+)\.skip\.weight', lambda r: tf_params[f'synthesis/{r}x{r}/Skip/weight'][::-1, ::-1].transpose(3, 2, 0, 1),
201
+ r'.*\.resample_filter', None,
202
+ )
203
+ return G
204
+
205
+ #----------------------------------------------------------------------------
206
+
207
+ def convert_tf_discriminator(tf_D):
208
+ if tf_D.version < 4:
209
+ raise ValueError('TensorFlow pickle version too low')
210
+
211
+ # Collect kwargs.
212
+ tf_kwargs = tf_D.static_kwargs
213
+ known_kwargs = set()
214
+ def kwarg(tf_name, default=None):
215
+ known_kwargs.add(tf_name)
216
+ return tf_kwargs.get(tf_name, default)
217
+
218
+ # Convert kwargs.
219
+ kwargs = dnnlib.EasyDict(
220
+ c_dim = kwarg('label_size', 0),
221
+ img_resolution = kwarg('resolution', 1024),
222
+ img_channels = kwarg('num_channels', 3),
223
+ architecture = kwarg('architecture', 'resnet'),
224
+ channel_base = kwarg('fmap_base', 16384) * 2,
225
+ channel_max = kwarg('fmap_max', 512),
226
+ num_fp16_res = kwarg('num_fp16_res', 0),
227
+ conv_clamp = kwarg('conv_clamp', None),
228
+ cmap_dim = kwarg('mapping_fmaps', None),
229
+ block_kwargs = dnnlib.EasyDict(
230
+ activation = kwarg('nonlinearity', 'lrelu'),
231
+ resample_filter = kwarg('resample_kernel', [1,3,3,1]),
232
+ freeze_layers = kwarg('freeze_layers', 0),
233
+ ),
234
+ mapping_kwargs = dnnlib.EasyDict(
235
+ num_layers = kwarg('mapping_layers', 0),
236
+ embed_features = kwarg('mapping_fmaps', None),
237
+ layer_features = kwarg('mapping_fmaps', None),
238
+ activation = kwarg('nonlinearity', 'lrelu'),
239
+ lr_multiplier = kwarg('mapping_lrmul', 0.1),
240
+ ),
241
+ epilogue_kwargs = dnnlib.EasyDict(
242
+ mbstd_group_size = kwarg('mbstd_group_size', None),
243
+ mbstd_num_channels = kwarg('mbstd_num_features', 1),
244
+ activation = kwarg('nonlinearity', 'lrelu'),
245
+ ),
246
+ )
247
+
248
+ # Check for unknown kwargs.
249
+ kwarg('structure')
250
+ unknown_kwargs = list(set(tf_kwargs.keys()) - known_kwargs)
251
+ if len(unknown_kwargs) > 0:
252
+ raise ValueError('Unknown TensorFlow kwarg', unknown_kwargs[0])
253
+
254
+ # Collect params.
255
+ tf_params = _collect_tf_params(tf_D)
256
+ for name, value in list(tf_params.items()):
257
+ match = re.fullmatch(r'FromRGB_lod(\d+)/(.*)', name)
258
+ if match:
259
+ r = kwargs.img_resolution // (2 ** int(match.group(1)))
260
+ tf_params[f'{r}x{r}/FromRGB/{match.group(2)}'] = value
261
+ kwargs.architecture = 'orig'
262
+ #for name, value in tf_params.items(): print(f'{name:<50s}{list(value.shape)}')
263
+
264
+ # Convert params.
265
+ from training import networks
266
+ D = networks.Discriminator(**kwargs).eval().requires_grad_(False)
267
+ # pylint: disable=unnecessary-lambda
268
+ _populate_module_params(D,
269
+ r'b(\d+)\.fromrgb\.weight', lambda r: tf_params[f'{r}x{r}/FromRGB/weight'].transpose(3, 2, 0, 1),
270
+ r'b(\d+)\.fromrgb\.bias', lambda r: tf_params[f'{r}x{r}/FromRGB/bias'],
271
+ r'b(\d+)\.conv(\d+)\.weight', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/weight'].transpose(3, 2, 0, 1),
272
+ r'b(\d+)\.conv(\d+)\.bias', lambda r, i: tf_params[f'{r}x{r}/Conv{i}{["","_down"][int(i)]}/bias'],
273
+ r'b(\d+)\.skip\.weight', lambda r: tf_params[f'{r}x{r}/Skip/weight'].transpose(3, 2, 0, 1),
274
+ r'mapping\.embed\.weight', lambda: tf_params[f'LabelEmbed/weight'].transpose(),
275
+ r'mapping\.embed\.bias', lambda: tf_params[f'LabelEmbed/bias'],
276
+ r'mapping\.fc(\d+)\.weight', lambda i: tf_params[f'Mapping{i}/weight'].transpose(),
277
+ r'mapping\.fc(\d+)\.bias', lambda i: tf_params[f'Mapping{i}/bias'],
278
+ r'b4\.conv\.weight', lambda: tf_params[f'4x4/Conv/weight'].transpose(3, 2, 0, 1),
279
+ r'b4\.conv\.bias', lambda: tf_params[f'4x4/Conv/bias'],
280
+ r'b4\.fc\.weight', lambda: tf_params[f'4x4/Dense0/weight'].transpose(),
281
+ r'b4\.fc\.bias', lambda: tf_params[f'4x4/Dense0/bias'],
282
+ r'b4\.out\.weight', lambda: tf_params[f'Output/weight'].transpose(),
283
+ r'b4\.out\.bias', lambda: tf_params[f'Output/bias'],
284
+ r'.*\.resample_filter', None,
285
+ )
286
+ return D
287
+
288
+ #----------------------------------------------------------------------------
289
+
290
+ @click.command()
291
+ @click.option('--source', help='Input pickle', required=True, metavar='PATH')
292
+ @click.option('--dest', help='Output pickle', required=True, metavar='PATH')
293
+ @click.option('--force-fp16', help='Force the networks to use FP16', type=bool, default=False, metavar='BOOL', show_default=True)
294
+ def convert_network_pickle(source, dest, force_fp16):
295
+ """Convert legacy network pickle into the native PyTorch format.
296
+
297
+ The tool is able to load the main network configurations exported using the TensorFlow version of StyleGAN2 or StyleGAN2-ADA.
298
+ It does not support e.g. StyleGAN2-ADA comparison methods, StyleGAN2 configs A-D, or StyleGAN1 networks.
299
+
300
+ Example:
301
+
302
+ \b
303
+ python legacy.py \\
304
+ --source=https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/stylegan2-cat-config-f.pkl \\
305
+ --dest=stylegan2-cat-config-f.pkl
306
+ """
307
+ print(f'Loading "{source}"...')
308
+ with dnnlib.util.open_url(source) as f:
309
+ data = load_network_pkl(f, force_fp16=force_fp16)
310
+ print(f'Saving "{dest}"...')
311
+ with open(dest, 'wb') as f:
312
+ pickle.dump(data, f)
313
+ print('Done.')
314
+
315
+ #----------------------------------------------------------------------------
316
+
317
+ if __name__ == "__main__":
318
+ convert_network_pickle() # pylint: disable=no-value-for-parameter
319
+
320
+ #----------------------------------------------------------------------------
ris/manipulator.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import copy
3
+ import os
4
+ import time
5
+ from tqdm import tqdm
6
+
7
+ import numpy as np
8
+ import PIL.Image
9
+ import torch
10
+
11
+ import clip
12
+ from wrapper import (FaceLandmarksDetector, Generator_wrapper,
13
+ VGGFeatExtractor, e4eEncoder, PivotTuning)
14
+ from projector import project
15
+
16
+ class Manipulator():
17
+ """Manipulator for style editing
18
+
19
+ in paper, use 100 image pairs to estimate the mean for alpha(magnitude of the perturbation) [-5, 5]
20
+
21
+ *** Args ***
22
+ G : Genertor wrapper for synthesis styles
23
+ device : torch.device
24
+ lst_alpha : magnitude of the perturbation
25
+ num_images : num images to process
26
+
27
+ *** Attributes ***
28
+ S : List[dict(str, torch.Tensor)] # length 2,000
29
+ styles : List[dict(str, torch.Tensor)] # length of num_images
30
+ (num_images, style)
31
+ lst_alpha : List[int]
32
+ boundary : (num_images, len_alpha)
33
+ edited_styles : List[styles]
34
+ edited_images : List[(num_images, 3, 1024, 1024)]
35
+ """
36
+ def __init__(
37
+ self,
38
+ G,
39
+ device,
40
+ lst_alpha=[0],
41
+ num_images=1,
42
+ start_ind=0,
43
+ face_preprocess=True,
44
+ dataset_name=''
45
+ ):
46
+ """Initialize
47
+ - use pre-saved generated latent/style from random Z
48
+ - to use projection, used method "set_real_img_projection"
49
+ """
50
+ assert start_ind + num_images < 2000
51
+ self.W = torch.load(f'tensor/W{dataset_name}.pt')
52
+ self.S = torch.load(f'tensor/S{dataset_name}.pt')
53
+ self.S_mean = torch.load(f'tensor/S_mean{dataset_name}.pt')
54
+ self.S_std = torch.load(f'tensor/S_std{dataset_name}.pt')
55
+
56
+ self.S = {layer: self.S[layer].to(device) for layer in G.style_layers}
57
+ self.styles = {layer: self.S[layer][start_ind:start_ind+num_images] for layer in G.style_layers}
58
+ self.latent = self.W[start_ind:start_ind+num_images]
59
+ self.latent = self.latent.to(device)
60
+ del self.W
61
+ del self.S
62
+
63
+ # S_mean, S_std for extracting global style direction
64
+ self.S_mean = {layer: self.S_mean[layer].to(device) for layer in G.style_layers}
65
+ self.S_std = {layer: self.S_std[layer].to(device) for layer in G.style_layers}
66
+
67
+ # setting
68
+ self.face_preprocess = face_preprocess
69
+ if face_preprocess:
70
+ self.landmarks_detector = FaceLandmarksDetector()
71
+ self.vgg16 = VGGFeatExtractor(device).module
72
+ self.W_projector_steps = 200
73
+ self.G = G
74
+ self.device = device
75
+ self.num_images = num_images
76
+ self.lst_alpha = lst_alpha
77
+ self.manipulate_layers = [layer for layer in G.style_layers if 'torgb' not in layer]
78
+
79
+ def set_alpha(self, lst_alpha):
80
+ """Setter for alpha
81
+ """
82
+ self.lst_alpha = lst_alpha
83
+
84
+ def set_real_img_projection(self, img, inv_mode='w', pti_mode=None):
85
+ """Set real img instead of pre-saved styles
86
+ Args :
87
+ - img : img directory or img file path to manipulate
88
+ - face aligned if self.face_preprocess == True
89
+ - set self.num_images
90
+ - inv_mode : inversion mode, setting self.latent, self.styles
91
+ - w : use W projector (projector.project)
92
+ - w+ : use e4e encoder (wrapper.e4eEncoder)
93
+ - pti_mode : pivot tuning inversion mode (wrapper.PivotTuning)
94
+ - None
95
+ - w : W latent pivot tuning
96
+ - s : S style pivot tuning
97
+ """
98
+ assert inv_mode in ['w', 'w+']
99
+ assert pti_mode in [None, 'w', 's']
100
+ allowed_extensions = ['jpg', 'JPG', 'jpeg', 'JPEG', 'png', 'PNG']
101
+
102
+ # img directory input
103
+ if os.path.isdir(img):
104
+ imgpaths = sorted(os.listdir(img))
105
+ imgpaths = [os.path.join(img, imgpath)
106
+ for imgpath in imgpaths
107
+ if imgpath.split('.')[-1] in allowed_extensions]
108
+ # img file path input
109
+ else:
110
+ imgpaths = [img]
111
+
112
+ self.num_images = len(imgpaths)
113
+ if inv_mode == 'w':
114
+ targets = list()
115
+ target_pils = list()
116
+ for imgpath in imgpaths:
117
+ if self.face_preprocess:
118
+ target_pil = self.landmarks_detector(imgpath)
119
+ else:
120
+ target_pil = PIL.Image.open(imgpath).convert('RGB')
121
+ target_pils.append(target_pil)
122
+ w, h = target_pil.size
123
+ s = min(w, h)
124
+ target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
125
+ target_pil = target_pil.resize((self.G.G.img_resolution, self.G.G.img_resolution),
126
+ PIL.Image.LANCZOS)
127
+ target_uint8 = np.array(target_pil, dtype=np.uint8)
128
+ targets.append(torch.Tensor(target_uint8.transpose([2,0,1])).to(self.device))
129
+
130
+ self.latent = list()
131
+ for target in tqdm(targets, total=len(targets)):
132
+ projected_w_steps = project(
133
+ self.G.G,
134
+ self.vgg16,
135
+ target=target,
136
+ num_steps=self.W_projector_steps, # TODO get projector steps from configs
137
+ device=self.device,
138
+ verbose=False,
139
+ )
140
+ self.latent.append(projected_w_steps[-1])
141
+ self.latent = torch.stack(self.latent)
142
+ self.styles = self.G.mapping_stylespace(self.latent)
143
+
144
+ else: # inv_mode == 'w+'
145
+ # use e4e encoder
146
+ target_pils = list()
147
+ for imgpath in imgpaths:
148
+ if self.face_preprocess:
149
+ target_pil = self.landmarks_detector(imgpath)
150
+ else:
151
+ target_pil = PIL.Image.open(imgpath).convert('RGB')
152
+ target_pils.append(target_pil)
153
+
154
+ self.encoder = e4eEncoder(self.device)
155
+ self.latent = self.encoder(target_pils)
156
+ self.styles = self.G.mapping_stylespace(self.latent)
157
+
158
+ if pti_mode is not None: # w or s
159
+ # pivot tuning inversion
160
+ pti = PivotTuning(self.device, self.G.G, mode=pti_mode)
161
+ new_G = pti(self.latent, target_pils)
162
+ self.G.G = new_G
163
+
164
+ def manipulate(self, delta_s):
165
+ """Edit style by given delta_style
166
+ - use perturbation (delta s) * (alpha) as a boundary
167
+ """
168
+ styles = [copy.deepcopy(self.styles) for _ in range(len(self.lst_alpha))]
169
+
170
+ for (alpha, style) in zip(self.lst_alpha, styles):
171
+ for layer in self.G.style_layers:
172
+ perturbation = delta_s[layer] * alpha
173
+ style[layer] += perturbation
174
+ return styles
175
+
176
+ def manipulate_one_channel(self, layer, channel_ind:int):
177
+ """Edit style from given layer, channel index
178
+ - use mean value of pre-saved style
179
+ - use perturbation (pre-saved style std) * (alpha) as a boundary
180
+ """
181
+ assert layer in self.G.style_layers
182
+ assert 0 <= channel_ind < self.styles[layer].shape[1]
183
+ boundary = self.S_std[layer][channel_ind].item()
184
+ # apply self.S_mean value for given layer, channel_ind
185
+ for ind in range(self.num_images):
186
+ self.styles[layer][ind][channel_ind] = self.S_mean[layer][channel_ind]
187
+ styles = [copy.deepcopy(self.styles) for _ in range(len(self.lst_alpha))]
188
+
189
+ perturbation = (torch.Tensor(self.lst_alpha) * boundary).numpy().tolist()
190
+
191
+ # apply one channel manipulation
192
+ for img_ind in range(self.num_images):
193
+ for edit_ind, delta in enumerate(perturbation):
194
+ styles[edit_ind][layer][img_ind][channel_ind] += delta
195
+
196
+ return styles
197
+
198
+ def synthesis_from_styles(self, styles, start_ind, end_ind):
199
+ """Synthesis edited styles from styles, lst_alpha
200
+ """
201
+ styles_ = list()
202
+ for style in styles:
203
+ style_ = dict()
204
+ for layer in self.G.style_layers:
205
+ style_[layer] = style[layer][start_ind:end_ind].to(self.device)
206
+ styles_.append(style_)
207
+ print("synthesis_from_styles", type(style_))
208
+ imgs = [self.G.synthesis_from_stylespace(self.latent[start_ind:end_ind], style_).cpu()
209
+ for style_ in styles_]
210
+ return imgs
211
+
212
+
213
+ def extract_global_direction(G, device, lst_alpha, num_images, dataset_name=''):
214
+ """Extract global style direction in 100 images
215
+ """
216
+ assert len(lst_alpha) == 2
217
+ model, preprocess = clip.load("ViT-B/32", device=device)
218
+
219
+ # lindex in original tf version
220
+ manipulate_layers = [layer for layer in G.style_layers if 'torgb' not in layer]
221
+
222
+ # total channel: 6048 (1024 resolution)
223
+ resolution = G.G.img_resolution
224
+ latent = torch.randn([1,G.to_w_idx[f'G.synthesis.b{resolution}.torgb.affine']+1,512]).to(device) # 1024 -> 18, 512 -> 16, 256 -> 14
225
+ style = G.mapping_stylespace(latent)
226
+ cnt = 0
227
+ for layer in manipulate_layers:
228
+ cnt += style[layer].shape[1]
229
+ del latent
230
+ del style
231
+
232
+ # 1024 -> 6048 channels, 256 -> 4928 channels
233
+ print(f"total channels to manipulate: {cnt}")
234
+
235
+ manipulator = Manipulator(G, device, lst_alpha, num_images, face_preprocess=False, dataset_name=dataset_name)
236
+
237
+ all_feats = list()
238
+
239
+ for layer in manipulate_layers:
240
+ print(f'\nStyle manipulation in layer "{layer}"')
241
+ channel_num = manipulator.styles[layer].shape[1]
242
+
243
+ for channel_ind in tqdm(range(channel_num), total=channel_num):
244
+ styles = manipulator.manipulate_one_channel(layer, channel_ind)
245
+ # 2 * 100 images
246
+ batchsize = 10
247
+ nbatch = int(100 / batchsize)
248
+ feats = list()
249
+ for img_ind in range(0, nbatch): # batch size 10 * 2
250
+ start = img_ind*nbatch
251
+ end = img_ind*nbatch + batchsize
252
+ synth_imgs = manipulator.synthesis_from_styles(styles, start, end)
253
+ synth_imgs = [(synth_img.permute(0,2,3,1)*127.5+128).clamp(0,255).to(torch.uint8).numpy()
254
+ for synth_img in synth_imgs]
255
+ imgs = list()
256
+ for i in range(batchsize):
257
+ img0 = PIL.Image.fromarray(synth_imgs[0][i])
258
+ img1 = PIL.Image.fromarray(synth_imgs[1][i])
259
+ imgs.append(preprocess(img0).unsqueeze(0).to(device))
260
+ imgs.append(preprocess(img1).unsqueeze(0).to(device))
261
+ with torch.no_grad():
262
+ feat = model.encode_image(torch.cat(imgs))
263
+ feats.append(feat)
264
+ all_feats.append(torch.cat(feats).view([-1, 2, 512]).cpu())
265
+
266
+ all_feats = torch.stack(all_feats).numpy()
267
+
268
+ fs = all_feats
269
+ fs1=fs/np.linalg.norm(fs,axis=-1)[:,:,:,None]
270
+ fs2=fs1[:,:,1,:]-fs1[:,:,0,:] # 5*sigma - (-5)*sigma
271
+ fs3=fs2/np.linalg.norm(fs2,axis=-1)[:,:,None]
272
+ fs3=fs3.mean(axis=1)
273
+ fs3=fs3/np.linalg.norm(fs3,axis=-1)[:,None]
274
+
275
+ np.save(f'tensor/fs3{dataset_name}.npy', fs3) # global style direction
276
+
277
+
278
+ if __name__ == '__main__':
279
+ parser = argparse.ArgumentParser()
280
+
281
+ parser.add_argument('runtype', type=str, default='test')
282
+ parser.add_argument('--ckpt', type=str, default='pretrained/ffhq.pkl')
283
+ parser.add_argument('--face_preprocess', type=bool, default=True)
284
+ parser.add_argument('--dataset_name', type=str, default='')
285
+ args = parser.parse_args()
286
+
287
+ runtype = args.runtype
288
+ assert runtype in ['test', 'extract']
289
+
290
+ device = torch.device('cuda:0')
291
+ ckpt = args.ckpt
292
+ G = Generator(ckpt, device)
293
+
294
+ face_preprocess = args.face_preprocess
295
+ dataset_name = args.dataset_name
296
+
297
+ if runtype == 'test': # test manipulator
298
+ num_images = 100
299
+ lst_alpha = [-5, 0, 5]
300
+ layer = G.style_layers[6]
301
+ channel_ind = 501
302
+ manipulator = Manipulator(G, device, lst_alpha, num_images, face_preprocess=face_preprocess, dataset_name=dataset_name)
303
+ styles = manipulator.manipulate_one_channel(layer, channel_ind)
304
+ start_ind, end_ind= 0, 10
305
+ imgs = manipulator.synthesis_from_styles(styles, start_ind, end_ind)
306
+ print(len(imgs), imgs[0].shape)
307
+
308
+ elif runtype == 'extract': # extract global style direction from "tensor/S.pt"
309
+ num_images = 100
310
+ lst_alpha = [-5, 5]
311
+ extract_global_direction(G, device, lst_alpha, num_images, dataset_name=dataset_name)
ris/model.py ADDED
@@ -0,0 +1,786 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import functools
4
+ import operator
5
+
6
+ import torch
7
+ import torchvision
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+ from torch.autograd import Function
11
+
12
+ from .op import FusedLeakyReLU, fused_leaky_relu, upfirdn2d
13
+
14
+
15
+ class PixelNorm(nn.Module):
16
+ def __init__(self):
17
+ super().__init__()
18
+
19
+ def forward(self, input):
20
+ return input * torch.rsqrt(torch.mean(input ** 2, dim=1, keepdim=True) + 1e-8)
21
+
22
+ class To4d(nn.Module):
23
+ def __init__(self):
24
+ super().__init__()
25
+
26
+ def forward(self, input):
27
+ return input.view(*input.size(),1,1)
28
+
29
+ def make_kernel(k):
30
+ k = torch.tensor(k, dtype=torch.float32)
31
+
32
+ if k.ndim == 1:
33
+ k = k[None, :] * k[:, None]
34
+
35
+ k /= k.sum()
36
+
37
+ return k
38
+
39
+
40
+ class Upsample(nn.Module):
41
+ def __init__(self, kernel, factor=2):
42
+ super().__init__()
43
+
44
+ self.factor = factor
45
+ kernel = make_kernel(kernel) * (factor ** 2)
46
+ self.register_buffer('kernel', kernel)
47
+
48
+ p = kernel.shape[0] - factor
49
+
50
+ pad0 = (p + 1) // 2 + factor - 1
51
+ pad1 = p // 2
52
+
53
+ self.pad = (pad0, pad1)
54
+
55
+ def forward(self, input):
56
+ out = upfirdn2d(input, self.kernel, up=self.factor, down=1, pad=self.pad)
57
+
58
+ return out
59
+
60
+
61
+ class Downsample(nn.Module):
62
+ def __init__(self, kernel, factor=2):
63
+ super().__init__()
64
+
65
+ self.factor = factor
66
+ kernel = make_kernel(kernel)
67
+ self.register_buffer('kernel', kernel)
68
+
69
+ p = kernel.shape[0] - factor
70
+
71
+ pad0 = (p + 1) // 2
72
+ pad1 = p // 2
73
+
74
+ self.pad = (pad0, pad1)
75
+
76
+ def forward(self, input):
77
+ out = upfirdn2d(input, self.kernel, up=1, down=self.factor, pad=self.pad)
78
+
79
+ return out
80
+
81
+
82
+ class Blur(nn.Module):
83
+ def __init__(self, kernel, pad, upsample_factor=1):
84
+ super().__init__()
85
+
86
+ kernel = make_kernel(kernel)
87
+
88
+ if upsample_factor > 1:
89
+ kernel = kernel * (upsample_factor ** 2)
90
+
91
+ self.register_buffer('kernel', kernel)
92
+
93
+ self.pad = pad
94
+
95
+ def forward(self, input):
96
+ out = upfirdn2d(input, self.kernel, pad=self.pad)
97
+
98
+ return out
99
+
100
+
101
+ class EqualConv2d(nn.Module):
102
+ def __init__(
103
+ self, in_channel, out_channel, kernel_size, groups=1, stride=1, padding=0, bias=True, lr_mul=1
104
+ ):
105
+ super().__init__()
106
+
107
+ self.weight = nn.Parameter(
108
+ torch.randn(out_channel, in_channel//groups, kernel_size, kernel_size).div_(lr_mul)
109
+ )
110
+ self.scale = lr_mul / math.sqrt((in_channel//groups) * kernel_size ** 2)
111
+
112
+ self.stride = stride
113
+ self.padding = padding
114
+ self.groups = groups
115
+ self.lr_mul =lr_mul
116
+
117
+ if bias:
118
+ self.bias = nn.Parameter(torch.zeros(out_channel))
119
+
120
+ else:
121
+ self.bias = None
122
+
123
+ def forward(self, input):
124
+ bias = self.bias * self.lr_mul if self.bias is not None else None
125
+ out = F.conv2d(
126
+ input,
127
+ self.weight * self.scale,
128
+ bias=self.bias,
129
+ stride=self.stride,
130
+ padding=self.padding,
131
+ groups=self.groups
132
+ )
133
+
134
+ return out
135
+
136
+ def __repr__(self):
137
+ return (
138
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]},'
139
+ f' {self.weight.shape[2]}, stride={self.stride}, padding={self.padding})'
140
+ )
141
+
142
+
143
+ class EqualLinear(nn.Module):
144
+ def __init__(
145
+ self, in_dim, out_dim, bias=True, bias_init=0, lr_mul=1, activation=None
146
+ ):
147
+ super().__init__()
148
+
149
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim).div_(lr_mul))
150
+
151
+ if bias:
152
+ self.bias = nn.Parameter(torch.zeros(out_dim).fill_(bias_init))
153
+
154
+ else:
155
+ self.bias = None
156
+
157
+ self.activation = activation
158
+
159
+ self.scale = (1 / math.sqrt(in_dim)) * lr_mul
160
+ self.lr_mul = lr_mul
161
+
162
+ def forward(self, input):
163
+ if self.activation:
164
+ out = F.linear(input, self.weight * self.scale)
165
+ out = fused_leaky_relu(out, self.bias * self.lr_mul)
166
+
167
+ else:
168
+ out = F.linear(
169
+ input, self.weight * self.scale, bias=self.bias * self.lr_mul
170
+ )
171
+
172
+ return out
173
+
174
+ def __repr__(self):
175
+ return (
176
+ f'{self.__class__.__name__}({self.weight.shape[1]}, {self.weight.shape[0]})'
177
+ )
178
+
179
+
180
+ class ScaledLeakyReLU(nn.Module):
181
+ def __init__(self, negative_slope=0.2):
182
+ super().__init__()
183
+
184
+ self.negative_slope = negative_slope
185
+
186
+ def forward(self, input):
187
+ out = F.leaky_relu(input, negative_slope=self.negative_slope)
188
+
189
+ return out * math.sqrt(2)
190
+
191
+
192
+ class ModulatedConv2d(nn.Module):
193
+ def __init__(
194
+ self,
195
+ in_channel,
196
+ out_channel,
197
+ kernel_size,
198
+ style_dim,
199
+ demodulate=True,
200
+ upsample=False,
201
+ downsample=False,
202
+ blur_kernel=[1, 3, 3, 1],
203
+ ):
204
+ super().__init__()
205
+
206
+ self.eps = 1e-8
207
+ self.kernel_size = kernel_size
208
+ self.in_channel = in_channel
209
+ self.out_channel = out_channel
210
+ self.upsample = upsample
211
+ self.downsample = downsample
212
+
213
+ if upsample:
214
+ factor = 2
215
+ p = (len(blur_kernel) - factor) - (kernel_size - 1)
216
+ pad0 = (p + 1) // 2 + factor - 1
217
+ pad1 = p // 2 + 1
218
+
219
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1), upsample_factor=factor)
220
+
221
+ if downsample:
222
+ factor = 2
223
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
224
+ pad0 = (p + 1) // 2
225
+ pad1 = p // 2
226
+
227
+ self.blur = Blur(blur_kernel, pad=(pad0, pad1))
228
+
229
+ fan_in = in_channel * kernel_size ** 2
230
+ self.scale = 1 / math.sqrt(fan_in)
231
+ self.padding = kernel_size // 2
232
+
233
+ self.weight = nn.Parameter(
234
+ torch.randn(1, out_channel, in_channel, kernel_size, kernel_size)
235
+ )
236
+
237
+ self.modulation = EqualLinear(style_dim, in_channel, bias_init=1)
238
+
239
+ self.demodulate = demodulate
240
+
241
+ def __repr__(self):
242
+ return (
243
+ f'{self.__class__.__name__}({self.in_channel}, {self.out_channel}, {self.kernel_size}, '
244
+ f'upsample={self.upsample}, downsample={self.downsample})'
245
+ )
246
+ def get_latent(self, style):
247
+ style = self.modulation(style)
248
+ return style
249
+
250
+ def forward(self, input, style, weights_delta=None):
251
+ batch, in_channel, height, width = input.shape
252
+
253
+ # style = self.modulation(style).view(batch, 1, in_channel, 1, 1)
254
+ style = style.view(batch, 1, in_channel, 1, 1)
255
+
256
+
257
+ if weights_delta is None:
258
+ weight = self.scale * self.weight * style
259
+ else:
260
+ weight = self.scale * (self.weight * (1 + weights_delta) * style)
261
+
262
+
263
+ if self.demodulate:
264
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + 1e-8)
265
+ weight = weight * demod.view(batch, self.out_channel, 1, 1, 1)
266
+
267
+ weight = weight.view(
268
+ batch * self.out_channel, in_channel, self.kernel_size, self.kernel_size
269
+ )
270
+
271
+ if self.upsample:
272
+ input = input.view(1, batch * in_channel, height, width)
273
+ weight = weight.view(
274
+ batch, self.out_channel, in_channel, self.kernel_size, self.kernel_size
275
+ )
276
+ weight = weight.transpose(1, 2).reshape(
277
+ batch * in_channel, self.out_channel, self.kernel_size, self.kernel_size
278
+ )
279
+ out = F.conv_transpose2d(input, weight, padding=0, stride=2, groups=batch)
280
+ _, _, height, width = out.shape
281
+ out = out.view(batch, self.out_channel, height, width)
282
+ out = self.blur(out)
283
+
284
+ elif self.downsample:
285
+ input = self.blur(input)
286
+ _, _, height, width = input.shape
287
+ input = input.view(1, batch * in_channel, height, width)
288
+ out = F.conv2d(input, weight, padding=0, stride=2, groups=batch)
289
+ _, _, height, width = out.shape
290
+ out = out.view(batch, self.out_channel, height, width)
291
+
292
+ else:
293
+ input = input.view(1, batch * in_channel, height, width)
294
+ out = F.conv2d(input, weight, padding=self.padding, groups=batch)
295
+ _, _, height, width = out.shape
296
+ out = out.view(batch, self.out_channel, height, width)
297
+
298
+ return out
299
+
300
+
301
+ class NoiseInjection(nn.Module):
302
+ def __init__(self):
303
+ super().__init__()
304
+
305
+ self.weight = nn.Parameter(torch.zeros(1))
306
+
307
+ def forward(self, image, noise=None):
308
+ if noise is None:
309
+ batch, _, height, width = image.shape
310
+ noise = image.new_empty(batch, 1, height, width).normal_()
311
+
312
+ return image + self.weight * noise
313
+
314
+
315
+ class ConstantInput(nn.Module):
316
+ def __init__(self, channel, size=4):
317
+ super().__init__()
318
+
319
+ self.input = nn.Parameter(torch.randn(1, channel, size, size))
320
+
321
+ def forward(self, input):
322
+ batch = input.shape[0]
323
+ out = self.input.repeat(batch, 1, 1, 1)
324
+
325
+ return out
326
+
327
+
328
+ class StyledConv(nn.Module):
329
+ def __init__(
330
+ self,
331
+ in_channel,
332
+ out_channel,
333
+ kernel_size,
334
+ style_dim,
335
+ upsample=False,
336
+ blur_kernel=[1, 3, 3, 1],
337
+ demodulate=True,
338
+ ):
339
+ super().__init__()
340
+
341
+ self.conv = ModulatedConv2d(
342
+ in_channel,
343
+ out_channel,
344
+ kernel_size,
345
+ style_dim,
346
+ upsample=upsample,
347
+ blur_kernel=blur_kernel,
348
+ demodulate=demodulate,
349
+ )
350
+
351
+ self.noise = NoiseInjection()
352
+ # self.bias = nn.Parameter(torch.zeros(1, out_channel, 1, 1))
353
+ # self.activate = ScaledLeakyReLU(0.2)
354
+ self.activate = FusedLeakyReLU(out_channel)
355
+
356
+ def get_latent(self, style):
357
+ return self.conv.get_latent(style)
358
+ def forward(self, input, style, noise=None, weights_delta=None):
359
+ out_t = self.conv(input, style, weights_delta=weights_delta)
360
+ out = self.noise(out_t, noise=noise)
361
+ # out = out + self.bias
362
+ out = self.activate(out)
363
+
364
+ return out, out_t
365
+
366
+
367
+ class ToRGB(nn.Module):
368
+ def __init__(self, in_channel, style_dim, upsample=True, blur_kernel=[1, 3, 3, 1]):
369
+ super().__init__()
370
+
371
+ if upsample:
372
+ self.upsample = Upsample(blur_kernel)
373
+
374
+ self.conv = ModulatedConv2d(in_channel, 3, 1, style_dim, demodulate=False)
375
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
376
+
377
+ def get_latent(self, style):
378
+ return self.conv.get_latent(style)
379
+ def forward(self, input, style, skip=None, weights_delta=None):
380
+ out = self.conv(input, style, weights_delta)
381
+ out = out + self.bias
382
+
383
+ if skip is not None:
384
+ skip = self.upsample(skip)
385
+
386
+ out = out + skip
387
+
388
+ return out
389
+
390
+
391
+ class Generator(nn.Module):
392
+ def __init__(
393
+ self,
394
+ size,
395
+ style_dim,
396
+ n_mlp,
397
+ channel_multiplier=2,
398
+ blur_kernel=[1, 3, 3, 1],
399
+ lr_mlp=0.01,
400
+ ):
401
+ super().__init__()
402
+
403
+ self.size = size
404
+
405
+ self.style_dim = style_dim
406
+
407
+ layers = [PixelNorm()]
408
+
409
+ for i in range(n_mlp):
410
+ layers.append(
411
+ EqualLinear(
412
+ style_dim, style_dim, lr_mul=lr_mlp, activation='fused_lrelu'
413
+ )
414
+ )
415
+
416
+ self.style = nn.Sequential(*layers)
417
+
418
+ self.channels = {
419
+ 4: 512,
420
+ 8: 512,
421
+ 16: 512,
422
+ 32: 512,
423
+ 64: 256 * channel_multiplier,
424
+ 128: 128 * channel_multiplier,
425
+ 256: 64 * channel_multiplier,
426
+ 512: 32 * channel_multiplier,
427
+ 1024: 16 * channel_multiplier,
428
+ }
429
+
430
+ self.input = ConstantInput(self.channels[4])
431
+ self.conv1 = StyledConv(
432
+ self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
433
+ )
434
+ self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
435
+
436
+ self.log_size = int(math.log(size, 2))
437
+ self.num_layers = (self.log_size - 2) * 2 + 1
438
+
439
+ self.convs = nn.ModuleList()
440
+ self.upsamples = nn.ModuleList()
441
+ self.to_rgbs = nn.ModuleList()
442
+ self.noises = nn.Module()
443
+
444
+ in_channel = self.channels[4]
445
+
446
+ for layer_idx in range(self.num_layers):
447
+ res = (layer_idx + 5) // 2
448
+ shape = [1, 1, 2 ** res, 2 ** res]
449
+ self.noises.register_buffer(f'noise_{layer_idx}', torch.randn(*shape))
450
+
451
+ for i in range(3, self.log_size + 1):
452
+ out_channel = self.channels[2 ** i]
453
+
454
+ self.convs.append(
455
+ StyledConv(
456
+ in_channel,
457
+ out_channel,
458
+ 3,
459
+ style_dim,
460
+ upsample=True,
461
+ blur_kernel=blur_kernel,
462
+ )
463
+ )
464
+
465
+ self.convs.append(
466
+ StyledConv(
467
+ out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
468
+ )
469
+ )
470
+
471
+ self.to_rgbs.append(ToRGB(out_channel, style_dim))
472
+
473
+ in_channel = out_channel
474
+
475
+ self.n_latent = self.log_size * 2 - 2
476
+
477
+ def make_noise(self):
478
+ device = self.input.input.device
479
+
480
+ noises = [torch.randn(1, 1, 2 ** 2, 2 ** 2, device=device)]
481
+
482
+ for i in range(3, self.log_size + 1):
483
+ for _ in range(2):
484
+ noises.append(torch.randn(1, 1, 2 ** i, 2 ** i, device=device))
485
+
486
+ return noises
487
+
488
+ def mean_latent(self, n_latent):
489
+ latent_in = torch.randn( n_latent, self.style_dim, device=self.input.input.device)
490
+ latent = self.get_latent(latent_in)#.mean(0, keepdim=True)
491
+ latent = [latent[i].mean(0, keepdim=True) for i in range(len(latent))]
492
+
493
+ return latent
494
+
495
+ def get_w(self, input):
496
+ latent = self.style(input)
497
+ latent = fused_leaky_relu(latent, torch.zeros_like(latent).cuda(), 5.)
498
+ return latent
499
+
500
+ def get_latent(self, input, is_latent=False, truncation=1, mean_latent=None):
501
+ output = []
502
+ if not is_latent:
503
+ latent = self.style(input)
504
+ latent = latent.unsqueeze(1).repeat(1, self.n_latent, 1) #[B, 14, 512]
505
+ else:
506
+ latent = input
507
+ output.append(self.conv1.get_latent(latent[:, 0]))
508
+ output.append(self.to_rgb1.get_latent(latent[:, 1]))
509
+
510
+ i = 1
511
+ # print("Get latent dimensions:")
512
+ for conv1, conv2, to_rgb in zip(self.convs[::2], self.convs[1::2], self.to_rgbs):
513
+ # print(f'{i}: {conv1.get_latent(latent[:, i]).shape}')
514
+ # print(f'{i+1}: {conv2.get_latent(latent[:, i+1]).shape}')
515
+ # print(f'{i+2}: {to_rgb.get_latent(latent[:, i+2]).shape}')
516
+ # print("")
517
+ output.append(conv1.get_latent(latent[:, i]))
518
+ output.append(conv2.get_latent(latent[:, i+1]))
519
+ output.append(to_rgb.get_latent(latent[:, i+2]))
520
+ i += 2
521
+
522
+ # output = torch.cat(output, 1)
523
+
524
+ if truncation < 1 and mean_latent is not None:
525
+ output = [mean_latent[i] + truncation * (output[i] - mean_latent[i]) for i in range(len(output))]
526
+
527
+ return output
528
+
529
+ def forward(
530
+ self,
531
+ styles,
532
+ stop_idx=99,
533
+ is_cluster=False,
534
+ noise=None,
535
+ randomize_noise=False,
536
+ weights_deltas=None,
537
+ ):
538
+ total_convs = len(self.convs) + len(self.to_rgbs) +2
539
+ if weights_deltas is None:
540
+ weights_deltas = [None]* total_convs
541
+ if noise is None:
542
+ if randomize_noise:
543
+ noise = [None] * self.num_layers
544
+ else:
545
+ noise = [
546
+ getattr(self.noises, f'noise_{i}') for i in range(self.num_layers)
547
+ ]
548
+
549
+ outputs = []
550
+ idx_count = 0
551
+
552
+ latent = styles
553
+ out = self.input(latent[0])
554
+ outputs.append([out, out])
555
+ if idx_count == stop_idx:
556
+ return outputs
557
+
558
+ out, out_t = self.conv1(out, latent[idx_count], noise=noise[0],weights_delta=weights_deltas[0])
559
+ outputs.append([out_t, out])
560
+ idx_count += 1
561
+ if idx_count == stop_idx:
562
+ return outputs
563
+
564
+ skip = self.to_rgb1(out, latent[idx_count], weights_delta=weights_deltas[1])
565
+
566
+ i = 1
567
+ weight_idx = 2
568
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
569
+ self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
570
+ ):
571
+ outputs.append([out_t, out])
572
+ idx_count += 1
573
+ if idx_count == stop_idx:
574
+ return outputs
575
+
576
+ out, out_t = conv1(out, latent[idx_count], noise=noise1, weights_delta=weights_deltas[weight_idx])
577
+ outputs.append([out_t, out])
578
+ idx_count += 1
579
+ if idx_count == stop_idx:
580
+ return outputs
581
+
582
+ out, out_t = conv2(out, latent[idx_count], noise=noise2, weights_delta=weights_deltas[weight_idx+1])
583
+ outputs.append([out_t, out])
584
+ idx_count += 1
585
+ if idx_count == stop_idx:
586
+ return outputs
587
+
588
+ skip = to_rgb(out, latent[idx_count], skip, weights_delta=weights_deltas[weight_idx+2])
589
+
590
+ i += 2
591
+ weight_idx += 3
592
+ image = skip.clamp(-1,1)
593
+ return image, outputs
594
+
595
+
596
+ class ConvLayer(nn.Sequential):
597
+ def __init__(
598
+ self,
599
+ in_channel,
600
+ out_channel,
601
+ kernel_size,
602
+ groups=1,
603
+ downsample=False,
604
+ blur_kernel=[1, 3, 3, 1],
605
+ bias=True,
606
+ activate=True,
607
+ lr_mul=1,
608
+ ):
609
+ layers = []
610
+
611
+ if downsample:
612
+ factor = 2
613
+ p = (len(blur_kernel) - factor) + (kernel_size - 1)
614
+ pad0 = (p + 1) // 2
615
+ pad1 = p // 2
616
+
617
+ layers.append(Blur(blur_kernel, pad=(pad0, pad1)))
618
+
619
+ stride = 2
620
+ self.padding = 0
621
+
622
+ else:
623
+ stride = 1
624
+ self.padding = kernel_size // 2
625
+
626
+ layers.append(
627
+ EqualConv2d(
628
+ in_channel,
629
+ out_channel,
630
+ kernel_size,
631
+ groups=groups,
632
+ padding=self.padding,
633
+ stride=stride,
634
+ bias=bias and not activate,
635
+ lr_mul=lr_mul,
636
+ )
637
+ )
638
+
639
+ if activate:
640
+ if bias:
641
+ layers.append(FusedLeakyReLU(out_channel, lr_mul=lr_mul))
642
+
643
+ else:
644
+ layers.append(ScaledLeakyReLU(0.2))
645
+
646
+ super().__init__(*layers)
647
+
648
+
649
+
650
+ class ResBlock(nn.Module):
651
+ def __init__(self, in_channel, out_channel, blur_kernel=[1, 3, 3, 1]):
652
+ super().__init__()
653
+
654
+ self.conv1 = ConvLayer(in_channel, in_channel, 3)
655
+ self.conv2 = ConvLayer(in_channel, out_channel, 3, downsample=True)
656
+
657
+ self.skip = ConvLayer(
658
+ in_channel, out_channel, 1, downsample=True, activate=False, bias=False
659
+ )
660
+
661
+ def forward(self, input):
662
+ out = self.conv1(input)
663
+ out = self.conv2(out)
664
+
665
+ skip = self.skip(input)
666
+ out = (out + skip) / math.sqrt(2)
667
+
668
+ return out
669
+
670
+
671
+ class Discriminator(nn.Module):
672
+ def __init__(self, size, channel_multiplier=2, blur_kernel=[1, 3, 3, 1]):
673
+ super().__init__()
674
+
675
+ channels = {
676
+ 4: 512,
677
+ 8: 512,
678
+ 16: 512,
679
+ 32: 512,
680
+ 64: 256 * channel_multiplier,
681
+ 128: 128 * channel_multiplier,
682
+ 256: 64 * channel_multiplier,
683
+ 512: 32 * channel_multiplier,
684
+ 1024: 16 * channel_multiplier,
685
+ }
686
+
687
+ convs = [ConvLayer(3, channels[size], 1)]
688
+
689
+ log_size = int(math.log(size, 2))
690
+
691
+ in_channel = channels[size]
692
+
693
+ for i in range(log_size, 2, -1):
694
+ out_channel = channels[2 ** (i - 1)]
695
+
696
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
697
+
698
+ in_channel = out_channel
699
+
700
+ self.convs = nn.Sequential(*convs)
701
+
702
+ self.stddev_group = 4
703
+ self.stddev_feat = 1
704
+
705
+ self.final_conv = ConvLayer(in_channel + 1, channels[4], 3)
706
+ self.final_linear = nn.Sequential(
707
+ EqualLinear(channels[4] * 4 * 4, channels[4], activation='fused_lrelu'),
708
+ EqualLinear(channels[4], 1),
709
+ )
710
+
711
+ def forward(self, input):
712
+ out = self.convs(input)
713
+
714
+ batch, channel, height, width = out.shape
715
+ group = min(batch, self.stddev_group)
716
+ #group = batch
717
+ stddev = out.view(
718
+ group, -1, self.stddev_feat, channel // self.stddev_feat, height, width
719
+ )
720
+ stddev = torch.sqrt(stddev.var(0, unbiased=False) + 1e-8)
721
+ stddev = stddev.mean([2, 3, 4], keepdims=True).squeeze(2)
722
+ stddev = stddev.repeat(group, 1, height, width)
723
+ out = torch.cat([out, stddev], 1)
724
+
725
+ out = self.final_conv(out)
726
+
727
+ out = out.view(batch, -1)
728
+ out = self.final_linear(out)
729
+
730
+ return out
731
+
732
+ class VGGExtractor(torch.nn.Module):
733
+ def __init__(self, resize=False):
734
+ super(VGGExtractor, self).__init__()
735
+ vgg16 = torchvision.models.vgg16(pretrained=True).eval()
736
+ blocks = vgg16.features[:23]
737
+ for p in blocks:
738
+ p.requires_grad = False
739
+ self.blocks = blocks
740
+ self.transform = torch.nn.functional.interpolate
741
+ self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1,3,1,1))
742
+ self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1,3,1,1))
743
+ self.resize = resize
744
+
745
+ def forward(self, input):
746
+ if input.shape[1] != 3:
747
+ input = input.repeat(1, 3, 1, 1)
748
+ input = (input + 1) / 2
749
+ input = (input-self.mean) / self.std
750
+ if self.resize:
751
+ input = self.transform(input, mode='bilinear', size=(224, 224), align_corners=False)
752
+ return self.blocks(input)
753
+
754
+ class Encoder(nn.Module):
755
+ def __init__(self, size, groups, channel_multiplier=1, blur_kernel=[1, 3, 3, 1]):
756
+ '''
757
+ [16]: [14,15,16,17,18,19]
758
+ [8]: [8,9,10,11,12,13]
759
+ [4]: [0,1,2,3,4,5,6,7]
760
+ '''
761
+ super().__init__()
762
+ in_channel = 3
763
+ out_channel = 64
764
+
765
+ convs = nn.ModuleList()
766
+ for i in range(6):
767
+ convs.append(ResBlock(in_channel, out_channel, blur_kernel))
768
+ in_channel = out_channel
769
+ out_channel = min(1024, in_channel*2)
770
+
771
+ self.fc_high = nn.Sequential(nn.AdaptiveAvgPool2d(4),
772
+ nn.Flatten(),
773
+ EqualLinear(512*4*4, 4*512+3*256+2*128))
774
+ self.fc_mid = nn.Sequential(nn.AdaptiveAvgPool2d(4),
775
+ nn.Flatten(),
776
+ EqualLinear(1024*4*4, 512*6))
777
+ self.fc_low = nn.Sequential(nn.AdaptiveAvgPool2d(4),
778
+ nn.Flatten(),
779
+ EqualLinear(1024*4*4, 512*5))
780
+
781
+ def forward(self, input):
782
+ shared = self.convs(input)
783
+ local = self.local_fc(shared)
784
+ glob = self.global_fc(shared)
785
+ return local.view(local.size(0), -1), glob
786
+
ris/op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
ris/op/fused_act.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ module_path = os.path.dirname(__file__)
8
+
9
+
10
+
11
+ class FusedLeakyReLU(nn.Module):
12
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
13
+ super().__init__()
14
+
15
+ self.bias = nn.Parameter(torch.zeros(channel))
16
+ self.negative_slope = negative_slope
17
+ self.scale = scale
18
+
19
+ def forward(self, input):
20
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
21
+
22
+
23
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
24
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
25
+ if input.ndim == 3:
26
+ return (
27
+ F.leaky_relu(
28
+ input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope
29
+ )
30
+ * scale
31
+ )
32
+ else:
33
+ return (
34
+ F.leaky_relu(
35
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
36
+ )
37
+ * scale
38
+ )
39
+
ris/op/upfirdn2d.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ module_path = os.path.dirname(__file__)
8
+
9
+
10
+
11
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
12
+ out = upfirdn2d_native(
13
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
14
+ )
15
+
16
+ return out
17
+
18
+
19
+ def upfirdn2d_native(
20
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
21
+ ):
22
+ _, channel, in_h, in_w = input.shape
23
+ input = input.reshape(-1, in_h, in_w, 1)
24
+
25
+ _, in_h, in_w, minor = input.shape
26
+ kernel_h, kernel_w = kernel.shape
27
+
28
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
29
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
30
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
31
+
32
+ out = F.pad(
33
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
34
+ )
35
+ out = out[
36
+ :,
37
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
38
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
39
+ :,
40
+ ]
41
+
42
+ out = out.permute(0, 3, 1, 2)
43
+ out = out.reshape(
44
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
45
+ )
46
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
47
+ out = F.conv2d(out, w)
48
+ out = out.reshape(
49
+ -1,
50
+ minor,
51
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
52
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
53
+ )
54
+ out = out.permute(0, 2, 3, 1)
55
+ out = out[:, ::down_y, ::down_x, :]
56
+
57
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
58
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
59
+
60
+ return out.view(-1, channel, out_h, out_w)
ris/projector.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION is strictly prohibited.
8
+
9
+ """Project given image to the latent space of pretrained network pickle."""
10
+
11
+ import copy
12
+ import os
13
+ from time import perf_counter
14
+
15
+ import click
16
+ import imageio
17
+ import numpy as np
18
+ import PIL.Image
19
+ import torch
20
+ import torch.nn.functional as F
21
+
22
+ import dnnlib
23
+ import legacy
24
+
25
+ def project(
26
+ G,
27
+ vgg16,
28
+ target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
29
+ *,
30
+ num_steps = 1000,
31
+ w_avg_samples = 10000,
32
+ initial_learning_rate = 0.1,
33
+ initial_noise_factor = 0.05,
34
+ lr_rampdown_length = 0.25,
35
+ lr_rampup_length = 0.05,
36
+ noise_ramp_length = 0.75,
37
+ regularize_noise_weight = 1e5,
38
+ verbose = False,
39
+ device: torch.device
40
+ ):
41
+ assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
42
+
43
+ def logprint(*args):
44
+ if verbose:
45
+ print(*args)
46
+
47
+ G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore
48
+
49
+ # Compute w stats.
50
+ logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...')
51
+ z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
52
+ w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]
53
+ w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
54
+ w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
55
+ w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
56
+
57
+ # Setup noise inputs.
58
+ noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name }
59
+
60
+ # Load VGG16 feature detector.
61
+ # url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
62
+ # with dnnlib.util.open_url(url) as f:
63
+ # vgg16 = torch.jit.load(f).eval().to(device)
64
+
65
+ # Features for target image.
66
+ target_images = target.unsqueeze(0).to(device).to(torch.float32)
67
+ if target_images.shape[2] > 256:
68
+ target_images = F.interpolate(target_images, size=(256, 256), mode='area')
69
+ target_features = vgg16(target_images, resize_images=False, return_lpips=True)
70
+
71
+ w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable
72
+ w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device)
73
+ optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate)
74
+
75
+ # Init noise.
76
+ for buf in noise_bufs.values():
77
+ buf[:] = torch.randn_like(buf)
78
+ buf.requires_grad = True
79
+
80
+ for step in range(num_steps):
81
+ # Learning rate schedule.
82
+ t = step / num_steps
83
+ w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
84
+ lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
85
+ lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
86
+ lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
87
+ lr = initial_learning_rate * lr_ramp
88
+ for param_group in optimizer.param_groups:
89
+ param_group['lr'] = lr
90
+
91
+ # Synth images from opt_w.
92
+ w_noise = torch.randn_like(w_opt) * w_noise_scale
93
+ ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
94
+ synth_images = G.synthesis(ws, noise_mode='const')
95
+
96
+ # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
97
+ synth_images = (synth_images + 1) * (255/2)
98
+ if synth_images.shape[2] > 256:
99
+ synth_images = F.interpolate(synth_images, size=(256, 256), mode='area')
100
+
101
+ # Features for synth images.
102
+ synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
103
+ dist = (target_features - synth_features).square().sum()
104
+
105
+ # Noise regularization.
106
+ reg_loss = 0.0
107
+ for v in noise_bufs.values():
108
+ noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d()
109
+ while True:
110
+ reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2
111
+ reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2
112
+ if noise.shape[2] <= 8:
113
+ break
114
+ noise = F.avg_pool2d(noise, kernel_size=2)
115
+ loss = dist + reg_loss * regularize_noise_weight
116
+
117
+ # Step
118
+ optimizer.zero_grad(set_to_none=True)
119
+ loss.backward()
120
+ optimizer.step()
121
+ logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}')
122
+
123
+ # Save projected W for each optimization step.
124
+ w_out[step] = w_opt.detach()[0]
125
+
126
+ # Normalize noise.
127
+ with torch.no_grad():
128
+ for buf in noise_bufs.values():
129
+ buf -= buf.mean()
130
+ buf *= buf.square().mean().rsqrt()
131
+
132
+ return w_out.repeat([1, G.mapping.num_ws, 1])
133
+
134
+ #----------------------------------------------------------------------------
135
+
136
+ @click.command()
137
+ @click.option('--network', 'network_pkl', help='Network pickle filename', required=True)
138
+ @click.option('--target', 'target_fname', help='Target image file to project to', required=True, metavar='FILE')
139
+ @click.option('--num-steps', help='Number of optimization steps', type=int, default=1000, show_default=True)
140
+ @click.option('--seed', help='Random seed', type=int, default=303, show_default=True)
141
+ @click.option('--save-video', help='Save an mp4 video of optimization progress', type=bool, default=True, show_default=True)
142
+ @click.option('--outdir', help='Where to save the output images', required=True, metavar='DIR')
143
+ def run_projection(
144
+ network_pkl: str,
145
+ target_fname: str,
146
+ outdir: str,
147
+ save_video: bool,
148
+ seed: int,
149
+ num_steps: int
150
+ ):
151
+ """Project given image to the latent space of pretrained network pickle.
152
+
153
+ Examples:
154
+
155
+ \b
156
+ python projector.py --outdir=out --target=~/mytargetimg.png \\
157
+ --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
158
+ """
159
+ np.random.seed(seed)
160
+ torch.manual_seed(seed)
161
+
162
+ # Load networks.
163
+ print('Loading networks from "%s"...' % network_pkl)
164
+ device = torch.device('cuda')
165
+ with dnnlib.util.open_url(network_pkl) as fp:
166
+ G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore
167
+
168
+ # Load target image.
169
+ target_pil = PIL.Image.open(target_fname).convert('RGB')
170
+ w, h = target_pil.size
171
+ s = min(w, h)
172
+ target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
173
+ target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
174
+ target_uint8 = np.array(target_pil, dtype=np.uint8)
175
+
176
+ # Optimize projection.
177
+ start_time = perf_counter()
178
+ projected_w_steps = project(
179
+ G,
180
+ target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
181
+ num_steps=num_steps,
182
+ device=device,
183
+ verbose=True
184
+ )
185
+ print (f'Elapsed: {(perf_counter()-start_time):.1f} s')
186
+
187
+ # Render debug output: optional video and projected image and W vector.
188
+ os.makedirs(outdir, exist_ok=True)
189
+ if save_video:
190
+ video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
191
+ print (f'Saving optimization progress video "{outdir}/proj.mp4"')
192
+ for projected_w in projected_w_steps:
193
+ synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
194
+ synth_image = (synth_image + 1) * (255/2)
195
+ synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
196
+ video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
197
+ video.close()
198
+
199
+ # Save final projected frame and W vector.
200
+ target_pil.save(f'{outdir}/target.png')
201
+ projected_w = projected_w_steps[-1]
202
+ synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
203
+ synth_image = (synth_image + 1) * (255/2)
204
+ synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
205
+ PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png')
206
+ np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())
207
+
208
+ #----------------------------------------------------------------------------
209
+
210
+ if __name__ == "__main__":
211
+ run_projection() # pylint: disable=no-value-for-parameter
212
+
213
+ #----------------------------------------------------------------------------
ris/spherical_kmeans.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import numpy as np
3
+ from sklearn.preprocessing import normalize
4
+ from sklearn.utils.sparsefuncs_fast import assign_rows_csr
5
+ from sklearn.utils.validation import _check_sample_weight
6
+ from sklearn.utils import check_array, check_random_state
7
+ from sklearn.utils.extmath import row_norms
8
+ import scipy.sparse as sp
9
+ from sklearn.cluster import MiniBatchKMeans
10
+ from sklearn.cluster.k_means_ import (
11
+ _init_centroids,
12
+ _labels_inertia,
13
+ _tolerance,
14
+ _mini_batch_step,
15
+ _mini_batch_convergence
16
+ )
17
+
18
+
19
+ def _check_normalize_sample_weight(sample_weight, X):
20
+ """Set sample_weight if None, and check for correct dtype"""
21
+
22
+ sample_weight_was_none = sample_weight is None
23
+
24
+ sample_weight = _check_sample_weight(sample_weight, X)
25
+
26
+ if not sample_weight_was_none:
27
+ # normalize the weights to sum up to n_samples
28
+ # an array of 1 (i.e. samples_weight is None) is already normalized
29
+ n_samples = len(sample_weight)
30
+ scale = n_samples / sample_weight.sum()
31
+ sample_weight *= scale
32
+ return sample_weight
33
+
34
+
35
+
36
+
37
+ def _mini_batch_spherical_step(X, sample_weight, x_squared_norms, centers, weight_sums,
38
+ old_center_buffer, compute_squared_diff,
39
+ distances, random_reassign=False,
40
+ random_state=None, reassignment_ratio=.01,
41
+ verbose=False):
42
+ """Incremental update of the centers for the Minibatch K-Means algorithm.
43
+ Parameters
44
+ ----------
45
+ X : array, shape (n_samples, n_features)
46
+ The original data array.
47
+ sample_weight : array-like, shape (n_samples,)
48
+ The weights for each observation in X.
49
+ x_squared_norms : array, shape (n_samples,)
50
+ Squared euclidean norm of each data point.
51
+ centers : array, shape (k, n_features)
52
+ The cluster centers. This array is MODIFIED IN PLACE
53
+ counts : array, shape (k,)
54
+ The vector in which we keep track of the numbers of elements in a
55
+ cluster. This array is MODIFIED IN PLACE
56
+ distances : array, dtype float, shape (n_samples), optional
57
+ If not None, should be a pre-allocated array that will be used to store
58
+ the distances of each sample to its closest center.
59
+ May not be None when random_reassign is True.
60
+ random_state : int, RandomState instance or None (default)
61
+ Determines random number generation for centroid initialization and to
62
+ pick new clusters amongst observations with uniform probability. Use
63
+ an int to make the randomness deterministic.
64
+ See :term:`Glossary <random_state>`.
65
+ random_reassign : boolean, optional
66
+ If True, centers with very low counts are randomly reassigned
67
+ to observations.
68
+ reassignment_ratio : float, optional
69
+ Control the fraction of the maximum number of counts for a
70
+ center to be reassigned. A higher value means that low count
71
+ centers are more likely to be reassigned, which means that the
72
+ model will take longer to converge, but should converge in a
73
+ better clustering.
74
+ verbose : bool, optional, default False
75
+ Controls the verbosity.
76
+ compute_squared_diff : bool
77
+ If set to False, the squared diff computation is skipped.
78
+ old_center_buffer : int
79
+ Copy of old centers for monitoring convergence.
80
+ Returns
81
+ -------
82
+ inertia : float
83
+ Sum of squared distances of samples to their closest cluster center.
84
+ squared_diff : numpy array, shape (n_clusters,)
85
+ Squared distances between previous and updated cluster centers.
86
+ """
87
+ # Perform label assignment to nearest centers
88
+ nearest_center, inertia = _labels_inertia(X, sample_weight,
89
+ x_squared_norms, centers,
90
+ distances=distances)
91
+
92
+ if random_reassign and reassignment_ratio > 0:
93
+ random_state = check_random_state(random_state)
94
+ # Reassign clusters that have very low weight
95
+ to_reassign = weight_sums < reassignment_ratio * weight_sums.max()
96
+ # pick at most .5 * batch_size samples as new centers
97
+ if to_reassign.sum() > .5 * X.shape[0]:
98
+ indices_dont_reassign = \
99
+ np.argsort(weight_sums)[int(.5 * X.shape[0]):]
100
+ to_reassign[indices_dont_reassign] = False
101
+ n_reassigns = to_reassign.sum()
102
+ if n_reassigns:
103
+ # Pick new clusters amongst observations with uniform probability
104
+ new_centers = random_state.choice(X.shape[0], replace=False,
105
+ size=n_reassigns)
106
+ if verbose:
107
+ print("[MiniBatchKMeans] Reassigning %i cluster centers."
108
+ % n_reassigns)
109
+
110
+ if sp.issparse(X) and not sp.issparse(centers):
111
+ assign_rows_csr(
112
+ X, new_centers.astype(np.intp, copy=False),
113
+ np.where(to_reassign)[0].astype(np.intp, copy=False),
114
+ centers)
115
+ else:
116
+ centers[to_reassign] = X[new_centers]
117
+ # reset counts of reassigned centers, but don't reset them too small
118
+ # to avoid instant reassignment. This is a pretty dirty hack as it
119
+ # also modifies the learning rates.
120
+ weight_sums[to_reassign] = np.min(weight_sums[~to_reassign])
121
+
122
+ # implementation for the sparse CSR representation completely written in
123
+ # cython
124
+ if sp.issparse(X):
125
+ return inertia, sklearn.cluster.k_means_._k_means._mini_batch_update_csr(
126
+ X, sample_weight, x_squared_norms, centers, weight_sums,
127
+ nearest_center, old_center_buffer, compute_squared_diff)
128
+
129
+ # dense variant in mostly numpy (not as memory efficient though)
130
+ k = centers.shape[0]
131
+ squared_diff = 0.0
132
+ for center_idx in range(k):
133
+ # find points from minibatch that are assigned to this center
134
+ center_mask = nearest_center == center_idx
135
+ wsum = sample_weight[center_mask].sum()
136
+
137
+ if wsum > 0:
138
+ if compute_squared_diff:
139
+ old_center_buffer[:] = centers[center_idx]
140
+
141
+ # inplace remove previous count scaling
142
+ centers[center_idx] *= weight_sums[center_idx]
143
+
144
+ # inplace sum with new points members of this cluster
145
+ centers[center_idx] += \
146
+ np.sum(X[center_mask] *
147
+ sample_weight[center_mask, np.newaxis], axis=0)
148
+
149
+ # unit-normalize for spherical k-means
150
+ centers[center_idx] = normalize(centers[center_idx, None])[:, 0]
151
+
152
+ # update the squared diff if necessary
153
+ if compute_squared_diff:
154
+ diff = centers[center_idx].ravel() - old_center_buffer.ravel()
155
+ squared_diff += np.dot(diff, diff)
156
+
157
+ return inertia, squared_diff
158
+
159
+
160
+ class MiniBatchSphericalKMeans(MiniBatchKMeans):
161
+
162
+ def fit(self, X, y=None, sample_weight=None):
163
+ """Compute the centroids on X by chunking it into mini-batches.
164
+ Parameters
165
+ ----------
166
+ X : array-like or sparse matrix, shape=(n_samples, n_features)
167
+ Training instances to cluster. It must be noted that the data
168
+ will be converted to C ordering, which will cause a memory copy
169
+ if the given data is not C-contiguous.
170
+ y : Ignored
171
+ Not used, present here for API consistency by convention.
172
+ sample_weight : array-like, shape (n_samples,), optional
173
+ The weights for each observation in X. If None, all observations
174
+ are assigned equal weight (default: None).
175
+ Returns
176
+ -------
177
+ self
178
+ """
179
+ random_state = check_random_state(self.random_state)
180
+ # unit-normalize for spherical k-means
181
+ X = normalize(X)
182
+ X = check_array(X, accept_sparse="csr", order='C',
183
+ dtype=[np.float64, np.float32])
184
+ n_samples, n_features = X.shape
185
+ if n_samples < self.n_clusters:
186
+ raise ValueError("n_samples=%d should be >= n_clusters=%d"
187
+ % (n_samples, self.n_clusters))
188
+
189
+ sample_weight = _check_normalize_sample_weight(sample_weight, X)
190
+
191
+ n_init = self.n_init
192
+ if hasattr(self.init, '__array__'):
193
+ self.init = np.ascontiguousarray(self.init, dtype=X.dtype)
194
+ if n_init != 1:
195
+ warnings.warn(
196
+ 'Explicit initial center position passed: '
197
+ 'performing only one init in MiniBatchKMeans instead of '
198
+ 'n_init=%d'
199
+ % self.n_init, RuntimeWarning, stacklevel=2)
200
+ n_init = 1
201
+
202
+ x_squared_norms = row_norms(X, squared=True)
203
+
204
+ if self.tol > 0.0:
205
+ tol = _tolerance(X, self.tol)
206
+
207
+ # using tol-based early stopping needs the allocation of a
208
+ # dedicated before which can be expensive for high dim data:
209
+ # hence we allocate it outside of the main loop
210
+ old_center_buffer = np.zeros(n_features, dtype=X.dtype)
211
+ else:
212
+ tol = 0.0
213
+ # no need for the center buffer if tol-based early stopping is
214
+ # disabled
215
+ old_center_buffer = np.zeros(0, dtype=X.dtype)
216
+
217
+ distances = np.zeros(self.batch_size, dtype=X.dtype)
218
+ n_batches = int(np.ceil(float(n_samples) / self.batch_size))
219
+ n_iter = int(self.max_iter * n_batches)
220
+
221
+ init_size = self.init_size
222
+ if init_size is None:
223
+ init_size = 3 * self.batch_size
224
+ if init_size > n_samples:
225
+ init_size = n_samples
226
+ self.init_size_ = init_size
227
+
228
+ validation_indices = random_state.randint(0, n_samples, init_size)
229
+ X_valid = X[validation_indices]
230
+ sample_weight_valid = sample_weight[validation_indices]
231
+ x_squared_norms_valid = x_squared_norms[validation_indices]
232
+
233
+ # perform several inits with random sub-sets
234
+ best_inertia = None
235
+ for init_idx in range(n_init):
236
+ if self.verbose:
237
+ print("Init %d/%d with method: %s"
238
+ % (init_idx + 1, n_init, self.init))
239
+ weight_sums = np.zeros(self.n_clusters, dtype=sample_weight.dtype)
240
+
241
+ # TODO: once the `k_means` function works with sparse input we
242
+ # should refactor the following init to use it instead.
243
+
244
+ # Initialize the centers using only a fraction of the data as we
245
+ # expect n_samples to be very large when using MiniBatchKMeans
246
+ cluster_centers = _init_centroids(
247
+ X, self.n_clusters, self.init,
248
+ random_state=random_state,
249
+ x_squared_norms=x_squared_norms,
250
+ init_size=init_size)
251
+
252
+ cluster_centers = normalize(cluster_centers)
253
+
254
+ # Compute the label assignment on the init dataset
255
+ _mini_batch_step(
256
+ X_valid, sample_weight_valid,
257
+ x_squared_norms[validation_indices], cluster_centers,
258
+ weight_sums, old_center_buffer, False, distances=None,
259
+ verbose=self.verbose)
260
+
261
+ cluster_centers = normalize(cluster_centers)
262
+
263
+ # Keep only the best cluster centers across independent inits on
264
+ # the common validation set
265
+ _, inertia = _labels_inertia(X_valid, sample_weight_valid,
266
+ x_squared_norms_valid,
267
+ cluster_centers)
268
+ if self.verbose:
269
+ print("Inertia for init %d/%d: %f"
270
+ % (init_idx + 1, n_init, inertia))
271
+ if best_inertia is None or inertia < best_inertia:
272
+ self.cluster_centers_ = cluster_centers
273
+ self.counts_ = weight_sums
274
+ best_inertia = inertia
275
+
276
+ # Empty context to be used inplace by the convergence check routine
277
+ convergence_context = {}
278
+
279
+ # Perform the iterative optimization until the final convergence
280
+ # criterion
281
+ for iteration_idx in range(n_iter):
282
+ # Sample a minibatch from the full dataset
283
+ minibatch_indices = random_state.randint(
284
+ 0, n_samples, self.batch_size)
285
+
286
+ # Perform the actual update step on the minibatch data
287
+ self.cluster_centers_ = normalize(self.cluster_centers_)
288
+ batch_inertia, centers_squared_diff = _mini_batch_step(
289
+ X[minibatch_indices], sample_weight[minibatch_indices],
290
+ x_squared_norms[minibatch_indices],
291
+ self.cluster_centers_, self.counts_,
292
+ old_center_buffer, tol > 0.0, distances=distances,
293
+ # Here we randomly choose whether to perform
294
+ # random reassignment: the choice is done as a function
295
+ # of the iteration index, and the minimum number of
296
+ # counts, in order to force this reassignment to happen
297
+ # every once in a while
298
+ random_reassign=((iteration_idx + 1)
299
+ % (10 + int(self.counts_.min())) == 0),
300
+ random_state=random_state,
301
+ reassignment_ratio=self.reassignment_ratio,
302
+ verbose=self.verbose)
303
+ self.cluster_centers_ = normalize(self.cluster_centers_)
304
+
305
+ # Monitor convergence and do early stopping if necessary
306
+ if _mini_batch_convergence(
307
+ self, iteration_idx, n_iter, tol, n_samples,
308
+ centers_squared_diff, batch_inertia, convergence_context,
309
+ verbose=self.verbose):
310
+ break
311
+
312
+ self.n_iter_ = iteration_idx + 1
313
+
314
+ if self.compute_labels:
315
+ self.labels_, self.inertia_ = \
316
+ self._labels_inertia_minibatch(X, sample_weight)
317
+
318
+ return self
319
+
320
+ def partial_fit(self, X, y=None, sample_weight=None):
321
+ """Update k means estimate on a single mini-batch X.
322
+ Parameters
323
+ ----------
324
+ X : array-like of shape (n_samples, n_features)
325
+ Coordinates of the data points to cluster. It must be noted that
326
+ X will be copied if it is not C-contiguous.
327
+ y : Ignored
328
+ Not used, present here for API consistency by convention.
329
+ sample_weight : array-like, shape (n_samples,), optional
330
+ The weights for each observation in X. If None, all observations
331
+ are assigned equal weight (default: None).
332
+ Returns
333
+ -------
334
+ self
335
+ """
336
+
337
+ X = check_array(X, accept_sparse="csr", order="C",
338
+ dtype=[np.float64, np.float32])
339
+ n_samples, n_features = X.shape
340
+ if hasattr(self.init, '__array__'):
341
+ self.init = np.ascontiguousarray(self.init, dtype=X.dtype)
342
+
343
+ if n_samples == 0:
344
+ return self
345
+
346
+ # unit-normalize for spherical k-means
347
+ X = normalize(X)
348
+
349
+ sample_weight = _check_normalize_sample_weight(sample_weight, X)
350
+
351
+ x_squared_norms = row_norms(X, squared=True)
352
+ self.random_state_ = getattr(self, "random_state_",
353
+ check_random_state(self.random_state))
354
+ if (not hasattr(self, 'counts_')
355
+ or not hasattr(self, 'cluster_centers_')):
356
+ # this is the first call partial_fit on this object:
357
+ # initialize the cluster centers
358
+ self.cluster_centers_ = _init_centroids(
359
+ X, self.n_clusters, self.init,
360
+ random_state=self.random_state_,
361
+ x_squared_norms=x_squared_norms, init_size=self.init_size)
362
+
363
+ self.counts_ = np.zeros(self.n_clusters,
364
+ dtype=sample_weight.dtype)
365
+ random_reassign = False
366
+ distances = None
367
+ else:
368
+ # The lower the minimum count is, the more we do random
369
+ # reassignment, however, we don't want to do random
370
+ # reassignment too often, to allow for building up counts
371
+ random_reassign = self.random_state_.randint(
372
+ 10 * (1 + self.counts_.min())) == 0
373
+ distances = np.zeros(X.shape[0], dtype=X.dtype)
374
+
375
+ self.cluster_centers_ = normalize(self.cluster_centers_)
376
+
377
+ _mini_batch_spherical_step(X, sample_weight, x_squared_norms,
378
+ self.cluster_centers_, self.counts_,
379
+ np.zeros(0, dtype=X.dtype), 0,
380
+ random_reassign=random_reassign, distances=distances,
381
+ random_state=self.random_state_,
382
+ reassignment_ratio=self.reassignment_ratio,
383
+ verbose=self.verbose)
384
+ self.cluster_centers_ = normalize(self.cluster_centers_)
385
+
386
+ if self.compute_labels:
387
+ self.labels_, self.inertia_ = _labels_inertia(
388
+ X, sample_weight, x_squared_norms, self.cluster_centers_)
389
+
390
+ return self
ris/util.py ADDED
@@ -0,0 +1,403 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from matplotlib import pyplot as plt
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import os
5
+ import cv2
6
+ import dlib
7
+ from PIL import Image
8
+ import numpy as np
9
+ import pandas as pd
10
+ import math
11
+ import scipy
12
+ import scipy.ndimage
13
+ import gc
14
+
15
+ # Number of style channels per StyleGAN layer
16
+ style2list_len = [512, 512, 512, 512, 512, 512, 512, 512, 512, 512,
17
+ 512, 512, 512, 512, 512, 256, 256, 256, 128, 128,
18
+ 128, 64, 64, 64, 32, 32]
19
+
20
+ # Layer indices of ToRGB modules
21
+ rgb_layer_idx = [1,4,7,10,13,16,19,22,25]
22
+
23
+ google_drive_paths = {
24
+ "stylegan2-ffhq-config-f.pt": "https://drive.google.com/uc?id=1Yr7KuD959btpmcKGAUsbAk5rPjX2MytK",
25
+ "inversion_stats.npz": "https://drive.google.com/uc?id=1oE_mIKf-Vr7b3J04l2UjsSrxZiw-UuFg",
26
+ "model_ir_se50.pt": "https://drive.google.com/uc?id=1KW7bjndL3QG3sxBbZxreGHigcCCpsDgn",
27
+ "dlibshape_predictor_68_face_landmarks.dat": "https://drive.google.com/uc?id=11BDmNKS1zxSZxkgsEvQoKgFd8J264jKp",
28
+ "e4e_ffhq_encode.pt": "https://drive.google.com/uc?id=1o6ijA3PkcewZvwJJ73dJ0fxhndn0nnh7"
29
+ }
30
+
31
+
32
+ def ensure_checkpoint_exists(model_weights_filename):
33
+ if not os.path.isfile(model_weights_filename) and (
34
+ model_weights_filename in google_drive_paths
35
+ ):
36
+ gdrive_url = google_drive_paths[model_weights_filename]
37
+ try:
38
+ from gdown import download as drive_download
39
+
40
+ drive_download(gdrive_url, model_weights_filename, quiet=False)
41
+ except ModuleNotFoundError:
42
+ print(
43
+ "gdown module not found.",
44
+ "pip3 install gdown or, manually download the checkpoint file:",
45
+ gdrive_url
46
+ )
47
+
48
+ if not os.path.isfile(model_weights_filename) and (
49
+ model_weights_filename not in google_drive_paths
50
+ ):
51
+ print(
52
+ model_weights_filename,
53
+ " not found, you may need to manually download the model weights."
54
+ )
55
+
56
+ # given a list of filenames, load the inverted style code
57
+ @torch.no_grad()
58
+ def load_source(files, generator, device='cuda'):
59
+ sources = []
60
+
61
+ # for file in files:
62
+
63
+ source = torch.load(f'./inversion_codes/{files}.pt')['latent'].to(device)
64
+
65
+ if source.size(0) != 1:
66
+ source = source.unsqueeze(0)
67
+
68
+ if source.ndim == 3:
69
+ source = generator.get_latent(source, truncation=1, is_latent=True)
70
+ source = list2style(source)
71
+
72
+ sources.append(source)
73
+
74
+ sources = torch.cat(sources, 0)
75
+ if type(sources) is not list:
76
+ sources = style2list(sources)
77
+
78
+ return sources
79
+
80
+ '''
81
+ Given M, we zero out the first 2048 dimensions for non pose or hair features.
82
+ The reason is that the first 2048 mostly contain hair and pose information and rarely
83
+ anything related to other classes.
84
+
85
+ '''
86
+ def remove_2048(M, labels2idx):
87
+ M_hair = M[:,labels2idx['hair']].clone()
88
+ # zero out first 2048 channels (4 style layers) for non hair and pose features
89
+ M[...,:2048] = 0
90
+ M[:,labels2idx['hair']] = M_hair
91
+ return M
92
+
93
+ # Compute pose M and append it as the last index of M
94
+ def add_pose(M, labels2idx):
95
+ M = remove_2048(M, labels2idx)
96
+ # Add pose to the very last index of M
97
+ pose = 1-M[:,labels2idx['hair']]
98
+ M = torch.cat([M, pose.view(-1,1,9088)], 1)
99
+ #zero out rest of the channels after 2048 as pose should not affect other features
100
+ M[:,-1, 2048:] = 0
101
+ return M
102
+
103
+
104
+ # add direction specified by q from source to reference, scaled by a
105
+ def add_direction(s, r, q, a):
106
+ if isinstance(s, list):
107
+ s = list2style(s)
108
+ if isinstance(r, list):
109
+ r = list2style(r)
110
+ if s.ndim == 1:
111
+ s = s.unsqueeze(0)
112
+ if r.ndim == 1:
113
+ r = r.unsqueeze(0)
114
+ if q.ndim == 1:
115
+ q = q.unsqueeze(0)
116
+ if len(s) != len(r):
117
+ if s.size(0)< r.size(0):
118
+ s = s.expand(r.size(0), -1)
119
+ else:
120
+ r = r.expand(s.size(0), -1)
121
+ q = q.float()
122
+
123
+ old_norm = (q*s).norm(2,dim=1, keepdim=True)+1e-8
124
+ new_dir = q*r
125
+ new_dir = new_dir/(new_dir.norm(2,dim=1, keepdim=True)+1e-8) * old_norm
126
+ return s -a*q*s + a*new_dir
127
+
128
+
129
+ # convert a style vector [B, 9088] into a suitable format (list) for our generator's input
130
+ def style2list(s):
131
+ output = []
132
+ count = 0
133
+ for size in style2list_len:
134
+ output.append(s[:, count:count+size])
135
+ count += size
136
+ return output
137
+
138
+ # convert the list back to a style vector
139
+ def list2style(s):
140
+ return torch.cat(s, 1)
141
+
142
+ # flatten spatial activations to vectors
143
+ def flatten_act(x):
144
+ b,c,h,w = x.size()
145
+ x = x.pow(2).permute(0,2,3,1).contiguous().view(-1, c) # [b,c]
146
+ return x.cpu().numpy()
147
+
148
+ def show(imgs, title=None):
149
+
150
+ plt.figure(figsize=(5 * len(imgs), 5))
151
+ if title is not None:
152
+ plt.suptitle(title + '\n', fontsize=24).set_y(1.05)
153
+
154
+ for i in range(len(imgs)):
155
+ plt.subplot(1, len(imgs), i + 1)
156
+ plt.imshow(imgs[i])
157
+ plt.axis('off')
158
+ plt.gca().set_axis_off()
159
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0,
160
+ hspace=0, wspace=0.02)
161
+ plt.savefig(title + '.png', bbox_inches='tight', pad_inches=0)
162
+ def part_grid(target_image, refernce_images, part_images, file_name, score=None):
163
+ def proc(img):
164
+ return (img * 255).permute(1, 2, 0).squeeze().cpu().numpy().astype('uint8')
165
+
166
+ rows, cols = len(part_images) + 1, len(refernce_images) + 1
167
+ fig = plt.figure(figsize=(cols*4, rows*4))
168
+ sz = target_image.shape[-1]
169
+
170
+ i = 1
171
+ plt.subplot(rows, cols, i)
172
+ plt.imshow(proc(target_image[0]))
173
+ plt.axis('off')
174
+ plt.gca().set_axis_off()
175
+ plt.title('Source', fontdict={'size': 26})
176
+
177
+ for img in refernce_images:
178
+ i += 1
179
+ plt.subplot(rows, cols, i)
180
+ plt.imshow(proc(img))
181
+ plt.axis('off')
182
+ plt.gca().set_axis_off()
183
+ plt.title('Reference', fontdict={'size': 26})
184
+
185
+ # plt.text(0, sz, 'Perceptual loss: {:.2f}'.format(score[i-2]), fontdict={'size': 25}, color='red')
186
+ for j, label in enumerate(part_images.keys()):
187
+ i += 1
188
+ plt.subplot(rows, cols, i)
189
+ plt.imshow(proc(target_image[0]) * 0 + 255)
190
+ # plt.text(sz // 2, sz // 2, label.capitalize(), fontdict={'size': 30})
191
+ if score is not None:
192
+ plt.text(0 , sz//6, f'ID: {score[0]:.2f}', fontdict={'size': 30})
193
+ plt.text(0 , sz//6*2, f'Face_LPIPS:{score[1]:.2f}', fontdict={'size': 30})
194
+ plt.text(0 , sz//6*3, f'Hair_LPIPS:{score[2]:.2f}', fontdict={'size': 30})
195
+ plt.text(0 , sz//6*4, f'Total_LPIPS:{score[3]:.2f}', fontdict={'size': 30})
196
+ plt.text(0 , sz//6*5, f'FACE_SSIM: {score[4]:.2f}', fontdict={'size': 30})
197
+ plt.text(0 , sz//6*6, f'Hair_SSIM: {score[5]:.2f}', fontdict={'size': 30})
198
+ plt.text(0 , sz//6*7, f'Total_SSIM: {score[6]:.2f}', fontdict={'size': 30})
199
+
200
+ plt.axis('off')
201
+ plt.gca().set_axis_off()
202
+
203
+ for img in part_images[label]:
204
+ i += 1
205
+ plt.subplot(rows, cols, i)
206
+ plt.imshow(proc(img))
207
+ plt.axis('off')
208
+ plt.gca().set_axis_off()
209
+
210
+ plt.tight_layout(pad=0, w_pad=0, h_pad=0)
211
+ plt.subplots_adjust(wspace=0, hspace=0)
212
+ ## Put 5 lines of text beside the image
213
+ # plt.text(0, sz+5, 'Perceptual loss: {:.2f}'.format(score[i-2]), fontdict={'size': 25}, color='red')
214
+
215
+ plt.savefig(file_name , bbox_inches='tight', pad_inches=0)
216
+ plt.close()
217
+ gc.collect()
218
+ return fig
219
+
220
+
221
+ def display_image(image, size=256, mode='nearest', unnorm=False, title=''):
222
+ # image is [3,h,w] or [1,3,h,w] tensor [0,1]
223
+ if image.is_cuda:
224
+ image = image.cpu()
225
+ if size is not None and image.size(-1) != size:
226
+ image = F.interpolate(image, size=(size,size), mode=mode)
227
+ if image.dim() == 4:
228
+ image = image[0]
229
+ image = ((image.clamp(-1,1)+1)/2).permute(1, 2, 0).detach().numpy()
230
+ plt.figure()
231
+ plt.title(title)
232
+ plt.axis('off')
233
+ plt.imshow(image)
234
+
235
+ def get_parsing_labels():
236
+ color = torch.FloatTensor([[0, 0, 0],
237
+ [128, 0, 0], [0, 128, 0], [128, 128, 0], [0, 0, 128], [128, 0, 128],
238
+ [0, 128, 128], [128, 128, 128], [64, 0, 0], [192, 0, 0], [64, 128, 0],
239
+ [192, 128, 0], [64, 0, 128], [192, 0, 128], [64, 128, 128], [192,128,128],
240
+ [0, 64, 0], [0, 0, 64], [128, 0, 192], [0, 192, 128], [64,128,192], [64,64,64]])
241
+ return (color/255 * 2)-1
242
+
243
+ def decode_segmap(seg):
244
+ seg = seg.float()
245
+ label_colors = get_parsing_labels()
246
+ r = seg.clone()
247
+ g = seg.clone()
248
+ b = seg.clone()
249
+
250
+ for l in range(label_colors.size(0)):
251
+ r[seg == l] = label_colors[l, 0]
252
+ g[seg == l] = label_colors[l, 1]
253
+ b[seg == l] = label_colors[l, 2]
254
+
255
+ output = torch.stack([r,g,b], 1)
256
+ return output
257
+
258
+ def remove_idx(act, i):
259
+ # act [N, 128]
260
+ return torch.cat([act[:i], act[i+1:]], 0)
261
+
262
+ def interpolate_style(s, t, q):
263
+ if isinstance(s, list):
264
+ s = list2style(s)
265
+ if isinstance(t, list):
266
+ t = list2style(t)
267
+ if s.ndim == 1:
268
+ s = s.unsqueeze(0)
269
+ if t.ndim == 1:
270
+ t = t.unsqueeze(0)
271
+ if q.ndim == 1:
272
+ q = q.unsqueeze(0)
273
+ if len(s) != len(t):
274
+ s = s.expand(t.size(0), -1)
275
+ q = q.float()
276
+
277
+ return (1 - q) * s + q * t
278
+
279
+ def index_layers(w, i):
280
+ return [w[j][[i]] for j in range(len(w))]
281
+
282
+
283
+ def normalize_im(x):
284
+ return (x.clamp(-1,1)+1)/2
285
+
286
+ def l2(a, b):
287
+ return (a-b).pow(2).sum(1)
288
+
289
+ def cos_dist(a,b):
290
+ return -F.cosine_similarity(a, b, 1)
291
+
292
+ def downsample(x):
293
+ return F.interpolate(x, size=(256,256), mode='bilinear')
294
+
295
+ def get_landmark(filepath, predictor):
296
+ """get landmark with dlib
297
+ :return: np.array shape=(68, 2)
298
+ """
299
+ detector = dlib.get_frontal_face_detector()
300
+
301
+ img = dlib.load_rgb_image(filepath)
302
+ dets = detector(img, 1)
303
+
304
+ for k, d in enumerate(dets):
305
+ shape = predictor(img, d)
306
+
307
+ t = list(shape.parts())
308
+ a = []
309
+ for tt in t:
310
+ a.append([tt.x, tt.y])
311
+ lm = np.array(a)
312
+ return lm
313
+
314
+ def align_face(filepath, predictor,output_size=512):
315
+ # def align_face(filepath,output_size=512):
316
+
317
+ """
318
+ :param filepath: str
319
+ :return: PIL Image
320
+ """
321
+ ensure_checkpoint_exists("dlibshape_predictor_68_face_landmarks.dat")
322
+ predictor = dlib.shape_predictor("dlibshape_predictor_68_face_landmarks.dat")
323
+ lm = get_landmark(filepath, predictor)
324
+
325
+ lm_chin = lm[0: 17] # left-right
326
+ lm_eyebrow_left = lm[17: 22] # left-right
327
+ lm_eyebrow_right = lm[22: 27] # left-right
328
+ lm_nose = lm[27: 31] # top-down
329
+ lm_nostrils = lm[31: 36] # top-down
330
+ lm_eye_left = lm[36: 42] # left-clockwise
331
+ lm_eye_right = lm[42: 48] # left-clockwise
332
+ lm_mouth_outer = lm[48: 60] # left-clockwise
333
+ lm_mouth_inner = lm[60: 68] # left-clockwise
334
+
335
+ # Calculate auxiliary vectors.
336
+ eye_left = np.mean(lm_eye_left, axis=0)
337
+ eye_right = np.mean(lm_eye_right, axis=0)
338
+ eye_avg = (eye_left + eye_right) * 0.5
339
+ eye_to_eye = eye_right - eye_left
340
+ mouth_left = lm_mouth_outer[0]
341
+ mouth_right = lm_mouth_outer[6]
342
+ mouth_avg = (mouth_left + mouth_right) * 0.5
343
+ eye_to_mouth = mouth_avg - eye_avg
344
+
345
+ # Choose oriented crop rectangle.
346
+ x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
347
+ x /= np.hypot(*x)
348
+ x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
349
+ y = np.flipud(x) * [-1, 1]
350
+ c = eye_avg + eye_to_mouth * 0.1
351
+ quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
352
+ qsize = np.hypot(*x) * 2
353
+
354
+ # read image
355
+ img = Image.open(filepath)
356
+
357
+ transform_size = output_size
358
+ enable_padding = True
359
+
360
+ # Shrink.
361
+ shrink = int(np.floor(qsize / output_size * 0.5))
362
+ if shrink > 1:
363
+ rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
364
+ img = img.resize(rsize, Image.ANTIALIAS)
365
+ quad /= shrink
366
+ qsize /= shrink
367
+
368
+ # Crop.
369
+ border = max(int(np.rint(qsize * 0.1)), 3)
370
+ crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
371
+ int(np.ceil(max(quad[:, 1]))))
372
+ crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
373
+ min(crop[3] + border, img.size[1]))
374
+ if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
375
+ img = img.crop(crop)
376
+ quad -= crop[0:2]
377
+
378
+ # Pad.
379
+ pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
380
+ int(np.ceil(max(quad[:, 1]))))
381
+ pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
382
+ max(pad[3] - img.size[1] + border, 0))
383
+ if enable_padding and max(pad) > border - 4:
384
+ pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
385
+ img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
386
+ h, w, _ = img.shape
387
+ y, x, _ = np.ogrid[:h, :w, :1]
388
+ mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
389
+ 1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
390
+ blur = qsize * 0.02
391
+ img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
392
+ img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
393
+ img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
394
+ quad += pad[:2]
395
+
396
+ # Transform.
397
+ img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(), Image.BILINEAR)
398
+ if output_size < transform_size:
399
+ img = img.resize((output_size, output_size), Image.ANTIALIAS)
400
+
401
+ # Return aligned image.
402
+ return img
403
+
ris/wrapper.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import shutil
3
+
4
+ import dlib
5
+ import numpy as np
6
+ import PIL.Image
7
+ import torch
8
+ from torchvision.transforms import transforms
9
+
10
+ import dnnlib
11
+ import legacy
12
+ from configs_gd import GENERATOR_CONFIGS
13
+ from dlib_utils.face_alignment import image_align
14
+ from dlib_utils.landmarks_detector import LandmarksDetector
15
+ from torch_utils.misc import copy_params_and_buffers
16
+
17
+ from pivot_tuning_inversion.utils.ImagesDataset import ImagesDataset, ImageLatentsDataset
18
+ from pivot_tuning_inversion.training.coaches.multi_id_coach import MultiIDCoach
19
+
20
+
21
+ class FaceLandmarksDetector:
22
+ """Dlib landmarks detector wrapper
23
+ """
24
+ def __init__(
25
+ self,
26
+ model_path='pretrained/shape_predictor_68_face_landmarks.dat',
27
+ tmp_dir='tmp'
28
+ ):
29
+
30
+ self.detector = LandmarksDetector(model_path)
31
+ self.timestamp = int(time.time())
32
+ self.tmp_src = f'{tmp_dir}/{self.timestamp}_src.png'
33
+ self.tmp_align = f'{tmp_dir}/{self.timestamp}_align.png'
34
+
35
+ def __call__(self, imgpath):
36
+ shutil.copy(imgpath, self.tmp_src)
37
+ try:
38
+ face_landmarks = list(self.detector.get_landmarks(self.tmp_src))[0]
39
+ assert isinstance(face_landmarks, list)
40
+ assert len(face_landmarks) == 68
41
+ image_align(self.tmp_src, self.tmp_align, face_landmarks)
42
+ except:
43
+ im = PIL.Image.open(self.tmp_src)
44
+ im.save(self.tmp_align)
45
+ return PIL.Image.open(self.tmp_align).convert('RGB')
46
+
47
+
48
+ class VGGFeatExtractor():
49
+ """VGG16 backbone wrapper
50
+ """
51
+ def __init__(self, device):
52
+ self.device = device
53
+ self.url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
54
+ with dnnlib.util.open_url(self.url) as f:
55
+ self.module = torch.jit.load(f).eval().to(device)
56
+
57
+ def __call__(self, img): # PIL
58
+ img = self._preprocess(img, self.device)
59
+ feat = self.module(img)
60
+ return feat # (1, 1000)
61
+
62
+ def _preprocess(self, img, device):
63
+ img = img.resize((256,256), PIL.Image.LANCZOS)
64
+ img = np.array(img, dtype=np.uint8)
65
+ img = torch.tensor(img.transpose([2,0,1])).unsqueeze(dim=0)
66
+ return img.to(device)
67
+
68
+
69
+ class Generator_wrapper():
70
+ """StyleGAN2 generator wrapper
71
+ """
72
+ def __init__(self, ckpt, device):
73
+ with dnnlib.util.open_url(ckpt) as f:
74
+ old_G = legacy.load_network_pkl(f)['G_ema'].requires_grad_(False).to(device)
75
+ resolution = old_G.img_resolution
76
+ generator_config = GENERATOR_CONFIGS(resolution=resolution)
77
+ self.G_kwargs = generator_config.G_kwargs
78
+ self.common_kwargs = generator_config.common_kwargs
79
+
80
+ self.G = dnnlib.util.construct_class_by_name(**self.G_kwargs, **self.common_kwargs).eval().requires_grad_(False).to(device)
81
+ copy_params_and_buffers(old_G, self.G, require_all=False)
82
+ del old_G
83
+ G = self.G
84
+
85
+ self.style_layers = [
86
+ f'G.synthesis.b{feat_size}.{layer}.affine'
87
+ for feat_size in [pow(2,x) for x in range(2, int(np.log2(resolution))+1)]
88
+ for layer in ['conv0', 'conv1', 'torgb']]
89
+ del(self.style_layers[0])
90
+ scope = locals()
91
+ self.to_stylespace = {layer:eval(layer, scope) for layer in self.style_layers}
92
+ w_idx_lst = generator_config.w_idx_lst
93
+ assert len(self.style_layers) == len(w_idx_lst)
94
+ self.to_w_idx = {self.style_layers[i]:w_idx_lst[i] for i in range(len(self.style_layers))}
95
+
96
+ def mapping(self, z, truncation_psi=0.7, truncation_cutoff=None, skip_w_avg_update=False):
97
+ '''random z -> latent w
98
+ '''
99
+ return self.G.mapping(
100
+ z,
101
+ None,
102
+ truncation_psi=truncation_psi,
103
+ truncation_cutoff=truncation_cutoff,
104
+ skip_w_avg_update=skip_w_avg_update
105
+ )
106
+
107
+ def mapping_stylespace(self, latent):
108
+ '''latent w -> style s
109
+ resolution | w_idx | # conv | # torgb | indices
110
+ 4 | 0 | 1 | 1 | 0-1
111
+ 8 | 1 | 2 | 1 | 1-3
112
+ 16 | 3 | 2 | 1 | 3-5
113
+ 32 | 5 | 2 | 1 | 5-7
114
+ 64 | 7 | 2 | 1 | 7-9
115
+ 128 | 9 | 2 | 1 | 9-11
116
+ 256 | 11 | 2 | 1 | 11-13 # for 256 resolution
117
+ 512 | 13 | 2 | 1 | 13-15 # for 512 resolution
118
+ 1024 | 15 | 2 | 1 | 15-17 # for 1024 resolution
119
+ '''
120
+ styles = dict()
121
+ for layer in self.style_layers:
122
+ module = self.to_stylespace.get(layer)
123
+ w_idx = self.to_w_idx.get(layer)
124
+ styles[layer] = module(latent.unbind(dim=1)[w_idx])
125
+ return styles
126
+
127
+ def synthesis_from_stylespace(self, latent, styles):
128
+ '''style s -> generated image
129
+ modulated conv2d, synthesis layer.weight, noise
130
+ forward after styles = affine(w)
131
+ '''
132
+ return self.G.synthesis(latent, styles=styles, noise_mode='const')
133
+
134
+ def synthesis(self, latent):
135
+ '''latent w -> generated image
136
+ '''
137
+ return self.G.synthesis(latent, noise_mode='const')
138
+
139
+
140
+ class e4eEncoder:
141
+ '''e4e Encoder
142
+ img paths -> latent w
143
+ '''
144
+ def __init__(self, device):
145
+ self.device = device
146
+
147
+ def __call__(self, target_pils):
148
+ dataset = ImagesDataset(
149
+ target_pils,
150
+ self.device,
151
+ transforms.Compose([
152
+ transforms.ToTensor(),
153
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
154
+ )
155
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
156
+
157
+ coach = MultiIDCoach(dataloader, device=self.device)
158
+ latents = list()
159
+ for fname, image in dataloader:
160
+ latents.append(coach.get_e4e_inversion(image))
161
+ latents = torch.cat(latents)
162
+ return latents
163
+
164
+
165
+ class PivotTuning:
166
+ '''pivot tuning inversion
167
+ latent, style -> latent, style,
168
+
169
+ mode
170
+ - 'latent' : use latent pivot
171
+ - 'style' : use style pivot
172
+ '''
173
+ def __init__(self, device, G, mode='w'):
174
+ assert mode in ['w', 's']
175
+ self.device = device
176
+ self.G = G
177
+ self.mode = mode
178
+ self.resolution = G.img_resolution
179
+
180
+ def __call__(self, latent, target_pils):
181
+ dataset = ImageLatentsDataset(
182
+ target_pils,
183
+ latent,
184
+ self.device,
185
+ transforms.Compose([
186
+ transforms.ToTensor(),
187
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])],),
188
+ self.resolution,
189
+ )
190
+ dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
191
+ coach = MultiIDCoach(
192
+ dataloader,
193
+ device=self.device,
194
+ generator=self.G,
195
+ mode=self.mode
196
+ )
197
+ # run coach by self.mode
198
+ new_G = coach.train_from_latent()
199
+ return new_G
spherical_kmeans.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import numpy as np
3
+ from sklearn.preprocessing import normalize
4
+ from sklearn.utils.sparsefuncs_fast import assign_rows_csr
5
+ from sklearn.utils.validation import _check_sample_weight
6
+ from sklearn.utils import check_array, check_random_state
7
+ from sklearn.utils.extmath import row_norms
8
+ import scipy.sparse as sp
9
+ from sklearn.cluster import MiniBatchKMeans
10
+ from sklearn.cluster.k_means_ import (
11
+ _init_centroids,
12
+ _labels_inertia,
13
+ _tolerance,
14
+ _mini_batch_step,
15
+ _mini_batch_convergence
16
+ )
17
+
18
+
19
+ def _check_normalize_sample_weight(sample_weight, X):
20
+ """Set sample_weight if None, and check for correct dtype"""
21
+
22
+ sample_weight_was_none = sample_weight is None
23
+
24
+ sample_weight = _check_sample_weight(sample_weight, X)
25
+
26
+ if not sample_weight_was_none:
27
+ # normalize the weights to sum up to n_samples
28
+ # an array of 1 (i.e. samples_weight is None) is already normalized
29
+ n_samples = len(sample_weight)
30
+ scale = n_samples / sample_weight.sum()
31
+ sample_weight *= scale
32
+ return sample_weight
33
+
34
+
35
+
36
+
37
+ def _mini_batch_spherical_step(X, sample_weight, x_squared_norms, centers, weight_sums,
38
+ old_center_buffer, compute_squared_diff,
39
+ distances, random_reassign=False,
40
+ random_state=None, reassignment_ratio=.01,
41
+ verbose=False):
42
+ """Incremental update of the centers for the Minibatch K-Means algorithm.
43
+ Parameters
44
+ ----------
45
+ X : array, shape (n_samples, n_features)
46
+ The original data array.
47
+ sample_weight : array-like, shape (n_samples,)
48
+ The weights for each observation in X.
49
+ x_squared_norms : array, shape (n_samples,)
50
+ Squared euclidean norm of each data point.
51
+ centers : array, shape (k, n_features)
52
+ The cluster centers. This array is MODIFIED IN PLACE
53
+ counts : array, shape (k,)
54
+ The vector in which we keep track of the numbers of elements in a
55
+ cluster. This array is MODIFIED IN PLACE
56
+ distances : array, dtype float, shape (n_samples), optional
57
+ If not None, should be a pre-allocated array that will be used to store
58
+ the distances of each sample to its closest center.
59
+ May not be None when random_reassign is True.
60
+ random_state : int, RandomState instance or None (default)
61
+ Determines random number generation for centroid initialization and to
62
+ pick new clusters amongst observations with uniform probability. Use
63
+ an int to make the randomness deterministic.
64
+ See :term:`Glossary <random_state>`.
65
+ random_reassign : boolean, optional
66
+ If True, centers with very low counts are randomly reassigned
67
+ to observations.
68
+ reassignment_ratio : float, optional
69
+ Control the fraction of the maximum number of counts for a
70
+ center to be reassigned. A higher value means that low count
71
+ centers are more likely to be reassigned, which means that the
72
+ model will take longer to converge, but should converge in a
73
+ better clustering.
74
+ verbose : bool, optional, default False
75
+ Controls the verbosity.
76
+ compute_squared_diff : bool
77
+ If set to False, the squared diff computation is skipped.
78
+ old_center_buffer : int
79
+ Copy of old centers for monitoring convergence.
80
+ Returns
81
+ -------
82
+ inertia : float
83
+ Sum of squared distances of samples to their closest cluster center.
84
+ squared_diff : numpy array, shape (n_clusters,)
85
+ Squared distances between previous and updated cluster centers.
86
+ """
87
+ # Perform label assignment to nearest centers
88
+ nearest_center, inertia = _labels_inertia(X, sample_weight,
89
+ x_squared_norms, centers,
90
+ distances=distances)
91
+
92
+ if random_reassign and reassignment_ratio > 0:
93
+ random_state = check_random_state(random_state)
94
+ # Reassign clusters that have very low weight
95
+ to_reassign = weight_sums < reassignment_ratio * weight_sums.max()
96
+ # pick at most .5 * batch_size samples as new centers
97
+ if to_reassign.sum() > .5 * X.shape[0]:
98
+ indices_dont_reassign = \
99
+ np.argsort(weight_sums)[int(.5 * X.shape[0]):]
100
+ to_reassign[indices_dont_reassign] = False
101
+ n_reassigns = to_reassign.sum()
102
+ if n_reassigns:
103
+ # Pick new clusters amongst observations with uniform probability
104
+ new_centers = random_state.choice(X.shape[0], replace=False,
105
+ size=n_reassigns)
106
+ if verbose:
107
+ print("[MiniBatchKMeans] Reassigning %i cluster centers."
108
+ % n_reassigns)
109
+
110
+ if sp.issparse(X) and not sp.issparse(centers):
111
+ assign_rows_csr(
112
+ X, new_centers.astype(np.intp, copy=False),
113
+ np.where(to_reassign)[0].astype(np.intp, copy=False),
114
+ centers)
115
+ else:
116
+ centers[to_reassign] = X[new_centers]
117
+ # reset counts of reassigned centers, but don't reset them too small
118
+ # to avoid instant reassignment. This is a pretty dirty hack as it
119
+ # also modifies the learning rates.
120
+ weight_sums[to_reassign] = np.min(weight_sums[~to_reassign])
121
+
122
+ # implementation for the sparse CSR representation completely written in
123
+ # cython
124
+ if sp.issparse(X):
125
+ return inertia, sklearn.cluster.k_means_._k_means._mini_batch_update_csr(
126
+ X, sample_weight, x_squared_norms, centers, weight_sums,
127
+ nearest_center, old_center_buffer, compute_squared_diff)
128
+
129
+ # dense variant in mostly numpy (not as memory efficient though)
130
+ k = centers.shape[0]
131
+ squared_diff = 0.0
132
+ for center_idx in range(k):
133
+ # find points from minibatch that are assigned to this center
134
+ center_mask = nearest_center == center_idx
135
+ wsum = sample_weight[center_mask].sum()
136
+
137
+ if wsum > 0:
138
+ if compute_squared_diff:
139
+ old_center_buffer[:] = centers[center_idx]
140
+
141
+ # inplace remove previous count scaling
142
+ centers[center_idx] *= weight_sums[center_idx]
143
+
144
+ # inplace sum with new points members of this cluster
145
+ centers[center_idx] += \
146
+ np.sum(X[center_mask] *
147
+ sample_weight[center_mask, np.newaxis], axis=0)
148
+
149
+ # unit-normalize for spherical k-means
150
+ centers[center_idx] = normalize(centers[center_idx, None])[:, 0]
151
+
152
+ # update the squared diff if necessary
153
+ if compute_squared_diff:
154
+ diff = centers[center_idx].ravel() - old_center_buffer.ravel()
155
+ squared_diff += np.dot(diff, diff)
156
+
157
+ return inertia, squared_diff
158
+
159
+
160
+ class MiniBatchSphericalKMeans(MiniBatchKMeans):
161
+
162
+ def fit(self, X, y=None, sample_weight=None):
163
+ """Compute the centroids on X by chunking it into mini-batches.
164
+ Parameters
165
+ ----------
166
+ X : array-like or sparse matrix, shape=(n_samples, n_features)
167
+ Training instances to cluster. It must be noted that the data
168
+ will be converted to C ordering, which will cause a memory copy
169
+ if the given data is not C-contiguous.
170
+ y : Ignored
171
+ Not used, present here for API consistency by convention.
172
+ sample_weight : array-like, shape (n_samples,), optional
173
+ The weights for each observation in X. If None, all observations
174
+ are assigned equal weight (default: None).
175
+ Returns
176
+ -------
177
+ self
178
+ """
179
+ random_state = check_random_state(self.random_state)
180
+ # unit-normalize for spherical k-means
181
+ X = normalize(X)
182
+ X = check_array(X, accept_sparse="csr", order='C',
183
+ dtype=[np.float64, np.float32])
184
+ n_samples, n_features = X.shape
185
+ if n_samples < self.n_clusters:
186
+ raise ValueError("n_samples=%d should be >= n_clusters=%d"
187
+ % (n_samples, self.n_clusters))
188
+
189
+ sample_weight = _check_normalize_sample_weight(sample_weight, X)
190
+
191
+ n_init = self.n_init
192
+ if hasattr(self.init, '__array__'):
193
+ self.init = np.ascontiguousarray(self.init, dtype=X.dtype)
194
+ if n_init != 1:
195
+ warnings.warn(
196
+ 'Explicit initial center position passed: '
197
+ 'performing only one init in MiniBatchKMeans instead of '
198
+ 'n_init=%d'
199
+ % self.n_init, RuntimeWarning, stacklevel=2)
200
+ n_init = 1
201
+
202
+ x_squared_norms = row_norms(X, squared=True)
203
+
204
+ if self.tol > 0.0:
205
+ tol = _tolerance(X, self.tol)
206
+
207
+ # using tol-based early stopping needs the allocation of a
208
+ # dedicated before which can be expensive for high dim data:
209
+ # hence we allocate it outside of the main loop
210
+ old_center_buffer = np.zeros(n_features, dtype=X.dtype)
211
+ else:
212
+ tol = 0.0
213
+ # no need for the center buffer if tol-based early stopping is
214
+ # disabled
215
+ old_center_buffer = np.zeros(0, dtype=X.dtype)
216
+
217
+ distances = np.zeros(self.batch_size, dtype=X.dtype)
218
+ n_batches = int(np.ceil(float(n_samples) / self.batch_size))
219
+ n_iter = int(self.max_iter * n_batches)
220
+
221
+ init_size = self.init_size
222
+ if init_size is None:
223
+ init_size = 3 * self.batch_size
224
+ if init_size > n_samples:
225
+ init_size = n_samples
226
+ self.init_size_ = init_size
227
+
228
+ validation_indices = random_state.randint(0, n_samples, init_size)
229
+ X_valid = X[validation_indices]
230
+ sample_weight_valid = sample_weight[validation_indices]
231
+ x_squared_norms_valid = x_squared_norms[validation_indices]
232
+
233
+ # perform several inits with random sub-sets
234
+ best_inertia = None
235
+ for init_idx in range(n_init):
236
+ if self.verbose:
237
+ print("Init %d/%d with method: %s"
238
+ % (init_idx + 1, n_init, self.init))
239
+ weight_sums = np.zeros(self.n_clusters, dtype=sample_weight.dtype)
240
+
241
+ # TODO: once the `k_means` function works with sparse input we
242
+ # should refactor the following init to use it instead.
243
+
244
+ # Initialize the centers using only a fraction of the data as we
245
+ # expect n_samples to be very large when using MiniBatchKMeans
246
+ cluster_centers = _init_centroids(
247
+ X, self.n_clusters, self.init,
248
+ random_state=random_state,
249
+ x_squared_norms=x_squared_norms,
250
+ init_size=init_size)
251
+
252
+ cluster_centers = normalize(cluster_centers)
253
+
254
+ # Compute the label assignment on the init dataset
255
+ _mini_batch_step(
256
+ X_valid, sample_weight_valid,
257
+ x_squared_norms[validation_indices], cluster_centers,
258
+ weight_sums, old_center_buffer, False, distances=None,
259
+ verbose=self.verbose)
260
+
261
+ cluster_centers = normalize(cluster_centers)
262
+
263
+ # Keep only the best cluster centers across independent inits on
264
+ # the common validation set
265
+ _, inertia = _labels_inertia(X_valid, sample_weight_valid,
266
+ x_squared_norms_valid,
267
+ cluster_centers)
268
+ if self.verbose:
269
+ print("Inertia for init %d/%d: %f"
270
+ % (init_idx + 1, n_init, inertia))
271
+ if best_inertia is None or inertia < best_inertia:
272
+ self.cluster_centers_ = cluster_centers
273
+ self.counts_ = weight_sums
274
+ best_inertia = inertia
275
+
276
+ # Empty context to be used inplace by the convergence check routine
277
+ convergence_context = {}
278
+
279
+ # Perform the iterative optimization until the final convergence
280
+ # criterion
281
+ for iteration_idx in range(n_iter):
282
+ # Sample a minibatch from the full dataset
283
+ minibatch_indices = random_state.randint(
284
+ 0, n_samples, self.batch_size)
285
+
286
+ # Perform the actual update step on the minibatch data
287
+ self.cluster_centers_ = normalize(self.cluster_centers_)
288
+ batch_inertia, centers_squared_diff = _mini_batch_step(
289
+ X[minibatch_indices], sample_weight[minibatch_indices],
290
+ x_squared_norms[minibatch_indices],
291
+ self.cluster_centers_, self.counts_,
292
+ old_center_buffer, tol > 0.0, distances=distances,
293
+ # Here we randomly choose whether to perform
294
+ # random reassignment: the choice is done as a function
295
+ # of the iteration index, and the minimum number of
296
+ # counts, in order to force this reassignment to happen
297
+ # every once in a while
298
+ random_reassign=((iteration_idx + 1)
299
+ % (10 + int(self.counts_.min())) == 0),
300
+ random_state=random_state,
301
+ reassignment_ratio=self.reassignment_ratio,
302
+ verbose=self.verbose)
303
+ self.cluster_centers_ = normalize(self.cluster_centers_)
304
+
305
+ # Monitor convergence and do early stopping if necessary
306
+ if _mini_batch_convergence(
307
+ self, iteration_idx, n_iter, tol, n_samples,
308
+ centers_squared_diff, batch_inertia, convergence_context,
309
+ verbose=self.verbose):
310
+ break
311
+
312
+ self.n_iter_ = iteration_idx + 1
313
+
314
+ if self.compute_labels:
315
+ self.labels_, self.inertia_ = \
316
+ self._labels_inertia_minibatch(X, sample_weight)
317
+
318
+ return self
319
+
320
+ def partial_fit(self, X, y=None, sample_weight=None):
321
+ """Update k means estimate on a single mini-batch X.
322
+ Parameters
323
+ ----------
324
+ X : array-like of shape (n_samples, n_features)
325
+ Coordinates of the data points to cluster. It must be noted that
326
+ X will be copied if it is not C-contiguous.
327
+ y : Ignored
328
+ Not used, present here for API consistency by convention.
329
+ sample_weight : array-like, shape (n_samples,), optional
330
+ The weights for each observation in X. If None, all observations
331
+ are assigned equal weight (default: None).
332
+ Returns
333
+ -------
334
+ self
335
+ """
336
+
337
+ X = check_array(X, accept_sparse="csr", order="C",
338
+ dtype=[np.float64, np.float32])
339
+ n_samples, n_features = X.shape
340
+ if hasattr(self.init, '__array__'):
341
+ self.init = np.ascontiguousarray(self.init, dtype=X.dtype)
342
+
343
+ if n_samples == 0:
344
+ return self
345
+
346
+ # unit-normalize for spherical k-means
347
+ X = normalize(X)
348
+
349
+ sample_weight = _check_normalize_sample_weight(sample_weight, X)
350
+
351
+ x_squared_norms = row_norms(X, squared=True)
352
+ self.random_state_ = getattr(self, "random_state_",
353
+ check_random_state(self.random_state))
354
+ if (not hasattr(self, 'counts_')
355
+ or not hasattr(self, 'cluster_centers_')):
356
+ # this is the first call partial_fit on this object:
357
+ # initialize the cluster centers
358
+ self.cluster_centers_ = _init_centroids(
359
+ X, self.n_clusters, self.init,
360
+ random_state=self.random_state_,
361
+ x_squared_norms=x_squared_norms, init_size=self.init_size)
362
+
363
+ self.counts_ = np.zeros(self.n_clusters,
364
+ dtype=sample_weight.dtype)
365
+ random_reassign = False
366
+ distances = None
367
+ else:
368
+ # The lower the minimum count is, the more we do random
369
+ # reassignment, however, we don't want to do random
370
+ # reassignment too often, to allow for building up counts
371
+ random_reassign = self.random_state_.randint(
372
+ 10 * (1 + self.counts_.min())) == 0
373
+ distances = np.zeros(X.shape[0], dtype=X.dtype)
374
+
375
+ self.cluster_centers_ = normalize(self.cluster_centers_)
376
+
377
+ _mini_batch_spherical_step(X, sample_weight, x_squared_norms,
378
+ self.cluster_centers_, self.counts_,
379
+ np.zeros(0, dtype=X.dtype), 0,
380
+ random_reassign=random_reassign, distances=distances,
381
+ random_state=self.random_state_,
382
+ reassignment_ratio=self.reassignment_ratio,
383
+ verbose=self.verbose)
384
+ self.cluster_centers_ = normalize(self.cluster_centers_)
385
+
386
+ if self.compute_labels:
387
+ self.labels_, self.inertia_ = _labels_inertia(
388
+ X, sample_weight, x_squared_norms, self.cluster_centers_)
389
+
390
+ return self