venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
# Copyright (C) 2020 NVIDIA Corporation. All rights reserved
import functools
import warnings
import numpy as np
import torch
import torch.nn as nn
from imaginaire.layers import Conv2dBlock
from imaginaire.utils.data import (get_paired_input_image_channel_number,
get_paired_input_label_channel_number)
from imaginaire.utils.distributed import master_only_print as print
class Discriminator(nn.Module):
r"""Multi-resolution patch discriminator.
Args:
dis_cfg (obj): Discriminator definition part of the yaml config
file.
data_cfg (obj): Data definition part of the yaml config file.
"""
def __init__(self, dis_cfg, data_cfg):
super(Discriminator, self).__init__()
print('Multi-resolution patch discriminator initialization.')
# We assume the first datum is the ground truth image.
image_channels = get_paired_input_image_channel_number(data_cfg)
# Calculate number of channels in the input label.
num_labels = get_paired_input_label_channel_number(data_cfg)
# Build the discriminator.
kernel_size = getattr(dis_cfg, 'kernel_size', 3)
num_filters = getattr(dis_cfg, 'num_filters', 128)
max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
num_discriminators = getattr(dis_cfg, 'num_discriminators', 2)
num_layers = getattr(dis_cfg, 'num_layers', 5)
activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
print('\tBase filter number: %d' % num_filters)
print('\tNumber of discriminators: %d' % num_discriminators)
print('\tNumber of layers in a discriminator: %d' % num_layers)
print('\tWeight norm type: %s' % weight_norm_type)
num_input_channels = image_channels + num_labels
self.model = MultiResPatchDiscriminator(num_discriminators,
kernel_size,
num_input_channels,
num_filters,
num_layers,
max_num_filters,
activation_norm_type,
weight_norm_type)
print('Done with the Multi-resolution patch '
'discriminator initialization.')
def forward(self, data, net_G_output, real=True):
r"""SPADE Generator forward.
Args:
data (dict):
- data (N x C1 x H x W tensor) : Ground truth images.
- label (N x C2 x H x W tensor) : Semantic representations.
- z (N x style_dims tensor): Gaussian random noise.
net_G_output (dict):
fake_images (N x C1 x H x W tensor) : Fake images.
real (bool): If ``True``, also classifies real images. Otherwise it
only classifies generated images to save computation during the
generator update.
Returns:
(tuple):
- real_outputs (list): list of output tensors produced by
- individual patch discriminators for real images.
- real_features (list): list of lists of features produced by
individual patch discriminators for real images.
- fake_outputs (list): list of output tensors produced by
individual patch discriminators for fake images.
- fake_features (list): list of lists of features produced by
individual patch discriminators for fake images.
"""
output_x = dict()
if 'label' in data:
fake_input_x = torch.cat(
(data['label'], net_G_output['fake_images']), 1)
else:
fake_input_x = net_G_output['fake_images']
output_x['fake_outputs'], output_x['fake_features'], _ = \
self.model.forward(fake_input_x)
if real:
if 'label' in data:
real_input_x = torch.cat(
(data['label'], data['images']), 1)
else:
real_input_x = data['images']
output_x['real_outputs'], output_x['real_features'], _ = \
self.model.forward(real_input_x)
return output_x
class MultiResPatchDiscriminator(nn.Module):
r"""Multi-resolution patch discriminator.
Args:
num_discriminators (int): Num. of discriminators (one per scale).
kernel_size (int): Convolution kernel size.
num_image_channels (int): Num. of channels in the real/fake image.
num_filters (int): Num. of base filters in a layer.
num_layers (int): Num. of layers for the patch discriminator.
max_num_filters (int): Maximum num. of filters in a layer.
activation_norm_type (str): batch_norm/instance_norm/none/....
weight_norm_type (str): none/spectral_norm/weight_norm
"""
def __init__(self,
num_discriminators=3,
kernel_size=3,
num_image_channels=3,
num_filters=64,
num_layers=4,
max_num_filters=512,
activation_norm_type='',
weight_norm_type='',
**kwargs):
super().__init__()
for key in kwargs:
if key != 'type' and key != 'patch_wise':
warnings.warn(
"Discriminator argument {} is not used".format(key))
self.discriminators = nn.ModuleList()
for i in range(num_discriminators):
net_discriminator = NLayerPatchDiscriminator(
kernel_size,
num_image_channels,
num_filters,
num_layers,
max_num_filters,
activation_norm_type,
weight_norm_type)
self.discriminators.append(net_discriminator)
print('Done with the Multi-resolution patch '
'discriminator initialization.')
def forward(self, input_x):
r"""Multi-resolution patch discriminator forward.
Args:
input_x (tensor) : Input images.
Returns:
(tuple):
- output_list (list): list of output tensors produced by
individual patch discriminators.
- features_list (list): list of lists of features produced by
individual patch discriminators.
- input_list (list): list of downsampled input images.
"""
input_list = []
output_list = []
features_list = []
input_downsampled = input_x
for net_discriminator in self.discriminators:
input_list.append(input_downsampled)
output, features = net_discriminator(input_downsampled)
output_list.append(output)
features_list.append(features)
input_downsampled = nn.functional.interpolate(
input_downsampled, scale_factor=0.5, mode='bilinear',
align_corners=True, recompute_scale_factor=True)
return output_list, features_list, input_list
class WeightSharedMultiResPatchDiscriminator(nn.Module):
r"""Multi-resolution patch discriminator with shared weights.
Args:
num_discriminators (int): Num. of discriminators (one per scale).
kernel_size (int): Convolution kernel size.
num_image_channels (int): Num. of channels in the real/fake image.
num_filters (int): Num. of base filters in a layer.
num_layers (int): Num. of layers for the patch discriminator.
max_num_filters (int): Maximum num. of filters in a layer.
activation_norm_type (str): batch_norm/instance_norm/none/....
weight_norm_type (str): none/spectral_norm/weight_norm
"""
def __init__(self,
num_discriminators=3,
kernel_size=3,
num_image_channels=3,
num_filters=64,
num_layers=4,
max_num_filters=512,
activation_norm_type='',
weight_norm_type='',
**kwargs):
super().__init__()
for key in kwargs:
if key != 'type' and key != 'patch_wise':
warnings.warn(
"Discriminator argument {} is not used".format(key))
self.num_discriminators = num_discriminators
self.discriminator = NLayerPatchDiscriminator(
kernel_size,
num_image_channels,
num_filters,
num_layers,
max_num_filters,
activation_norm_type,
weight_norm_type)
print('Done with the Weight-Shared Multi-resolution patch '
'discriminator initialization.')
def forward(self, input_x):
r"""Multi-resolution patch discriminator forward.
Args:
input_x (tensor) : Input images.
Returns:
(tuple):
- output_list (list): list of output tensors produced by
individual patch discriminators.
- features_list (list): list of lists of features produced by
individual patch discriminators.
- input_list (list): list of downsampled input images.
"""
input_list = []
output_list = []
features_list = []
input_downsampled = input_x
for i in range(self.num_discriminators):
input_list.append(input_downsampled)
output, features = self.discriminator(input_downsampled)
output_list.append(output)
features_list.append(features)
input_downsampled = nn.functional.interpolate(
input_downsampled, scale_factor=0.5, mode='bilinear',
align_corners=True)
return output_list, features_list, input_list
class NLayerPatchDiscriminator(nn.Module):
r"""Patch Discriminator constructor.
Args:
kernel_size (int): Convolution kernel size.
num_input_channels (int): Num. of channels in the real/fake image.
num_filters (int): Num. of base filters in a layer.
num_layers (int): Num. of layers for the patch discriminator.
max_num_filters (int): Maximum num. of filters in a layer.
activation_norm_type (str): batch_norm/instance_norm/none/....
weight_norm_type (str): none/spectral_norm/weight_norm
"""
def __init__(self,
kernel_size,
num_input_channels,
num_filters,
num_layers,
max_num_filters,
activation_norm_type,
weight_norm_type):
super(NLayerPatchDiscriminator, self).__init__()
self.num_layers = num_layers
padding = int(np.floor((kernel_size - 1.0) / 2))
nonlinearity = 'leakyrelu'
base_conv2d_block = \
functools.partial(Conv2dBlock,
kernel_size=kernel_size,
padding=padding,
weight_norm_type=weight_norm_type,
activation_norm_type=activation_norm_type,
nonlinearity=nonlinearity,
# inplace_nonlinearity=True,
order='CNA')
layers = [[base_conv2d_block(
num_input_channels, num_filters, stride=2)]]
for n in range(num_layers):
num_filters_prev = num_filters
num_filters = min(num_filters * 2, max_num_filters)
stride = 2 if n < (num_layers - 1) else 1
layers += [[base_conv2d_block(num_filters_prev, num_filters,
stride=stride)]]
layers += [[Conv2dBlock(num_filters, 1,
3, 1,
padding,
weight_norm_type=weight_norm_type)]]
for n in range(len(layers)):
setattr(self, 'layer' + str(n), nn.Sequential(*layers[n]))
def forward(self, input_x):
r"""Patch Discriminator forward.
Args:
input_x (N x C x H1 x W2 tensor): Concatenation of images and
semantic representations.
Returns:
(tuple):
- output (N x 1 x H2 x W2 tensor): Discriminator output value.
Before the sigmoid when using NSGAN.
- features (list): lists of tensors of the intermediate
activations.
"""
res = [input_x]
for n in range(self.num_layers + 2):
layer = getattr(self, 'layer' + str(n))
x = res[-1]
res.append(layer(x))
output = res[-1]
features = res[1:-1]
return output, features