Spaces:
Runtime error
Runtime error
File size: 4,676 Bytes
d7dbcdd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
import os
from argparse import Namespace
import numpy as np
import torch
import sys
sys.path.append(".")
sys.path.append("..")
from models.StyleGANControler import StyleGANControler
class demo():
def __init__(self, checkpoint_path, truncation = 0.5, use_average_code_as_input = False):
self.truncation = truncation
self.use_average_code_as_input = use_average_code_as_input
ckpt = torch.load(checkpoint_path, map_location='cpu')
opts = ckpt['opts']
opts['checkpoint_path'] = checkpoint_path
self.opts = Namespace(**ckpt['opts'])
self.net = StyleGANControler(self.opts)
self.net.eval()
self.net.cuda()
self.target_layers = [0,1,2,3,4,5]
self.w1 = None
self.w1_after = None
self.f1 = None
def run(self):
z1 = torch.randn(1,512).to("cuda")
x1, self.w1, self.f1 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_feature_map=True,return_latents=True,truncation=self.truncation, truncation_latent=self.net.latent_avg[0])
self.w1_after = self.w1.clone()
x1 = self.net.face_pool(x1)
result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
return result
def translate(self, dxy, sxsy=[0,0], stop_points=[], zoom_in=False, zoom_out=False):
dz = -5. if zoom_in else 0.
dz = 5. if zoom_out else dz
dxyz = np.array([dxy[0],dxy[1],dz], dtype=np.float32)
dxy_norm = np.linalg.norm(dxyz[:2], ord=2)
dxyz[:2] = dxyz[:2]/dxy_norm
vec_num = dxy_norm/10
x = torch.from_numpy(np.array([[dxyz]],dtype=np.float32)).cuda()
f1 = torch.nn.functional.interpolate(self.f1, (256,256))
y = f1[:,:,sxsy[1],sxsy[0]].unsqueeze(0)
if len(stop_points)>0:
x = torch.cat([x, torch.zeros(x.shape[0],len(stop_points),x.shape[2]).cuda()], dim=1)
tmp = []
for sp in stop_points:
tmp.append(f1[:,:,sp[1],sp[0]].unsqueeze(1))
y = torch.cat([y,torch.cat(tmp, dim=1)],dim=1)
if not self.use_average_code_as_input:
w_hat = self.net.encoder(self.w1[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num)
w1 = self.w1.clone()
w1[:,self.target_layers] = w_hat
else:
w_hat = self.net.encoder(self.net.latent_avg.unsqueeze(0)[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num)
w1 = self.w1.clone()
w1[:,self.target_layers] = self.w1.clone()[:,self.target_layers] + w_hat - self.net.latent_avg.unsqueeze(0)[:,self.target_layers]
x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
self.w1_after = w1.clone()
x1 = self.net.face_pool(x1)
result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
return result
def zoom(self, dz, sxsy=[0,0], stop_points=[]):
vec_num = abs(dz)/5
dz = 100*np.sign(dz)
x = torch.from_numpy(np.array([[[1.,0,dz]]],dtype=np.float32)).cuda()
f1 = torch.nn.functional.interpolate(self.f1, (256,256))
y = f1[:,:,sxsy[1],sxsy[0]].unsqueeze(0)
if len(stop_points)>0:
x = torch.cat([x, torch.zeros(x.shape[0],len(stop_points),x.shape[2]).cuda()], dim=1)
tmp = []
for sp in stop_points:
tmp.append(f1[:,:,sp[1],sp[0]].unsqueeze(1))
y = torch.cat([y,torch.cat(tmp, dim=1)],dim=1)
if not self.use_average_code_as_input:
w_hat = self.net.encoder(self.w1[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num)
w1 = self.w1.clone()
w1[:,self.target_layers] = w_hat
else:
w_hat = self.net.encoder(self.net.latent_avg.unsqueeze(0)[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num)
w1 = self.w1.clone()
w1[:,self.target_layers] = self.w1.clone()[:,self.target_layers] + w_hat - self.net.latent_avg.unsqueeze(0)[:,self.target_layers]
x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
x1 = self.net.face_pool(x1)
result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
return result
def change_style(self):
z1 = torch.randn(1,512).to("cuda")
x1, w2 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_latents=True, truncation=self.truncation, truncation_latent=self.net.latent_avg[0])
self.w1_after[:,6:] = w2.detach()[:,0]
x1, _ = self.net.decoder([self.w1_after], input_is_latent=True, randomize_noise=False, return_latents=False)
result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
return result
def reset(self):
x1, _ = self.net.decoder([self.w1], input_is_latent=True, randomize_noise=False, return_latents=False)
result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1]
return result
|