adirik's picture
update repo
f4c3c2b
raw
history blame
4.94 kB
import os
import argparse
from argparse import Namespace
import time
import os
import sys
import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms
sys.path.append(".")
sys.path.append("..")
from utils.common import tensor2im
from models.psp import pSp # we use the pSp framework to load the e4e encoder.
experiment_type = 'ffhq_encode'
parser = argparse.ArgumentParser()
parser.add_argument('--input_image', type=str, default="", help='input image path')
args = parser.parse_args()
opts = vars(args)
print(opts)
image_path = opts["input_image"]
def get_download_model_command(file_id, file_name):
""" Get wget download command for downloading the desired model and save to directory pretrained_models. """
current_directory = os.getcwd()
save_path = "encoder4editing/saves"
if not os.path.exists(save_path):
os.makedirs(save_path)
url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path)
return url
MODEL_PATHS = {
"ffhq_encode": {"id": "1cUv_reLE6k3604or78EranS7XzuVMWeO", "name": "e4e_ffhq_encode.pt"},
"cars_encode": {"id": "17faPqBce2m1AQeLCLHUVXaDfxMRU2QcV", "name": "e4e_cars_encode.pt"},
"horse_encode": {"id": "1TkLLnuX86B_BMo2ocYD0kX9kWh53rUVX", "name": "e4e_horse_encode.pt"},
"church_encode": {"id": "1-L0ZdnQLwtdy6-A_Ccgq5uNJGTqE7qBa", "name": "e4e_church_encode.pt"}
}
path = MODEL_PATHS[experiment_type]
download_command = get_download_model_command(file_id=path["id"], file_name=path["name"])
EXPERIMENT_DATA_ARGS = {
"ffhq_encode": {
"model_path": "encoder4editing/e4e_ffhq_encode.pt",
"image_path": "notebooks/images/input_img.jpg"
},
"cars_encode": {
"model_path": "pretrained_models/e4e_cars_encode.pt",
"image_path": "notebooks/images/car_img.jpg"
},
"horse_encode": {
"model_path": "pretrained_models/e4e_horse_encode.pt",
"image_path": "notebooks/images/horse_img.jpg"
},
"church_encode": {
"model_path": "pretrained_models/e4e_church_encode.pt",
"image_path": "notebooks/images/church_img.jpg"
}
}
# Setup required image transformations
EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type]
if experiment_type == 'cars_encode':
EXPERIMENT_ARGS['transform'] = transforms.Compose([
transforms.Resize((192, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
resize_dims = (256, 192)
else:
EXPERIMENT_ARGS['transform'] = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
resize_dims = (256, 256)
model_path = EXPERIMENT_ARGS['model_path']
ckpt = torch.load(model_path, map_location='cpu')
opts = ckpt['opts']
# update the training options
opts['checkpoint_path'] = model_path
opts= Namespace(**opts)
net = pSp(opts)
net.eval()
net.cuda()
print('Model successfully loaded!')
original_image = Image.open(image_path)
original_image = original_image.convert("RGB")
def run_alignment(image_path):
import dlib
from utils.alignment import align_face
predictor = dlib.shape_predictor("encoder4editing/shape_predictor_68_face_landmarks.dat")
aligned_image = align_face(filepath=image_path, predictor=predictor)
print("Aligned image has shape: {}".format(aligned_image.size))
return aligned_image
if experiment_type == "ffhq_encode":
input_image = run_alignment(image_path)
else:
input_image = original_image
input_image.resize(resize_dims)
img_transforms = EXPERIMENT_ARGS['transform']
transformed_image = img_transforms(input_image)
def display_alongside_source_image(result_image, source_image):
res = np.concatenate([np.array(source_image.resize(resize_dims)),
np.array(result_image.resize(resize_dims))], axis=1)
return Image.fromarray(res)
def run_on_batch(inputs, net):
images, latents = net(inputs.to("cuda").float(), randomize_noise=False, return_latents=True)
if experiment_type == 'cars_encode':
images = images[:, :, 32:224, :]
return images, latents
with torch.no_grad():
tic = time.time()
images, latents = run_on_batch(transformed_image.unsqueeze(0), net)
result_image, latent = images[0], latents[0]
toc = time.time()
print('Inference took {:.4f} seconds.'.format(toc - tic))
# Display inversion:
display_alongside_source_image(tensor2im(result_image), input_image)
np.savez(f'encoder4editing/projected_w.npz', w=latents.cpu().numpy())