QiyuWu's picture
Upload 100 files
1fd7780 verified
import argparse
import torch
import os
from Project.configs import data_configs
from Project.datasets.inference_dataset import InferenceDataset
from torch.utils.data import DataLoader
from Project.utils.model_utils import setup_model
def main(args,device):
net, opts ,latent_avg= setup_model(args.ckpt, device)
is_cars = 'cars_' in opts.dataset_type
args, data_loader = setup_data_loader(args, opts)
# Check if latents exist
latents_file_path = os.path.join(args.save_dir, 'latents.pt')
latent_codes = get_all_latents(net, device,data_loader,latent_avg, args.n_sample, is_cars=is_cars)
torch.save(latent_codes, latents_file_path)
def setup_data_loader(args, opts):
dataset_args = data_configs.DATASETS[opts.dataset_type]
transforms_dict = dataset_args['transforms'](opts).get_transforms()
images_path = args.images_dir if args.images_dir is not None else dataset_args['test_source_root']
print(f"images path: {images_path}")
align_function = None
test_dataset = InferenceDataset(root=images_path,
transform=transforms_dict['transform_test'],
preprocess=align_function,
opts=opts)
data_loader = DataLoader(test_dataset,
batch_size=args.batch,
shuffle=False,
num_workers=0,
drop_last=True)
print(f'dataset length: {len(test_dataset)}')
if args.n_sample is None:
args.n_sample = len(test_dataset)
return args, data_loader
def get_latents(net, x,latent_avg, is_cars=False):
input = {net.get_inputs()[0].name: to_numpy(x)}
codes = net.run(None,input)
codes=torch.from_numpy(codes[0])
if codes.ndim == 2:
codes = codes + latent_avg.repeat(codes.shape[0], 1, 1)[:, 0, :]
else:
codes = codes + latent_avg.repeat(codes.shape[0], 1, 1)
return codes
def get_all_latents(net, device ,data_loader, latent_avg,n_images=None, is_cars=False):
all_latents = []
with torch.no_grad():
for batch in data_loader:
x = batch
inputs = x.float()
print(inputs.shape)
latents = get_latents(net, inputs,latent_avg, is_cars)
all_latents.append(latents)
return torch.cat(all_latents)
#@torch.no_grad()
#def generate_inversions(args, g, latent_codes, is_cars):
# print('Saving inversion images')
# inversions_directory_path = os.path.join(args.save_dir, 'inversions')
# os.makedirs(inversions_directory_path, exist_ok=True)
# for i in range(min(args.n_sample, len(latent_codes))):
# imgs, _ = g([latent_codes[i].unsqueeze(0)], input_is_latent=True, randomize_noise=False, return_latents=True)
# if is_cars:
# imgs = imgs[:, :, 64:448, :]
# save_image(imgs[0], inversions_directory_path, i + 1)
#def run_alignment(image_path):
# predictor = dlib.shape_predictor(paths_config.model_paths['shape_predictor'])
# aligned_image = align_face(filepath=image_path, predictor=predictor)
# print("Aligned image has shape: {}".format(aligned_image.size))
# return aligned_image
def to_numpy(tensor):
return tensor.cpu().numpy()
def inference():
device = "cpu"
parser = argparse.ArgumentParser(description="Inference")
parser.add_argument("--images_dir", type=str, default='static/img_aligned',
help="The directory of the images to be inverted")
parser.add_argument("--save_dir", type=str, default='static/latents',
help="The directory to save the latent codes and inversion images. (default: images_dir")
parser.add_argument("--batch", type=int, default=1, help="batch size for the generator")
parser.add_argument("--n_sample", type=int, default=None, help="number of the samples to infer.")
parser.add_argument("--latents_only", action="store_true",default=True, help="infer only the latent codes of the directory")
parser.add_argument("--align", action="store_true",default=False,help="align face images before inference")
parser.add_argument("--ckpt", default='Project/pretrained_models/e4e_ffhq_encode.pt',help="path to generator checkpoint")
args = parser.parse_args()
main(args,device)
if __name__=="__main__":
inference()