import os |
from tqdm import tqdm |
import torch |
from torch.utils.data import DataLoader |
from logger import Logger, Visualizer |
import numpy as np |
import imageio |
from sync_batchnorm import DataParallelWithCallback |
def reconstruction(config, generator, kp_detector, checkpoint, log_dir, dataset): |
png_dir = os.path.join(log_dir, 'reconstruction/png') |
log_dir = os.path.join(log_dir, 'reconstruction') |
if checkpoint is not None: |
Logger.load_cpk(checkpoint, generator=generator, kp_detector=kp_detector) |
else: |
raise AttributeError("Checkpoint should be specified for mode='reconstruction'.") |
dataloader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) |
if not os.path.exists(log_dir): |
os.makedirs(log_dir) |
if not os.path.exists(png_dir): |
os.makedirs(png_dir) |
loss_list = [] |
if torch.cuda.is_available(): |
generator = DataParallelWithCallback(generator) |
kp_detector = DataParallelWithCallback(kp_detector) |
generator.eval() |
kp_detector.eval() |
for it, x in tqdm(enumerate(dataloader)): |
if config['reconstruction_params']['num_videos'] is not None: |
if it > config['reconstruction_params']['num_videos']: |
break |
with torch.no_grad(): |
predictions = [] |
visualizations = [] |
if torch.cuda.is_available(): |
x['video'] = x['video'].cuda() |
kp_source = kp_detector(x['video'][:, :, 0]) |
for frame_idx in range(x['video'].shape[2]): |
source = x['video'][:, :, 0] |
driving = x['video'][:, :, frame_idx] |
kp_driving = kp_detector(driving) |
out = generator(source, kp_source=kp_source, kp_driving=kp_driving) |
out['kp_source'] = kp_source |
out['kp_driving'] = kp_driving |
del out['sparse_deformed'] |
predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]) |
visualization = Visualizer(**config['visualizer_params']).visualize(source=source, |
driving=driving, out=out) |
visualizations.append(visualization) |
loss_list.append(torch.abs(out['prediction'] - driving).mean().cpu().numpy()) |
predictions = np.concatenate(predictions, axis=1) |
imageio.imsave(os.path.join(png_dir, x['name'][0] + '.png'), (255 * predictions).astype(np.uint8)) |
image_name = x['name'][0] + config['reconstruction_params']['format'] |
imageio.mimsave(os.path.join(log_dir, image_name), visualizations) |
print("Reconstruction loss: %s" % np.mean(loss_list)) |