import torch from argparse import Namespace from torchvision.transforms import transforms from configs import paths_config from models.e4e.psp import pSp from scripts.latent_creators.base_latent_creator import BaseLatentCreator from utils.log_utils import log_image_from_w class E4ELatentCreator(BaseLatentCreator): def __init__(self, use_wandb=False): self.e4e_inversion_pre_process = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) super().__init__('e4e', self.e4e_inversion_pre_process, use_wandb=use_wandb) e4e_model_path = paths_config.e4e ckpt = torch.load(e4e_model_path, map_location='cpu') opts = ckpt['opts'] opts['batch_size'] = 1 opts['checkpoint_path'] = e4e_model_path opts = Namespace(**opts) self.e4e_inversion_net = pSp(opts) self.e4e_inversion_net.eval() self.e4e_inversion_net = self.e4e_inversion_net.cuda() def run_projection(self, fname, image): _, e4e_image_latent = self.e4e_inversion_net(image, randomize_noise=False, return_latents=True, resize=False, input_code=False) if self.use_wandb: log_image_from_w(e4e_image_latent, self.old_G, 'First e4e inversion') return e4e_image_latent if __name__ == '__main__': e4e_latent_creator = E4ELatentCreator() e4e_latent_creator.create_latents()