File size: 5,619 Bytes
f670afc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES.  All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
import torch
import torch.nn as nn

from imaginaire.discriminators.fpse import FPSEDiscriminator
from imaginaire.discriminators.multires_patch import NLayerPatchDiscriminator
from imaginaire.utils.data import (get_paired_input_image_channel_number,
                                   get_paired_input_label_channel_number)
from imaginaire.utils.distributed import master_only_print as print


class Discriminator(nn.Module):
    r"""Multi-resolution patch discriminator.

    Args:
        dis_cfg (obj): Discriminator definition part of the yaml config file.
        data_cfg (obj): Data definition part of the yaml config file.
    """

    def __init__(self, dis_cfg, data_cfg):
        super(Discriminator, self).__init__()
        print('Multi-resolution patch discriminator initialization.')
        image_channels = getattr(dis_cfg, 'image_channels', None)
        if image_channels is None:
            image_channels = get_paired_input_image_channel_number(data_cfg)
        num_labels = getattr(dis_cfg, 'num_labels', None)
        if num_labels is None:
            # Calculate number of channels in the input label when not specified.
            num_labels = get_paired_input_label_channel_number(data_cfg)

        # Build the discriminator.
        kernel_size = getattr(dis_cfg, 'kernel_size', 3)
        num_filters = getattr(dis_cfg, 'num_filters', 128)
        max_num_filters = getattr(dis_cfg, 'max_num_filters', 512)
        num_discriminators = getattr(dis_cfg, 'num_discriminators', 2)
        num_layers = getattr(dis_cfg, 'num_layers', 5)
        activation_norm_type = getattr(dis_cfg, 'activation_norm_type', 'none')
        weight_norm_type = getattr(dis_cfg, 'weight_norm_type', 'spectral')
        print('\tBase filter number: %d' % num_filters)
        print('\tNumber of discriminators: %d' % num_discriminators)
        print('\tNumber of layers in a discriminator: %d' % num_layers)
        print('\tWeight norm type: %s' % weight_norm_type)
        num_input_channels = image_channels + num_labels
        self.discriminators = nn.ModuleList()
        for i in range(num_discriminators):
            net_discriminator = NLayerPatchDiscriminator(
                kernel_size,
                num_input_channels,
                num_filters,
                num_layers,
                max_num_filters,
                activation_norm_type,
                weight_norm_type)
            self.discriminators.append(net_discriminator)
        print('Done with the Multi-resolution patch discriminator initialization.')
        self.use_fpse = getattr(dis_cfg, 'use_fpse', True)
        if self.use_fpse:
            fpse_kernel_size = getattr(dis_cfg, 'fpse_kernel_size', 3)
            fpse_activation_norm_type = getattr(dis_cfg,
                                                'fpse_activation_norm_type',
                                                'none')
            self.fpse_discriminator = FPSEDiscriminator(
                image_channels,
                num_labels,
                num_filters,
                fpse_kernel_size,
                weight_norm_type,
                fpse_activation_norm_type)

    def _single_forward(self, input_label, input_image):
        # Compute discriminator outputs and intermediate features from input
        # images and semantic labels.
        input_x = torch.cat(
            (input_label, input_image), 1)
        output_list = []
        features_list = []
        if self.use_fpse:
            pred2, pred3, pred4 = self.fpse_discriminator(input_image, input_label)
            output_list = [pred2, pred3, pred4]
        input_downsampled = input_x
        for net_discriminator in self.discriminators:
            output, features = net_discriminator(input_downsampled)
            output_list.append(output)
            features_list.append(features)
            input_downsampled = nn.functional.interpolate(
                input_downsampled, scale_factor=0.5, mode='bilinear',
                align_corners=True)
        return output_list, features_list

    def forward(self, data, net_G_output):
        r"""SPADE discriminator forward.

        Args:
            data (dict):
              - data  (N x C1 x H x W tensor) : Ground truth images.
              - label (N x C2 x H x W tensor) : Semantic representations.
              - z (N x style_dims tensor): Gaussian random noise.
            net_G_output (dict):
                fake_images  (N x C1 x H x W tensor) : Fake images.
        Returns:
            (dict):
              - real_outputs (list): list of output tensors produced by
                individual patch discriminators for real images.
              - real_features (list): list of lists of features produced by
                individual patch discriminators for real images.
              - fake_outputs (list): list of output tensors produced by
                individual patch discriminators for fake images.
              - fake_features (list): list of lists of features produced by
                individual patch discriminators for fake images.
        """
        output_x = dict()
        output_x['real_outputs'], output_x['real_features'] = \
            self._single_forward(data['label'], data['images'])
        output_x['fake_outputs'], output_x['fake_features'] = \
            self._single_forward(data['label'], net_G_output['fake_images'])
        return output_x