Spaces:
Sleeping
Sleeping
File size: 3,779 Bytes
1fd7780 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 |
import os
from argparse import Namespace
import torchvision
import numpy as np
import torch
from torch.utils.data import DataLoader
import sys
import time
from tqdm import tqdm
from Project.mapper.training.train_utils import convert_s_tensor_to_list
sys.path.append(".")
sys.path.append("..")
from Project.mapper.datasets.latents_dataset import LatentsDataset, StyleSpaceLatentsDataset
from Project.mapper.options.test_options import TestOptions
from Project.mapper.styleclip_mapper import StyleCLIPMapper
def run(test_opts):
out_path_results = os.path.join(test_opts.exp_dir, 'inference_results')
os.makedirs(out_path_results, exist_ok=True)
# update test options with options used during training
ckpt = torch.load(test_opts.checkpoint_path, map_location='cpu')
opts = ckpt['opts']
opts.update(vars(test_opts))
opts = Namespace(**opts)
net = StyleCLIPMapper(opts)
net.eval()
test_latents = torch.load(opts.latents_test_path)
if opts.work_in_stylespace:
dataset = StyleSpaceLatentsDataset(latents=[l.cpu() for l in test_latents], opts=opts)
else:
dataset = LatentsDataset(latents=test_latents, opts=opts)
dataloader = DataLoader(dataset,
batch_size=opts.test_batch_size,
shuffle=False,
num_workers=int(opts.test_workers),
drop_last=True)
if opts.n_images is None:
opts.n_images = len(dataset)
global_i = 0
global_time = []
for input_batch in tqdm(dataloader):
if global_i >= opts.n_images:
break
with torch.no_grad():
if opts.work_in_stylespace:
input_cuda = convert_s_tensor_to_list(input_batch)
input_cuda = [c for c in input_cuda]
else:
input_cuda = input_batch
input_cuda = input_cuda
tic = time.time()
result_batch = run_on_batch(input_cuda, net, opts.couple_outputs, opts.work_in_stylespace)
toc = time.time()
global_time.append(toc - tic)
for i in range(opts.test_batch_size):
im_path = str(global_i).zfill(5)
if test_opts.couple_outputs:
couple_output = torch.cat([result_batch[2][i].unsqueeze(0), result_batch[0][i].unsqueeze(0)])
torchvision.utils.save_image(couple_output, os.path.join(out_path_results, f"{im_path}.jpg"), normalize=True, value_range=(-1, 1))
else:
torchvision.utils.save_image(result_batch[0][i], os.path.join(out_path_results, f"{im_path}.jpg"), normalize=True, value_range=(-1, 1))
torch.save(result_batch[1][i].detach().cpu(), os.path.join(out_path_results, f"latent_{im_path}.pt"))
global_i += 1
stats_path = os.path.join(opts.exp_dir, 'stats.txt')
result_str = 'Runtime {:.4f}+-{:.4f}'.format(np.mean(global_time), np.std(global_time))
print(result_str)
with open(stats_path, 'w') as f:
f.write(result_str)
def run_on_batch(inputs, net, couple_outputs=False, stylespace=False):
w = inputs
with torch.no_grad():
if stylespace:
delta = net.mapper(w)
w_hat = [c + 0.1 * delta_c for (c, delta_c) in zip(w, delta)]
x_hat, _, w_hat = net.decoder([w_hat], input_is_latent=True, return_latents=True,
randomize_noise=False, truncation=1, input_is_stylespace=True)
else:
w_hat = w + 0.1 * net.mapper(w)
x_hat, w_hat, _ = net.decoder([w_hat], input_is_latent=True, return_latents=True,
randomize_noise=False, truncation=1)
result_batch = (x_hat, w_hat)
if couple_outputs:
x, _ = net.decoder([w], input_is_latent=True, randomize_noise=False, truncation=1, input_is_stylespace=stylespace)
result_batch = (x_hat, w_hat, x)
return result_batch
if __name__ == '__main__':
test_opts = TestOptions().parse()
run(test_opts)
|