Spaces:
Runtime error
Runtime error
""" | |
Approach: "StyleMC: Multi-Channel Based Fast Text-Guided Image Generation and Manipulation" | |
Original source code: | |
https://github.com/autonomousvision/stylegan_xl/blob/f9be58e98110bd946fcdadef2aac8345466faaf3/run_stylemc.py# | |
Modified by Håkon Hukkelås | |
""" | |
import os | |
from pathlib import Path | |
import tqdm | |
import re | |
import click | |
from dp2 import utils | |
import tops | |
from typing import List, Optional | |
import PIL.Image | |
import imageio | |
from timeit import default_timer as timer | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision.transforms.functional import resize, normalize | |
from dp2.infer import build_trained_generator | |
import clip | |
#---------------------------------------------------------------------------- | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self, name, fmt=':f'): | |
self.name = name | |
self.fmt = fmt | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def __str__(self): | |
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' | |
return fmtstr.format(**self.__dict__) | |
class ProgressMeter(object): | |
def __init__(self, num_batches, meters, prefix=""): | |
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) | |
self.meters = meters | |
self.prefix = prefix | |
def display(self, batch): | |
entries = [self.prefix + self.batch_fmtstr.format(batch)] | |
entries += [str(meter) for meter in self.meters] | |
print('\t'.join(entries)) | |
def _get_batch_fmtstr(self, num_batches): | |
num_digits = len(str(num_batches // 1)) | |
fmt = '{:' + str(num_digits) + 'd}' | |
return '[' + fmt + '/' + fmt.format(num_batches) + ']' | |
def save_image(img, path): | |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) | |
PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(path) | |
def unravel_index(index, shape): | |
out = [] | |
for dim in reversed(shape): | |
out.append(index % dim) | |
index = index // dim | |
return tuple(reversed(out)) | |
def num_range(s: str) -> List[int]: | |
'''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' | |
range_re = re.compile(r'^(\d+)-(\d+)$') | |
m = range_re.match(s) | |
if m: | |
return list(range(int(m.group(1)), int(m.group(2))+1)) | |
vals = s.split(',') | |
return [int(x) for x in vals] | |
#---------------------------------------------------------------------------- | |
def spherical_dist_loss(x, y): | |
x = F.normalize(x, dim=-1) | |
y = F.normalize(y, dim=-1) | |
return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) | |
def prompts_dist_loss(x, targets, loss): | |
if len(targets) == 1: # Keeps consistent results vs previous method for single objective guidance | |
return loss(x, targets[0]) | |
distances = [loss(x, target) for target in targets] | |
return torch.stack(distances, dim=-1).sum(dim=-1) | |
def embed_text(model, prompt, device='cuda'): | |
return | |
#---------------------------------------------------------------------------- | |
def generate_edit( | |
G, | |
dl, | |
direction, | |
edit_strength, | |
path, | |
): | |
for it, batch in enumerate(dl): | |
batch["embedding"] = None | |
styles = get_styles(None, G, batch, truncation_value=0) | |
imgs = [] | |
grad_changes = [_*edit_strength for _ in [0, 0.25, 0.5, 0.75, 1]] | |
grad_changes = [*[-x for x in grad_changes][::-1], *grad_changes] | |
batch = {k: tops.to_cuda(v) if v is not None else v for k,v in batch.items()} | |
for i, grad_change in enumerate(grad_changes): | |
s = styles + direction*grad_change | |
img = G(**batch, s=iter(s))["img"] | |
img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255) | |
imgs.append(img[0].to(torch.uint8).cpu().numpy()) | |
PIL.Image.fromarray(np.concatenate(imgs, axis=1), 'RGB').save(path + f'{it}.png') | |
def get_styles(seed, G: torch.nn.Module, batch, truncation_value=1): | |
all_styles = [] | |
if seed is None: | |
z = np.random.normal(0, 0, size=(1, G.z_channels)) | |
else: | |
z = np.random.RandomState(seed=seed).normal(0, 1, size=(1, G.z_channels)) | |
z_idx = np.random.RandomState(seed=seed).randint(0, len(G.style_net.w_centers)) | |
w_c = G.style_net.w_centers[z_idx].to(tops.get_device()).view(1, -1) | |
w = G.style_net(torch.from_numpy(z).to(tops.get_device())) | |
w = w_c.to(w.dtype).lerp(w, truncation_value) | |
if hasattr(G, "get_comod_y"): | |
w = G.get_comod_y(batch, w) | |
for block in G.modules(): | |
if not hasattr(block, "affine") or not hasattr(block.affine, "weight"): | |
continue | |
gamma0 = block.affine(w) | |
if hasattr(block, "affine_beta"): | |
beta0 = block.affine_beta(w) | |
gamma0 = torch.cat((gamma0, beta0), dim=1) | |
all_styles.append(gamma0) | |
max_ch = max([s.shape[-1] for s in all_styles]) | |
all_styles = [F.pad(s, ((0, max_ch - s.shape[-1])), "constant", 0) for s in all_styles] | |
all_styles = torch.cat(all_styles) | |
return all_styles | |
def get_and_cache_direction(output_dir: Path, dl_val, G, text_prompt): | |
cache_path = output_dir.joinpath( | |
"stylemc_cache", text_prompt.replace(" ", "_") + ".torch") | |
if cache_path.is_file(): | |
print("Loaded cache from:", cache_path) | |
return torch.load(cache_path) | |
direction = find_direction(G, text_prompt, None, dl_val=iter(dl_val)) | |
cache_path.parent.mkdir(exist_ok=True, parents=True) | |
torch.save(direction, cache_path) | |
return direction | |
def find_direction( | |
G, | |
text_prompt, | |
batches, | |
#layers, | |
n_iterations=128*8, | |
batch_size=8, | |
dl_val=None | |
): | |
time_start = timer() | |
clip_model = clip.load("ViT-B/16", device=tops.get_device())[0] | |
target = [clip_model.encode_text(clip.tokenize(text_prompt).to(tops.get_device())).float()] | |
all_styles = [] | |
if dl_val is not None: | |
first_batch = next(dl_val) | |
else: | |
first_batch = batches[0] | |
first_batch["embedding"] = None if "embedding" not in first_batch else first_batch["embedding"] | |
s = get_styles(0, G, first_batch) | |
# stats tracker | |
cos_sim_track = AverageMeter('cos_sim', ':.4f') | |
norm_track = AverageMeter('norm', ':.4f') | |
n_iterations = n_iterations // batch_size | |
progress = ProgressMeter(n_iterations, [cos_sim_track, norm_track]) | |
# initalize styles direction | |
direction = torch.zeros(s.shape, device=tops.get_device()) | |
direction.requires_grad_() | |
utils.set_requires_grad(G, False) | |
direction_tracker = torch.zeros_like(direction) | |
opt = torch.optim.AdamW([direction], lr=0.05, betas=(0., 0.999), weight_decay=0.25) | |
grads = [] | |
for seed_idx in tqdm.trange(n_iterations): | |
# forward pass through synthesis network with new styles | |
if seed_idx == 0: | |
batch = first_batch | |
elif dl_val is not None: | |
batch = next(dl_val) | |
batch["embedding"] = None if "embedding" not in batch else batch["embedding"] | |
else: | |
batch = {k: tops.to_cuda(v) if v is not None else v for k, v in batches[seed_idx].items()} | |
styles = get_styles(seed_idx, G, batch) + direction | |
img = G(**batch, s=iter(styles))["img"] | |
batch = {k: v.cpu() if v is not None else v for k, v in batch.items()} | |
# clip loss | |
img = (img + 1)/2 | |
img = normalize(img, mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) | |
img = resize(img, (224, 224)) | |
embeds = clip_model.encode_image(img) | |
cos_sim = prompts_dist_loss(embeds, target, spherical_dist_loss) | |
cos_sim.backward(retain_graph=True) | |
# track stats | |
cos_sim_track.update(cos_sim.item()) | |
norm_track.update(torch.norm(direction).item()) | |
if not (seed_idx % batch_size): | |
# zeroing out gradients for non-optimized layers | |
#layers_zeroed = torch.tensor([x for x in range(G.num_ws) if not x in layers]) | |
#direction.grad[:, layers_zeroed] = 0 | |
opt.step() | |
grads.append(direction.grad.clone()) | |
direction.grad.data.zero_() | |
# keep track of gradients over time | |
if seed_idx > 3: | |
direction_tracker[grads[-2] * grads[-1] < 0] += 1 | |
# plot stats | |
progress.display(seed_idx) | |
# throw out fluctuating channels | |
direction = direction.detach() | |
direction[direction_tracker > n_iterations / 4] = 0 | |
print(direction) | |
print(f"Time for direction search: {timer() - time_start:.2f} s") | |
return direction | |
#@click.option('--layers', type=num_range, help='Restrict the style space to a range of layers. We recommend not to optimize the critically sampled layers (last 3).', required=True) | |
def stylemc( | |
config_path, | |
#layers: List[int], | |
text_prompt: str, | |
edit_strength: float, | |
outdir: str, | |
): | |
cfg = utils.load_config(config_path) | |
G = build_trained_generator(cfg) | |
cfg.train.batch_size = 1 | |
n_iterations = 256 | |
dl_val = tops.config.instantiate(cfg.data.val.loader) | |
direction = find_direction(G, text_prompt, None, n_iterations=n_iterations, dl_val=iter(dl_val)) | |
text_prompt = text_prompt.replace(" ", "_") | |
generate_edit(G, input_path, direction, edit_strength, output_path) | |
if __name__ == "__main__": | |
stylemc() | |