stylemc-demo / psp_wrapper.py
adirik's picture
update app
88c803f
raw
history blame
2.02 kB
from argparse import Namespace
import sys
sys.path.append(".")
sys.path.append("..")
sys.path.append("./encoder4editing")
from PIL import Image
import torch
import torchvision.transforms as transforms
import dlib
from utils.alignment import align_face
from utils.common import tensor2im
from models.psp import pSp # we use the pSp framework to load the e4e encoder.
experiment_type = 'ffhq_encode'
EXPERIMENT_DATA_ARGS = {
"ffhq_encode": {
"model_path": "encoder4editing/e4e_ffhq_encode.pt",
"image_path": "notebooks/images/input_img.jpg"
},
}
# Setup required image transformations
EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type]
EXPERIMENT_ARGS['transform'] = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
class psp_encoder:
def __init__(self, model_path: str, shape_predictor_path: str):
self.ckpt = torch.load(model_path, map_location="cpu")
self.opts = self.ckpt["opts"]
# update the training options
self.opts["checkpoint_path"] = model_path
self.opts= Namespace(**self.opts)
self.net = pSp(self.opts)
self.net.eval()
self.net.cuda()
self.shape_predictor = dlib.shape_predictor(shape_predictor_path)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def get_w(self, image_path):
original_image = Image.open(image_path)
original_image = original_image.convert("RGB")
input_image = align_face(filepath=image_path, predictor=self.shape_predictor)
resize_dims = (256, 256)
input_image.resize(resize_dims)
img_transforms = EXPERIMENT_ARGS["transform"]
transformed_image = img_transforms(input_image)
with torch.no_grad():
_, latents = self.net(transformed_image.unsqueeze(0).to(self.device).float(), randomize_noise=False, return_latents=True)
return latents.cpu().numpy()