Spaces:
Build error
Build error
PKUWilliamYang
commited on
Commit
•
7f6643a
1
Parent(s):
01ad5b5
Upload latent_optimization.py
Browse files- latent_optimization.py +107 -0
latent_optimization.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import models.stylegan2.lpips as lpips
|
2 |
+
from torch import autograd, optim
|
3 |
+
from torchvision import transforms, utils
|
4 |
+
from tqdm import tqdm
|
5 |
+
import torch
|
6 |
+
from scripts.align_all_parallel import align_face
|
7 |
+
from utils.inference_utils import noise_regularize, noise_normalize_, get_lr, latent_noise, visualize
|
8 |
+
|
9 |
+
def latent_optimization(frame, pspex, landmarkpredictor, step=500, device='cuda'):
|
10 |
+
percept = lpips.PerceptualLoss(
|
11 |
+
model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
|
12 |
+
)
|
13 |
+
|
14 |
+
transform = transforms.Compose([
|
15 |
+
transforms.ToTensor(),
|
16 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]),
|
17 |
+
])
|
18 |
+
|
19 |
+
with torch.no_grad():
|
20 |
+
|
21 |
+
noise_sample = torch.randn(1000, 512, device=device)
|
22 |
+
latent_out = pspex.decoder.style(noise_sample)
|
23 |
+
latent_mean = latent_out.mean(0)
|
24 |
+
latent_std = ((latent_out - latent_mean).pow(2).sum() / 1000) ** 0.5
|
25 |
+
|
26 |
+
y = transform(frame).unsqueeze(dim=0).to(device)
|
27 |
+
I_ = align_face(frame, landmarkpredictor)
|
28 |
+
I_ = transform(I_).unsqueeze(dim=0).to(device)
|
29 |
+
wplus = pspex.encoder(I_) + pspex.latent_avg.unsqueeze(0)
|
30 |
+
_, f = pspex.encoder(y, return_feat=True)
|
31 |
+
latent_in = wplus.detach().clone()
|
32 |
+
feat = [f[0].detach().clone(), f[1].detach().clone()]
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
# wplus and f to optimize
|
37 |
+
latent_in.requires_grad = True
|
38 |
+
feat[0].requires_grad = True
|
39 |
+
feat[1].requires_grad = True
|
40 |
+
|
41 |
+
noises_single = pspex.decoder.make_noise()
|
42 |
+
basic_height, basic_width = int(y.shape[2]*32/256), int(y.shape[3]*32/256)
|
43 |
+
noises = []
|
44 |
+
for noise in noises_single:
|
45 |
+
noises.append(noise.new_empty(y.shape[0], 1, max(basic_height, int(y.shape[2]*noise.shape[2]/256)),
|
46 |
+
max(basic_width, int(y.shape[3]*noise.shape[2]/256))).normal_())
|
47 |
+
for noise in noises:
|
48 |
+
noise.requires_grad = True
|
49 |
+
|
50 |
+
init_lr=0.05
|
51 |
+
optimizer = optim.Adam(feat + noises, lr=init_lr)
|
52 |
+
optimizer2 = optim.Adam([latent_in], lr=init_lr)
|
53 |
+
noise_weight = 0.05 * 0.2
|
54 |
+
|
55 |
+
pbar = tqdm(range(step))
|
56 |
+
latent_path = []
|
57 |
+
|
58 |
+
for i in pbar:
|
59 |
+
t = i / step
|
60 |
+
lr = get_lr(t, init_lr)
|
61 |
+
optimizer.param_groups[0]["lr"] = lr
|
62 |
+
optimizer2.param_groups[0]["lr"] = get_lr(t, init_lr)
|
63 |
+
|
64 |
+
noise_strength = latent_std * noise_weight * max(0, 1 - t / 0.75) ** 2
|
65 |
+
latent_n = latent_noise(latent_in, noise_strength.item())
|
66 |
+
|
67 |
+
y_hat, _ = pspex.decoder([latent_n], input_is_latent=True, randomize_noise=False,
|
68 |
+
first_layer_feature=feat, noise=noises)
|
69 |
+
|
70 |
+
|
71 |
+
batch, channel, height, width = y_hat.shape
|
72 |
+
|
73 |
+
if height > y.shape[2]:
|
74 |
+
factor = height // y.shape[2]
|
75 |
+
|
76 |
+
y_hat = y_hat.reshape(
|
77 |
+
batch, channel, height // factor, factor, width // factor, factor
|
78 |
+
)
|
79 |
+
y_hat = y_hat.mean([3, 5])
|
80 |
+
|
81 |
+
p_loss = percept(y_hat, y).sum()
|
82 |
+
n_loss = noise_regularize(noises) * 1e3
|
83 |
+
|
84 |
+
loss = p_loss + n_loss
|
85 |
+
|
86 |
+
optimizer.zero_grad()
|
87 |
+
optimizer2.zero_grad()
|
88 |
+
loss.backward()
|
89 |
+
optimizer.step()
|
90 |
+
optimizer2.step()
|
91 |
+
|
92 |
+
noise_normalize_(noises)
|
93 |
+
|
94 |
+
''' for visualization
|
95 |
+
if (i + 1) % 100 == 0 or i == 0:
|
96 |
+
viz = torch.cat((y_hat,y,y_hat-y), dim=3)
|
97 |
+
visualize(torch.clamp(viz[0].cpu(),-1,1), 60)
|
98 |
+
'''
|
99 |
+
|
100 |
+
pbar.set_description(
|
101 |
+
(
|
102 |
+
f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
|
103 |
+
f" lr: {lr:.4f}"
|
104 |
+
)
|
105 |
+
)
|
106 |
+
|
107 |
+
return latent_n, feat, noises, wplus, f
|