PTI / scripts /latent_creators /e4e_latent_creator.py
ucalyptus's picture
simp
2d7efb8
import torch
from argparse import Namespace
from torchvision.transforms import transforms
from configs import paths_config
from models.e4e.psp import pSp
from scripts.latent_creators.base_latent_creator import BaseLatentCreator
from utils.log_utils import log_image_from_w
class E4ELatentCreator(BaseLatentCreator):
def __init__(self, use_wandb=False):
self.e4e_inversion_pre_process = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
super().__init__('e4e', self.e4e_inversion_pre_process, use_wandb=use_wandb)
e4e_model_path = paths_config.e4e
ckpt = torch.load(e4e_model_path, map_location='cpu')
opts = ckpt['opts']
opts['batch_size'] = 1
opts['checkpoint_path'] = e4e_model_path
opts = Namespace(**opts)
self.e4e_inversion_net = pSp(opts)
self.e4e_inversion_net.eval()
self.e4e_inversion_net = self.e4e_inversion_net.cuda()
def run_projection(self, fname, image):
_, e4e_image_latent = self.e4e_inversion_net(image, randomize_noise=False, return_latents=True,
resize=False,
input_code=False)
if self.use_wandb:
log_image_from_w(e4e_image_latent, self.old_G, 'First e4e inversion')
return e4e_image_latent
if __name__ == '__main__':
e4e_latent_creator = E4ELatentCreator()
e4e_latent_creator.create_latents()