Spaces:
Sleeping
Sleeping
import os | |
import argparse | |
from argparse import Namespace | |
import time | |
import os | |
import sys | |
import numpy as np | |
from PIL import Image | |
import torch | |
import torchvision.transforms as transforms | |
sys.path.append(".") | |
sys.path.append("..") | |
from utils.common import tensor2im | |
from models.psp import pSp # we use the pSp framework to load the e4e encoder. | |
experiment_type = 'ffhq_encode' | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--input_image', type=str, default="", help='input image path') | |
args = parser.parse_args() | |
opts = vars(args) | |
print(opts) | |
image_path = opts["input_image"] | |
def get_download_model_command(file_id, file_name): | |
""" Get wget download command for downloading the desired model and save to directory pretrained_models. """ | |
current_directory = os.getcwd() | |
save_path = "encoder4editing/saves" | |
if not os.path.exists(save_path): | |
os.makedirs(save_path) | |
url = r"""wget --load-cookies /tmp/cookies.txt "https://docs.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://docs.google.com/uc?export=download&id={FILE_ID}' -O- | sed -rn 's/.*confirm=([0-9A-Za-z_]+).*/\1\n/p')&id={FILE_ID}" -O {SAVE_PATH}/{FILE_NAME} && rm -rf /tmp/cookies.txt""".format(FILE_ID=file_id, FILE_NAME=file_name, SAVE_PATH=save_path) | |
return url | |
MODEL_PATHS = { | |
"ffhq_encode": {"id": "1cUv_reLE6k3604or78EranS7XzuVMWeO", "name": "e4e_ffhq_encode.pt"}, | |
"cars_encode": {"id": "17faPqBce2m1AQeLCLHUVXaDfxMRU2QcV", "name": "e4e_cars_encode.pt"}, | |
"horse_encode": {"id": "1TkLLnuX86B_BMo2ocYD0kX9kWh53rUVX", "name": "e4e_horse_encode.pt"}, | |
"church_encode": {"id": "1-L0ZdnQLwtdy6-A_Ccgq5uNJGTqE7qBa", "name": "e4e_church_encode.pt"} | |
} | |
path = MODEL_PATHS[experiment_type] | |
download_command = get_download_model_command(file_id=path["id"], file_name=path["name"]) | |
EXPERIMENT_DATA_ARGS = { | |
"ffhq_encode": { | |
"model_path": "encoder4editing/e4e_ffhq_encode.pt", | |
"image_path": "notebooks/images/input_img.jpg" | |
}, | |
"cars_encode": { | |
"model_path": "pretrained_models/e4e_cars_encode.pt", | |
"image_path": "notebooks/images/car_img.jpg" | |
}, | |
"horse_encode": { | |
"model_path": "pretrained_models/e4e_horse_encode.pt", | |
"image_path": "notebooks/images/horse_img.jpg" | |
}, | |
"church_encode": { | |
"model_path": "pretrained_models/e4e_church_encode.pt", | |
"image_path": "notebooks/images/church_img.jpg" | |
} | |
} | |
# Setup required image transformations | |
EXPERIMENT_ARGS = EXPERIMENT_DATA_ARGS[experiment_type] | |
if experiment_type == 'cars_encode': | |
EXPERIMENT_ARGS['transform'] = transforms.Compose([ | |
transforms.Resize((192, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | |
resize_dims = (256, 192) | |
else: | |
EXPERIMENT_ARGS['transform'] = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]) | |
resize_dims = (256, 256) | |
model_path = EXPERIMENT_ARGS['model_path'] | |
ckpt = torch.load(model_path, map_location='cpu') | |
opts = ckpt['opts'] | |
# update the training options | |
opts['checkpoint_path'] = model_path | |
opts= Namespace(**opts) | |
net = pSp(opts) | |
net.eval() | |
net.cuda() | |
print('Model successfully loaded!') | |
original_image = Image.open(image_path) | |
original_image = original_image.convert("RGB") | |
def run_alignment(image_path): | |
import dlib | |
from utils.alignment import align_face | |
predictor = dlib.shape_predictor("encoder4editing/shape_predictor_68_face_landmarks.dat") | |
aligned_image = align_face(filepath=image_path, predictor=predictor) | |
print("Aligned image has shape: {}".format(aligned_image.size)) | |
return aligned_image | |
if experiment_type == "ffhq_encode": | |
input_image = run_alignment(image_path) | |
else: | |
input_image = original_image | |
input_image.resize(resize_dims) | |
img_transforms = EXPERIMENT_ARGS['transform'] | |
transformed_image = img_transforms(input_image) | |
def display_alongside_source_image(result_image, source_image): | |
res = np.concatenate([np.array(source_image.resize(resize_dims)), | |
np.array(result_image.resize(resize_dims))], axis=1) | |
return Image.fromarray(res) | |
def run_on_batch(inputs, net): | |
images, latents = net(inputs.to("cuda").float(), randomize_noise=False, return_latents=True) | |
if experiment_type == 'cars_encode': | |
images = images[:, :, 32:224, :] | |
return images, latents | |
with torch.no_grad(): | |
tic = time.time() | |
images, latents = run_on_batch(transformed_image.unsqueeze(0), net) | |
result_image, latent = images[0], latents[0] | |
toc = time.time() | |
print('Inference took {:.4f} seconds.'.format(toc - tic)) | |
# Display inversion: | |
display_alongside_source_image(tensor2im(result_image), input_image) | |
np.savez(f'encoder4editing/projected_w.npz', w=latents.cpu().numpy()) | |