File size: 1,589 Bytes
2d7efb8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import torch
from argparse import Namespace
from torchvision.transforms import transforms

from configs import paths_config
from models.e4e.psp import pSp
from scripts.latent_creators.base_latent_creator import BaseLatentCreator
from utils.log_utils import log_image_from_w


class E4ELatentCreator(BaseLatentCreator):

    def __init__(self, use_wandb=False):
        self.e4e_inversion_pre_process = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

        super().__init__('e4e', self.e4e_inversion_pre_process, use_wandb=use_wandb)

        e4e_model_path = paths_config.e4e
        ckpt = torch.load(e4e_model_path, map_location='cpu')
        opts = ckpt['opts']
        opts['batch_size'] = 1
        opts['checkpoint_path'] = e4e_model_path
        opts = Namespace(**opts)
        self.e4e_inversion_net = pSp(opts)
        self.e4e_inversion_net.eval()
        self.e4e_inversion_net = self.e4e_inversion_net.cuda()

    def run_projection(self, fname, image):
        _, e4e_image_latent = self.e4e_inversion_net(image, randomize_noise=False, return_latents=True,
                                                     resize=False,
                                                     input_code=False)

        if self.use_wandb:
            log_image_from_w(e4e_image_latent, self.old_G, 'First e4e inversion')

        return e4e_image_latent


if __name__ == '__main__':
    e4e_latent_creator = E4ELatentCreator()
    e4e_latent_creator.create_latents()