# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, check out LICENSE.md import os import time import imageio import numpy as np import torch from tqdm import tqdm from imaginaire.losses import MaskedL1Loss from imaginaire.model_utils.fs_vid2vid import concat_frames, resample from imaginaire.trainers.vid2vid import Trainer as Vid2VidTrainer from imaginaire.utils.distributed import is_master from imaginaire.utils.distributed import master_only_print as print from imaginaire.utils.misc import split_labels, to_cuda from imaginaire.utils.visualization import tensor2flow, tensor2im class Trainer(Vid2VidTrainer): r"""Initialize world consistent vid2vid trainer. Args: cfg (obj): Global configuration. net_G (obj): Generator network. net_D (obj): Discriminator network. opt_G (obj): Optimizer for the generator network. opt_D (obj): Optimizer for the discriminator network. sch_G (obj): Scheduler for the generator optimizer. sch_D (obj): Scheduler for the discriminator optimizer. train_data_loader (obj): Train data loader. val_data_loader (obj): Validation data loader. """ def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader): super(Trainer, self).__init__(cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, train_data_loader, val_data_loader) self.guidance_start_after = getattr(cfg.gen.guidance, 'start_from', 0) self.train_data_loader = train_data_loader def _define_custom_losses(self): r"""All other custom losses are defined here.""" # Setup the guidance loss. self.criteria['Guidance'] = MaskedL1Loss(normalize_over_valid=True) self.weights['Guidance'] = self.cfg.trainer.loss_weight.guidance def start_of_iteration(self, data, current_iteration): r"""Things to do before an iteration. Args: data (dict): Data used for the current iteration. current_iteration (int): Current iteration number. """ self.net_G_module.reset_renderer(is_flipped_input=data['is_flipped']) # Keep unprojections on cpu to prevent unnecessary transfer. unprojections = data.pop('unprojections') data = to_cuda(data) data['unprojections'] = unprojections self.current_iteration = current_iteration if not self.is_inference: self.net_D.train() self.net_G.train() self.start_iteration_time = time.time() return data def reset(self): r"""Reset the trainer (for inference) at the beginning of a sequence.""" # Inference time. self.net_G_module.reset_renderer(is_flipped_input=False) # print('Resetting trainer.') self.net_G_output = self.data_prev = None self.t = 0 test_in_model_average_mode = getattr( self, 'test_in_model_average_mode', False) if test_in_model_average_mode: if hasattr(self.net_G.module.averaged_model, 'reset'): self.net_G.module.averaged_model.reset() else: if hasattr(self.net_G.module, 'reset'): self.net_G.module.reset() def create_sequence_output_dir(self, output_dir, key): r"""Create output subdir for this sequence. Args: output_dir (str): Root output dir. key (str): LMDB key which contains sequence name and file name. Returns: output_dir (str): Output subdir for this sequence. seq_name (str): Name of this sequence. """ seq_dir = '/'.join(key.split('/')[:-1]) output_dir = os.path.join(output_dir, seq_dir) os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir + '/all', exist_ok=True) os.makedirs(output_dir + '/fake', exist_ok=True) seq_name = seq_dir.replace('/', '-') return output_dir, seq_name def test(self, test_data_loader, root_output_dir, inference_args): r"""Run inference on all sequences. Args: test_data_loader (object): Test data loader. root_output_dir (str): Location to dump outputs. inference_args (optional): Optional args. """ # Go over all sequences. loader = test_data_loader num_inference_sequences = loader.dataset.num_inference_sequences() for sequence_idx in range(num_inference_sequences): loader.dataset.set_inference_sequence_idx(sequence_idx) print('Seq id: %d, Seq length: %d' % (sequence_idx + 1, len(loader))) # Reset model at start of new inference sequence. self.reset() self.sequence_length = len(loader) # Go over all frames of this sequence. video = [] for idx, data in enumerate(tqdm(loader)): key = data['key']['images'][0][0] filename = key.split('/')[-1] # Create output dir for this sequence. if idx == 0: output_dir, seq_name = \ self.create_sequence_output_dir(root_output_dir, key) video_path = os.path.join(output_dir, '..', seq_name) # Get output, and save all vis to all/. data['img_name'] = filename data = to_cuda(data) output = self.test_single(data, output_dir=output_dir + '/all') # Dump just the fake image here. fake = tensor2im(output['fake_images'])[0] video.append(fake) imageio.imsave(output_dir + '/fake/%s.jpg' % (filename), fake) # Save as mp4 and gif. imageio.mimsave(video_path + '.mp4', video, fps=15) def test_single(self, data, output_dir=None, save_fake_only=False): r"""The inference function. If output_dir exists, also save the output image. Args: data (dict): Training data at the current iteration. output_dir (str): Save image directory. save_fake_only (bool): Only save the fake output image. """ if self.is_inference and self.cfg.trainer.model_average_config.enabled: test_in_model_average_mode = True else: test_in_model_average_mode = getattr( self, 'test_in_model_average_mode', False) data_t = self.get_data_t(data, self.net_G_output, self.data_prev, 0) if self.sequence_length > 1: self.data_prev = data_t # Generator forward. # Reset renderer if first time step. if self.t == 0: self.net_G_module.reset_renderer( is_flipped_input=data['is_flipped']) with torch.no_grad(): if test_in_model_average_mode: net_G = self.net_G.module.averaged_model else: net_G = self.net_G self.net_G_output = net_G(data_t) if output_dir is not None: if save_fake_only: image_grid = tensor2im(self.net_G_output['fake_images'])[0] else: vis_images = self.get_test_output_images(data) image_grid = np.hstack([np.vstack(im) for im in vis_images if im is not None]) if 'img_name' in data: save_name = data['img_name'].split('.')[0] + '.jpg' else: save_name = '%04d.jpg' % self.t output_filename = os.path.join(output_dir, save_name) os.makedirs(output_dir, exist_ok=True) imageio.imwrite(output_filename, image_grid) self.t += 1 return self.net_G_output def get_test_output_images(self, data): r"""Get the visualization output of test function. Args: data (dict): Training data at the current iteration. """ # Visualize labels. label_lengths = self.val_data_loader.dataset.get_label_lengths() labels = split_labels(data['label'], label_lengths) vis_labels = [] for key, value in labels.items(): if key == 'seg_maps': vis_labels.append(self.visualize_label(value[:, -1])) else: vis_labels.append(tensor2im(value[:, -1])) # Get gt image. im = tensor2im(data['images'][:, -1]) # Get guidance image and masks. if self.net_G_output['guidance_images_and_masks'] is not None: guidance_image = tensor2im( self.net_G_output['guidance_images_and_masks'][:, :3]) guidance_mask = tensor2im( self.net_G_output['guidance_images_and_masks'][:, 3:4], normalize=False) else: guidance_image = [np.zeros_like(item) for item in im] guidance_mask = [np.zeros_like(item) for item in im] # Create output. vis_images = [ *vis_labels, im, guidance_image, guidance_mask, tensor2im(self.net_G_output['fake_images']), ] return vis_images def gen_frames(self, data, use_model_average=False): r"""Generate a sequence of frames given a sequence of data. Args: data (dict): Training data at the current iteration. use_model_average (bool): Whether to use model average for update or not. """ net_G_output = None # Previous generator output. data_prev = None # Previous data. if use_model_average: net_G = self.net_G.module.averaged_model else: net_G = self.net_G # Iterate through the length of sequence. self.net_G_module.reset_renderer(is_flipped_input=data['is_flipped']) all_info = {'inputs': [], 'outputs': []} for t in range(self.sequence_length): # Get the data at the current time frame. data_t = self.get_data_t(data, net_G_output, data_prev, t) data_prev = data_t # Generator forward. with torch.no_grad(): net_G_output = net_G(data_t) # Do any postprocessing if necessary. data_t, net_G_output = self.post_process(data_t, net_G_output) if t == 0: # Get the output at beginning of sequence for visualization. first_net_G_output = net_G_output all_info['inputs'].append(data_t) all_info['outputs'].append(net_G_output) return first_net_G_output, net_G_output, all_info def _get_custom_gen_losses(self, data_t, net_G_output, net_D_output): r"""All other custom generator losses go here. Args: data_t (dict): Training data at the current time t. net_G_output (dict): Output of the generator. net_D_output (dict): Output of the discriminator. """ # Compute guidance loss. if net_G_output['guidance_images_and_masks'] is not None: guidance_image = net_G_output['guidance_images_and_masks'][:, :3] guidance_mask = net_G_output['guidance_images_and_masks'][:, 3:] self.gen_losses['Guidance'] = self.criteria['Guidance']( net_G_output['fake_images'], guidance_image, guidance_mask) else: self.gen_losses['Guidance'] = self.Tensor(1).fill_(0) def get_data_t(self, data, net_G_output, data_prev, t): r"""Get data at current time frame given the sequence of data. Args: data (dict): Training data for current iteration. net_G_output (dict): Output of the generator (for previous frame). data_prev (dict): Data for previous frame. t (int): Current time. """ label = data['label'][:, t] image = data['images'][:, t] # Get keypoint mapping. unprojection = None if t >= self.guidance_start_after: if 'unprojections' in data: try: # Remove unwanted padding. unprojection = {} for key, value in data['unprojections'].items(): value = value[0, t].cpu().numpy() length = value[-1][0] unprojection[key] = value[:length] except: # noqa pass if data_prev is not None: # Concat previous labels/fake images to the ones before. num_frames_G = self.cfg.data.num_frames_G prev_labels = concat_frames(data_prev['prev_labels'], data_prev['label'], num_frames_G - 1) prev_images = concat_frames( data_prev['prev_images'], net_G_output['fake_images'].detach(), num_frames_G - 1) else: prev_labels = prev_images = None data_t = dict() data_t['label'] = label data_t['image'] = image data_t['prev_labels'] = prev_labels data_t['prev_images'] = prev_images data_t['real_prev_image'] = data['images'][:, t - 1] if t > 0 else None data_t['unprojection'] = unprojection return data_t def save_image(self, path, data): r"""Save the output images to path. Note when the generate_raw_output is FALSE. Then, first_net_G_output['fake_raw_images'] is None and will not be displayed. In model average mode, we will plot the flow visualization twice. Args: path (str): Save path. data (dict): Training data for current iteration. """ self.net_G.eval() if self.cfg.trainer.model_average_config.enabled: self.net_G.module.averaged_model.eval() self.net_G_output = None with torch.no_grad(): first_net_G_output, net_G_output, all_info = self.gen_frames(data) if self.cfg.trainer.model_average_config.enabled: first_net_G_output_avg, net_G_output_avg = self.gen_frames( data, use_model_average=True) # Visualize labels. label_lengths = self.train_data_loader.dataset.get_label_lengths() labels = split_labels(data['label'], label_lengths) vis_labels_start, vis_labels_end = [], [] for key, value in labels.items(): if 'seg_maps' in key: vis_labels_start.append(self.visualize_label(value[:, -1])) vis_labels_end.append(self.visualize_label(value[:, 0])) else: normalize = self.train_data_loader.dataset.normalize[key] vis_labels_start.append( tensor2im(value[:, -1], normalize=normalize)) vis_labels_end.append( tensor2im(value[:, 0], normalize=normalize)) if is_master(): vis_images = [ *vis_labels_start, tensor2im(data['images'][:, -1]), tensor2im(net_G_output['fake_images']), tensor2im(net_G_output['fake_raw_images'])] if self.cfg.trainer.model_average_config.enabled: vis_images += [ tensor2im(net_G_output_avg['fake_images']), tensor2im(net_G_output_avg['fake_raw_images'])] if self.sequence_length > 1: if net_G_output['guidance_images_and_masks'] is not None: guidance_image = tensor2im( net_G_output['guidance_images_and_masks'][:, :3]) guidance_mask = tensor2im( net_G_output['guidance_images_and_masks'][:, 3:4], normalize=False) else: im = tensor2im(data['images'][:, -1]) guidance_image = [np.zeros_like(item) for item in im] guidance_mask = [np.zeros_like(item) for item in im] vis_images += [guidance_image, guidance_mask] vis_images_first = [ *vis_labels_end, tensor2im(data['images'][:, 0]), tensor2im(first_net_G_output['fake_images']), tensor2im(first_net_G_output['fake_raw_images']), [np.zeros_like(item) for item in guidance_image], [np.zeros_like(item) for item in guidance_mask] ] if self.cfg.trainer.model_average_config.enabled: vis_images_first += [ tensor2im(first_net_G_output_avg['fake_images']), tensor2im(first_net_G_output_avg['fake_raw_images'])] if self.use_flow: flow_gt, conf_gt = self.criteria['Flow'].flowNet( data['images'][:, -1], data['images'][:, -2]) warped_image_gt = resample(data['images'][:, -1], flow_gt) vis_images_first += [ tensor2flow(flow_gt), tensor2im(conf_gt, normalize=False), tensor2im(warped_image_gt), ] vis_images += [ tensor2flow(net_G_output['fake_flow_maps']), tensor2im(net_G_output['fake_occlusion_masks'], normalize=False), tensor2im(net_G_output['warped_images']), ] if self.cfg.trainer.model_average_config.enabled: vis_images_first += [ tensor2flow(flow_gt), tensor2im(conf_gt, normalize=False), tensor2im(warped_image_gt), ] vis_images += [ tensor2flow(net_G_output_avg['fake_flow_maps']), tensor2im(net_G_output_avg['fake_occlusion_masks'], normalize=False), tensor2im(net_G_output_avg['warped_images'])] vis_images = [[np.vstack((im_first, im)) for im_first, im in zip(imgs_first, imgs)] for imgs_first, imgs in zip(vis_images_first, vis_images) if imgs is not None] image_grid = np.hstack([np.vstack(im) for im in vis_images if im is not None]) print('Save output images to {}'.format(path)) os.makedirs(os.path.dirname(path), exist_ok=True) imageio.imwrite(path, image_grid) # Gather all inputs and outputs for dumping into video. if self.sequence_length > 1: input_images, output_images, output_guidance = [], [], [] for item in all_info['inputs']: input_images.append(tensor2im(item['image'])[0]) for item in all_info['outputs']: output_images.append(tensor2im(item['fake_images'])[0]) if item['guidance_images_and_masks'] is not None: output_guidance.append(tensor2im( item['guidance_images_and_masks'][:, :3])[0]) else: output_guidance.append(np.zeros_like(output_images[-1])) imageio.mimwrite(os.path.splitext(path)[0] + '.mp4', output_images, fps=2, macro_block_size=None) imageio.mimwrite(os.path.splitext(path)[0] + '_guidance.mp4', output_guidance, fps=2, macro_block_size=None) # for idx, item in enumerate(output_guidance): # imageio.imwrite(os.path.splitext( # path)[0] + '_guidance_%d.jpg' % (idx), item) # for idx, item in enumerate(input_images): # imageio.imwrite(os.path.splitext( # path)[0] + '_input_%d.jpg' % (idx), item) self.net_G.float() def _compute_fid(self): r"""Compute fid. Ignore for faster training.""" return None def load_checkpoint(self, cfg, checkpoint_path, resume=None, load_sch=True): r"""Save network weights, optimizer parameters, scheduler parameters in the checkpoint. Args: cfg (obj): Global configuration. checkpoint_path (str): Path to the checkpoint. """ # Create the single image model. if self.train_data_loader is None: load_single_image_model_weights = False else: load_single_image_model_weights = True self.net_G.module._init_single_image_model( load_weights=load_single_image_model_weights) # Call the original super function. return super().load_checkpoint(cfg, checkpoint_path, resume, load_sch)