Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import os | |
import imageio | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
from imaginaire.model_utils.fs_vid2vid import (concat_frames, get_fg_mask, | |
pre_process_densepose, | |
random_roll) | |
from imaginaire.model_utils.pix2pixHD import get_optimizer_with_params | |
from imaginaire.trainers.vid2vid import Trainer as vid2vidTrainer | |
from imaginaire.utils.distributed import is_master | |
from imaginaire.utils.distributed import master_only_print as print | |
from imaginaire.utils.misc import to_cuda | |
from imaginaire.utils.visualization import tensor2flow, tensor2im | |
class Trainer(vid2vidTrainer): | |
r"""Initialize vid2vid trainer. | |
Args: | |
cfg (obj): Global configuration. | |
net_G (obj): Generator network. | |
net_D (obj): Discriminator network. | |
opt_G (obj): Optimizer for the generator network. | |
opt_D (obj): Optimizer for the discriminator network. | |
sch_G (obj): Scheduler for the generator optimizer. | |
sch_D (obj): Scheduler for the discriminator optimizer. | |
train_data_loader (obj): Train data loader. | |
val_data_loader (obj): Validation data loader. | |
""" | |
def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, | |
train_data_loader, val_data_loader): | |
super(Trainer, self).__init__(cfg, net_G, net_D, opt_G, | |
opt_D, sch_G, sch_D, | |
train_data_loader, val_data_loader) | |
def _start_of_iteration(self, data, current_iteration): | |
r"""Things to do before an iteration. | |
Args: | |
data (dict): Data used for the current iteration. | |
current_iteration (int): Current number of iteration. | |
""" | |
data = self.pre_process(data) | |
return to_cuda(data) | |
def pre_process(self, data): | |
r"""Do any data pre-processing here. | |
Args: | |
data (dict): Data used for the current iteration. | |
""" | |
data_cfg = self.cfg.data | |
if hasattr(data_cfg, 'for_pose_dataset') and \ | |
('pose_maps-densepose' in data_cfg.input_labels): | |
pose_cfg = data_cfg.for_pose_dataset | |
data['label'] = pre_process_densepose(pose_cfg, data['label'], | |
self.is_inference) | |
data['few_shot_label'] = pre_process_densepose( | |
pose_cfg, data['few_shot_label'], self.is_inference) | |
return data | |
def get_test_output_images(self, data): | |
r"""Get the visualization output of test function. | |
Args: | |
data (dict): Training data at the current iteration. | |
""" | |
vis_images = [ | |
tensor2im(data['few_shot_images'][:, 0]), | |
self.visualize_label(data['label'][:, -1]), | |
tensor2im(data['images'][:, -1]), | |
tensor2im(self.net_G_output['fake_images']), | |
] | |
return vis_images | |
def get_data_t(self, data, net_G_output, data_prev, t): | |
r"""Get data at current time frame given the sequence of data. | |
Args: | |
data (dict): Training data for current iteration. | |
net_G_output (dict): Output of the generator (for previous frame). | |
data_prev (dict): Data for previous frame. | |
t (int): Current time. | |
""" | |
label = data['label'][:, t] if 'label' in data else None | |
image = data['images'][:, t] | |
if data_prev is not None: | |
nG = self.cfg.data.num_frames_G | |
prev_labels = concat_frames(data_prev['prev_labels'], | |
data_prev['label'], nG - 1) | |
prev_images = concat_frames( | |
data_prev['prev_images'], | |
net_G_output['fake_images'].detach(), nG - 1) | |
else: | |
prev_labels = prev_images = None | |
data_t = dict() | |
data_t['label'] = label | |
data_t['image'] = image | |
data_t['ref_labels'] = data['few_shot_label'] if 'few_shot_label' \ | |
in data else None | |
data_t['ref_images'] = data['few_shot_images'] | |
data_t['prev_labels'] = prev_labels | |
data_t['prev_images'] = prev_images | |
data_t['real_prev_image'] = data['images'][:, t - 1] if t > 0 else None | |
# if 'landmarks_xy' in data: | |
# data_t['landmarks_xy'] = data['landmarks_xy'][:, t] | |
# data_t['ref_landmarks_xy'] = data['few_shot_landmarks_xy'] | |
return data_t | |
def post_process(self, data, net_G_output): | |
r"""Do any postprocessing of the data / output here. | |
Args: | |
data (dict): Training data at the current iteration. | |
net_G_output (dict): Output of the generator. | |
""" | |
if self.has_fg: | |
fg_mask = get_fg_mask(data['label'], self.has_fg) | |
if net_G_output['fake_raw_images'] is not None: | |
net_G_output['fake_raw_images'] = \ | |
net_G_output['fake_raw_images'] * fg_mask | |
return data, net_G_output | |
def test(self, test_data_loader, root_output_dir, inference_args): | |
r"""Run inference on the specified sequence. | |
Args: | |
test_data_loader (object): Test data loader. | |
root_output_dir (str): Location to dump outputs. | |
inference_args (optional): Optional args. | |
""" | |
self.reset() | |
test_data_loader.dataset.set_sequence_length(0) | |
# Set the inference sequences. | |
test_data_loader.dataset.set_inference_sequence_idx( | |
inference_args.driving_seq_index, | |
inference_args.few_shot_seq_index, | |
inference_args.few_shot_frame_index) | |
video = [] | |
for idx, data in enumerate(tqdm(test_data_loader)): | |
key = data['key']['images'][0][0] | |
filename = key.split('/')[-1] | |
# Create output dir for this sequence. | |
if idx == 0: | |
seq_name = '%03d' % inference_args.driving_seq_index | |
output_dir = os.path.join(root_output_dir, seq_name) | |
os.makedirs(output_dir, exist_ok=True) | |
video_path = output_dir | |
# Get output and save images. | |
data['img_name'] = filename | |
data = self.start_of_iteration(data, current_iteration=-1) | |
output = self.test_single(data, output_dir, inference_args) | |
video.append(output) | |
# Save output as mp4. | |
imageio.mimsave(video_path + '.mp4', video, fps=15) | |
def save_image(self, path, data): | |
r"""Save the output images to path. | |
Note when the generate_raw_output is FALSE. Then, | |
first_net_G_output['fake_raw_images'] is None and will not be displayed. | |
In model average mode, we will plot the flow visualization twice. | |
Args: | |
path (str): Save path. | |
data (dict): Training data for current iteration. | |
""" | |
self.net_G.eval() | |
if self.cfg.trainer.model_average_config.enabled: | |
self.net_G.module.averaged_model.eval() | |
self.net_G_output = None | |
with torch.no_grad(): | |
first_net_G_output, last_net_G_output, _ = self.gen_frames(data) | |
if self.cfg.trainer.model_average_config.enabled: | |
first_net_G_output_avg, last_net_G_output_avg, _ = \ | |
self.gen_frames(data, use_model_average=True) | |
def get_images(data, net_G_output, return_first_frame=True, | |
for_model_average=False): | |
r"""Get the ourput images to save. | |
Args: | |
data (dict): Training data for current iteration. | |
net_G_output (dict): Generator output. | |
return_first_frame (bool): Return output for first frame in the | |
sequence. | |
for_model_average (bool): For model average output. | |
Return: | |
vis_images (list of numpy arrays): Visualization images. | |
""" | |
frame_idx = 0 if return_first_frame else -1 | |
warped_idx = 0 if return_first_frame else 1 | |
vis_images = [] | |
if not for_model_average: | |
vis_images += [ | |
tensor2im(data['few_shot_images'][:, frame_idx]), | |
self.visualize_label(data['label'][:, frame_idx]), | |
tensor2im(data['images'][:, frame_idx]) | |
] | |
vis_images += [ | |
tensor2im(net_G_output['fake_images']), | |
tensor2im(net_G_output['fake_raw_images'])] | |
if not for_model_average: | |
vis_images += [ | |
tensor2im(net_G_output['warped_images'][warped_idx]), | |
tensor2flow(net_G_output['fake_flow_maps'][warped_idx]), | |
tensor2im(net_G_output['fake_occlusion_masks'][warped_idx], | |
normalize=False) | |
] | |
return vis_images | |
if is_master(): | |
vis_images_first = get_images(data, first_net_G_output) | |
if self.cfg.trainer.model_average_config.enabled: | |
vis_images_first += get_images(data, first_net_G_output_avg, | |
for_model_average=True) | |
if self.sequence_length > 1: | |
vis_images_last = get_images(data, last_net_G_output, | |
return_first_frame=False) | |
if self.cfg.trainer.model_average_config.enabled: | |
vis_images_last += get_images(data, last_net_G_output_avg, | |
return_first_frame=False, | |
for_model_average=True) | |
# If generating a video, the first row of each batch will be | |
# the first generated frame and the flow/mask for warping the | |
# reference image, and the second row will be the last | |
# generated frame and the flow/mask for warping the previous | |
# frame. If using model average, the frames generated by model | |
# average will be at the rightmost columns. | |
vis_images = [[np.vstack((im_first, im_last)) | |
for im_first, im_last in | |
zip(imgs_first, imgs_last)] | |
for imgs_first, imgs_last in zip(vis_images_first, | |
vis_images_last) | |
if imgs_first is not None] | |
else: | |
vis_images = vis_images_first | |
image_grid = np.hstack([np.vstack(im) for im in vis_images | |
if im is not None]) | |
print('Save output images to {}'.format(path)) | |
os.makedirs(os.path.dirname(path), exist_ok=True) | |
imageio.imwrite(path, image_grid) | |
def finetune(self, data, inference_args): | |
r"""Finetune the model for a few iterations on the inference data.""" | |
# Get the list of params to finetune. | |
self.net_G, self.net_D, self.opt_G, self.opt_D = \ | |
get_optimizer_with_params(self.cfg, self.net_G, self.net_D, | |
param_names_start_with=[ | |
'weight_generator.fc', 'conv_img', | |
'up']) | |
data_finetune = {k: v for k, v in data.items()} | |
ref_labels = data_finetune['few_shot_label'] | |
ref_images = data_finetune['few_shot_images'] | |
# Number of iterations to finetune. | |
iterations = getattr(inference_args, 'finetune_iter', 100) | |
for it in range(1, iterations + 1): | |
# Randomly set one of the reference images as target. | |
idx = np.random.randint(ref_labels.size(1)) | |
tgt_label, tgt_image = ref_labels[:, idx], ref_images[:, idx] | |
# Randomly shift and flip the target image. | |
tgt_label, tgt_image = random_roll([tgt_label, tgt_image]) | |
data_finetune['label'] = tgt_label.unsqueeze(1) | |
data_finetune['images'] = tgt_image.unsqueeze(1) | |
self.gen_update(data_finetune) | |
self.dis_update(data_finetune) | |
if (it % (iterations // 10)) == 0: | |
print(it) | |
self.has_finetuned = True | |