JoJoGAN / e4e_projection.py
sanjayw's picture
Duplicate from akhaliq/JoJoGAN
779b255
import os
import sys
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
from argparse import Namespace
from e4e.models.psp import pSp
from util import *
@ torch.no_grad()
def projection(img, name, device='cuda'):
model_path = 'e4e_ffhq_encode.pt'
ckpt = torch.load(model_path, map_location='cpu')
opts = ckpt['opts']
opts['checkpoint_path'] = model_path
opts= Namespace(**opts)
net = pSp(opts, device).eval().to(device)
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
img = transform(img).unsqueeze(0).to(device)
images, w_plus = net(img, randomize_noise=False, return_latents=True)
result_file = {}
result_file['latent'] = w_plus[0]
torch.save(result_file, name)
return w_plus[0]