# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, check out LICENSE.md import importlib import torch import torch.nn as nn import torch.nn.functional as F from imaginaire.discriminators.multires_patch import NLayerPatchDiscriminator from imaginaire.model_utils.fs_vid2vid import get_fg_mask, pick_image from imaginaire.utils.data import (get_paired_input_image_channel_number, get_paired_input_label_channel_number) from imaginaire.utils.misc import get_nested_attr class Discriminator(nn.Module): r"""Image and video discriminator constructor. Args: dis_cfg (obj): Discriminator part of the yaml config file. data_cfg (obj): Data definition part of the yaml config file """ def __init__(self, dis_cfg, data_cfg): super().__init__() self.data_cfg = data_cfg num_input_channels = get_paired_input_label_channel_number(data_cfg) if num_input_channels == 0: num_input_channels = getattr(data_cfg, 'label_channels', 1) num_img_channels = get_paired_input_image_channel_number(data_cfg) self.num_frames_D = data_cfg.num_frames_D self.num_scales = get_nested_attr(dis_cfg, 'temporal.num_scales', 0) num_netD_input_channels = (num_input_channels + num_img_channels) self.use_few_shot = 'few_shot' in data_cfg.type if self.use_few_shot: num_netD_input_channels *= 2 self.net_D = MultiPatchDiscriminator(dis_cfg.image, num_netD_input_channels) self.add_dis_cfg = getattr(dis_cfg, 'additional_discriminators', None) if self.add_dis_cfg is not None: for name in self.add_dis_cfg: add_dis_cfg = self.add_dis_cfg[name] num_ch = num_img_channels * (2 if self.use_few_shot else 1) setattr(self, 'net_D_' + name, MultiPatchDiscriminator(add_dis_cfg, num_ch)) # Temporal discriminator. self.num_netDT_input_channels = num_img_channels * self.num_frames_D for n in range(self.num_scales): setattr(self, 'net_DT%d' % n, MultiPatchDiscriminator(dis_cfg.temporal, self.num_netDT_input_channels)) self.has_fg = getattr(data_cfg, 'has_foreground', False) def forward(self, data, net_G_output, past_frames): r"""Discriminator forward. Args: data (dict): Input data. net_G_output (dict): Generator output. past_frames (list of tensors): Past real frames / generator outputs. Returns: (tuple): - output (dict): Discriminator output. - past_frames (list of tensors): New past frames by adding current outputs. """ label, real_image = data['label'], data['image'] # Only operate on the latest output frame. if label.dim() == 5: label = label[:, -1] if self.use_few_shot: # Pick only one reference image to concat with. ref_idx = net_G_output['ref_idx'] \ if 'ref_idx' in net_G_output else 0 ref_label = pick_image(data['ref_labels'], ref_idx) ref_image = pick_image(data['ref_images'], ref_idx) # Concat references with label map as discriminator input. label = torch.cat([label, ref_label, ref_image], dim=1) fake_image = net_G_output['fake_images'] output = dict() # Individual frame loss. pred_real, pred_fake = self.discrminate_image(self.net_D, label, real_image, fake_image) output['indv'] = dict() output['indv']['pred_real'] = pred_real output['indv']['pred_fake'] = pred_fake if 'fake_raw_images' in net_G_output and \ net_G_output['fake_raw_images'] is not None: # Raw generator output loss. fake_raw_image = net_G_output['fake_raw_images'] fg_mask = get_fg_mask(data['label'], self.has_fg) pred_real, pred_fake = self.discrminate_image( self.net_D, label, real_image * fg_mask, fake_raw_image * fg_mask) output['raw'] = dict() output['raw']['pred_real'] = pred_real output['raw']['pred_fake'] = pred_fake # Additional GAN loss on specific regions. if self.add_dis_cfg is not None: for name in self.add_dis_cfg: # Crop corresponding regions in the image according to the # crop function. add_dis_cfg = self.add_dis_cfg[name] file, crop_func = add_dis_cfg.crop_func.split('::') file = importlib.import_module(file) crop_func = getattr(file, crop_func) real_crop = crop_func(self.data_cfg, real_image, label) fake_crop = crop_func(self.data_cfg, fake_image, label) if self.use_few_shot: ref_crop = crop_func(self.data_cfg, ref_image, label) if ref_crop is not None: real_crop = torch.cat([real_crop, ref_crop], dim=1) fake_crop = torch.cat([fake_crop, ref_crop], dim=1) # Feed the crops to specific discriminator. if fake_crop is not None: net_D = getattr(self, 'net_D_' + name) pred_real, pred_fake = \ self.discrminate_image(net_D, None, real_crop, fake_crop) else: pred_real = pred_fake = None output[name] = dict() output[name]['pred_real'] = pred_real output[name]['pred_fake'] = pred_fake # Temporal loss. past_frames, skipped_frames = \ get_all_skipped_frames(past_frames, [real_image, fake_image], self.num_scales, self.num_frames_D) for scale in range(self.num_scales): real_image, fake_image = \ [skipped_frame[scale] for skipped_frame in skipped_frames] pred_real, pred_fake = self.discriminate_video(real_image, fake_image, scale) output['temporal_%d' % scale] = dict() output['temporal_%d' % scale]['pred_real'] = pred_real output['temporal_%d' % scale]['pred_fake'] = pred_fake return output, past_frames def discrminate_image(self, net_D, real_A, real_B, fake_B): r"""Discriminate individual images. Args: net_D (obj): Discriminator network. real_A (NxC1xHxW tensor): Input label map. real_B (NxC2xHxW tensor): Real image. fake_B (NxC2xHxW tensor): Fake image. Returns: (tuple): - pred_real (NxC3xH2xW2 tensor): Output of net_D for real images. - pred_fake (NxC3xH2xW2 tensor): Output of net_D for fake images. """ if real_A is not None: real_AB = torch.cat([real_A, real_B], dim=1) fake_AB = torch.cat([real_A, fake_B], dim=1) else: real_AB, fake_AB = real_B, fake_B pred_real = net_D.forward(real_AB) pred_fake = net_D.forward(fake_AB) return pred_real, pred_fake def discriminate_video(self, real_B, fake_B, scale): r"""Discriminate a sequence of images. Args: real_B (NxCxHxW tensor): Real image. fake_B (NxCxHxW tensor): Fake image. scale (int): Temporal scale. Returns: (tuple): - pred_real (NxC2xH2xW2 tensor): Output of net_D for real images. - pred_fake (NxC2xH2xW2 tensor): Output of net_D for fake images. """ if real_B is None: return None, None net_DT = getattr(self, 'net_DT%d' % scale) height, width = real_B.shape[-2:] real_B = real_B.view(-1, self.num_netDT_input_channels, height, width) fake_B = fake_B.view(-1, self.num_netDT_input_channels, height, width) pred_real = net_DT.forward(real_B) pred_fake = net_DT.forward(fake_B) return pred_real, pred_fake def get_all_skipped_frames(past_frames, new_frames, t_scales, tD): r"""Get temporally skipped frames from the input frames. Args: past_frames (list of tensors): Past real frames / generator outputs. new_frames (list of tensors): Current real frame / generated output. t_scales (int): Temporal scale. tD (int): Number of frames as input to the temporal discriminator. Returns: (tuple): - new_past_frames (list of tensors): Past + current frames. - skipped_frames (list of tensors): Temporally skipped frames using the given t_scales. """ new_past_frames, skipped_frames = [], [] for past_frame, new_frame in zip(past_frames, new_frames): skipped_frame = None if t_scales > 0: past_frame, skipped_frame = \ get_skipped_frames(past_frame, new_frame.unsqueeze(1), t_scales, tD) new_past_frames.append(past_frame) skipped_frames.append(skipped_frame) return new_past_frames, skipped_frames def get_skipped_frames(all_frames, frame, t_scales, tD): r"""Get temporally skipped frames from the input frames. Args: all_frames (NxTxCxHxW tensor): All past frames. frame (Nx1xCxHxW tensor): Current frame. t_scales (int): Temporal scale. tD (int): Number of frames as input to the temporal discriminator. Returns: (tuple): - all_frames (NxTxCxHxW tensor): Past + current frames. - skipped_frames (list of NxTxCxHxW tensors): Temporally skipped frames. """ all_frames = torch.cat([all_frames.detach(), frame], dim=1) \ if all_frames is not None else frame skipped_frames = [None] * t_scales for s in range(t_scales): # Number of skipped frames between neighboring frames (e.g. 1, 3, 9,...) t_step = tD ** s # Number of frames the final triplet frames span before skipping # (e.g., 2, 6, 18, ...). t_span = t_step * (tD-1) if all_frames.size(1) > t_span: skipped_frames[s] = all_frames[:, -(t_span+1)::t_step].contiguous() # Maximum number of past frames we need to keep track of. max_num_prev_frames = (tD ** (t_scales-1)) * (tD-1) # Remove past frames that are older than this number. if all_frames.size()[1] > max_num_prev_frames: all_frames = all_frames[:, -max_num_prev_frames:] return all_frames, skipped_frames class MultiPatchDiscriminator(nn.Module): r"""Multi-resolution patch discriminator. Args: dis_cfg (obj): Discriminator part of the yaml config file. num_input_channels (int): Number of input channels. """ def __init__(self, dis_cfg, num_input_channels): super(MultiPatchDiscriminator, self).__init__() kernel_size = getattr(dis_cfg, 'kernel_size', 4) num_filters = getattr(dis_cfg, 'num_filters', 64) max_num_filters = getattr(dis_cfg, 'max_num_filters', 512) num_discriminators = getattr(dis_cfg, 'num_discriminators', 3) num_layers = getattr(dis_cfg, 'num_layers', 3) activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none') weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral_norm') self.nets_discriminator = [] for i in range(num_discriminators): net_discriminator = NLayerPatchDiscriminator( kernel_size, num_input_channels, num_filters, num_layers, max_num_filters, activation_norm_type, weight_norm_type) self.add_module('discriminator_%d' % i, net_discriminator) self.nets_discriminator.append(net_discriminator) def forward(self, input_x): r"""Multi-resolution patch discriminator forward. Args: input_x (N x C x H x W tensor) : Concatenation of images and semantic representations. Returns: (dict): - output (list): list of output tensors produced by individual patch discriminators. - features (list): list of lists of features produced by individual patch discriminators. """ output_list = [] features_list = [] input_downsampled = input_x for name, net_discriminator in self.named_children(): if not name.startswith('discriminator_'): continue output, features = net_discriminator(input_downsampled) output_list.append(output) features_list.append(features) input_downsampled = F.interpolate( input_downsampled, scale_factor=0.5, mode='bilinear', align_corners=True, recompute_scale_factor=True) output_x = dict() output_x['output'] = output_list output_x['features'] = features_list return output_x