PKUWilliamYang commited on
Commit
7f6643a
1 Parent(s): 01ad5b5

Upload latent_optimization.py

Browse files
Files changed (1) hide show
  1. latent_optimization.py +107 -0
latent_optimization.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import models.stylegan2.lpips as lpips
2
+ from torch import autograd, optim
3
+ from torchvision import transforms, utils
4
+ from tqdm import tqdm
5
+ import torch
6
+ from scripts.align_all_parallel import align_face
7
+ from utils.inference_utils import noise_regularize, noise_normalize_, get_lr, latent_noise, visualize
8
+
9
+ def latent_optimization(frame, pspex, landmarkpredictor, step=500, device='cuda'):
10
+ percept = lpips.PerceptualLoss(
11
+ model="net-lin", net="vgg", use_gpu=device.startswith("cuda")
12
+ )
13
+
14
+ transform = transforms.Compose([
15
+ transforms.ToTensor(),
16
+ transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]),
17
+ ])
18
+
19
+ with torch.no_grad():
20
+
21
+ noise_sample = torch.randn(1000, 512, device=device)
22
+ latent_out = pspex.decoder.style(noise_sample)
23
+ latent_mean = latent_out.mean(0)
24
+ latent_std = ((latent_out - latent_mean).pow(2).sum() / 1000) ** 0.5
25
+
26
+ y = transform(frame).unsqueeze(dim=0).to(device)
27
+ I_ = align_face(frame, landmarkpredictor)
28
+ I_ = transform(I_).unsqueeze(dim=0).to(device)
29
+ wplus = pspex.encoder(I_) + pspex.latent_avg.unsqueeze(0)
30
+ _, f = pspex.encoder(y, return_feat=True)
31
+ latent_in = wplus.detach().clone()
32
+ feat = [f[0].detach().clone(), f[1].detach().clone()]
33
+
34
+
35
+
36
+ # wplus and f to optimize
37
+ latent_in.requires_grad = True
38
+ feat[0].requires_grad = True
39
+ feat[1].requires_grad = True
40
+
41
+ noises_single = pspex.decoder.make_noise()
42
+ basic_height, basic_width = int(y.shape[2]*32/256), int(y.shape[3]*32/256)
43
+ noises = []
44
+ for noise in noises_single:
45
+ noises.append(noise.new_empty(y.shape[0], 1, max(basic_height, int(y.shape[2]*noise.shape[2]/256)),
46
+ max(basic_width, int(y.shape[3]*noise.shape[2]/256))).normal_())
47
+ for noise in noises:
48
+ noise.requires_grad = True
49
+
50
+ init_lr=0.05
51
+ optimizer = optim.Adam(feat + noises, lr=init_lr)
52
+ optimizer2 = optim.Adam([latent_in], lr=init_lr)
53
+ noise_weight = 0.05 * 0.2
54
+
55
+ pbar = tqdm(range(step))
56
+ latent_path = []
57
+
58
+ for i in pbar:
59
+ t = i / step
60
+ lr = get_lr(t, init_lr)
61
+ optimizer.param_groups[0]["lr"] = lr
62
+ optimizer2.param_groups[0]["lr"] = get_lr(t, init_lr)
63
+
64
+ noise_strength = latent_std * noise_weight * max(0, 1 - t / 0.75) ** 2
65
+ latent_n = latent_noise(latent_in, noise_strength.item())
66
+
67
+ y_hat, _ = pspex.decoder([latent_n], input_is_latent=True, randomize_noise=False,
68
+ first_layer_feature=feat, noise=noises)
69
+
70
+
71
+ batch, channel, height, width = y_hat.shape
72
+
73
+ if height > y.shape[2]:
74
+ factor = height // y.shape[2]
75
+
76
+ y_hat = y_hat.reshape(
77
+ batch, channel, height // factor, factor, width // factor, factor
78
+ )
79
+ y_hat = y_hat.mean([3, 5])
80
+
81
+ p_loss = percept(y_hat, y).sum()
82
+ n_loss = noise_regularize(noises) * 1e3
83
+
84
+ loss = p_loss + n_loss
85
+
86
+ optimizer.zero_grad()
87
+ optimizer2.zero_grad()
88
+ loss.backward()
89
+ optimizer.step()
90
+ optimizer2.step()
91
+
92
+ noise_normalize_(noises)
93
+
94
+ ''' for visualization
95
+ if (i + 1) % 100 == 0 or i == 0:
96
+ viz = torch.cat((y_hat,y,y_hat-y), dim=3)
97
+ visualize(torch.clamp(viz[0].cpu(),-1,1), 60)
98
+ '''
99
+
100
+ pbar.set_description(
101
+ (
102
+ f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};"
103
+ f" lr: {lr:.4f}"
104
+ )
105
+ )
106
+
107
+ return latent_n, feat, noises, wplus, f