Spaces:
Runtime error
Runtime error
initial
Browse files- stylemc.py +295 -0
stylemc.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Approach: "StyleMC: Multi-Channel Based Fast Text-Guided Image Generation and Manipulation"
|
3 |
+
Original source code:
|
4 |
+
https://github.com/autonomousvision/stylegan_xl/blob/f9be58e98110bd946fcdadef2aac8345466faaf3/run_stylemc.py#
|
5 |
+
Modified by Håkon Hukkelås
|
6 |
+
"""
|
7 |
+
import os
|
8 |
+
from pathlib import Path
|
9 |
+
import tqdm
|
10 |
+
import re
|
11 |
+
import click
|
12 |
+
from dp2 import utils
|
13 |
+
import tops
|
14 |
+
from typing import List, Optional
|
15 |
+
import PIL.Image
|
16 |
+
import imageio
|
17 |
+
from timeit import default_timer as timer
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from torchvision.transforms.functional import resize, normalize
|
24 |
+
from dp2.infer import build_trained_generator
|
25 |
+
import clip
|
26 |
+
|
27 |
+
#----------------------------------------------------------------------------
|
28 |
+
|
29 |
+
class AverageMeter(object):
|
30 |
+
"""Computes and stores the average and current value"""
|
31 |
+
def __init__(self, name, fmt=':f'):
|
32 |
+
self.name = name
|
33 |
+
self.fmt = fmt
|
34 |
+
self.reset()
|
35 |
+
|
36 |
+
def reset(self):
|
37 |
+
self.val = 0
|
38 |
+
self.avg = 0
|
39 |
+
self.sum = 0
|
40 |
+
self.count = 0
|
41 |
+
|
42 |
+
def update(self, val, n=1):
|
43 |
+
self.val = val
|
44 |
+
self.sum += val * n
|
45 |
+
self.count += n
|
46 |
+
self.avg = self.sum / self.count
|
47 |
+
|
48 |
+
def __str__(self):
|
49 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
50 |
+
return fmtstr.format(**self.__dict__)
|
51 |
+
|
52 |
+
|
53 |
+
class ProgressMeter(object):
|
54 |
+
def __init__(self, num_batches, meters, prefix=""):
|
55 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
56 |
+
self.meters = meters
|
57 |
+
self.prefix = prefix
|
58 |
+
|
59 |
+
def display(self, batch):
|
60 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
61 |
+
entries += [str(meter) for meter in self.meters]
|
62 |
+
print('\t'.join(entries))
|
63 |
+
|
64 |
+
def _get_batch_fmtstr(self, num_batches):
|
65 |
+
num_digits = len(str(num_batches // 1))
|
66 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
67 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
68 |
+
|
69 |
+
|
70 |
+
def save_image(img, path):
|
71 |
+
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
|
72 |
+
PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(path)
|
73 |
+
|
74 |
+
|
75 |
+
def unravel_index(index, shape):
|
76 |
+
out = []
|
77 |
+
for dim in reversed(shape):
|
78 |
+
out.append(index % dim)
|
79 |
+
index = index // dim
|
80 |
+
return tuple(reversed(out))
|
81 |
+
|
82 |
+
|
83 |
+
def num_range(s: str) -> List[int]:
|
84 |
+
'''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.'''
|
85 |
+
|
86 |
+
range_re = re.compile(r'^(\d+)-(\d+)$')
|
87 |
+
m = range_re.match(s)
|
88 |
+
if m:
|
89 |
+
return list(range(int(m.group(1)), int(m.group(2))+1))
|
90 |
+
vals = s.split(',')
|
91 |
+
return [int(x) for x in vals]
|
92 |
+
|
93 |
+
|
94 |
+
#----------------------------------------------------------------------------
|
95 |
+
|
96 |
+
|
97 |
+
|
98 |
+
def spherical_dist_loss(x, y):
|
99 |
+
x = F.normalize(x, dim=-1)
|
100 |
+
y = F.normalize(y, dim=-1)
|
101 |
+
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
|
102 |
+
|
103 |
+
|
104 |
+
def prompts_dist_loss(x, targets, loss):
|
105 |
+
if len(targets) == 1: # Keeps consistent results vs previous method for single objective guidance
|
106 |
+
return loss(x, targets[0])
|
107 |
+
distances = [loss(x, target) for target in targets]
|
108 |
+
return torch.stack(distances, dim=-1).sum(dim=-1)
|
109 |
+
|
110 |
+
|
111 |
+
def embed_text(model, prompt, device='cuda'):
|
112 |
+
return
|
113 |
+
|
114 |
+
|
115 |
+
#----------------------------------------------------------------------------
|
116 |
+
|
117 |
+
@torch.no_grad()
|
118 |
+
@torch.cuda.amp.autocast()
|
119 |
+
def generate_edit(
|
120 |
+
G,
|
121 |
+
dl,
|
122 |
+
direction,
|
123 |
+
edit_strength,
|
124 |
+
path,
|
125 |
+
):
|
126 |
+
for it, batch in enumerate(dl):
|
127 |
+
batch["embedding"] = None
|
128 |
+
styles = get_styles(None, G, batch, truncation_value=0)
|
129 |
+
imgs = []
|
130 |
+
grad_changes = [_*edit_strength for _ in [0, 0.25, 0.5, 0.75, 1]]
|
131 |
+
grad_changes = [*[-x for x in grad_changes][::-1], *grad_changes]
|
132 |
+
batch = {k: tops.to_cuda(v) if v is not None else v for k,v in batch.items()}
|
133 |
+
for i, grad_change in enumerate(grad_changes):
|
134 |
+
s = styles + direction*grad_change
|
135 |
+
|
136 |
+
img = G(**batch, s=iter(s))["img"]
|
137 |
+
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255)
|
138 |
+
imgs.append(img[0].to(torch.uint8).cpu().numpy())
|
139 |
+
PIL.Image.fromarray(np.concatenate(imgs, axis=1), 'RGB').save(path + f'{it}.png')
|
140 |
+
|
141 |
+
|
142 |
+
@torch.no_grad()
|
143 |
+
def get_styles(seed, G: torch.nn.Module, batch, truncation_value=1):
|
144 |
+
all_styles = []
|
145 |
+
if seed is None:
|
146 |
+
z = np.random.normal(0, 0, size=(1, G.z_channels))
|
147 |
+
else:
|
148 |
+
z = np.random.RandomState(seed=seed).normal(0, 1, size=(1, G.z_channels))
|
149 |
+
z_idx = np.random.RandomState(seed=seed).randint(0, len(G.style_net.w_centers))
|
150 |
+
w_c = G.style_net.w_centers[z_idx].to(tops.get_device()).view(1, -1)
|
151 |
+
w = G.style_net(torch.from_numpy(z).to(tops.get_device()))
|
152 |
+
|
153 |
+
w = w_c.to(w.dtype).lerp(w, truncation_value)
|
154 |
+
if hasattr(G, "get_comod_y"):
|
155 |
+
w = G.get_comod_y(batch, w)
|
156 |
+
for block in G.modules():
|
157 |
+
if not hasattr(block, "affine") or not hasattr(block.affine, "weight"):
|
158 |
+
continue
|
159 |
+
gamma0 = block.affine(w)
|
160 |
+
if hasattr(block, "affine_beta"):
|
161 |
+
beta0 = block.affine_beta(w)
|
162 |
+
gamma0 = torch.cat((gamma0, beta0), dim=1)
|
163 |
+
all_styles.append(gamma0)
|
164 |
+
max_ch = max([s.shape[-1] for s in all_styles])
|
165 |
+
all_styles = [F.pad(s, ((0, max_ch - s.shape[-1])), "constant", 0) for s in all_styles]
|
166 |
+
all_styles = torch.cat(all_styles)
|
167 |
+
return all_styles
|
168 |
+
|
169 |
+
def get_and_cache_direction(output_dir: Path, dl_val, G, text_prompt):
|
170 |
+
cache_path = output_dir.joinpath(
|
171 |
+
"stylemc_cache", text_prompt.replace(" ", "_") + ".torch")
|
172 |
+
if cache_path.is_file():
|
173 |
+
print("Loaded cache from:", cache_path)
|
174 |
+
return torch.load(cache_path)
|
175 |
+
direction = find_direction(G, text_prompt, None, dl_val=iter(dl_val))
|
176 |
+
cache_path.parent.mkdir(exist_ok=True, parents=True)
|
177 |
+
torch.save(direction, cache_path)
|
178 |
+
return direction
|
179 |
+
|
180 |
+
@torch.cuda.amp.autocast()
|
181 |
+
def find_direction(
|
182 |
+
G,
|
183 |
+
text_prompt,
|
184 |
+
batches,
|
185 |
+
#layers,
|
186 |
+
n_iterations=128*8,
|
187 |
+
batch_size=8,
|
188 |
+
dl_val=None
|
189 |
+
):
|
190 |
+
time_start = timer()
|
191 |
+
|
192 |
+
clip_model = clip.load("ViT-B/16", device=tops.get_device())[0]
|
193 |
+
|
194 |
+
target = [clip_model.encode_text(clip.tokenize(text_prompt).to(tops.get_device())).float()]
|
195 |
+
all_styles = []
|
196 |
+
if dl_val is not None:
|
197 |
+
first_batch = next(dl_val)
|
198 |
+
else:
|
199 |
+
first_batch = batches[0]
|
200 |
+
first_batch["embedding"] = None if "embedding" not in first_batch else first_batch["embedding"]
|
201 |
+
s = get_styles(0, G, first_batch)
|
202 |
+
# stats tracker
|
203 |
+
cos_sim_track = AverageMeter('cos_sim', ':.4f')
|
204 |
+
norm_track = AverageMeter('norm', ':.4f')
|
205 |
+
n_iterations = n_iterations // batch_size
|
206 |
+
progress = ProgressMeter(n_iterations, [cos_sim_track, norm_track])
|
207 |
+
|
208 |
+
# initalize styles direction
|
209 |
+
direction = torch.zeros(s.shape, device=tops.get_device())
|
210 |
+
direction.requires_grad_()
|
211 |
+
utils.set_requires_grad(G, False)
|
212 |
+
direction_tracker = torch.zeros_like(direction)
|
213 |
+
opt = torch.optim.AdamW([direction], lr=0.05, betas=(0., 0.999), weight_decay=0.25)
|
214 |
+
|
215 |
+
grads = []
|
216 |
+
for seed_idx in tqdm.trange(n_iterations):
|
217 |
+
# forward pass through synthesis network with new styles
|
218 |
+
if seed_idx == 0:
|
219 |
+
batch = first_batch
|
220 |
+
elif dl_val is not None:
|
221 |
+
batch = next(dl_val)
|
222 |
+
batch["embedding"] = None if "embedding" not in batch else batch["embedding"]
|
223 |
+
else:
|
224 |
+
batch = {k: tops.to_cuda(v) if v is not None else v for k, v in batches[seed_idx].items()}
|
225 |
+
styles = get_styles(seed_idx, G, batch) + direction
|
226 |
+
img = G(**batch, s=iter(styles))["img"]
|
227 |
+
batch = {k: v.cpu() if v is not None else v for k, v in batch.items()}
|
228 |
+
# clip loss
|
229 |
+
img = (img + 1)/2
|
230 |
+
img = normalize(img, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
231 |
+
img = resize(img, (224, 224))
|
232 |
+
embeds = clip_model.encode_image(img)
|
233 |
+
cos_sim = prompts_dist_loss(embeds, target, spherical_dist_loss)
|
234 |
+
cos_sim.backward(retain_graph=True)
|
235 |
+
|
236 |
+
# track stats
|
237 |
+
cos_sim_track.update(cos_sim.item())
|
238 |
+
norm_track.update(torch.norm(direction).item())
|
239 |
+
|
240 |
+
if not (seed_idx % batch_size):
|
241 |
+
|
242 |
+
# zeroing out gradients for non-optimized layers
|
243 |
+
#layers_zeroed = torch.tensor([x for x in range(G.num_ws) if not x in layers])
|
244 |
+
#direction.grad[:, layers_zeroed] = 0
|
245 |
+
|
246 |
+
opt.step()
|
247 |
+
grads.append(direction.grad.clone())
|
248 |
+
direction.grad.data.zero_()
|
249 |
+
|
250 |
+
# keep track of gradients over time
|
251 |
+
if seed_idx > 3:
|
252 |
+
direction_tracker[grads[-2] * grads[-1] < 0] += 1
|
253 |
+
|
254 |
+
# plot stats
|
255 |
+
progress.display(seed_idx)
|
256 |
+
|
257 |
+
# throw out fluctuating channels
|
258 |
+
direction = direction.detach()
|
259 |
+
direction[direction_tracker > n_iterations / 4] = 0
|
260 |
+
print(direction)
|
261 |
+
print(f"Time for direction search: {timer() - time_start:.2f} s")
|
262 |
+
return direction
|
263 |
+
|
264 |
+
|
265 |
+
|
266 |
+
|
267 |
+
@click.command()
|
268 |
+
@click.argument("config_path")
|
269 |
+
@click.argument("input_path")
|
270 |
+
@click.argument("output_path")
|
271 |
+
#@click.option('--layers', type=num_range, help='Restrict the style space to a range of layers. We recommend not to optimize the critically sampled layers (last 3).', required=True)
|
272 |
+
@click.option('--text-prompt', help='Text', type=str, required=True)
|
273 |
+
@click.option('--edit-strength', help='Strength of edit', type=float, required=True)
|
274 |
+
@click.option('--outdir', help='Where to save the output images', type=str, required=True)
|
275 |
+
def stylemc(
|
276 |
+
config_path,
|
277 |
+
#layers: List[int],
|
278 |
+
text_prompt: str,
|
279 |
+
edit_strength: float,
|
280 |
+
outdir: str,
|
281 |
+
):
|
282 |
+
cfg = utils.load_config(config_path)
|
283 |
+
G = build_trained_generator(cfg)
|
284 |
+
cfg.train.batch_size = 1
|
285 |
+
n_iterations = 256
|
286 |
+
dl_val = tops.config.instantiate(cfg.data.val.loader)
|
287 |
+
|
288 |
+
direction = find_direction(G, text_prompt, None, n_iterations=n_iterations, dl_val=iter(dl_val))
|
289 |
+
|
290 |
+
text_prompt = text_prompt.replace(" ", "_")
|
291 |
+
generate_edit(G, input_path, direction, edit_strength, output_path)
|
292 |
+
|
293 |
+
|
294 |
+
if __name__ == "__main__":
|
295 |
+
stylemc()
|