geninhu commited on
Commit
545bc3c
1 Parent(s): 6c130bf

Upload layers.py

Browse files
Files changed (1) hide show
  1. layers.py +272 -0
layers.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.modules.batchnorm import BatchNorm2d
5
+ from torch.nn.utils import spectral_norm
6
+
7
+
8
+ class SpectralConv2d(nn.Module):
9
+
10
+ def __init__(self, *args, **kwargs):
11
+ super().__init__()
12
+ self._conv = spectral_norm(
13
+ nn.Conv2d(*args, **kwargs)
14
+ )
15
+
16
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
17
+ return self._conv(input)
18
+
19
+
20
+ class SpectralConvTranspose2d(nn.Module):
21
+
22
+ def __init__(self, *args, **kwargs):
23
+ super().__init__()
24
+ self._conv = spectral_norm(
25
+ nn.ConvTranspose2d(*args, **kwargs)
26
+ )
27
+
28
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
29
+ return self._conv(input)
30
+
31
+
32
+ class Noise(nn.Module):
33
+
34
+ def __init__(self):
35
+ super().__init__()
36
+ self._weight = nn.Parameter(
37
+ torch.zeros(1),
38
+ requires_grad=True,
39
+ )
40
+
41
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
42
+ batch_size, _, height, width = input.shape
43
+ noise = torch.randn(batch_size, 1, height, width, device=input.device)
44
+ return self._weight * noise + input
45
+
46
+
47
+ class InitLayer(nn.Module):
48
+
49
+ def __init__(self, in_channels: int,
50
+ out_channels: int):
51
+ super().__init__()
52
+
53
+ self._layers = nn.Sequential(
54
+ SpectralConvTranspose2d(
55
+ in_channels=in_channels,
56
+ out_channels=out_channels * 2,
57
+ kernel_size=4,
58
+ stride=1,
59
+ padding=0,
60
+ bias=False,
61
+ ),
62
+ nn.BatchNorm2d(num_features=out_channels * 2),
63
+ nn.GLU(dim=1),
64
+ )
65
+
66
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
67
+ return self._layers(input)
68
+
69
+
70
+ class SLEBlock(nn.Module):
71
+
72
+ def __init__(self, in_channels: int,
73
+ out_channels: int):
74
+ super().__init__()
75
+
76
+ self._layers = nn.Sequential(
77
+ nn.AdaptiveAvgPool2d(output_size=4),
78
+ SpectralConv2d(
79
+ in_channels=in_channels,
80
+ out_channels=out_channels,
81
+ kernel_size=4,
82
+ stride=1,
83
+ padding=0,
84
+ bias=False,
85
+ ),
86
+ nn.SiLU(),
87
+ SpectralConv2d(
88
+ in_channels=out_channels,
89
+ out_channels=out_channels,
90
+ kernel_size=1,
91
+ stride=1,
92
+ padding=0,
93
+ bias=False,
94
+ ),
95
+ nn.Sigmoid(),
96
+ )
97
+
98
+ def forward(self, low_dim: torch.Tensor,
99
+ high_dim: torch.Tensor) -> torch.Tensor:
100
+ return high_dim * self._layers(low_dim)
101
+
102
+
103
+ class UpsampleBlockT1(nn.Module):
104
+
105
+ def __init__(self, in_channels: int,
106
+ out_channels: int):
107
+ super().__init__()
108
+
109
+ self._layers = nn.Sequential(
110
+ nn.Upsample(scale_factor=2, mode='nearest'),
111
+ SpectralConv2d(
112
+ in_channels=in_channels,
113
+ out_channels=out_channels * 2,
114
+ kernel_size=3,
115
+ stride=1,
116
+ padding='same',
117
+ bias=False,
118
+ ),
119
+ nn.BatchNorm2d(num_features=out_channels * 2),
120
+ nn.GLU(dim=1),
121
+ )
122
+
123
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
124
+ return self._layers(input)
125
+
126
+
127
+ class UpsampleBlockT2(nn.Module):
128
+
129
+ def __init__(self, in_channels: int,
130
+ out_channels: int):
131
+ super().__init__()
132
+
133
+ self._layers = nn.Sequential(
134
+ nn.Upsample(scale_factor=2, mode='nearest'),
135
+ SpectralConv2d(
136
+ in_channels=in_channels,
137
+ out_channels=out_channels * 2,
138
+ kernel_size=3,
139
+ stride=1,
140
+ padding='same',
141
+ bias=False,
142
+ ),
143
+ Noise(),
144
+ BatchNorm2d(num_features=out_channels * 2),
145
+ nn.GLU(dim=1),
146
+ SpectralConv2d(
147
+ in_channels=out_channels,
148
+ out_channels=out_channels * 2,
149
+ kernel_size=3,
150
+ stride=1,
151
+ padding='same',
152
+ bias=False,
153
+ ),
154
+ Noise(),
155
+ nn.BatchNorm2d(num_features=out_channels * 2),
156
+ nn.GLU(dim=1),
157
+ )
158
+
159
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
160
+ return self._layers(input)
161
+
162
+
163
+ class DownsampleBlockT1(nn.Module):
164
+
165
+ def __init__(self, in_channels: int,
166
+ out_channels: int):
167
+ super().__init__()
168
+
169
+ self._layers = nn.Sequential(
170
+ SpectralConv2d(
171
+ in_channels=in_channels,
172
+ out_channels=out_channels,
173
+ kernel_size=4,
174
+ stride=2,
175
+ padding=1,
176
+ bias=False,
177
+ ),
178
+ nn.BatchNorm2d(num_features=out_channels),
179
+ nn.LeakyReLU(negative_slope=0.2),
180
+ )
181
+
182
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
183
+ return self._layers(input)
184
+
185
+
186
+ class DownsampleBlockT2(nn.Module):
187
+
188
+ def __init__(self, in_channels: int,
189
+ out_channels: int):
190
+ super().__init__()
191
+
192
+ self._layers_1 = nn.Sequential(
193
+ SpectralConv2d(
194
+ in_channels=in_channels,
195
+ out_channels=out_channels,
196
+ kernel_size=4,
197
+ stride=2,
198
+ padding=1,
199
+ bias=False,
200
+ ),
201
+ nn.BatchNorm2d(num_features=out_channels),
202
+ nn.LeakyReLU(negative_slope=0.2),
203
+ SpectralConv2d(
204
+ in_channels=out_channels,
205
+ out_channels=out_channels,
206
+ kernel_size=3,
207
+ stride=1,
208
+ padding='same',
209
+ bias=False,
210
+ ),
211
+ nn.BatchNorm2d(num_features=out_channels),
212
+ nn.LeakyReLU(negative_slope=0.2),
213
+ )
214
+
215
+ self._layers_2 = nn.Sequential(
216
+ nn.AvgPool2d(
217
+ kernel_size=2,
218
+ stride=2,
219
+ ),
220
+ SpectralConv2d(
221
+ in_channels=in_channels,
222
+ out_channels=out_channels,
223
+ kernel_size=1,
224
+ stride=1,
225
+ padding=0,
226
+ bias=False,
227
+ ),
228
+ nn.BatchNorm2d(num_features=out_channels),
229
+ nn.LeakyReLU(negative_slope=0.2),
230
+ )
231
+
232
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
233
+ t1 = self._layers_1(input)
234
+ t2 = self._layers_2(input)
235
+ return (t1 + t2) / 2
236
+
237
+
238
+ class Decoder(nn.Module):
239
+
240
+ def __init__(self, in_channels: int,
241
+ out_channels: int):
242
+ super().__init__()
243
+
244
+ self._channels = {
245
+ 16: 128,
246
+ 32: 64,
247
+ 64: 64,
248
+ 128: 32,
249
+ 256: 16,
250
+ 512: 8,
251
+ 1024: 4,
252
+ }
253
+
254
+ self._layers = nn.Sequential(
255
+ nn.AdaptiveAvgPool2d(output_size=8),
256
+ UpsampleBlockT1(in_channels=in_channels, out_channels=self._channels[16]),
257
+ UpsampleBlockT1(in_channels=self._channels[16], out_channels=self._channels[32]),
258
+ UpsampleBlockT1(in_channels=self._channels[32], out_channels=self._channels[64]),
259
+ UpsampleBlockT1(in_channels=self._channels[64], out_channels=self._channels[128]),
260
+ SpectralConv2d(
261
+ in_channels=self._channels[128],
262
+ out_channels=out_channels,
263
+ kernel_size=3,
264
+ stride=1,
265
+ padding='same',
266
+ bias=False,
267
+ ),
268
+ nn.Tanh(),
269
+ )
270
+
271
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
272
+ return self._layers(input)