Spaces:
Runtime error
Runtime error
import os | |
from argparse import Namespace | |
import numpy as np | |
import torch | |
import sys | |
sys.path.append(".") | |
sys.path.append("..") | |
from models.StyleGANControler import StyleGANControler | |
class demo(): | |
def __init__(self, checkpoint_path, truncation = 0.5, use_average_code_as_input = False): | |
self.truncation = truncation | |
self.use_average_code_as_input = use_average_code_as_input | |
ckpt = torch.load(checkpoint_path, map_location='cpu') | |
opts = ckpt['opts'] | |
opts['checkpoint_path'] = checkpoint_path | |
self.opts = Namespace(**ckpt['opts']) | |
self.net = StyleGANControler(self.opts) | |
self.net.eval() | |
self.net.cuda() | |
self.target_layers = [0,1,2,3,4,5] | |
self.w1 = None | |
self.w1_after = None | |
self.f1 = None | |
def run(self): | |
z1 = torch.randn(1,512).to("cuda") | |
x1, self.w1, self.f1 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_feature_map=True,return_latents=True,truncation=self.truncation, truncation_latent=self.net.latent_avg[0]) | |
self.w1_after = self.w1.clone() | |
x1 = self.net.face_pool(x1) | |
result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] | |
return result | |
def translate(self, dxy, sxsy=[0,0], stop_points=[], zoom_in=False, zoom_out=False): | |
dz = -5. if zoom_in else 0. | |
dz = 5. if zoom_out else dz | |
dxyz = np.array([dxy[0],dxy[1],dz], dtype=np.float32) | |
dxy_norm = np.linalg.norm(dxyz[:2], ord=2) | |
dxyz[:2] = dxyz[:2]/dxy_norm | |
vec_num = dxy_norm/10 | |
x = torch.from_numpy(np.array([[dxyz]],dtype=np.float32)).cuda() | |
f1 = torch.nn.functional.interpolate(self.f1, (256,256)) | |
y = f1[:,:,sxsy[1],sxsy[0]].unsqueeze(0) | |
if len(stop_points)>0: | |
x = torch.cat([x, torch.zeros(x.shape[0],len(stop_points),x.shape[2]).cuda()], dim=1) | |
tmp = [] | |
for sp in stop_points: | |
tmp.append(f1[:,:,sp[1],sp[0]].unsqueeze(1)) | |
y = torch.cat([y,torch.cat(tmp, dim=1)],dim=1) | |
if not self.use_average_code_as_input: | |
w_hat = self.net.encoder(self.w1[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num) | |
w1 = self.w1.clone() | |
w1[:,self.target_layers] = w_hat | |
else: | |
w_hat = self.net.encoder(self.net.latent_avg.unsqueeze(0)[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num) | |
w1 = self.w1.clone() | |
w1[:,self.target_layers] = self.w1.clone()[:,self.target_layers] + w_hat - self.net.latent_avg.unsqueeze(0)[:,self.target_layers] | |
x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False) | |
self.w1_after = w1.clone() | |
x1 = self.net.face_pool(x1) | |
result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] | |
return result | |
def zoom(self, dz, sxsy=[0,0], stop_points=[]): | |
vec_num = abs(dz)/5 | |
dz = 100*np.sign(dz) | |
x = torch.from_numpy(np.array([[[1.,0,dz]]],dtype=np.float32)).cuda() | |
f1 = torch.nn.functional.interpolate(self.f1, (256,256)) | |
y = f1[:,:,sxsy[1],sxsy[0]].unsqueeze(0) | |
if len(stop_points)>0: | |
x = torch.cat([x, torch.zeros(x.shape[0],len(stop_points),x.shape[2]).cuda()], dim=1) | |
tmp = [] | |
for sp in stop_points: | |
tmp.append(f1[:,:,sp[1],sp[0]].unsqueeze(1)) | |
y = torch.cat([y,torch.cat(tmp, dim=1)],dim=1) | |
if not self.use_average_code_as_input: | |
w_hat = self.net.encoder(self.w1[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num) | |
w1 = self.w1.clone() | |
w1[:,self.target_layers] = w_hat | |
else: | |
w_hat = self.net.encoder(self.net.latent_avg.unsqueeze(0)[:,self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num) | |
w1 = self.w1.clone() | |
w1[:,self.target_layers] = self.w1.clone()[:,self.target_layers] + w_hat - self.net.latent_avg.unsqueeze(0)[:,self.target_layers] | |
x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False) | |
x1 = self.net.face_pool(x1) | |
result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] | |
return result | |
def change_style(self): | |
z1 = torch.randn(1,512).to("cuda") | |
x1, w2 = self.net.decoder([z1],input_is_latent=False,randomize_noise=False,return_latents=True, truncation=self.truncation, truncation_latent=self.net.latent_avg[0]) | |
self.w1_after[:,6:] = w2.detach()[:,0] | |
x1, _ = self.net.decoder([self.w1_after], input_is_latent=True, randomize_noise=False, return_latents=False) | |
result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] | |
return result | |
def reset(self): | |
x1, _ = self.net.decoder([self.w1], input_is_latent=True, randomize_noise=False, return_latents=False) | |
result = ((x1.detach()[0].permute(1,2,0)+1.)*127.5).cpu().numpy()[:,:,::-1] | |
return result | |