ArantxaCasanova
First model version
a00ee36
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
"""Project given image to the latent space of pretrained network pickle."""
import copy
import os
from time import perf_counter
import click
import imageio
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
import dnnlib
import legacy
def project(
G,
target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution
*,
num_steps=1000,
w_avg_samples=10000,
initial_learning_rate=0.1,
initial_noise_factor=0.05,
lr_rampdown_length=0.25,
lr_rampup_length=0.05,
noise_ramp_length=0.75,
regularize_noise_weight=1e5,
verbose=False,
device: torch.device,
):
assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution)
def logprint(*args):
if verbose:
print(*args)
G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore
# Compute w stats.
logprint(f"Computing W midpoint and stddev using {w_avg_samples} samples...")
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim)
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C]
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C]
w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C]
w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5
# Setup noise inputs.
noise_bufs = {
name: buf
for (name, buf) in G.synthesis.named_buffers()
if "noise_const" in name
}
# Load VGG16 feature detector.
url = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt"
with dnnlib.util.open_url(url) as f:
vgg16 = torch.jit.load(f).eval().to(device)
# Features for target image.
target_images = target.unsqueeze(0).to(device).to(torch.float32)
if target_images.shape[2] > 256:
target_images = F.interpolate(target_images, size=(256, 256), mode="area")
target_features = vgg16(target_images, resize_images=False, return_lpips=True)
w_opt = torch.tensor(
w_avg, dtype=torch.float32, device=device, requires_grad=True
) # pylint: disable=not-callable
w_out = torch.zeros(
[num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device
)
optimizer = torch.optim.Adam(
[w_opt] + list(noise_bufs.values()),
betas=(0.9, 0.999),
lr=initial_learning_rate,
)
# Init noise.
for buf in noise_bufs.values():
buf[:] = torch.randn_like(buf)
buf.requires_grad = True
for step in range(num_steps):
# Learning rate schedule.
t = step / num_steps
w_noise_scale = (
w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2
)
lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length)
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi)
lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length)
lr = initial_learning_rate * lr_ramp
for param_group in optimizer.param_groups:
param_group["lr"] = lr
# Synth images from opt_w.
w_noise = torch.randn_like(w_opt) * w_noise_scale
ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1])
synth_images = G.synthesis(ws, noise_mode="const")
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
synth_images = (synth_images + 1) * (255 / 2)
if synth_images.shape[2] > 256:
synth_images = F.interpolate(synth_images, size=(256, 256), mode="area")
# Features for synth images.
synth_features = vgg16(synth_images, resize_images=False, return_lpips=True)
dist = (target_features - synth_features).square().sum()
# Noise regularization.
reg_loss = 0.0
for v in noise_bufs.values():
noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d()
while True:
reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
if noise.shape[2] <= 8:
break
noise = F.avg_pool2d(noise, kernel_size=2)
loss = dist + reg_loss * regularize_noise_weight
# Step
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
logprint(
f"step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}"
)
# Save projected W for each optimization step.
w_out[step] = w_opt.detach()[0]
# Normalize noise.
with torch.no_grad():
for buf in noise_bufs.values():
buf -= buf.mean()
buf *= buf.square().mean().rsqrt()
return w_out.repeat([1, G.mapping.num_ws, 1])
# ----------------------------------------------------------------------------
@click.command()
@click.option("--network", "network_pkl", help="Network pickle filename", required=True)
@click.option(
"--target",
"target_fname",
help="Target image file to project to",
required=True,
metavar="FILE",
)
@click.option(
"--num-steps",
help="Number of optimization steps",
type=int,
default=1000,
show_default=True,
)
@click.option("--seed", help="Random seed", type=int, default=303, show_default=True)
@click.option(
"--save-video",
help="Save an mp4 video of optimization progress",
type=bool,
default=True,
show_default=True,
)
@click.option(
"--outdir", help="Where to save the output images", required=True, metavar="DIR"
)
def run_projection(
network_pkl: str,
target_fname: str,
outdir: str,
save_video: bool,
seed: int,
num_steps: int,
):
"""Project given image to the latent space of pretrained network pickle.
Examples:
\b
python projector.py --outdir=out --target=~/mytargetimg.png \\
--network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
"""
np.random.seed(seed)
torch.manual_seed(seed)
# Load networks.
print('Loading networks from "%s"...' % network_pkl)
device = torch.device("cuda")
with dnnlib.util.open_url(network_pkl) as fp:
G = (
legacy.load_network_pkl(fp)["G_ema"].requires_grad_(False).to(device)
) # type: ignore
# Load target image.
target_pil = PIL.Image.open(target_fname).convert("RGB")
w, h = target_pil.size
s = min(w, h)
target_pil = target_pil.crop(
((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)
)
target_pil = target_pil.resize(
(G.img_resolution, G.img_resolution), PIL.Image.LANCZOS
)
target_uint8 = np.array(target_pil, dtype=np.uint8)
# Optimize projection.
start_time = perf_counter()
projected_w_steps = project(
G,
target=torch.tensor(
target_uint8.transpose([2, 0, 1]), device=device
), # pylint: disable=not-callable
num_steps=num_steps,
device=device,
verbose=True,
)
print(f"Elapsed: {(perf_counter()-start_time):.1f} s")
# Render debug output: optional video and projected image and W vector.
os.makedirs(outdir, exist_ok=True)
if save_video:
video = imageio.get_writer(
f"{outdir}/proj.mp4", mode="I", fps=10, codec="libx264", bitrate="16M"
)
print(f'Saving optimization progress video "{outdir}/proj.mp4"')
for projected_w in projected_w_steps:
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode="const")
synth_image = (synth_image + 1) * (255 / 2)
synth_image = (
synth_image.permute(0, 2, 3, 1)
.clamp(0, 255)
.to(torch.uint8)[0]
.cpu()
.numpy()
)
video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
video.close()
# Save final projected frame and W vector.
target_pil.save(f"{outdir}/target.png")
projected_w = projected_w_steps[-1]
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode="const")
synth_image = (synth_image + 1) * (255 / 2)
synth_image = (
synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
)
PIL.Image.fromarray(synth_image, "RGB").save(f"{outdir}/proj.png")
np.savez(f"{outdir}/projected_w.npz", w=projected_w.unsqueeze(0).cpu().numpy())
# ----------------------------------------------------------------------------
if __name__ == "__main__":
run_projection() # pylint: disable=no-value-for-parameter
# ----------------------------------------------------------------------------