Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
# Copyright (C) 2020 NVIDIA Corporation. All rights reserved | |
import functools | |
import warnings | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from imaginaire.layers import Conv2dBlock | |
from imaginaire.utils.data import (get_paired_input_image_channel_number, | |
get_paired_input_label_channel_number) | |
from imaginaire.utils.distributed import master_only_print as print | |
class Discriminator(nn.Module): | |
r"""Multi-resolution patch discriminator. | |
Args: | |
dis_cfg (obj): Discriminator definition part of the yaml config | |
file. | |
data_cfg (obj): Data definition part of the yaml config file. | |
""" | |
def __init__(self, dis_cfg, data_cfg): | |
super(Discriminator, self).__init__() | |
print('Multi-resolution patch discriminator initialization.') | |
# We assume the first datum is the ground truth image. | |
image_channels = get_paired_input_image_channel_number(data_cfg) | |
# Calculate number of channels in the input label. | |
num_labels = get_paired_input_label_channel_number(data_cfg) | |
# Build the discriminator. | |
kernel_size = getattr(dis_cfg, 'kernel_size', 3) | |
num_filters = getattr(dis_cfg, 'num_filters', 128) | |
max_num_filters = getattr(dis_cfg, 'max_num_filters', 512) | |
num_discriminators = getattr(dis_cfg, 'num_discriminators', 2) | |
num_layers = getattr(dis_cfg, 'num_layers', 5) | |
activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none') | |
weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral') | |
print('\tBase filter number: %d' % num_filters) | |
print('\tNumber of discriminators: %d' % num_discriminators) | |
print('\tNumber of layers in a discriminator: %d' % num_layers) | |
print('\tWeight norm type: %s' % weight_norm_type) | |
num_input_channels = image_channels + num_labels | |
self.model = MultiResPatchDiscriminator(num_discriminators, | |
kernel_size, | |
num_input_channels, | |
num_filters, | |
num_layers, | |
max_num_filters, | |
activation_norm_type, | |
weight_norm_type) | |
print('Done with the Multi-resolution patch ' | |
'discriminator initialization.') | |
def forward(self, data, net_G_output, real=True): | |
r"""SPADE Generator forward. | |
Args: | |
data (dict): | |
- data (N x C1 x H x W tensor) : Ground truth images. | |
- label (N x C2 x H x W tensor) : Semantic representations. | |
- z (N x style_dims tensor): Gaussian random noise. | |
net_G_output (dict): | |
fake_images (N x C1 x H x W tensor) : Fake images. | |
real (bool): If ``True``, also classifies real images. Otherwise it | |
only classifies generated images to save computation during the | |
generator update. | |
Returns: | |
(tuple): | |
- real_outputs (list): list of output tensors produced by | |
- individual patch discriminators for real images. | |
- real_features (list): list of lists of features produced by | |
individual patch discriminators for real images. | |
- fake_outputs (list): list of output tensors produced by | |
individual patch discriminators for fake images. | |
- fake_features (list): list of lists of features produced by | |
individual patch discriminators for fake images. | |
""" | |
output_x = dict() | |
if 'label' in data: | |
fake_input_x = torch.cat( | |
(data['label'], net_G_output['fake_images']), 1) | |
else: | |
fake_input_x = net_G_output['fake_images'] | |
output_x['fake_outputs'], output_x['fake_features'], _ = \ | |
self.model.forward(fake_input_x) | |
if real: | |
if 'label' in data: | |
real_input_x = torch.cat( | |
(data['label'], data['images']), 1) | |
else: | |
real_input_x = data['images'] | |
output_x['real_outputs'], output_x['real_features'], _ = \ | |
self.model.forward(real_input_x) | |
return output_x | |
class MultiResPatchDiscriminator(nn.Module): | |
r"""Multi-resolution patch discriminator. | |
Args: | |
num_discriminators (int): Num. of discriminators (one per scale). | |
kernel_size (int): Convolution kernel size. | |
num_image_channels (int): Num. of channels in the real/fake image. | |
num_filters (int): Num. of base filters in a layer. | |
num_layers (int): Num. of layers for the patch discriminator. | |
max_num_filters (int): Maximum num. of filters in a layer. | |
activation_norm_type (str): batch_norm/instance_norm/none/.... | |
weight_norm_type (str): none/spectral_norm/weight_norm | |
""" | |
def __init__(self, | |
num_discriminators=3, | |
kernel_size=3, | |
num_image_channels=3, | |
num_filters=64, | |
num_layers=4, | |
max_num_filters=512, | |
activation_norm_type='', | |
weight_norm_type='', | |
**kwargs): | |
super().__init__() | |
for key in kwargs: | |
if key != 'type' and key != 'patch_wise': | |
warnings.warn( | |
"Discriminator argument {} is not used".format(key)) | |
self.discriminators = nn.ModuleList() | |
for i in range(num_discriminators): | |
net_discriminator = NLayerPatchDiscriminator( | |
kernel_size, | |
num_image_channels, | |
num_filters, | |
num_layers, | |
max_num_filters, | |
activation_norm_type, | |
weight_norm_type) | |
self.discriminators.append(net_discriminator) | |
print('Done with the Multi-resolution patch ' | |
'discriminator initialization.') | |
def forward(self, input_x): | |
r"""Multi-resolution patch discriminator forward. | |
Args: | |
input_x (tensor) : Input images. | |
Returns: | |
(tuple): | |
- output_list (list): list of output tensors produced by | |
individual patch discriminators. | |
- features_list (list): list of lists of features produced by | |
individual patch discriminators. | |
- input_list (list): list of downsampled input images. | |
""" | |
input_list = [] | |
output_list = [] | |
features_list = [] | |
input_downsampled = input_x | |
for net_discriminator in self.discriminators: | |
input_list.append(input_downsampled) | |
output, features = net_discriminator(input_downsampled) | |
output_list.append(output) | |
features_list.append(features) | |
input_downsampled = nn.functional.interpolate( | |
input_downsampled, scale_factor=0.5, mode='bilinear', | |
align_corners=True, recompute_scale_factor=True) | |
return output_list, features_list, input_list | |
class WeightSharedMultiResPatchDiscriminator(nn.Module): | |
r"""Multi-resolution patch discriminator with shared weights. | |
Args: | |
num_discriminators (int): Num. of discriminators (one per scale). | |
kernel_size (int): Convolution kernel size. | |
num_image_channels (int): Num. of channels in the real/fake image. | |
num_filters (int): Num. of base filters in a layer. | |
num_layers (int): Num. of layers for the patch discriminator. | |
max_num_filters (int): Maximum num. of filters in a layer. | |
activation_norm_type (str): batch_norm/instance_norm/none/.... | |
weight_norm_type (str): none/spectral_norm/weight_norm | |
""" | |
def __init__(self, | |
num_discriminators=3, | |
kernel_size=3, | |
num_image_channels=3, | |
num_filters=64, | |
num_layers=4, | |
max_num_filters=512, | |
activation_norm_type='', | |
weight_norm_type='', | |
**kwargs): | |
super().__init__() | |
for key in kwargs: | |
if key != 'type' and key != 'patch_wise': | |
warnings.warn( | |
"Discriminator argument {} is not used".format(key)) | |
self.num_discriminators = num_discriminators | |
self.discriminator = NLayerPatchDiscriminator( | |
kernel_size, | |
num_image_channels, | |
num_filters, | |
num_layers, | |
max_num_filters, | |
activation_norm_type, | |
weight_norm_type) | |
print('Done with the Weight-Shared Multi-resolution patch ' | |
'discriminator initialization.') | |
def forward(self, input_x): | |
r"""Multi-resolution patch discriminator forward. | |
Args: | |
input_x (tensor) : Input images. | |
Returns: | |
(tuple): | |
- output_list (list): list of output tensors produced by | |
individual patch discriminators. | |
- features_list (list): list of lists of features produced by | |
individual patch discriminators. | |
- input_list (list): list of downsampled input images. | |
""" | |
input_list = [] | |
output_list = [] | |
features_list = [] | |
input_downsampled = input_x | |
for i in range(self.num_discriminators): | |
input_list.append(input_downsampled) | |
output, features = self.discriminator(input_downsampled) | |
output_list.append(output) | |
features_list.append(features) | |
input_downsampled = nn.functional.interpolate( | |
input_downsampled, scale_factor=0.5, mode='bilinear', | |
align_corners=True) | |
return output_list, features_list, input_list | |
class NLayerPatchDiscriminator(nn.Module): | |
r"""Patch Discriminator constructor. | |
Args: | |
kernel_size (int): Convolution kernel size. | |
num_input_channels (int): Num. of channels in the real/fake image. | |
num_filters (int): Num. of base filters in a layer. | |
num_layers (int): Num. of layers for the patch discriminator. | |
max_num_filters (int): Maximum num. of filters in a layer. | |
activation_norm_type (str): batch_norm/instance_norm/none/.... | |
weight_norm_type (str): none/spectral_norm/weight_norm | |
""" | |
def __init__(self, | |
kernel_size, | |
num_input_channels, | |
num_filters, | |
num_layers, | |
max_num_filters, | |
activation_norm_type, | |
weight_norm_type): | |
super(NLayerPatchDiscriminator, self).__init__() | |
self.num_layers = num_layers | |
padding = int(np.floor((kernel_size - 1.0) / 2)) | |
nonlinearity = 'leakyrelu' | |
base_conv2d_block = \ | |
functools.partial(Conv2dBlock, | |
kernel_size=kernel_size, | |
padding=padding, | |
weight_norm_type=weight_norm_type, | |
activation_norm_type=activation_norm_type, | |
nonlinearity=nonlinearity, | |
# inplace_nonlinearity=True, | |
order='CNA') | |
layers = [[base_conv2d_block( | |
num_input_channels, num_filters, stride=2)]] | |
for n in range(num_layers): | |
num_filters_prev = num_filters | |
num_filters = min(num_filters * 2, max_num_filters) | |
stride = 2 if n < (num_layers - 1) else 1 | |
layers += [[base_conv2d_block(num_filters_prev, num_filters, | |
stride=stride)]] | |
layers += [[Conv2dBlock(num_filters, 1, | |
3, 1, | |
padding, | |
weight_norm_type=weight_norm_type)]] | |
for n in range(len(layers)): | |
setattr(self, 'layer' + str(n), nn.Sequential(*layers[n])) | |
def forward(self, input_x): | |
r"""Patch Discriminator forward. | |
Args: | |
input_x (N x C x H1 x W2 tensor): Concatenation of images and | |
semantic representations. | |
Returns: | |
(tuple): | |
- output (N x 1 x H2 x W2 tensor): Discriminator output value. | |
Before the sigmoid when using NSGAN. | |
- features (list): lists of tensors of the intermediate | |
activations. | |
""" | |
res = [input_x] | |
for n in range(self.num_layers + 2): | |
layer = getattr(self, 'layer' + str(n)) | |
x = res[-1] | |
res.append(layer(x)) | |
output = res[-1] | |
features = res[1:-1] | |
return output, features | |