StyleGAN-NADA / e4e /scripts /inference.py
rinong's picture
Added e4e code
0d2ed80
raw history blame
No virus
4.92 kB
import argparse
import torch
import numpy as np
import sys
import os
import dlib
sys.path.append(".")
sys.path.append("..")
from configs import data_configs, paths_config
from datasets.inference_dataset import InferenceDataset
from torch.utils.data import DataLoader
from utils.model_utils import setup_model
from utils.common import tensor2im
from utils.alignment import align_face
from PIL import Image
def main(args):
net, opts = setup_model(args.ckpt, device)
is_cars = 'cars_' in opts.dataset_type
generator = net.decoder
generator.eval()
args, data_loader = setup_data_loader(args, opts)
# Check if latents exist
latents_file_path = os.path.join(args.save_dir, 'latents.pt')
if os.path.exists(latents_file_path):
latent_codes = torch.load(latents_file_path).to(device)
else:
latent_codes = get_all_latents(net, data_loader, args.n_sample, is_cars=is_cars)
torch.save(latent_codes, latents_file_path)
if not args.latents_only:
generate_inversions(args, generator, latent_codes, is_cars=is_cars)
def setup_data_loader(args, opts):
dataset_args = data_configs.DATASETS[opts.dataset_type]
transforms_dict = dataset_args['transforms'](opts).get_transforms()
images_path = args.images_dir if args.images_dir is not None else dataset_args['test_source_root']
print(f"images path: {images_path}")
align_function = None
if args.align:
align_function = run_alignment
test_dataset = InferenceDataset(root=images_path,
transform=transforms_dict['transform_test'],
preprocess=align_function,
opts=opts)
data_loader = DataLoader(test_dataset,
batch_size=args.batch,
shuffle=False,
num_workers=2,
drop_last=True)
print(f'dataset length: {len(test_dataset)}')
if args.n_sample is None:
args.n_sample = len(test_dataset)
return args, data_loader
def get_latents(net, x, is_cars=False):
codes = net.encoder(x)
if net.opts.start_from_latent_avg:
if codes.ndim == 2:
codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
else:
codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)
if codes.shape[1] == 18 and is_cars:
codes = codes[:, :16, :]
return codes
def get_all_latents(net, data_loader, n_images=None, is_cars=False):
all_latents = []
i = 0
with torch.no_grad():
for batch in data_loader:
if n_images is not None and i > n_images:
break
x = batch
inputs = x.to(device).float()
latents = get_latents(net, inputs, is_cars)
all_latents.append(latents)
i += len(latents)
return torch.cat(all_latents)
def save_image(img, save_dir, idx):
result = tensor2im(img)
im_save_path = os.path.join(save_dir, f"{idx:05d}.jpg")
Image.fromarray(np.array(result)).save(im_save_path)
@torch.no_grad()
def generate_inversions(args, g, latent_codes, is_cars):
print('Saving inversion images')
inversions_directory_path = os.path.join(args.save_dir, 'inversions')
os.makedirs(inversions_directory_path, exist_ok=True)
for i in range(args.n_sample):
imgs, _ = g([latent_codes[i].unsqueeze(0)], input_is_latent=True, randomize_noise=False, return_latents=True)
if is_cars:
imgs = imgs[:, :, 64:448, :]
save_image(imgs[0], inversions_directory_path, i + 1)
def run_alignment(image_path):
predictor = dlib.shape_predictor(paths_config.model_paths['shape_predictor'])
aligned_image = align_face(filepath=image_path, predictor=predictor)
print("Aligned image has shape: {}".format(aligned_image.size))
return aligned_image
if __name__ == "__main__":
device = "cuda"
parser = argparse.ArgumentParser(description="Inference")
parser.add_argument("--images_dir", type=str, default=None,
help="The directory of the images to be inverted")
parser.add_argument("--save_dir", type=str, default=None,
help="The directory to save the latent codes and inversion images. (default: images_dir")
parser.add_argument("--batch", type=int, default=1, help="batch size for the generator")
parser.add_argument("--n_sample", type=int, default=None, help="number of the samples to infer.")
parser.add_argument("--latents_only", action="store_true", help="infer only the latent codes of the directory")
parser.add_argument("--align", action="store_true", help="align face images before inference")
parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to generator checkpoint")
args = parser.parse_args()
main(args)