# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # This work is made available under the Nvidia Source Code License-NC. # To view a copy of this license, check out LICENSE.md import warnings import torch import torch.nn as nn from imaginaire.layers import Conv2dBlock, Res2dBlock from imaginaire.third_party.upfirdn2d import BlurDownsample class ResDiscriminator(nn.Module): r"""Global residual discriminator. Args: image_channels (int): Num. of channels in the real/fake image. num_filters (int): Num. of base filters in a layer. max_num_filters (int): Maximum num. of filters in a layer. first_kernel_size (int): Kernel size in the first layer. num_layers (int): Num. of layers in discriminator. padding_mode (str): Padding mode. activation_norm_type (str): Type of activation normalization. ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``. weight_norm_type (str): Type of weight normalization. ``'none'``, ``'spectral'``, or ``'weight'``. aggregation (str): Method to aggregate features across different locations in the final layer. ``'conv'``, or ``'pool'``. order (str): Order of operations in the residual link. anti_aliased (bool): If ``True``, uses anti-aliased pooling. """ def __init__(self, image_channels=3, num_filters=64, max_num_filters=512, first_kernel_size=1, num_layers=4, padding_mode='zeros', activation_norm_type='', weight_norm_type='', aggregation='conv', order='pre_act', anti_aliased=False, **kwargs): super().__init__() for key in kwargs: if key != 'type' and key != 'patch_wise': warnings.warn( "Discriminator argument {} is not used".format(key)) conv_params = dict(padding_mode=padding_mode, activation_norm_type=activation_norm_type, weight_norm_type=weight_norm_type, nonlinearity='leakyrelu') first_padding = (first_kernel_size - 1) // 2 model = [Conv2dBlock(image_channels, num_filters, first_kernel_size, 1, first_padding, **conv_params)] for _ in range(num_layers): num_filters_prev = num_filters num_filters = min(num_filters * 2, max_num_filters) model.append(Res2dBlock(num_filters_prev, num_filters, order=order, **conv_params)) if anti_aliased: model.append(BlurDownsample()) else: model.append(nn.AvgPool2d(2, stride=2)) if aggregation == 'pool': model += [torch.nn.AdaptiveAvgPool2d(1)] elif aggregation == 'conv': model += [Conv2dBlock(num_filters, num_filters, 4, 1, 0, nonlinearity='leakyrelu')] else: raise ValueError('The aggregation mode is not recognized' % self.aggregation) self.model = nn.Sequential(*model) self.classifier = nn.Linear(num_filters, 1) def forward(self, images): r"""Multi-resolution patch discriminator forward. Args: images (tensor) : Input images. Returns: (tuple): - outputs (tensor): Output of the discriminator. - features (tensor): Intermediate features of the discriminator. - images (tensor): Input images. """ batch_size = images.size(0) features = self.model(images) outputs = self.classifier(features.view(batch_size, -1)) return outputs, features, images