File size: 3,967 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import warnings

import torch
import torch.nn as nn

from imaginaire.layers import Conv2dBlock, Res2dBlock
from imaginaire.third_party.upfirdn2d import BlurDownsample


class ResDiscriminator(nn.Module):
    r"""Global residual discriminator.

    Args:
        image_channels (int): Num. of channels in the real/fake image.
        num_filters (int): Num. of base filters in a layer.
        max_num_filters (int): Maximum num. of filters in a layer.
        first_kernel_size (int): Kernel size in the first layer.
        num_layers (int): Num. of layers in discriminator.
        padding_mode (str): Padding mode.
        activation_norm_type (str): Type of activation normalization.
            ``'none'``, ``'instance'``, ``'batch'``, ``'sync_batch'``.
        weight_norm_type (str): Type of weight normalization.
            ``'none'``, ``'spectral'``, or ``'weight'``.
        aggregation (str): Method to aggregate features across different
            locations in the final layer. ``'conv'``, or ``'pool'``.
        order (str): Order of operations in the residual link.
        anti_aliased (bool): If ``True``, uses anti-aliased pooling.
    """

    def __init__(self,
                 image_channels=3,
                 num_filters=64,
                 max_num_filters=512,
                 first_kernel_size=1,
                 num_layers=4,
                 padding_mode='zeros',
                 activation_norm_type='',
                 weight_norm_type='',
                 aggregation='conv',
                 order='pre_act',
                 anti_aliased=False,
                 **kwargs):
        super().__init__()
        for key in kwargs:
            if key != 'type' and key != 'patch_wise':
                warnings.warn(
                    "Discriminator argument {} is not used".format(key))

        conv_params = dict(padding_mode=padding_mode,
                           activation_norm_type=activation_norm_type,
                           weight_norm_type=weight_norm_type,
                           nonlinearity='leakyrelu')

        first_padding = (first_kernel_size - 1) // 2
        model = [Conv2dBlock(image_channels, num_filters,
                             first_kernel_size, 1, first_padding,
                             **conv_params)]
        for _ in range(num_layers):
            num_filters_prev = num_filters
            num_filters = min(num_filters * 2, max_num_filters)
            model.append(Res2dBlock(num_filters_prev, num_filters, order=order,
                                    **conv_params))
            if anti_aliased:
                model.append(BlurDownsample())
            else:
                model.append(nn.AvgPool2d(2, stride=2))
        if aggregation == 'pool':
            model += [torch.nn.AdaptiveAvgPool2d(1)]
        elif aggregation == 'conv':
            model += [Conv2dBlock(num_filters, num_filters, 4, 1, 0,
                                  nonlinearity='leakyrelu')]
        else:
            raise ValueError('The aggregation mode is not recognized'
                             % self.aggregation)
        self.model = nn.Sequential(*model)
        self.classifier = nn.Linear(num_filters, 1)

    def forward(self, images):
        r"""Multi-resolution patch discriminator forward.

        Args:
            images (tensor) : Input images.
        Returns:
            (tuple):
              - outputs (tensor): Output of the discriminator.
              - features (tensor): Intermediate features of the discriminator.
              - images (tensor): Input images.
        """
        batch_size = images.size(0)
        features = self.model(images)
        outputs = self.classifier(features.view(batch_size, -1))
        return outputs, features, images