File size: 4,676 Bytes
d7dbcdd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
from argparse import Namespace
import numpy as np
import torch
import sys

sys.path.append(".")
sys.path.append("..")

from models.StyleGANControler import StyleGANControler

class demo():
	
	def __init__(self, checkpoint_path, truncation = 0.5, use_average_code_as_input = False):
		self.truncation = truncation
		self.use_average_code_as_input = use_average_code_as_input
		ckpt = torch.load(checkpoint_path, map_location='cpu')
		opts = ckpt['opts']
		opts['checkpoint_path'] = checkpoint_path
		self.opts = Namespace(**ckpt['opts'])
		
		self.net = StyleGANControler(self.opts)
		self.net.eval()
		self.net.cuda()
		self.target_layers = [0,1,2,3,4,5]
		
		self.w1 = None
		self.w1_after = None
		self.f1 = None

	def run(self):
		z1 = torch.randn(1,512).to("cuda")
		x1, self.w1, self.f1 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_feature_map=True,return_latents=True,truncation=self.truncation, truncation_latent=self.net.latent_avg[0])
		self.w1_after = self.w1.clone()
		x1 = self.net.face_pool(x1)
		result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
		return result
	
	def translate(self, dxy, sxsy=[0,0], stop_points=[], zoom_in=False, zoom_out=False):
		dz = -5. if zoom_in else 0.
		dz = 5. if zoom_out else dz
			
		dxyz = np.array([dxy[0],dxy[1],dz], dtype=np.float32)
		dxy_norm = np.linalg.norm(dxyz[:2], ord=2)
		dxyz[:2] = dxyz[:2]/dxy_norm
		vec_num = dxy_norm/10
		
		x = torch.from_numpy(np.array([[dxyz]],dtype=np.float32)).cuda()
		f1 = torch.nn.functional.interpolate(self.f1, (256,256))
		y = f1[:,:,sxsy[1],sxsy[0]].unsqueeze(0)

		if len(stop_points)>0:
			x = torch.cat([x, torch.zeros(x.shape[0],len(stop_points),x.shape[2]).cuda()], dim=1)
			tmp = []
			for sp in stop_points:
				tmp.append(f1[:,:,sp[1],sp[0]].unsqueeze(1))
			y = torch.cat([y,torch.cat(tmp, dim=1)],dim=1)

		if not self.use_average_code_as_input:
			w_hat = self.net.encoder(self.w1[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num)
			w1 = self.w1.clone()
			w1[:,self.target_layers] = w_hat
		else:
			w_hat = self.net.encoder(self.net.latent_avg.unsqueeze(0)[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num)
			w1 = self.w1.clone()
			w1[:,self.target_layers]  = self.w1.clone()[:,self.target_layers]  + w_hat - self.net.latent_avg.unsqueeze(0)[:,self.target_layers]

		x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
		
		self.w1_after = w1.clone()
		x1 = self.net.face_pool(x1)
		result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
		return result
	
	def zoom(self, dz, sxsy=[0,0], stop_points=[]):
		vec_num = abs(dz)/5
		dz = 100*np.sign(dz)
		x = torch.from_numpy(np.array([[[1.,0,dz]]],dtype=np.float32)).cuda()
		f1 = torch.nn.functional.interpolate(self.f1, (256,256))
		y = f1[:,:,sxsy[1],sxsy[0]].unsqueeze(0)

		if len(stop_points)>0:
			x = torch.cat([x, torch.zeros(x.shape[0],len(stop_points),x.shape[2]).cuda()], dim=1)
			tmp = []
			for sp in stop_points:
				tmp.append(f1[:,:,sp[1],sp[0]].unsqueeze(1))
			y = torch.cat([y,torch.cat(tmp, dim=1)],dim=1)
			
		if not self.use_average_code_as_input:
			w_hat = self.net.encoder(self.w1[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num)
			w1 = self.w1.clone()
			w1[:,self.target_layers] = w_hat
		else:
			w_hat = self.net.encoder(self.net.latent_avg.unsqueeze(0)[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num)
			w1 = self.w1.clone()
			w1[:,self.target_layers]  = self.w1.clone()[:,self.target_layers]  + w_hat - self.net.latent_avg.unsqueeze(0)[:,self.target_layers]
		
		
		x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
		
		x1 = self.net.face_pool(x1)
		result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
		return result
	
	def change_style(self):
		z1 = torch.randn(1,512).to("cuda")
		x1, w2 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_latents=True, truncation=self.truncation, truncation_latent=self.net.latent_avg[0])
		self.w1_after[:,6:] = w2.detach()[:,0]
		x1, _ = self.net.decoder([self.w1_after], input_is_latent=True, randomize_noise=False, return_latents=False)
		result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
		return result
	
	def reset(self):
		x1, _ = self.net.decoder([self.w1], input_is_latent=True, randomize_noise=False, return_latents=False)
		result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
		return result