import os from argparse import Namespace import numpy as np import torch import sys sys.path.append(".") sys.path.append("..") from models.StyleGANControler import StyleGANControler class demo(): def __init__(self, checkpoint_path, truncation = 0.5, use_average_code_as_input = False): self.truncation = truncation self.use_average_code_as_input = use_average_code_as_input ckpt = torch.load(checkpoint_path, map_location='cpu') opts = ckpt['opts'] opts['checkpoint_path'] = checkpoint_path self.opts = Namespace(**ckpt['opts']) self.net = StyleGANControler(self.opts) self.net.eval() self.net.cuda() self.target_layers = [0,1,2,3,4,5] self.w1 = None self.w1_after = None self.f1 = None def run(self): z1 = torch.randn(1,512).to("cuda") x1, self.w1, self.f1 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_feature_map=True,return_latents=True,truncation=self.truncation, truncation_latent=self.net.latent_avg[0]) self.w1_after = self.w1.clone() x1 = self.net.face_pool(x1) result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] return result def translate(self, dxy, sxsy=[0,0], stop_points=[], zoom_in=False, zoom_out=False): dz = -5. if zoom_in else 0. dz = 5. if zoom_out else dz dxyz = np.array([dxy[0],dxy[1],dz], dtype=np.float32) dxy_norm = np.linalg.norm(dxyz[:2], ord=2) dxyz[:2] = dxyz[:2]/dxy_norm vec_num = dxy_norm/10 x = torch.from_numpy(np.array([[dxyz]],dtype=np.float32)).cuda() f1 = torch.nn.functional.interpolate(self.f1, (256,256)) y = f1[:,:,sxsy[1],sxsy[0]].unsqueeze(0) if len(stop_points)>0: x = torch.cat([x, torch.zeros(x.shape[0],len(stop_points),x.shape[2]).cuda()], dim=1) tmp = [] for sp in stop_points: tmp.append(f1[:,:,sp[1],sp[0]].unsqueeze(1)) y = torch.cat([y,torch.cat(tmp, dim=1)],dim=1) if not self.use_average_code_as_input: w_hat = self.net.encoder(self.w1[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num) w1 = self.w1.clone() w1[:,self.target_layers] = w_hat else: w_hat = self.net.encoder(self.net.latent_avg.unsqueeze(0)[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num) w1 = self.w1.clone() w1[:,self.target_layers] = self.w1.clone()[:,self.target_layers] + w_hat - self.net.latent_avg.unsqueeze(0)[:,self.target_layers] x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False) self.w1_after = w1.clone() x1 = self.net.face_pool(x1) result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] return result def zoom(self, dz, sxsy=[0,0], stop_points=[]): vec_num = abs(dz)/5 dz = 100*np.sign(dz) x = torch.from_numpy(np.array([[[1.,0,dz]]],dtype=np.float32)).cuda() f1 = torch.nn.functional.interpolate(self.f1, (256,256)) y = f1[:,:,sxsy[1],sxsy[0]].unsqueeze(0) if len(stop_points)>0: x = torch.cat([x, torch.zeros(x.shape[0],len(stop_points),x.shape[2]).cuda()], dim=1) tmp = [] for sp in stop_points: tmp.append(f1[:,:,sp[1],sp[0]].unsqueeze(1)) y = torch.cat([y,torch.cat(tmp, dim=1)],dim=1) if not self.use_average_code_as_input: w_hat = self.net.encoder(self.w1[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num) w1 = self.w1.clone() w1[:,self.target_layers] = w_hat else: w_hat = self.net.encoder(self.net.latent_avg.unsqueeze(0)[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num) w1 = self.w1.clone() w1[:,self.target_layers] = self.w1.clone()[:,self.target_layers] + w_hat - self.net.latent_avg.unsqueeze(0)[:,self.target_layers] x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False) x1 = self.net.face_pool(x1) result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] return result def change_style(self): z1 = torch.randn(1,512).to("cuda") x1, w2 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_latents=True, truncation=self.truncation, truncation_latent=self.net.latent_avg[0]) self.w1_after[:,6:] = w2.detach()[:,0] x1, _ = self.net.decoder([self.w1_after], input_is_latent=True, randomize_noise=False, return_latents=False) result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] return result def reset(self): x1, _ = self.net.decoder([self.w1], input_is_latent=True, randomize_noise=False, return_latents=False) result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] return result