File size: 3,119 Bytes
6672bfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
Copyright (C) 2019 NVIDIA Corporation.  All rights reserved.
Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from spade.normalizer import SPADE

class SPADEGenerator(nn.Module):
  def __init__(self, opt):
    super().__init__()

    # nf: # of gen filters in first conv layer
    nf = 64

    self.sw, self.sh = self.compute_latent_vector_size(opt['crop_size'], opt['aspect_ratio'])

    self.fc = nn.Conv2d(opt['label_nc'], 16 * nf, 3, padding=1)

    self.head_0 = SPADEResnetBlock(opt, 16 * nf, 16 * nf)

    self.G_middle_0 = SPADEResnetBlock(opt, 16 * nf, 16 * nf)
    self.G_middle_1 = SPADEResnetBlock(opt, 16 * nf, 16 * nf)

    self.up_0 = SPADEResnetBlock(opt, 16 * nf, 8 * nf)
    self.up_1 = SPADEResnetBlock(opt, 8 * nf, 4 * nf)
    self.up_2 = SPADEResnetBlock(opt, 4 * nf, 2 * nf)
    self.up_3 = SPADEResnetBlock(opt, 2 * nf, 1 * nf)

    self.conv_img = nn.Conv2d(1 * nf, 3, 3, padding=1)

    self.up = nn.Upsample(scale_factor=2)
  
  def compute_latent_vector_size(self, crop_size, aspect_ratio):
    num_up_layers = 5

    sw = crop_size // (2**num_up_layers)
    sh = round(sw / aspect_ratio)

    return sw, sh
  
  def forward(self, seg):
    # we downsample segmap and run convolution
    x = F.interpolate(seg, size=(self.sh, self.sw))
    x = self.fc(x)

    x = self.head_0(x, seg)

    x = self.up(x)
    x = self.G_middle_0(x, seg)
    x = self.G_middle_1(x, seg)

    x = self.up(x)
    x = self.up_0(x, seg)
    x = self.up(x)
    x = self.up_1(x, seg)
    x = self.up(x)
    x = self.up_2(x, seg)
    x = self.up(x)
    x = self.up_3(x, seg)

    x = self.conv_img(F.leaky_relu(x, 2e-1))
    x = torch.tanh(x)

    return x

import torch.nn.utils.spectral_norm as spectral_norm

# label_nc: the #channels of the input semantic map, hence the input dim of SPADE
# label_nc: also equivalent to the # of input label classes
class SPADEResnetBlock(nn.Module):
  def __init__(self, opt, fin, fout):
    super().__init__()

    self.learned_shortcut = (fin != fout)
    fmiddle = min(fin, fout)

    self.conv_0 = spectral_norm(nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1))
    self.conv_1 = spectral_norm(nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1))
    if self.learned_shortcut:
      self.conv_s = spectral_norm(nn.Conv2d(fin, fout, kernel_size=1, bias=False))

    # define normalization layers
    self.norm_0 = SPADE(opt, fin)
    self.norm_1 = SPADE(opt, fmiddle)
    if self.learned_shortcut:
      self.norm_s = SPADE(opt, fin)

  # note the resnet block with SPADE also takes in |seg|,
  # the semantic segmentation map as input
  def forward(self, x, seg):
    x_s = self.shortcut(x, seg)

    dx = self.conv_0(self.relu(self.norm_0(x, seg)))
    dx = self.conv_1(self.relu(self.norm_1(dx, seg)))

    out = x_s + dx
    return out

  def shortcut(self, x, seg):
    if self.learned_shortcut:
      x_s = self.conv_s(self.norm_s(x, seg))
    else:
      x_s = x
    return x_s

  def relu(self, x):
    return F.leaky_relu(x, 2e-1)