radames's picture
radames HF staff
add inversion
d9778ff
import os
from argparse import Namespace
import numpy as np
import torch
from models.StyleGANControler import StyleGANControler
class Model:
def __init__(
self, checkpoint_path, truncation=0.5, use_average_code_as_input=False
):
self.truncation = truncation
self.use_average_code_as_input = use_average_code_as_input
ckpt = torch.load(checkpoint_path, map_location="cpu")
opts = ckpt["opts"]
opts["checkpoint_path"] = checkpoint_path
self.opts = Namespace(**ckpt["opts"])
self.net = StyleGANControler(self.opts)
self.net.eval()
self.net.cuda()
self.target_layers = [0, 1, 2, 3, 4, 5]
def random_sample(self):
z1 = torch.randn(1, 512).to("cuda")
x1, w1, f1 = self.net.decoder(
[z1],
input_is_latent=False,
randomize_noise=False,
return_feature_map=True,
return_latents=True,
truncation=self.truncation,
truncation_latent=self.net.latent_avg[0],
)
w1_initial = w1.clone()
x1 = self.net.face_pool(x1)
image = (
((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
)
return (
image,
{
"w1": w1.cpu().detach().numpy(),
"w1_initial": w1_initial.cpu().detach().numpy(),
},
) # return latent vector along with the image
def latents_to_tensor(self, latents):
w1 = latents["w1"]
w1_initial = latents["w1_initial"]
w1 = torch.tensor(w1).to("cuda")
w1_initial = torch.tensor(w1_initial).to("cuda")
x1, w1, f1 = self.net.decoder(
[w1],
input_is_latent=True,
randomize_noise=False,
return_feature_map=True,
return_latents=True,
)
x1, w1_initial, f1 = self.net.decoder(
[w1_initial],
input_is_latent=True,
randomize_noise=False,
return_feature_map=True,
return_latents=True,
)
return (w1, w1_initial, f1)
def transform(
self,
latents,
dz,
dxy,
sxsy=[0, 0],
stop_points=[],
zoom_in=False,
zoom_out=False,
):
w1, w1_initial, f1 = self.latents_to_tensor(latents)
w1 = w1_initial.clone()
dxyz = np.array([dxy[0], dxy[1], dz], dtype=np.float32)
dxy_norm = np.linalg.norm(dxyz[:2], ord=2)
epsilon = 1e-8
dxy_norm = dxy_norm + epsilon
dxyz[:2] = dxyz[:2] / dxy_norm
vec_num = dxy_norm / 10
x = torch.from_numpy(np.array([[dxyz]], dtype=np.float32)).cuda()
f1 = torch.nn.functional.interpolate(f1, (256, 256))
y = f1[:, :, sxsy[1], sxsy[0]].unsqueeze(0)
if len(stop_points) > 0:
x = torch.cat(
[x, torch.zeros(x.shape[0], len(stop_points), x.shape[2]).cuda()], dim=1
)
tmp = []
for sp in stop_points:
tmp.append(f1[:, :, sp[1], sp[0]].unsqueeze(1))
y = torch.cat([y, torch.cat(tmp, dim=1)], dim=1)
if not self.use_average_code_as_input:
w_hat = self.net.encoder(
w1[:, self.target_layers].detach(),
x.detach(),
y.detach(),
alpha=vec_num,
)
w1 = w1.clone()
w1[:, self.target_layers] = w_hat
else:
w_hat = self.net.encoder(
self.net.latent_avg.unsqueeze(0)[:, self.target_layers].detach(),
x.detach(),
y.detach(),
alpha=vec_num,
)
w1 = w1.clone()
w1[:, self.target_layers] = (
w1.clone()[:, self.target_layers]
+ w_hat
- self.net.latent_avg.unsqueeze(0)[:, self.target_layers]
)
x1, _ = self.net.decoder([w1], input_is_latent=True, randomize_noise=False)
x1 = self.net.face_pool(x1)
result = (
((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
)
return (
result,
{
"w1": w1.cpu().detach().numpy(),
"w1_initial": w1_initial.cpu().detach().numpy(),
},
)
def change_style(self, latents):
w1, w1_initial, f1 = self.latents_to_tensor(latents)
w1 = w1_initial.clone()
z1 = torch.randn(1, 512).to("cuda")
x1, w2 = self.net.decoder(
[z1],
input_is_latent=False,
randomize_noise=False,
return_latents=True,
truncation=self.truncation,
truncation_latent=self.net.latent_avg[0],
)
w1[:, 6:] = w2.detach()[:, 0]
x1, w1_new = self.net.decoder(
[w1],
input_is_latent=True,
randomize_noise=False,
return_latents=True,
)
result = (
((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
)
return (
result,
{
"w1": w1_new.cpu().detach().numpy(),
"w1_initial": w1_initial.cpu().detach().numpy(),
},
)
def reset(self, latents):
w1, w1_initial, f1 = self.latents_to_tensor(latents)
x1, w1_new, f1 = self.net.decoder(
[w1_initial],
input_is_latent=True,
randomize_noise=False,
return_feature_map=True,
return_latents=True,
)
result = (
((x1.detach()[0].permute(1, 2, 0) + 1.0) * 127.5).cpu().numpy()[:, :, ::-1]
)
return (
result,
{
"w1": w1_new.cpu().detach().numpy(),
"w1_initial": w1_new.cpu().detach().numpy(),
},
)