venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import os
import imageio
import numpy as np
import torch
from tqdm import tqdm
from imaginaire.model_utils.fs_vid2vid import (concat_frames, get_fg_mask,
pre_process_densepose,
random_roll)
from imaginaire.model_utils.pix2pixHD import get_optimizer_with_params
from imaginaire.trainers.vid2vid import Trainer as vid2vidTrainer
from imaginaire.utils.distributed import is_master
from imaginaire.utils.distributed import master_only_print as print
from imaginaire.utils.misc import to_cuda
from imaginaire.utils.visualization import tensor2flow, tensor2im
class Trainer(vid2vidTrainer):
r"""Initialize vid2vid trainer.
Args:
cfg (obj): Global configuration.
net_G (obj): Generator network.
net_D (obj): Discriminator network.
opt_G (obj): Optimizer for the generator network.
opt_D (obj): Optimizer for the discriminator network.
sch_G (obj): Scheduler for the generator optimizer.
sch_D (obj): Scheduler for the discriminator optimizer.
train_data_loader (obj): Train data loader.
val_data_loader (obj): Validation data loader.
"""
def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
train_data_loader, val_data_loader):
super(Trainer, self).__init__(cfg, net_G, net_D, opt_G,
opt_D, sch_G, sch_D,
train_data_loader, val_data_loader)
def _start_of_iteration(self, data, current_iteration):
r"""Things to do before an iteration.
Args:
data (dict): Data used for the current iteration.
current_iteration (int): Current number of iteration.
"""
data = self.pre_process(data)
return to_cuda(data)
def pre_process(self, data):
r"""Do any data pre-processing here.
Args:
data (dict): Data used for the current iteration.
"""
data_cfg = self.cfg.data
if hasattr(data_cfg, 'for_pose_dataset') and \
('pose_maps-densepose' in data_cfg.input_labels):
pose_cfg = data_cfg.for_pose_dataset
data['label'] = pre_process_densepose(pose_cfg, data['label'],
self.is_inference)
data['few_shot_label'] = pre_process_densepose(
pose_cfg, data['few_shot_label'], self.is_inference)
return data
def get_test_output_images(self, data):
r"""Get the visualization output of test function.
Args:
data (dict): Training data at the current iteration.
"""
vis_images = [
tensor2im(data['few_shot_images'][:, 0]),
self.visualize_label(data['label'][:, -1]),
tensor2im(data['images'][:, -1]),
tensor2im(self.net_G_output['fake_images']),
]
return vis_images
def get_data_t(self, data, net_G_output, data_prev, t):
r"""Get data at current time frame given the sequence of data.
Args:
data (dict): Training data for current iteration.
net_G_output (dict): Output of the generator (for previous frame).
data_prev (dict): Data for previous frame.
t (int): Current time.
"""
label = data['label'][:, t] if 'label' in data else None
image = data['images'][:, t]
if data_prev is not None:
nG = self.cfg.data.num_frames_G
prev_labels = concat_frames(data_prev['prev_labels'],
data_prev['label'], nG - 1)
prev_images = concat_frames(
data_prev['prev_images'],
net_G_output['fake_images'].detach(), nG - 1)
else:
prev_labels = prev_images = None
data_t = dict()
data_t['label'] = label
data_t['image'] = image
data_t['ref_labels'] = data['few_shot_label'] if 'few_shot_label' \
in data else None
data_t['ref_images'] = data['few_shot_images']
data_t['prev_labels'] = prev_labels
data_t['prev_images'] = prev_images
data_t['real_prev_image'] = data['images'][:, t - 1] if t > 0 else None
# if 'landmarks_xy' in data:
# data_t['landmarks_xy'] = data['landmarks_xy'][:, t]
# data_t['ref_landmarks_xy'] = data['few_shot_landmarks_xy']
return data_t
def post_process(self, data, net_G_output):
r"""Do any postprocessing of the data / output here.
Args:
data (dict): Training data at the current iteration.
net_G_output (dict): Output of the generator.
"""
if self.has_fg:
fg_mask = get_fg_mask(data['label'], self.has_fg)
if net_G_output['fake_raw_images'] is not None:
net_G_output['fake_raw_images'] = \
net_G_output['fake_raw_images'] * fg_mask
return data, net_G_output
def test(self, test_data_loader, root_output_dir, inference_args):
r"""Run inference on the specified sequence.
Args:
test_data_loader (object): Test data loader.
root_output_dir (str): Location to dump outputs.
inference_args (optional): Optional args.
"""
self.reset()
test_data_loader.dataset.set_sequence_length(0)
# Set the inference sequences.
test_data_loader.dataset.set_inference_sequence_idx(
inference_args.driving_seq_index,
inference_args.few_shot_seq_index,
inference_args.few_shot_frame_index)
video = []
for idx, data in enumerate(tqdm(test_data_loader)):
key = data['key']['images'][0][0]
filename = key.split('/')[-1]
# Create output dir for this sequence.
if idx == 0:
seq_name = '%03d' % inference_args.driving_seq_index
output_dir = os.path.join(root_output_dir, seq_name)
os.makedirs(output_dir, exist_ok=True)
video_path = output_dir
# Get output and save images.
data['img_name'] = filename
data = self.start_of_iteration(data, current_iteration=-1)
output = self.test_single(data, output_dir, inference_args)
video.append(output)
# Save output as mp4.
imageio.mimsave(video_path + '.mp4', video, fps=15)
def save_image(self, path, data):
r"""Save the output images to path.
Note when the generate_raw_output is FALSE. Then,
first_net_G_output['fake_raw_images'] is None and will not be displayed.
In model average mode, we will plot the flow visualization twice.
Args:
path (str): Save path.
data (dict): Training data for current iteration.
"""
self.net_G.eval()
if self.cfg.trainer.model_average_config.enabled:
self.net_G.module.averaged_model.eval()
self.net_G_output = None
with torch.no_grad():
first_net_G_output, last_net_G_output, _ = self.gen_frames(data)
if self.cfg.trainer.model_average_config.enabled:
first_net_G_output_avg, last_net_G_output_avg, _ = \
self.gen_frames(data, use_model_average=True)
def get_images(data, net_G_output, return_first_frame=True,
for_model_average=False):
r"""Get the ourput images to save.
Args:
data (dict): Training data for current iteration.
net_G_output (dict): Generator output.
return_first_frame (bool): Return output for first frame in the
sequence.
for_model_average (bool): For model average output.
Return:
vis_images (list of numpy arrays): Visualization images.
"""
frame_idx = 0 if return_first_frame else -1
warped_idx = 0 if return_first_frame else 1
vis_images = []
if not for_model_average:
vis_images += [
tensor2im(data['few_shot_images'][:, frame_idx]),
self.visualize_label(data['label'][:, frame_idx]),
tensor2im(data['images'][:, frame_idx])
]
vis_images += [
tensor2im(net_G_output['fake_images']),
tensor2im(net_G_output['fake_raw_images'])]
if not for_model_average:
vis_images += [
tensor2im(net_G_output['warped_images'][warped_idx]),
tensor2flow(net_G_output['fake_flow_maps'][warped_idx]),
tensor2im(net_G_output['fake_occlusion_masks'][warped_idx],
normalize=False)
]
return vis_images
if is_master():
vis_images_first = get_images(data, first_net_G_output)
if self.cfg.trainer.model_average_config.enabled:
vis_images_first += get_images(data, first_net_G_output_avg,
for_model_average=True)
if self.sequence_length > 1:
vis_images_last = get_images(data, last_net_G_output,
return_first_frame=False)
if self.cfg.trainer.model_average_config.enabled:
vis_images_last += get_images(data, last_net_G_output_avg,
return_first_frame=False,
for_model_average=True)
# If generating a video, the first row of each batch will be
# the first generated frame and the flow/mask for warping the
# reference image, and the second row will be the last
# generated frame and the flow/mask for warping the previous
# frame. If using model average, the frames generated by model
# average will be at the rightmost columns.
vis_images = [[np.vstack((im_first, im_last))
for im_first, im_last in
zip(imgs_first, imgs_last)]
for imgs_first, imgs_last in zip(vis_images_first,
vis_images_last)
if imgs_first is not None]
else:
vis_images = vis_images_first
image_grid = np.hstack([np.vstack(im) for im in vis_images
if im is not None])
print('Save output images to {}'.format(path))
os.makedirs(os.path.dirname(path), exist_ok=True)
imageio.imwrite(path, image_grid)
def finetune(self, data, inference_args):
r"""Finetune the model for a few iterations on the inference data."""
# Get the list of params to finetune.
self.net_G, self.net_D, self.opt_G, self.opt_D = \
get_optimizer_with_params(self.cfg, self.net_G, self.net_D,
param_names_start_with=[
'weight_generator.fc', 'conv_img',
'up'])
data_finetune = {k: v for k, v in data.items()}
ref_labels = data_finetune['few_shot_label']
ref_images = data_finetune['few_shot_images']
# Number of iterations to finetune.
iterations = getattr(inference_args, 'finetune_iter', 100)
for it in range(1, iterations + 1):
# Randomly set one of the reference images as target.
idx = np.random.randint(ref_labels.size(1))
tgt_label, tgt_image = ref_labels[:, idx], ref_images[:, idx]
# Randomly shift and flip the target image.
tgt_label, tgt_image = random_roll([tgt_label, tgt_image])
data_finetune['label'] = tgt_label.unsqueeze(1)
data_finetune['images'] = tgt_image.unsqueeze(1)
self.gen_update(data_finetune)
self.dis_update(data_finetune)
if (it % (iterations // 10)) == 0:
print(it)
self.has_finetuned = True