File size: 4,486 Bytes
1fd7780
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import argparse

import torch
import os

from Project.configs import data_configs
from Project.datasets.inference_dataset import InferenceDataset
from torch.utils.data import DataLoader
from Project.utils.model_utils import setup_model

def main(args,device):
    net, opts ,latent_avg= setup_model(args.ckpt, device)
    is_cars = 'cars_' in opts.dataset_type
    args, data_loader = setup_data_loader(args, opts)
    # Check if latents exist
    latents_file_path = os.path.join(args.save_dir, 'latents.pt')
    latent_codes = get_all_latents(net, device,data_loader,latent_avg, args.n_sample, is_cars=is_cars)
    torch.save(latent_codes, latents_file_path)



def setup_data_loader(args, opts):
    dataset_args = data_configs.DATASETS[opts.dataset_type]
    transforms_dict = dataset_args['transforms'](opts).get_transforms()
    images_path = args.images_dir if args.images_dir is not None else dataset_args['test_source_root']
    print(f"images path: {images_path}")
    align_function = None
    test_dataset = InferenceDataset(root=images_path,
                                    transform=transforms_dict['transform_test'],
                                    preprocess=align_function,
                                    opts=opts)

    data_loader = DataLoader(test_dataset,
                             batch_size=args.batch,
                             shuffle=False,
                             num_workers=0,
                             drop_last=True)

    print(f'dataset length: {len(test_dataset)}')

    if args.n_sample is None:
        args.n_sample = len(test_dataset)
    return args, data_loader


def get_latents(net, x,latent_avg, is_cars=False):
    input = {net.get_inputs()[0].name: to_numpy(x)}
    codes = net.run(None,input)
    codes=torch.from_numpy(codes[0])
    if codes.ndim == 2:
        codes = codes + latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
    else:
        codes = codes + latent_avg.repeat(codes.shape[0], 1, 1)
    return codes


def get_all_latents(net, device ,data_loader, latent_avg,n_images=None, is_cars=False):
    all_latents = []
    with torch.no_grad():
        for batch in data_loader:
            x = batch
            inputs = x.float()
            print(inputs.shape)
            latents = get_latents(net, inputs,latent_avg, is_cars)
            all_latents.append(latents)
    return torch.cat(all_latents)

#@torch.no_grad()
#def generate_inversions(args, g, latent_codes, is_cars):
#    print('Saving inversion images')
#    inversions_directory_path = os.path.join(args.save_dir, 'inversions')
#    os.makedirs(inversions_directory_path, exist_ok=True)
#    for i in range(min(args.n_sample, len(latent_codes))):
#        imgs, _ = g([latent_codes[i].unsqueeze(0)], input_is_latent=True, randomize_noise=False, return_latents=True)
#        if is_cars:
#            imgs = imgs[:, :, 64:448, :]
#        save_image(imgs[0], inversions_directory_path, i + 1)


#def run_alignment(image_path):
#    predictor = dlib.shape_predictor(paths_config.model_paths['shape_predictor'])
#    aligned_image = align_face(filepath=image_path, predictor=predictor)
 #   print("Aligned image has shape: {}".format(aligned_image.size))
 #   return aligned_image

def to_numpy(tensor):
    return tensor.cpu().numpy()

def inference():
    device = "cpu"
    parser = argparse.ArgumentParser(description="Inference")
    parser.add_argument("--images_dir", type=str, default='static/img_aligned',
                        help="The directory of the images to be inverted")
    parser.add_argument("--save_dir", type=str, default='static/latents',
                        help="The directory to save the latent codes and inversion images. (default: images_dir")
    parser.add_argument("--batch", type=int, default=1, help="batch size for the generator")
    parser.add_argument("--n_sample", type=int, default=None, help="number of the samples to infer.")
    parser.add_argument("--latents_only", action="store_true",default=True, help="infer only the latent codes of the directory")
    parser.add_argument("--align", action="store_true",default=False,help="align face images before inference")
    parser.add_argument("--ckpt", default='Project/pretrained_models/e4e_ffhq_encode.pt',help="path to generator checkpoint")

    args = parser.parse_args()
    main(args,device)

if __name__=="__main__":
    inference()