sat3density / imaginaire /discriminators /multires_patch_pano.py
venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
# Copyright (C) 2020 NVIDIA Corporation. All rights reserved
import functools
import warnings
import numpy as np
import torch
import torch.nn as nn
from imaginaire.layers import Conv2dBlock
from imaginaire.utils.data import (get_paired_input_image_channel_number,
get_paired_input_label_channel_number)
from imaginaire.utils.distributed import master_only_print as print
from model.sample import Equirectangular
class Discriminator(nn.Module):
r"""Multi-resolution patch discriminator.
Args:
dis_cfg (obj): Discriminator definition part of the yaml config
file.
data_cfg (obj): Data definition part of the yaml config file.
"""
def __init__(self, dis_cfg):
super(Discriminator, self).__init__()
print('Multi-resolution patch discriminator initialization.')
# We assume the first datum is the ground truth image.
num_input_channels = getattr(dis_cfg, 'input_channels', 3)
# Calculate number of channels in the input label.
# Build the discriminator.
kernel_size = getattr(dis_cfg, 'kernel_size', 3)
num_filters = getattr(dis_cfg, 'num_filters', 128)
max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
num_discriminators = getattr(dis_cfg, 'num_discriminators', 2)
num_layers = getattr(dis_cfg, 'num_layers', 5)
activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
print('\tBase filter number: %d' % num_filters)
print('\tNumber of discriminators: %d' % num_discriminators)
print('\tNumber of layers in a discriminator: %d' % num_layers)
print('\tWeight norm type: %s' % weight_norm_type)
self.condition = getattr(dis_cfg, 'condition', None)
# self.condition = dis_cfg.condition
self.model = MultiResPatchDiscriminator(num_discriminators,
kernel_size,
num_input_channels,
num_filters,
num_layers,
max_num_filters,
activation_norm_type,
weight_norm_type)
print('Done with the Multi-resolution patch '
'discriminator initialization.')
def forward(self, data, net_G_output, real=True):
r"""SPADE Generator forward.
Args:
data (N x C1 x H x W tensor) : Ground truth images.
net_G_output (dict):
fake_images (N x C1 x H x W tensor) : Fake images.
real (bool): If ``True``, also classifies real images. Otherwise it
only classifies generated images to save computation during the
generator update.
Returns:
(tuple):
- real_outputs (list): list of output tensors produced by
- individual patch discriminators for real images.
- real_features (list): list of lists of features produced by
individual patch discriminators for real images.
- fake_outputs (list): list of output tensors produced by
individual patch discriminators for fake images.
- fake_features (list): list of lists of features produced by
individual patch discriminators for fake images.
"""
output_x = dict()
if self.condition:
fake_input_x = torch.cat([net_G_output['pred'],net_G_output['generator_inputs']],dim=1)
else:
fake_input_x = net_G_output['pred']
output_x['fake_outputs'], output_x['fake_features'], _ = \
self.model.forward(fake_input_x)
if real:
if self.condition:
real_input_x = torch.cat([net_G_output['pred'],net_G_output['generator_inputs']],dim=1)
else:
real_input_x = data
output_x['real_outputs'], output_x['real_features'], _ = \
self.model.forward(real_input_x)
return output_x
class MultiResPatchDiscriminator(nn.Module):
r"""Multi-resolution patch discriminator.
Args:
num_discriminators (int): Num. of discriminators (one per scale).
kernel_size (int): Convolution kernel size.
num_image_channels (int): Num. of channels in the real/fake image.
num_filters (int): Num. of base filters in a layer.
num_layers (int): Num. of layers for the patch discriminator.
max_num_filters (int): Maximum num. of filters in a layer.
activation_norm_type (str): batch_norm/instance_norm/none/....
weight_norm_type (str): none/spectral_norm/weight_norm
"""
def __init__(self,
num_discriminators=3,
kernel_size=3,
num_image_channels=3,
num_filters=64,
num_layers=4,
max_num_filters=512,
activation_norm_type='',
weight_norm_type='',
**kwargs):
super().__init__()
for key in kwargs:
if key != 'type' and key != 'patch_wise':
warnings.warn(
"Discriminator argument {} is not used".format(key))
self.discriminators = nn.ModuleList()
for i in range(num_discriminators):
net_discriminator = NLayerPatchDiscriminator(
kernel_size,
num_image_channels,
num_filters,
num_layers,
max_num_filters,
activation_norm_type,
weight_norm_type)
self.discriminators.append(net_discriminator)
print('Done with the Multi-resolution patch '
'discriminator initialization.')
self.e = Equirectangular(theta=[-40., 40.],width = 128, height = 128,FovX = 100)
def forward(self, input_x):
r"""Multi-resolution patch discriminator forward.
Args:
input_x (tensor) : Input images.
Returns:
(tuple):
- output_list (list): list of output tensors produced by
individual patch discriminators.
- features_list (list): list of lists of features produced by
individual patch discriminators.
- input_list (list): list of downsampled input images.
"""
input_list = []
output_list = []
features_list = []
input_N = nn.functional.interpolate(
input_x, scale_factor=0.5, mode='bilinear',
align_corners=True, recompute_scale_factor=True)
equ= self.e(input_x)
for i, net_discriminator in enumerate(self.discriminators):
input_list.append(input_N)
output, features = net_discriminator(input_N)
output_list.append(output)
features_list.append(features)
if i == 0:
input_N = torch.nn.functional.grid_sample(input_x, equ.float(), align_corners = True)*0.99
elif i == 1:
input_N = nn.functional.interpolate(
input_N, scale_factor=0.5, mode='bilinear',
align_corners=True, recompute_scale_factor=True)
return output_list, features_list, input_list
class NLayerPatchDiscriminator(nn.Module):
r"""Patch Discriminator constructor.
Args:
kernel_size (int): Convolution kernel size.
num_input_channels (int): Num. of channels in the real/fake image.
num_filters (int): Num. of base filters in a layer.
num_layers (int): Num. of layers for the patch discriminator.
max_num_filters (int): Maximum num. of filters in a layer.
activation_norm_type (str): batch_norm/instance_norm/none/....
weight_norm_type (str): none/spectral_norm/weight_norm
"""
def __init__(self,
kernel_size,
num_input_channels,
num_filters,
num_layers,
max_num_filters,
activation_norm_type,
weight_norm_type):
super(NLayerPatchDiscriminator, self).__init__()
self.num_layers = num_layers
padding = int(np.floor((kernel_size - 1.0) / 2))
nonlinearity = 'leakyrelu'
base_conv2d_block = \
functools.partial(Conv2dBlock,
kernel_size=kernel_size,
padding=padding,
weight_norm_type=weight_norm_type,
activation_norm_type=activation_norm_type,
nonlinearity=nonlinearity,
# inplace_nonlinearity=True,
order='CNA')
layers = [[base_conv2d_block(
num_input_channels, num_filters, stride=2)]]
for n in range(num_layers):
num_filters_prev = num_filters
num_filters = min(num_filters * 2, max_num_filters)
stride = 2 if n < (num_layers - 1) else 1
layers += [[base_conv2d_block(num_filters_prev, num_filters,
stride=stride)]]
layers += [[Conv2dBlock(num_filters, 1,
3, 1,
padding,
weight_norm_type=weight_norm_type)]]
for n in range(len(layers)):
setattr(self, 'layer' + str(n), nn.Sequential(*layers[n]))
def forward(self, input_x):
r"""Patch Discriminator forward.
Args:
input_x (N x C x H1 x W2 tensor): Concatenation of images and
semantic representations.
Returns:
(tuple):
- output (N x 1 x H2 x W2 tensor): Discriminator output value.
Before the sigmoid when using NSGAN.
- features (list): lists of tensors of the intermediate
activations.
"""
res = [input_x]
for n in range(self.num_layers + 2):
layer = getattr(self, 'layer' + str(n))
x = res[-1]
res.append(layer(x))
output = res[-1]
features = res[1:-1]
return output, features