Spaces:
Runtime error
Runtime error
File size: 12,677 Bytes
f670afc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 |
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import os
import imageio
import numpy as np
import torch
from tqdm import tqdm
from imaginaire.model_utils.fs_vid2vid import (concat_frames, get_fg_mask,
pre_process_densepose,
random_roll)
from imaginaire.model_utils.pix2pixHD import get_optimizer_with_params
from imaginaire.trainers.vid2vid import Trainer as vid2vidTrainer
from imaginaire.utils.distributed import is_master
from imaginaire.utils.distributed import master_only_print as print
from imaginaire.utils.misc import to_cuda
from imaginaire.utils.visualization import tensor2flow, tensor2im
class Trainer(vid2vidTrainer):
r"""Initialize vid2vid trainer.
Args:
cfg (obj): Global configuration.
net_G (obj): Generator network.
net_D (obj): Discriminator network.
opt_G (obj): Optimizer for the generator network.
opt_D (obj): Optimizer for the discriminator network.
sch_G (obj): Scheduler for the generator optimizer.
sch_D (obj): Scheduler for the discriminator optimizer.
train_data_loader (obj): Train data loader.
val_data_loader (obj): Validation data loader.
"""
def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D,
train_data_loader, val_data_loader):
super(Trainer, self).__init__(cfg, net_G, net_D, opt_G,
opt_D, sch_G, sch_D,
train_data_loader, val_data_loader)
def _start_of_iteration(self, data, current_iteration):
r"""Things to do before an iteration.
Args:
data (dict): Data used for the current iteration.
current_iteration (int): Current number of iteration.
"""
data = self.pre_process(data)
return to_cuda(data)
def pre_process(self, data):
r"""Do any data pre-processing here.
Args:
data (dict): Data used for the current iteration.
"""
data_cfg = self.cfg.data
if hasattr(data_cfg, 'for_pose_dataset') and \
('pose_maps-densepose' in data_cfg.input_labels):
pose_cfg = data_cfg.for_pose_dataset
data['label'] = pre_process_densepose(pose_cfg, data['label'],
self.is_inference)
data['few_shot_label'] = pre_process_densepose(
pose_cfg, data['few_shot_label'], self.is_inference)
return data
def get_test_output_images(self, data):
r"""Get the visualization output of test function.
Args:
data (dict): Training data at the current iteration.
"""
vis_images = [
tensor2im(data['few_shot_images'][:, 0]),
self.visualize_label(data['label'][:, -1]),
tensor2im(data['images'][:, -1]),
tensor2im(self.net_G_output['fake_images']),
]
return vis_images
def get_data_t(self, data, net_G_output, data_prev, t):
r"""Get data at current time frame given the sequence of data.
Args:
data (dict): Training data for current iteration.
net_G_output (dict): Output of the generator (for previous frame).
data_prev (dict): Data for previous frame.
t (int): Current time.
"""
label = data['label'][:, t] if 'label' in data else None
image = data['images'][:, t]
if data_prev is not None:
nG = self.cfg.data.num_frames_G
prev_labels = concat_frames(data_prev['prev_labels'],
data_prev['label'], nG - 1)
prev_images = concat_frames(
data_prev['prev_images'],
net_G_output['fake_images'].detach(), nG - 1)
else:
prev_labels = prev_images = None
data_t = dict()
data_t['label'] = label
data_t['image'] = image
data_t['ref_labels'] = data['few_shot_label'] if 'few_shot_label' \
in data else None
data_t['ref_images'] = data['few_shot_images']
data_t['prev_labels'] = prev_labels
data_t['prev_images'] = prev_images
data_t['real_prev_image'] = data['images'][:, t - 1] if t > 0 else None
# if 'landmarks_xy' in data:
# data_t['landmarks_xy'] = data['landmarks_xy'][:, t]
# data_t['ref_landmarks_xy'] = data['few_shot_landmarks_xy']
return data_t
def post_process(self, data, net_G_output):
r"""Do any postprocessing of the data / output here.
Args:
data (dict): Training data at the current iteration.
net_G_output (dict): Output of the generator.
"""
if self.has_fg:
fg_mask = get_fg_mask(data['label'], self.has_fg)
if net_G_output['fake_raw_images'] is not None:
net_G_output['fake_raw_images'] = \
net_G_output['fake_raw_images'] * fg_mask
return data, net_G_output
def test(self, test_data_loader, root_output_dir, inference_args):
r"""Run inference on the specified sequence.
Args:
test_data_loader (object): Test data loader.
root_output_dir (str): Location to dump outputs.
inference_args (optional): Optional args.
"""
self.reset()
test_data_loader.dataset.set_sequence_length(0)
# Set the inference sequences.
test_data_loader.dataset.set_inference_sequence_idx(
inference_args.driving_seq_index,
inference_args.few_shot_seq_index,
inference_args.few_shot_frame_index)
video = []
for idx, data in enumerate(tqdm(test_data_loader)):
key = data['key']['images'][0][0]
filename = key.split('/')[-1]
# Create output dir for this sequence.
if idx == 0:
seq_name = '%03d' % inference_args.driving_seq_index
output_dir = os.path.join(root_output_dir, seq_name)
os.makedirs(output_dir, exist_ok=True)
video_path = output_dir
# Get output and save images.
data['img_name'] = filename
data = self.start_of_iteration(data, current_iteration=-1)
output = self.test_single(data, output_dir, inference_args)
video.append(output)
# Save output as mp4.
imageio.mimsave(video_path + '.mp4', video, fps=15)
def save_image(self, path, data):
r"""Save the output images to path.
Note when the generate_raw_output is FALSE. Then,
first_net_G_output['fake_raw_images'] is None and will not be displayed.
In model average mode, we will plot the flow visualization twice.
Args:
path (str): Save path.
data (dict): Training data for current iteration.
"""
self.net_G.eval()
if self.cfg.trainer.model_average_config.enabled:
self.net_G.module.averaged_model.eval()
self.net_G_output = None
with torch.no_grad():
first_net_G_output, last_net_G_output, _ = self.gen_frames(data)
if self.cfg.trainer.model_average_config.enabled:
first_net_G_output_avg, last_net_G_output_avg, _ = \
self.gen_frames(data, use_model_average=True)
def get_images(data, net_G_output, return_first_frame=True,
for_model_average=False):
r"""Get the ourput images to save.
Args:
data (dict): Training data for current iteration.
net_G_output (dict): Generator output.
return_first_frame (bool): Return output for first frame in the
sequence.
for_model_average (bool): For model average output.
Return:
vis_images (list of numpy arrays): Visualization images.
"""
frame_idx = 0 if return_first_frame else -1
warped_idx = 0 if return_first_frame else 1
vis_images = []
if not for_model_average:
vis_images += [
tensor2im(data['few_shot_images'][:, frame_idx]),
self.visualize_label(data['label'][:, frame_idx]),
tensor2im(data['images'][:, frame_idx])
]
vis_images += [
tensor2im(net_G_output['fake_images']),
tensor2im(net_G_output['fake_raw_images'])]
if not for_model_average:
vis_images += [
tensor2im(net_G_output['warped_images'][warped_idx]),
tensor2flow(net_G_output['fake_flow_maps'][warped_idx]),
tensor2im(net_G_output['fake_occlusion_masks'][warped_idx],
normalize=False)
]
return vis_images
if is_master():
vis_images_first = get_images(data, first_net_G_output)
if self.cfg.trainer.model_average_config.enabled:
vis_images_first += get_images(data, first_net_G_output_avg,
for_model_average=True)
if self.sequence_length > 1:
vis_images_last = get_images(data, last_net_G_output,
return_first_frame=False)
if self.cfg.trainer.model_average_config.enabled:
vis_images_last += get_images(data, last_net_G_output_avg,
return_first_frame=False,
for_model_average=True)
# If generating a video, the first row of each batch will be
# the first generated frame and the flow/mask for warping the
# reference image, and the second row will be the last
# generated frame and the flow/mask for warping the previous
# frame. If using model average, the frames generated by model
# average will be at the rightmost columns.
vis_images = [[np.vstack((im_first, im_last))
for im_first, im_last in
zip(imgs_first, imgs_last)]
for imgs_first, imgs_last in zip(vis_images_first,
vis_images_last)
if imgs_first is not None]
else:
vis_images = vis_images_first
image_grid = np.hstack([np.vstack(im) for im in vis_images
if im is not None])
print('Save output images to {}'.format(path))
os.makedirs(os.path.dirname(path), exist_ok=True)
imageio.imwrite(path, image_grid)
def finetune(self, data, inference_args):
r"""Finetune the model for a few iterations on the inference data."""
# Get the list of params to finetune.
self.net_G, self.net_D, self.opt_G, self.opt_D = \
get_optimizer_with_params(self.cfg, self.net_G, self.net_D,
param_names_start_with=[
'weight_generator.fc', 'conv_img',
'up'])
data_finetune = {k: v for k, v in data.items()}
ref_labels = data_finetune['few_shot_label']
ref_images = data_finetune['few_shot_images']
# Number of iterations to finetune.
iterations = getattr(inference_args, 'finetune_iter', 100)
for it in range(1, iterations + 1):
# Randomly set one of the reference images as target.
idx = np.random.randint(ref_labels.size(1))
tgt_label, tgt_image = ref_labels[:, idx], ref_images[:, idx]
# Randomly shift and flip the target image.
tgt_label, tgt_image = random_roll([tgt_label, tgt_image])
data_finetune['label'] = tgt_label.unsqueeze(1)
data_finetune['images'] = tgt_image.unsqueeze(1)
self.gen_update(data_finetune)
self.dis_update(data_finetune)
if (it % (iterations // 10)) == 0:
print(it)
self.has_finetuned = True
|