Spaces:
Build error
Build error
import os | |
import sys | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torchvision.transforms as transforms | |
from argparse import Namespace | |
from e4e.models.psp import pSp | |
from util import * | |
def projection(img, name, device='cuda'): | |
model_path = 'e4e_ffhq_encode.pt' | |
ckpt = torch.load(model_path, map_location='cpu') | |
opts = ckpt['opts'] | |
opts['checkpoint_path'] = model_path | |
opts= Namespace(**opts) | |
net = pSp(opts, device).eval().to(device) | |
transform = transforms.Compose( | |
[ | |
transforms.Resize(256), | |
transforms.CenterCrop(256), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
] | |
) | |
img = transform(img).unsqueeze(0).to(device) | |
images, w_plus = net(img, randomize_noise=False, return_latents=True) | |
result_file = {} | |
result_file['latent'] = w_plus[0] | |
torch.save(result_file, name) | |
return w_plus[0] | |