|
import torch |
|
from argparse import Namespace |
|
from torchvision.transforms import transforms |
|
|
|
from configs import paths_config |
|
from models.e4e.psp import pSp |
|
from scripts.latent_creators.base_latent_creator import BaseLatentCreator |
|
from utils.log_utils import log_image_from_w |
|
|
|
|
|
class E4ELatentCreator(BaseLatentCreator): |
|
|
|
def __init__(self, use_wandb=False): |
|
self.e4e_inversion_pre_process = transforms.Compose([ |
|
transforms.Resize((256, 256)), |
|
transforms.ToTensor(), |
|
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) |
|
|
|
super().__init__('e4e', self.e4e_inversion_pre_process, use_wandb=use_wandb) |
|
|
|
e4e_model_path = paths_config.e4e |
|
ckpt = torch.load(e4e_model_path, map_location='cpu') |
|
opts = ckpt['opts'] |
|
opts['batch_size'] = 1 |
|
opts['checkpoint_path'] = e4e_model_path |
|
opts = Namespace(**opts) |
|
self.e4e_inversion_net = pSp(opts) |
|
self.e4e_inversion_net.eval() |
|
self.e4e_inversion_net = self.e4e_inversion_net.cuda() |
|
|
|
def run_projection(self, fname, image): |
|
_, e4e_image_latent = self.e4e_inversion_net(image, randomize_noise=False, return_latents=True, |
|
resize=False, |
|
input_code=False) |
|
|
|
if self.use_wandb: |
|
log_image_from_w(e4e_image_latent, self.old_G, 'First e4e inversion') |
|
|
|
return e4e_image_latent |
|
|
|
|
|
if __name__ == '__main__': |
|
e4e_latent_creator = E4ELatentCreator() |
|
e4e_latent_creator.create_latents() |
|
|