haakohu commited on
Commit
a9f29d9
1 Parent(s): 5d756f1
Files changed (1) hide show
  1. stylemc.py +295 -0
stylemc.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Approach: "StyleMC: Multi-Channel Based Fast Text-Guided Image Generation and Manipulation"
3
+ Original source code:
4
+ https://github.com/autonomousvision/stylegan_xl/blob/f9be58e98110bd946fcdadef2aac8345466faaf3/run_stylemc.py#
5
+ Modified by Håkon Hukkelås
6
+ """
7
+ import os
8
+ from pathlib import Path
9
+ import tqdm
10
+ import re
11
+ import click
12
+ from dp2 import utils
13
+ import tops
14
+ from typing import List, Optional
15
+ import PIL.Image
16
+ import imageio
17
+ from timeit import default_timer as timer
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torchvision.transforms.functional import resize, normalize
24
+ from dp2.infer import build_trained_generator
25
+ import clip
26
+
27
+ #----------------------------------------------------------------------------
28
+
29
+ class AverageMeter(object):
30
+ """Computes and stores the average and current value"""
31
+ def __init__(self, name, fmt=':f'):
32
+ self.name = name
33
+ self.fmt = fmt
34
+ self.reset()
35
+
36
+ def reset(self):
37
+ self.val = 0
38
+ self.avg = 0
39
+ self.sum = 0
40
+ self.count = 0
41
+
42
+ def update(self, val, n=1):
43
+ self.val = val
44
+ self.sum += val * n
45
+ self.count += n
46
+ self.avg = self.sum / self.count
47
+
48
+ def __str__(self):
49
+ fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
50
+ return fmtstr.format(**self.__dict__)
51
+
52
+
53
+ class ProgressMeter(object):
54
+ def __init__(self, num_batches, meters, prefix=""):
55
+ self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
56
+ self.meters = meters
57
+ self.prefix = prefix
58
+
59
+ def display(self, batch):
60
+ entries = [self.prefix + self.batch_fmtstr.format(batch)]
61
+ entries += [str(meter) for meter in self.meters]
62
+ print('\t'.join(entries))
63
+
64
+ def _get_batch_fmtstr(self, num_batches):
65
+ num_digits = len(str(num_batches // 1))
66
+ fmt = '{:' + str(num_digits) + 'd}'
67
+ return '[' + fmt + '/' + fmt.format(num_batches) + ']'
68
+
69
+
70
+ def save_image(img, path):
71
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
72
+ PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(path)
73
+
74
+
75
+ def unravel_index(index, shape):
76
+ out = []
77
+ for dim in reversed(shape):
78
+ out.append(index % dim)
79
+ index = index // dim
80
+ return tuple(reversed(out))
81
+
82
+
83
+ def num_range(s: str) -> List[int]:
84
+ '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
85
+
86
+ range_re = re.compile(r'^(\d+)-(\d+)$')
87
+ m = range_re.match(s)
88
+ if m:
89
+ return list(range(int(m.group(1)), int(m.group(2))+1))
90
+ vals = s.split(',')
91
+ return [int(x) for x in vals]
92
+
93
+
94
+ #----------------------------------------------------------------------------
95
+
96
+
97
+
98
+ def spherical_dist_loss(x, y):
99
+ x = F.normalize(x, dim=-1)
100
+ y = F.normalize(y, dim=-1)
101
+ return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
102
+
103
+
104
+ def prompts_dist_loss(x, targets, loss):
105
+ if len(targets) == 1: # Keeps consistent results vs previous method for single objective guidance
106
+ return loss(x, targets[0])
107
+ distances = [loss(x, target) for target in targets]
108
+ return torch.stack(distances, dim=-1).sum(dim=-1)
109
+
110
+
111
+ def embed_text(model, prompt, device='cuda'):
112
+ return
113
+
114
+
115
+ #----------------------------------------------------------------------------
116
+
117
+ @torch.no_grad()
118
+ @torch.cuda.amp.autocast()
119
+ def generate_edit(
120
+ G,
121
+ dl,
122
+ direction,
123
+ edit_strength,
124
+ path,
125
+ ):
126
+ for it, batch in enumerate(dl):
127
+ batch["embedding"] = None
128
+ styles = get_styles(None, G, batch, truncation_value=0)
129
+ imgs = []
130
+ grad_changes = [_*edit_strength for _ in [0, 0.25, 0.5, 0.75, 1]]
131
+ grad_changes = [*[-x for x in grad_changes][::-1], *grad_changes]
132
+ batch = {k: tops.to_cuda(v) if v is not None else v for k,v in batch.items()}
133
+ for i, grad_change in enumerate(grad_changes):
134
+ s = styles + direction*grad_change
135
+
136
+ img = G(**batch, s=iter(s))["img"]
137
+ img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255)
138
+ imgs.append(img[0].to(torch.uint8).cpu().numpy())
139
+ PIL.Image.fromarray(np.concatenate(imgs, axis=1), 'RGB').save(path + f'{it}.png')
140
+
141
+
142
+ @torch.no_grad()
143
+ def get_styles(seed, G: torch.nn.Module, batch, truncation_value=1):
144
+ all_styles = []
145
+ if seed is None:
146
+ z = np.random.normal(0, 0, size=(1, G.z_channels))
147
+ else:
148
+ z = np.random.RandomState(seed=seed).normal(0, 1, size=(1, G.z_channels))
149
+ z_idx = np.random.RandomState(seed=seed).randint(0, len(G.style_net.w_centers))
150
+ w_c = G.style_net.w_centers[z_idx].to(tops.get_device()).view(1, -1)
151
+ w = G.style_net(torch.from_numpy(z).to(tops.get_device()))
152
+
153
+ w = w_c.to(w.dtype).lerp(w, truncation_value)
154
+ if hasattr(G, "get_comod_y"):
155
+ w = G.get_comod_y(batch, w)
156
+ for block in G.modules():
157
+ if not hasattr(block, "affine") or not hasattr(block.affine, "weight"):
158
+ continue
159
+ gamma0 = block.affine(w)
160
+ if hasattr(block, "affine_beta"):
161
+ beta0 = block.affine_beta(w)
162
+ gamma0 = torch.cat((gamma0, beta0), dim=1)
163
+ all_styles.append(gamma0)
164
+ max_ch = max([s.shape[-1] for s in all_styles])
165
+ all_styles = [F.pad(s, ((0, max_ch - s.shape[-1])), "constant", 0) for s in all_styles]
166
+ all_styles = torch.cat(all_styles)
167
+ return all_styles
168
+
169
+ def get_and_cache_direction(output_dir: Path, dl_val, G, text_prompt):
170
+ cache_path = output_dir.joinpath(
171
+ "stylemc_cache", text_prompt.replace(" ", "_") + ".torch")
172
+ if cache_path.is_file():
173
+ print("Loaded cache from:", cache_path)
174
+ return torch.load(cache_path)
175
+ direction = find_direction(G, text_prompt, None, dl_val=iter(dl_val))
176
+ cache_path.parent.mkdir(exist_ok=True, parents=True)
177
+ torch.save(direction, cache_path)
178
+ return direction
179
+
180
+ @torch.cuda.amp.autocast()
181
+ def find_direction(
182
+ G,
183
+ text_prompt,
184
+ batches,
185
+ #layers,
186
+ n_iterations=128*8,
187
+ batch_size=8,
188
+ dl_val=None
189
+ ):
190
+ time_start = timer()
191
+
192
+ clip_model = clip.load("ViT-B/16", device=tops.get_device())[0]
193
+
194
+ target = [clip_model.encode_text(clip.tokenize(text_prompt).to(tops.get_device())).float()]
195
+ all_styles = []
196
+ if dl_val is not None:
197
+ first_batch = next(dl_val)
198
+ else:
199
+ first_batch = batches[0]
200
+ first_batch["embedding"] = None if "embedding" not in first_batch else first_batch["embedding"]
201
+ s = get_styles(0, G, first_batch)
202
+ # stats tracker
203
+ cos_sim_track = AverageMeter('cos_sim', ':.4f')
204
+ norm_track = AverageMeter('norm', ':.4f')
205
+ n_iterations = n_iterations // batch_size
206
+ progress = ProgressMeter(n_iterations, [cos_sim_track, norm_track])
207
+
208
+ # initalize styles direction
209
+ direction = torch.zeros(s.shape, device=tops.get_device())
210
+ direction.requires_grad_()
211
+ utils.set_requires_grad(G, False)
212
+ direction_tracker = torch.zeros_like(direction)
213
+ opt = torch.optim.AdamW([direction], lr=0.05, betas=(0., 0.999), weight_decay=0.25)
214
+
215
+ grads = []
216
+ for seed_idx in tqdm.trange(n_iterations):
217
+ # forward pass through synthesis network with new styles
218
+ if seed_idx == 0:
219
+ batch = first_batch
220
+ elif dl_val is not None:
221
+ batch = next(dl_val)
222
+ batch["embedding"] = None if "embedding" not in batch else batch["embedding"]
223
+ else:
224
+ batch = {k: tops.to_cuda(v) if v is not None else v for k, v in batches[seed_idx].items()}
225
+ styles = get_styles(seed_idx, G, batch) + direction
226
+ img = G(**batch, s=iter(styles))["img"]
227
+ batch = {k: v.cpu() if v is not None else v for k, v in batch.items()}
228
+ # clip loss
229
+ img = (img + 1)/2
230
+ img = normalize(img, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
231
+ img = resize(img, (224, 224))
232
+ embeds = clip_model.encode_image(img)
233
+ cos_sim = prompts_dist_loss(embeds, target, spherical_dist_loss)
234
+ cos_sim.backward(retain_graph=True)
235
+
236
+ # track stats
237
+ cos_sim_track.update(cos_sim.item())
238
+ norm_track.update(torch.norm(direction).item())
239
+
240
+ if not (seed_idx % batch_size):
241
+
242
+ # zeroing out gradients for non-optimized layers
243
+ #layers_zeroed = torch.tensor([x for x in range(G.num_ws) if not x in layers])
244
+ #direction.grad[:, layers_zeroed] = 0
245
+
246
+ opt.step()
247
+ grads.append(direction.grad.clone())
248
+ direction.grad.data.zero_()
249
+
250
+ # keep track of gradients over time
251
+ if seed_idx > 3:
252
+ direction_tracker[grads[-2] * grads[-1] < 0] += 1
253
+
254
+ # plot stats
255
+ progress.display(seed_idx)
256
+
257
+ # throw out fluctuating channels
258
+ direction = direction.detach()
259
+ direction[direction_tracker > n_iterations / 4] = 0
260
+ print(direction)
261
+ print(f"Time for direction search: {timer() - time_start:.2f} s")
262
+ return direction
263
+
264
+
265
+
266
+
267
+ @click.command()
268
+ @click.argument("config_path")
269
+ @click.argument("input_path")
270
+ @click.argument("output_path")
271
+ #@click.option('--layers', type=num_range, help='Restrict the style space to a range of layers. We recommend not to optimize the critically sampled layers (last 3).', required=True)
272
+ @click.option('--text-prompt', help='Text', type=str, required=True)
273
+ @click.option('--edit-strength', help='Strength of edit', type=float, required=True)
274
+ @click.option('--outdir', help='Where to save the output images', type=str, required=True)
275
+ def stylemc(
276
+ config_path,
277
+ #layers: List[int],
278
+ text_prompt: str,
279
+ edit_strength: float,
280
+ outdir: str,
281
+ ):
282
+ cfg = utils.load_config(config_path)
283
+ G = build_trained_generator(cfg)
284
+ cfg.train.batch_size = 1
285
+ n_iterations = 256
286
+ dl_val = tops.config.instantiate(cfg.data.val.loader)
287
+
288
+ direction = find_direction(G, text_prompt, None, n_iterations=n_iterations, dl_val=iter(dl_val))
289
+
290
+ text_prompt = text_prompt.replace(" ", "_")
291
+ generate_edit(G, input_path, direction, edit_strength, output_path)
292
+
293
+
294
+ if __name__ == "__main__":
295
+ stylemc()