Yiyuan's picture
Upload 98 files
96a9519 verified
import math
import os
from draggan.viz import renderer
import torch
from torch import optim
from torch.nn import functional as F
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import dataclasses
import draggan.dnnlib as dnnlib
from .lpips import util
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
lr_ramp = min(1, (1 - t) / rampdown)
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
lr_ramp = lr_ramp * min(1, t / rampup)
return initial_lr * lr_ramp
def make_image(tensor):
return (
tensor.detach()
.clamp_(min=-1, max=1)
.add(1)
.div_(2)
.mul(255)
.type(torch.uint8)
.permute(0, 2, 3, 1)
.to("cpu")
.numpy()
)
@dataclasses.dataclass
class InverseConfig:
lr_warmup = 0.05
lr_decay = 0.25
lr = 0.1
noise = 0.05
noise_decay = 0.75
# step = 1000
step = 1000
noise_regularize = 1e5
mse = 0.1
def inverse_image(
g_ema,
image,
percept,
image_size=256,
w_plus = False,
config=InverseConfig(),
device='cuda:0'
):
args = config
n_mean_latent = 10000
resize = min(image_size, 256)
if torch.is_tensor(image)==False:
transform = transforms.Compose(
[
transforms.Resize(resize,),
transforms.CenterCrop(resize),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
img = transform(image)
else:
img = transforms.functional.resize(image,resize)
transform = transforms.Compose(
[
transforms.CenterCrop(resize),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
img = transform(img)
imgs = []
imgs.append(img)
imgs = torch.stack(imgs, 0).to(device)
with torch.no_grad():
#noise_sample = torch.randn(n_mean_latent, 512, device=device)
noise_sample = torch.randn(n_mean_latent, g_ema.z_dim, device=device)
#label = torch.zeros([n_mean_latent,g_ema.c_dim],device = device)
w_samples = g_ema.mapping(noise_sample,None)
w_samples = w_samples[:, :1, :]
w_avg = w_samples.mean(0)
w_std = ((w_samples - w_avg).pow(2).sum() / n_mean_latent) ** 0.5
noises = {name: buf for (name, buf) in g_ema.synthesis.named_buffers() if 'noise_const' in name}
for noise in noises.values():
noise = torch.randn_like(noise)
noise.requires_grad = True
w_opt = w_avg.detach().clone()
if w_plus:
w_opt = w_opt.repeat(1,g_ema.mapping.num_ws, 1)
w_opt.requires_grad = True
#if args.w_plus:
#latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1)
optimizer = optim.Adam([w_opt] + list(noises.values()), lr=args.lr)
pbar = tqdm(range(args.step))
latent_path = []
for i in pbar:
t = i / args.step
lr = get_lr(t, args.lr)
optimizer.param_groups[0]["lr"] = lr
noise_strength = w_std * args.noise * max(0, 1 - t / args.noise_decay) ** 2
w_noise = torch.randn_like(w_opt) * noise_strength
if w_plus:
ws = w_opt + w_noise
else:
ws = (w_opt + w_noise).repeat([1, g_ema.mapping.num_ws, 1])
img_gen = g_ema.synthesis(ws, noise_mode='const', force_fp32=True)
#latent_n = latent_noise(latent_in, noise_strength.item())
#latent, noise = g_ema.prepare([latent_n], input_is_latent=True, noise=noises)
#img_gen, F = g_ema.generate(latent, noise)
# Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images.
if img_gen.shape[2] > 256:
img_gen = F.interpolate(img_gen, size=(256, 256), mode='area')
p_loss = percept(img_gen,imgs)
# Noise regularization.
reg_loss = 0.0
for v in noises.values():
noise = v[None, None, :, :] # must be [1,1,H,W] for F.avg_pool2d()
while True:
reg_loss += (noise * torch.roll(noise, shifts=1, dims=3)).mean() ** 2
reg_loss += (noise * torch.roll(noise, shifts=1, dims=2)).mean() ** 2
if noise.shape[2] <= 8:
break
noise = F.avg_pool2d(noise, kernel_size=2)
mse_loss = F.mse_loss(img_gen, imgs)
loss = p_loss + args.noise_regularize * reg_loss + args.mse * mse_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Normalize noise.
with torch.no_grad():
for buf in noises.values():
buf -= buf.mean()
buf *= buf.square().mean().rsqrt()
if (i + 1) % 100 == 0:
latent_path.append(w_opt.detach().clone())
pbar.set_description(
(
f"perceptual: {p_loss.item():.4f}; noise regularize: {reg_loss:.4f};"
f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}"
)
)
#latent, noise = g_ema.prepare([latent_path[-1]], input_is_latent=True, noise=noises)
#img_gen, F = g_ema.generate(latent, noise)
if w_plus:
ws = latent_path[-1]
else:
ws = latent_path[-1].repeat([1, g_ema.mapping.num_ws, 1])
img_gen = g_ema.synthesis(ws, noise_mode='const')
result = {
"latent": latent_path[-1],
"sample": img_gen,
"real": imgs,
}
return result
def toogle_grad(model, flag=True):
for p in model.parameters():
p.requires_grad = flag
class PTI:
def __init__(self,G, percept, l2_lambda = 1,max_pti_step = 400, pti_lr = 3e-4 ):
self.g_ema = G
self.l2_lambda = l2_lambda
self.max_pti_step = max_pti_step
self.pti_lr = pti_lr
self.percept = percept
def cacl_loss(self,percept, generated_image,real_image):
mse_loss = F.mse_loss(generated_image, real_image)
p_loss = percept(generated_image, real_image).sum()
loss = p_loss +self.l2_lambda * mse_loss
return loss
def train(self,img,w_plus=False):
if not torch.cuda.is_available():
device = 'cpu'
else:
device = 'cuda'
if torch.is_tensor(img) == False:
transform = transforms.Compose(
[
transforms.Resize(self.g_ema.img_resolution, ),
transforms.CenterCrop(self.g_ema.img_resolution),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
real_img = transform(img).to(device).unsqueeze(0)
else:
img = transforms.functional.resize(img, self.g_ema.img_resolution)
transform = transforms.Compose(
[
transforms.CenterCrop(self.g_ema.img_resolution),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
real_img = transform(img).to(device).unsqueeze(0)
inversed_result = inverse_image(self.g_ema,img,self.percept,self.g_ema.img_resolution,w_plus,device=device)
w_pivot = inversed_result['latent']
if w_plus:
ws = w_pivot
else:
ws = w_pivot.repeat([1, self.g_ema.mapping.num_ws, 1])
toogle_grad(self.g_ema,True)
optimizer = torch.optim.Adam(self.g_ema.parameters(), lr=self.pti_lr)
print('start PTI')
pbar = tqdm(range(self.max_pti_step))
for i in pbar:
t = i / self.max_pti_step
lr = get_lr(t, self.pti_lr)
optimizer.param_groups[0]["lr"] = lr
generated_image = self.g_ema.synthesis(ws,noise_mode='const')
loss = self.cacl_loss(self.percept,generated_image,real_img)
pbar.set_description(
(
f"loss: {loss.item():.4f}"
)
)
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
generated_image = self.g_ema.synthesis(ws, noise_mode='const')
return generated_image,ws