File size: 2,015 Bytes
e90f2c5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c472db
 
e90f2c5
1c472db
e90f2c5
 
 
 
 
1c472db
e90f2c5
 
 
 
 
 
 
1c472db
e90f2c5
1c472db
 
e90f2c5
88c803f
e90f2c5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from argparse import Namespace
import sys

sys.path.append(".")
sys.path.append("..")
sys.path.append("./encoder4editing")

from PIL import Image
import torch
import torchvision.transforms as transforms
import dlib
from utils.alignment import align_face


from utils.common import tensor2im
from models.psp import pSp  # we use the pSp framework to load the e4e encoder.
experiment_type = 'ffhq_encode'

EXPERIMENT_DATA_ARGS = {
    "ffhq_encode": {
        "model_path": "encoder4editing/e4e_ffhq_encode.pt",
        "image_path": "notebooks/images/input_img.jpg"
    },
}
# Setup required image transformations
EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type]
EXPERIMENT_ARGS['transform'] = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

class psp_encoder:
    def __init__(self, model_path: str, shape_predictor_path: str):
        self.ckpt = torch.load(model_path, map_location="cpu")
        self.opts = self.ckpt["opts"]
        # update the training options
        self.opts["checkpoint_path"] = model_path
        self.opts= Namespace(**self.opts)
        self.net = pSp(self.opts)
        self.net.eval()
        self.net.cuda()
        self.shape_predictor = dlib.shape_predictor(shape_predictor_path)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    def get_w(self, image_path):
        original_image = Image.open(image_path)
        original_image = original_image.convert("RGB")
        input_image = align_face(filepath=image_path, predictor=self.shape_predictor)
        resize_dims = (256, 256)
        input_image.resize(resize_dims)
        img_transforms = EXPERIMENT_ARGS["transform"]
        transformed_image = img_transforms(input_image)


        with torch.no_grad():
            _, latents = self.net(transformed_image.unsqueeze(0).to(self.device).float(), randomize_noise=False, return_latents=True)
            return latents.cpu().numpy()