import os from argparse import Namespace import numpy as np import torch from models.StyleGANControler import StyleGANControler class Model: def __init__( self, checkpoint_path, truncation=0.5, use_average_code_as_input=False ): self.truncation = truncation self.use_average_code_as_input = use_average_code_as_input ckpt = torch.load(checkpoint_path, map_location="cpu") opts = ckpt["opts"] opts["checkpoint_path"] = checkpoint_path self.opts = Namespace(**ckpt["opts"]) self.net = StyleGANControler(self.opts) self.net.eval() self.net.cuda() self.target_layers = [0, 1, 2, 3, 4, 5] def random_sample(self): z1 = torch.randn(1, 512).to("cuda") x1, w1, f1 = self.net.decoder( [z1], input_is_latent=False, randomize_noise=False, return_feature_map=True, return_latents=True, truncation=self.truncation, truncation_latent=self.net.latent_avg[0], ) w1_initial = w1.clone() x1 = self.net.face_pool(x1) image = ( ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1] ) return ( image, { "w1": w1.cpu().detach().numpy(), "w1_initial": w1_initial.cpu().detach().numpy(), }, ) # return latent vector along with the image def latents_to_tensor(self, latents): w1 = latents["w1"] w1_initial = latents["w1_initial"] w1 = torch.tensor(w1).to("cuda") w1_initial = torch.tensor(w1_initial).to("cuda") x1, w1, f1 = self.net.decoder( [w1], input_is_latent=True, randomize_noise=False, return_feature_map=True, return_latents=True, ) x1, w1_initial, f1 = self.net.decoder( [w1_initial], input_is_latent=True, randomize_noise=False, return_feature_map=True, return_latents=True, ) return (w1, w1_initial, f1) def transform( self, latents, dz, dxy, sxsy=[0, 0], stop_points=[], zoom_in=False, zoom_out=False, ): w1, w1_initial, f1 = self.latents_to_tensor(latents) w1 = w1_initial.clone() dxyz = np.array([dxy[0], dxy[1], dz], dtype=np.float32) dxy_norm = np.linalg.norm(dxyz[:2], ord=2) epsilon = 1e-8 dxy_norm = dxy_norm + epsilon dxyz[:2] = dxyz[:2] / dxy_norm vec_num = dxy_norm / 10 x = torch.from_numpy(np.array([[dxyz]], dtype=np.float32)).cuda() f1 = torch.nn.functional.interpolate(f1, (256, 256)) y = f1[:, :, sxsy[1], sxsy[0]].unsqueeze(0) if len(stop_points) > 0: x = torch.cat( [x, torch.zeros(x.shape[0], len(stop_points), x.shape[2]).cuda()], dim=1 ) tmp = [] for sp in stop_points: tmp.append(f1[:, :, sp[1], sp[0]].unsqueeze(1)) y = torch.cat([y, torch.cat(tmp, dim=1)], dim=1) if not self.use_average_code_as_input: w_hat = self.net.encoder( w1[:, self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num, ) w1 = w1.clone() w1[:, self.target_layers] = w_hat else: w_hat = self.net.encoder( self.net.latent_avg.unsqueeze(0)[:, self.target_layers].detach(), x.detach(), y.detach(), alpha=vec_num, ) w1 = w1.clone() w1[:, self.target_layers] = ( w1.clone()[:, self.target_layers] + w_hat - self.net.latent_avg.unsqueeze(0)[:, self.target_layers] ) x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False) x1 = self.net.face_pool(x1) result = ( ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1] ) return ( result, { "w1": w1.cpu().detach().numpy(), "w1_initial": w1_initial.cpu().detach().numpy(), }, ) def change_style(self, latents): w1, w1_initial, f1 = self.latents_to_tensor(latents) w1 = w1_initial.clone() z1 = torch.randn(1, 512).to("cuda") x1, w2 = self.net.decoder( [z1], input_is_latent=False, randomize_noise=False, return_latents=True, truncation=self.truncation, truncation_latent=self.net.latent_avg[0], ) w1[:, 6:] = w2.detach()[:, 0] x1, w1_new = self.net.decoder( [w1], input_is_latent=True, randomize_noise=False, return_latents=True, ) result = ( ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1] ) return ( result, { "w1": w1_new.cpu().detach().numpy(), "w1_initial": w1_initial.cpu().detach().numpy(), }, ) def reset(self, latents): w1, w1_initial, f1 = self.latents_to_tensor(latents) x1, w1_new, f1 = self.net.decoder( [w1_initial], input_is_latent=True, randomize_noise=False, return_feature_map=True, return_latents=True, ) result = ( ((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1] ) return ( result, { "w1": w1_new.cpu().detach().numpy(), "w1_initial": w1_new.cpu().detach().numpy(), }, )