venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import importlib
import torch
import torch.nn as nn
import torch.nn.functional as F
from imaginaire.discriminators.multires_patch import NLayerPatchDiscriminator
from imaginaire.model_utils.fs_vid2vid import get_fg_mask, pick_image
from imaginaire.utils.data import (get_paired_input_image_channel_number,
get_paired_input_label_channel_number)
from imaginaire.utils.misc import get_nested_attr
class Discriminator(nn.Module):
r"""Image and video discriminator constructor.
Args:
dis_cfg (obj): Discriminator part of the yaml config file.
data_cfg (obj): Data definition part of the yaml config file
"""
def __init__(self, dis_cfg, data_cfg):
super().__init__()
self.data_cfg = data_cfg
num_input_channels = get_paired_input_label_channel_number(data_cfg)
if num_input_channels == 0:
num_input_channels = getattr(data_cfg, 'label_channels', 1)
num_img_channels = get_paired_input_image_channel_number(data_cfg)
self.num_frames_D = data_cfg.num_frames_D
self.num_scales = get_nested_attr(dis_cfg, 'temporal.num_scales', 0)
num_netD_input_channels = (num_input_channels + num_img_channels)
self.use_few_shot = 'few_shot' in data_cfg.type
if self.use_few_shot:
num_netD_input_channels *= 2
self.net_D = MultiPatchDiscriminator(dis_cfg.image,
num_netD_input_channels)
self.add_dis_cfg = getattr(dis_cfg, 'additional_discriminators', None)
if self.add_dis_cfg is not None:
for name in self.add_dis_cfg:
add_dis_cfg = self.add_dis_cfg[name]
num_ch = num_img_channels * (2 if self.use_few_shot else 1)
setattr(self, 'net_D_' + name,
MultiPatchDiscriminator(add_dis_cfg, num_ch))
# Temporal discriminator.
self.num_netDT_input_channels = num_img_channels * self.num_frames_D
for n in range(self.num_scales):
setattr(self, 'net_DT%d' % n,
MultiPatchDiscriminator(dis_cfg.temporal,
self.num_netDT_input_channels))
self.has_fg = getattr(data_cfg, 'has_foreground', False)
def forward(self, data, net_G_output, past_frames):
r"""Discriminator forward.
Args:
data (dict): Input data.
net_G_output (dict): Generator output.
past_frames (list of tensors): Past real frames / generator outputs.
Returns:
(tuple):
- output (dict): Discriminator output.
- past_frames (list of tensors): New past frames by adding
current outputs.
"""
label, real_image = data['label'], data['image']
# Only operate on the latest output frame.
if label.dim() == 5:
label = label[:, -1]
if self.use_few_shot:
# Pick only one reference image to concat with.
ref_idx = net_G_output['ref_idx'] \
if 'ref_idx' in net_G_output else 0
ref_label = pick_image(data['ref_labels'], ref_idx)
ref_image = pick_image(data['ref_images'], ref_idx)
# Concat references with label map as discriminator input.
label = torch.cat([label, ref_label, ref_image], dim=1)
fake_image = net_G_output['fake_images']
output = dict()
# Individual frame loss.
pred_real, pred_fake = self.discrminate_image(self.net_D, label,
real_image, fake_image)
output['indv'] = dict()
output['indv']['pred_real'] = pred_real
output['indv']['pred_fake'] = pred_fake
if 'fake_raw_images' in net_G_output and \
net_G_output['fake_raw_images'] is not None:
# Raw generator output loss.
fake_raw_image = net_G_output['fake_raw_images']
fg_mask = get_fg_mask(data['label'], self.has_fg)
pred_real, pred_fake = self.discrminate_image(
self.net_D, label,
real_image * fg_mask,
fake_raw_image * fg_mask)
output['raw'] = dict()
output['raw']['pred_real'] = pred_real
output['raw']['pred_fake'] = pred_fake
# Additional GAN loss on specific regions.
if self.add_dis_cfg is not None:
for name in self.add_dis_cfg:
# Crop corresponding regions in the image according to the
# crop function.
add_dis_cfg = self.add_dis_cfg[name]
file, crop_func = add_dis_cfg.crop_func.split('::')
file = importlib.import_module(file)
crop_func = getattr(file, crop_func)
real_crop = crop_func(self.data_cfg, real_image, label)
fake_crop = crop_func(self.data_cfg, fake_image, label)
if self.use_few_shot:
ref_crop = crop_func(self.data_cfg, ref_image, label)
if ref_crop is not None:
real_crop = torch.cat([real_crop, ref_crop], dim=1)
fake_crop = torch.cat([fake_crop, ref_crop], dim=1)
# Feed the crops to specific discriminator.
if fake_crop is not None:
net_D = getattr(self, 'net_D_' + name)
pred_real, pred_fake = \
self.discrminate_image(net_D, None,
real_crop, fake_crop)
else:
pred_real = pred_fake = None
output[name] = dict()
output[name]['pred_real'] = pred_real
output[name]['pred_fake'] = pred_fake
# Temporal loss.
past_frames, skipped_frames = \
get_all_skipped_frames(past_frames, [real_image, fake_image],
self.num_scales, self.num_frames_D)
for scale in range(self.num_scales):
real_image, fake_image = \
[skipped_frame[scale] for skipped_frame in skipped_frames]
pred_real, pred_fake = self.discriminate_video(real_image,
fake_image, scale)
output['temporal_%d' % scale] = dict()
output['temporal_%d' % scale]['pred_real'] = pred_real
output['temporal_%d' % scale]['pred_fake'] = pred_fake
return output, past_frames
def discrminate_image(self, net_D, real_A, real_B, fake_B):
r"""Discriminate individual images.
Args:
net_D (obj): Discriminator network.
real_A (NxC1xHxW tensor): Input label map.
real_B (NxC2xHxW tensor): Real image.
fake_B (NxC2xHxW tensor): Fake image.
Returns:
(tuple):
- pred_real (NxC3xH2xW2 tensor): Output of net_D for real images.
- pred_fake (NxC3xH2xW2 tensor): Output of net_D for fake images.
"""
if real_A is not None:
real_AB = torch.cat([real_A, real_B], dim=1)
fake_AB = torch.cat([real_A, fake_B], dim=1)
else:
real_AB, fake_AB = real_B, fake_B
pred_real = net_D.forward(real_AB)
pred_fake = net_D.forward(fake_AB)
return pred_real, pred_fake
def discriminate_video(self, real_B, fake_B, scale):
r"""Discriminate a sequence of images.
Args:
real_B (NxCxHxW tensor): Real image.
fake_B (NxCxHxW tensor): Fake image.
scale (int): Temporal scale.
Returns:
(tuple):
- pred_real (NxC2xH2xW2 tensor): Output of net_D for real images.
- pred_fake (NxC2xH2xW2 tensor): Output of net_D for fake images.
"""
if real_B is None:
return None, None
net_DT = getattr(self, 'net_DT%d' % scale)
height, width = real_B.shape[-2:]
real_B = real_B.view(-1, self.num_netDT_input_channels, height, width)
fake_B = fake_B.view(-1, self.num_netDT_input_channels, height, width)
pred_real = net_DT.forward(real_B)
pred_fake = net_DT.forward(fake_B)
return pred_real, pred_fake
def get_all_skipped_frames(past_frames, new_frames, t_scales, tD):
r"""Get temporally skipped frames from the input frames.
Args:
past_frames (list of tensors): Past real frames / generator outputs.
new_frames (list of tensors): Current real frame / generated output.
t_scales (int): Temporal scale.
tD (int): Number of frames as input to the temporal discriminator.
Returns:
(tuple):
- new_past_frames (list of tensors): Past + current frames.
- skipped_frames (list of tensors): Temporally skipped frames using
the given t_scales.
"""
new_past_frames, skipped_frames = [], []
for past_frame, new_frame in zip(past_frames, new_frames):
skipped_frame = None
if t_scales > 0:
past_frame, skipped_frame = \
get_skipped_frames(past_frame, new_frame.unsqueeze(1),
t_scales, tD)
new_past_frames.append(past_frame)
skipped_frames.append(skipped_frame)
return new_past_frames, skipped_frames
def get_skipped_frames(all_frames, frame, t_scales, tD):
r"""Get temporally skipped frames from the input frames.
Args:
all_frames (NxTxCxHxW tensor): All past frames.
frame (Nx1xCxHxW tensor): Current frame.
t_scales (int): Temporal scale.
tD (int): Number of frames as input to the temporal discriminator.
Returns:
(tuple):
- all_frames (NxTxCxHxW tensor): Past + current frames.
- skipped_frames (list of NxTxCxHxW tensors): Temporally skipped
frames.
"""
all_frames = torch.cat([all_frames.detach(), frame], dim=1) \
if all_frames is not None else frame
skipped_frames = [None] * t_scales
for s in range(t_scales):
# Number of skipped frames between neighboring frames (e.g. 1, 3, 9,...)
t_step = tD ** s
# Number of frames the final triplet frames span before skipping
# (e.g., 2, 6, 18, ...).
t_span = t_step * (tD-1)
if all_frames.size(1) > t_span:
skipped_frames[s] = all_frames[:, -(t_span+1)::t_step].contiguous()
# Maximum number of past frames we need to keep track of.
max_num_prev_frames = (tD ** (t_scales-1)) * (tD-1)
# Remove past frames that are older than this number.
if all_frames.size()[1] > max_num_prev_frames:
all_frames = all_frames[:, -max_num_prev_frames:]
return all_frames, skipped_frames
class MultiPatchDiscriminator(nn.Module):
r"""Multi-resolution patch discriminator.
Args:
dis_cfg (obj): Discriminator part of the yaml config file.
num_input_channels (int): Number of input channels.
"""
def __init__(self, dis_cfg, num_input_channels):
super(MultiPatchDiscriminator, self).__init__()
kernel_size = getattr(dis_cfg, 'kernel_size', 4)
num_filters = getattr(dis_cfg, 'num_filters', 64)
max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
num_discriminators = getattr(dis_cfg, 'num_discriminators', 3)
num_layers = getattr(dis_cfg, 'num_layers', 3)
activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
weight_norm_type = getattr(dis_cfg, 'weight_norm_type',
'spectral_norm')
self.nets_discriminator = []
for i in range(num_discriminators):
net_discriminator = NLayerPatchDiscriminator(
kernel_size,
num_input_channels,
num_filters,
num_layers,
max_num_filters,
activation_norm_type,
weight_norm_type)
self.add_module('discriminator_%d' % i, net_discriminator)
self.nets_discriminator.append(net_discriminator)
def forward(self, input_x):
r"""Multi-resolution patch discriminator forward.
Args:
input_x (N x C x H x W tensor) : Concatenation of images and
semantic representations.
Returns:
(dict):
- output (list): list of output tensors produced by individual
patch discriminators.
- features (list): list of lists of features produced by
individual patch discriminators.
"""
output_list = []
features_list = []
input_downsampled = input_x
for name, net_discriminator in self.named_children():
if not name.startswith('discriminator_'):
continue
output, features = net_discriminator(input_downsampled)
output_list.append(output)
features_list.append(features)
input_downsampled = F.interpolate(
input_downsampled, scale_factor=0.5, mode='bilinear',
align_corners=True, recompute_scale_factor=True)
output_x = dict()
output_x['output'] = output_list
output_x['features'] = features_list
return output_x