JoJoGAN / e4e_projection.py
Ahsen Khaliq
Update e4e_projection.py
612374d
raw history blame
No virus
970 Bytes
import os
import sys
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
from argparse import Namespace
from e4e.models.psp import pSp
from util import *
device='cpu'
model_path = 'models/e4e_ffhq_encode.pt'
ckpt = torch.load(model_path, map_location='cpu')
opts = ckpt['opts']
opts['checkpoint_path'] = model_path
opts= Namespace(**opts)
net = pSp(opts, device).eval().to(device)
@ torch.no_grad()
def projection(img, name, device='cuda'):
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(256),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
]
)
img = transform(img).unsqueeze(0).to(device)
images, w_plus = net(img, randomize_noise=False, return_latents=True)
result_file = {}
result_file['latent'] = w_plus[0]
torch.save(result_file, name)
return w_plus[0]