import argparse import torch import numpy as np import sys import os import dlib sys.path.append(".") sys.path.append("..") from configs import data_configs, paths_config from datasets.inference_dataset import InferenceDataset from torch.utils.data import DataLoader from utils.model_utils import setup_model from utils.common import tensor2im from utils.alignment import align_face from PIL import Image def main(args): net, opts = setup_model(args.ckpt, device) is_cars = 'cars_' in opts.dataset_type generator = net.decoder generator.eval() args, data_loader = setup_data_loader(args, opts) # Check if latents exist latents_file_path = os.path.join(args.save_dir, 'latents.pt') if os.path.exists(latents_file_path): latent_codes = torch.load(latents_file_path).to(device) else: latent_codes = get_all_latents(net, data_loader, args.n_sample, is_cars=is_cars) torch.save(latent_codes, latents_file_path) if not args.latents_only: generate_inversions(args, generator, latent_codes, is_cars=is_cars) def setup_data_loader(args, opts): dataset_args = data_configs.DATASETS[opts.dataset_type] transforms_dict = dataset_args['transforms'](opts).get_transforms() images_path = args.images_dir if args.images_dir is not None else dataset_args['test_source_root'] print(f"images path: {images_path}") align_function = None if args.align: align_function = run_alignment test_dataset = InferenceDataset(root=images_path, transform=transforms_dict['transform_test'], preprocess=align_function, opts=opts) data_loader = DataLoader(test_dataset, batch_size=args.batch, shuffle=False, num_workers=2, drop_last=True) print(f'dataset length: {len(test_dataset)}') if args.n_sample is None: args.n_sample = len(test_dataset) return args, data_loader def get_latents(net, x, is_cars=False): codes = net.encoder(x) if net.opts.start_from_latent_avg: if codes.ndim == 2: codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :] else: codes = codes + net.latent_avg.repeat(codes.shape[0], 1, 1) if codes.shape[1] == 18 and is_cars: codes = codes[:, :16, :] return codes def get_all_latents(net, data_loader, n_images=None, is_cars=False): all_latents = [] i = 0 with torch.no_grad(): for batch in data_loader: if n_images is not None and i > n_images: break x = batch inputs = x.to(device).float() latents = get_latents(net, inputs, is_cars) all_latents.append(latents) i += len(latents) return torch.cat(all_latents) def save_image(img, save_dir, idx): result = tensor2im(img) im_save_path = os.path.join(save_dir, f"{idx:05d}.jpg") Image.fromarray(np.array(result)).save(im_save_path) @torch.no_grad() def generate_inversions(args, g, latent_codes, is_cars): print('Saving inversion images') inversions_directory_path = os.path.join(args.save_dir, 'inversions') os.makedirs(inversions_directory_path, exist_ok=True) for i in range(args.n_sample): imgs, _ = g([latent_codes[i].unsqueeze(0)], input_is_latent=True, randomize_noise=False, return_latents=True) if is_cars: imgs = imgs[:, :, 64:448, :] save_image(imgs[0], inversions_directory_path, i + 1) def run_alignment(image_path): predictor = dlib.shape_predictor(paths_config.model_paths['shape_predictor']) aligned_image = align_face(filepath=image_path, predictor=predictor) print("Aligned image has shape: {}".format(aligned_image.size)) return aligned_image if __name__ == "__main__": device = "cuda" parser = argparse.ArgumentParser(description="Inference") parser.add_argument("--images_dir", type=str, default=None, help="The directory of the images to be inverted") parser.add_argument("--save_dir", type=str, default=None, help="The directory to save the latent codes and inversion images. (default: images_dir") parser.add_argument("--batch", type=int, default=1, help="batch size for the generator") parser.add_argument("--n_sample", type=int, default=None, help="number of the samples to infer.") parser.add_argument("--latents_only", action="store_true", help="infer only the latent codes of the directory") parser.add_argument("--align", action="store_true", help="align face images before inference") parser.add_argument("ckpt", metavar="CHECKPOINT", help="path to generator checkpoint") args = parser.parse_args() main(args)