|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Project given image to the latent space of pretrained network pickle.""" |
|
|
|
import copy |
|
import os |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from tqdm import tqdm |
|
import dnnlib |
|
import PIL |
|
from camera_utils import LookAtPoseSampler |
|
def project( |
|
G, |
|
c, |
|
outdir, |
|
target: torch.Tensor, |
|
*, |
|
num_steps=1000, |
|
w_avg_samples=10000, |
|
initial_learning_rate=0.01, |
|
initial_noise_factor=0.05, |
|
lr_rampdown_length=0.25, |
|
lr_rampup_length=0.05, |
|
noise_ramp_length=0.75, |
|
regularize_noise_weight=1e5, |
|
verbose=False, |
|
device: torch.device, |
|
initial_w=None, |
|
image_log_step=100, |
|
w_name: str |
|
): |
|
os.makedirs(f'{outdir}/{w_name}_w',exist_ok=True) |
|
outdir = f'{outdir}/{w_name}_w' |
|
assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution) |
|
|
|
def logprint(*args): |
|
if verbose: |
|
print(*args) |
|
|
|
G = copy.deepcopy(G).eval().requires_grad_(False).to(device).float() |
|
|
|
|
|
|
|
w_avg_path = './w_avg.npy' |
|
w_std_path = './w_std.npy' |
|
if (not os.path.exists(w_avg_path)) or (not os.path.exists(w_std_path)): |
|
print(f'Computing W midpoint and stddev using {w_avg_samples} samples...') |
|
z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) |
|
|
|
|
|
|
|
|
|
camera_lookat_point = torch.tensor(G.rendering_kwargs['avg_camera_pivot'], device=device) |
|
cam2world_pose = LookAtPoseSampler.sample(3.14 / 2, 3.14 / 2, camera_lookat_point, |
|
radius=G.rendering_kwargs['avg_camera_radius'], device=device) |
|
focal_length = 4.2647 |
|
intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) |
|
c_samples = torch.cat([cam2world_pose.reshape(-1, 16), intrinsics.reshape(-1, 9)], 1) |
|
c_samples = c_samples.repeat(w_avg_samples, 1) |
|
|
|
|
|
|
|
w_samples = G.mapping(torch.from_numpy(z_samples).to(device), c_samples) |
|
w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) |
|
w_avg = np.mean(w_samples, axis=0, keepdims=True) |
|
|
|
|
|
w_avg_tensor = torch.from_numpy(w_avg).cuda() |
|
w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 |
|
|
|
|
|
|
|
else: |
|
|
|
|
|
raise Exception(' ') |
|
|
|
start_w = initial_w if initial_w is not None else w_avg |
|
|
|
|
|
noise_bufs = {name: buf for (name, buf) in G.backbone.synthesis.named_buffers() if 'noise_const' in name} |
|
|
|
|
|
url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' |
|
|
|
with dnnlib.util.open_url(url) as f: |
|
vgg16 = torch.jit.load(f).eval().to(device) |
|
|
|
|
|
target_images = target.unsqueeze(0).to(device).to(torch.float32) |
|
if target_images.shape[2] > 256: |
|
target_images = F.interpolate(target_images, size=(256, 256), mode='area') |
|
target_features = vgg16(target_images, resize_images=False, return_lpips=True) |
|
|
|
w_opt = torch.tensor(start_w, dtype=torch.float32, device=device, |
|
requires_grad=True) |
|
print('w_opt shape: ',w_opt.shape) |
|
|
|
optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), |
|
lr=0.1) |
|
|
|
|
|
for buf in noise_bufs.values(): |
|
buf[:] = torch.randn_like(buf) |
|
buf.requires_grad = True |
|
|
|
for step in tqdm(range(num_steps), position=0, leave=True): |
|
|
|
|
|
t = step / num_steps |
|
w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 |
|
lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) |
|
lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) |
|
lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) |
|
lr = initial_learning_rate * lr_ramp |
|
for param_group in optimizer.param_groups: |
|
param_group['lr'] = lr |
|
|
|
|
|
w_noise = torch.randn_like(w_opt) * w_noise_scale |
|
ws = (w_opt + w_noise).repeat([1, G.backbone.mapping.num_ws, 1]) |
|
synth_images = G.synthesis(ws,c, noise_mode='const')['image'] |
|
|
|
if step % image_log_step == 0: |
|
with torch.no_grad(): |
|
vis_img = (synth_images.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) |
|
|
|
PIL.Image.fromarray(vis_img[0].cpu().numpy(), 'RGB').save(f'{outdir}/{step}.png') |
|
|
|
|
|
synth_images = (synth_images + 1) * (255 / 2) |
|
if synth_images.shape[2] > 256: |
|
synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') |
|
|
|
|
|
synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) |
|
dist = (target_features - synth_features).square().sum() |
|
|
|
|
|
reg_loss = 0.0 |
|
for v in noise_bufs.values(): |
|
noise = v[None, None, :, :] |
|
while True: |
|
reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2 |
|
reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2 |
|
if noise.shape[2] <= 8: |
|
break |
|
noise = F.avg_pool2d(noise, kernel_size=2) |
|
loss = dist + reg_loss * regularize_noise_weight |
|
|
|
|
|
|
|
|
|
|
|
|
|
optimizer.zero_grad(set_to_none=True) |
|
loss.backward() |
|
optimizer.step() |
|
logprint(f'step {step + 1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') |
|
|
|
|
|
with torch.no_grad(): |
|
for buf in noise_bufs.values(): |
|
buf -= buf.mean() |
|
buf *= buf.square().mean().rsqrt() |
|
|
|
|
|
return w_opt.repeat([1, G.backbone.mapping.num_ws, 1]) |
|
del G |
|
|