geninhu commited on
Commit
d6c219d
1 Parent(s): 545bc3c

Upload models.py

Browse files
Files changed (1) hide show
  1. models.py +245 -0
models.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from typing import Any, Tuple, Union
5
+
6
+ from utils import (
7
+ ImageType,
8
+ crop_image_part,
9
+ )
10
+
11
+ from layers import (
12
+ SpectralConv2d,
13
+ InitLayer,
14
+ SLEBlock,
15
+ UpsampleBlockT1,
16
+ UpsampleBlockT2,
17
+ DownsampleBlockT1,
18
+ DownsampleBlockT2,
19
+ Decoder,
20
+ )
21
+
22
+ from huggan.pytorch.huggan_mixin import HugGANModelHubMixin
23
+
24
+
25
+ class Generator(nn.Module, HugGANModelHubMixin):
26
+
27
+ def __init__(self, in_channels: int,
28
+ out_channels: int):
29
+ super().__init__()
30
+
31
+ self._channels = {
32
+ 4: 1024,
33
+ 8: 512,
34
+ 16: 256,
35
+ 32: 128,
36
+ 64: 128,
37
+ 128: 64,
38
+ 256: 32,
39
+ 512: 16,
40
+ 1024: 8,
41
+ }
42
+
43
+ self._init = InitLayer(
44
+ in_channels=in_channels,
45
+ out_channels=self._channels[4],
46
+ )
47
+
48
+ self._upsample_8 = UpsampleBlockT2(in_channels=self._channels[4], out_channels=self._channels[8] )
49
+ self._upsample_16 = UpsampleBlockT1(in_channels=self._channels[8], out_channels=self._channels[16] )
50
+ self._upsample_32 = UpsampleBlockT2(in_channels=self._channels[16], out_channels=self._channels[32] )
51
+ self._upsample_64 = UpsampleBlockT1(in_channels=self._channels[32], out_channels=self._channels[64] )
52
+ self._upsample_128 = UpsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[128] )
53
+ self._upsample_256 = UpsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[256] )
54
+ self._upsample_512 = UpsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[512] )
55
+ self._upsample_1024 = UpsampleBlockT1(in_channels=self._channels[512], out_channels=self._channels[1024])
56
+
57
+ self._sle_64 = SLEBlock(in_channels=self._channels[4], out_channels=self._channels[64] )
58
+ self._sle_128 = SLEBlock(in_channels=self._channels[8], out_channels=self._channels[128])
59
+ self._sle_256 = SLEBlock(in_channels=self._channels[16], out_channels=self._channels[256])
60
+ self._sle_512 = SLEBlock(in_channels=self._channels[32], out_channels=self._channels[512])
61
+
62
+ self._out_128 = nn.Sequential(
63
+ SpectralConv2d(
64
+ in_channels=self._channels[128],
65
+ out_channels=out_channels,
66
+ kernel_size=1,
67
+ stride=1,
68
+ padding='same',
69
+ bias=False,
70
+ ),
71
+ nn.Tanh(),
72
+ )
73
+
74
+ self._out_1024 = nn.Sequential(
75
+ SpectralConv2d(
76
+ in_channels=self._channels[1024],
77
+ out_channels=out_channels,
78
+ kernel_size=3,
79
+ stride=1,
80
+ padding='same',
81
+ bias=False,
82
+ ),
83
+ nn.Tanh(),
84
+ )
85
+
86
+ def forward(self, input: torch.Tensor) -> \
87
+ Tuple[torch.Tensor, torch.Tensor]:
88
+ size_4 = self._init(input)
89
+ size_8 = self._upsample_8(size_4)
90
+ size_16 = self._upsample_16(size_8)
91
+ size_32 = self._upsample_32(size_16)
92
+
93
+ size_64 = self._sle_64 (size_4, self._upsample_64 (size_32) )
94
+ size_128 = self._sle_128(size_8, self._upsample_128(size_64) )
95
+ size_256 = self._sle_256(size_16, self._upsample_256(size_128))
96
+ size_512 = self._sle_512(size_32, self._upsample_512(size_256))
97
+
98
+ size_1024 = self._upsample_1024(size_512)
99
+
100
+ out_128 = self._out_128 (size_128)
101
+ out_1024 = self._out_1024(size_1024)
102
+ return out_1024, out_128
103
+
104
+
105
+ class Discriminrator(nn.Module, HugGANModelHubMixin):
106
+
107
+ def __init__(self, in_channels: int):
108
+ super().__init__()
109
+
110
+ self._channels = {
111
+ 4: 1024,
112
+ 8: 512,
113
+ 16: 256,
114
+ 32: 128,
115
+ 64: 128,
116
+ 128: 64,
117
+ 256: 32,
118
+ 512: 16,
119
+ 1024: 8,
120
+ }
121
+
122
+ self._init = nn.Sequential(
123
+ SpectralConv2d(
124
+ in_channels=in_channels,
125
+ out_channels=self._channels[1024],
126
+ kernel_size=4,
127
+ stride=2,
128
+ padding=1,
129
+ bias=False,
130
+ ),
131
+ nn.LeakyReLU(negative_slope=0.2),
132
+ SpectralConv2d(
133
+ in_channels=self._channels[1024],
134
+ out_channels=self._channels[512],
135
+ kernel_size=4,
136
+ stride=2,
137
+ padding=1,
138
+ bias=False,
139
+ ),
140
+ nn.BatchNorm2d(num_features=self._channels[512]),
141
+ nn.LeakyReLU(negative_slope=0.2),
142
+ )
143
+
144
+ self._downsample_256 = DownsampleBlockT2(in_channels=self._channels[512], out_channels=self._channels[256])
145
+ self._downsample_128 = DownsampleBlockT2(in_channels=self._channels[256], out_channels=self._channels[128])
146
+ self._downsample_64 = DownsampleBlockT2(in_channels=self._channels[128], out_channels=self._channels[64] )
147
+ self._downsample_32 = DownsampleBlockT2(in_channels=self._channels[64], out_channels=self._channels[32] )
148
+ self._downsample_16 = DownsampleBlockT2(in_channels=self._channels[32], out_channels=self._channels[16] )
149
+
150
+ self._sle_64 = SLEBlock(in_channels=self._channels[512], out_channels=self._channels[64])
151
+ self._sle_32 = SLEBlock(in_channels=self._channels[256], out_channels=self._channels[32])
152
+ self._sle_16 = SLEBlock(in_channels=self._channels[128], out_channels=self._channels[16])
153
+
154
+ self._small_track = nn.Sequential(
155
+ SpectralConv2d(
156
+ in_channels=in_channels,
157
+ out_channels=self._channels[256],
158
+ kernel_size=4,
159
+ stride=2,
160
+ padding=1,
161
+ bias=False,
162
+ ),
163
+ nn.LeakyReLU(negative_slope=0.2),
164
+ DownsampleBlockT1(in_channels=self._channels[256], out_channels=self._channels[128]),
165
+ DownsampleBlockT1(in_channels=self._channels[128], out_channels=self._channels[64] ),
166
+ DownsampleBlockT1(in_channels=self._channels[64], out_channels=self._channels[32] ),
167
+ )
168
+
169
+ self._features_large = nn.Sequential(
170
+ SpectralConv2d(
171
+ in_channels=self._channels[16] ,
172
+ out_channels=self._channels[8],
173
+ kernel_size=1,
174
+ stride=1,
175
+ padding=0,
176
+ bias=False,
177
+ ),
178
+ nn.BatchNorm2d(num_features=self._channels[8]),
179
+ nn.LeakyReLU(negative_slope=0.2),
180
+ SpectralConv2d(
181
+ in_channels=self._channels[8],
182
+ out_channels=1,
183
+ kernel_size=4,
184
+ stride=1,
185
+ padding=0,
186
+ bias=False,
187
+ )
188
+ )
189
+
190
+ self._features_small = nn.Sequential(
191
+ SpectralConv2d(
192
+ in_channels=self._channels[32],
193
+ out_channels=1,
194
+ kernel_size=4,
195
+ stride=1,
196
+ padding=0,
197
+ bias=False,
198
+ ),
199
+ )
200
+
201
+ self._decoder_large = Decoder(in_channels=self._channels[16], out_channels=3)
202
+ self._decoder_small = Decoder(in_channels=self._channels[32], out_channels=3)
203
+ self._decoder_piece = Decoder(in_channels=self._channels[32], out_channels=3)
204
+
205
+ def forward(self, images_1024: torch.Tensor,
206
+ images_128: torch.Tensor,
207
+ image_type: ImageType) -> \
208
+ Union[
209
+ torch.Tensor,
210
+ Tuple[torch.Tensor, Tuple[Any, Any, Any]]
211
+ ]:
212
+ # large track
213
+
214
+ down_512 = self._init(images_1024)
215
+ down_256 = self._downsample_256(down_512)
216
+ down_128 = self._downsample_128(down_256)
217
+
218
+ down_64 = self._downsample_64(down_128)
219
+ down_64 = self._sle_64(down_512, down_64)
220
+
221
+ down_32 = self._downsample_32(down_64)
222
+ down_32 = self._sle_32(down_256, down_32)
223
+
224
+ down_16 = self._downsample_16(down_32)
225
+ down_16 = self._sle_16(down_128, down_16)
226
+
227
+ # small track
228
+
229
+ down_small = self._small_track(images_128)
230
+
231
+ # features
232
+
233
+ features_large = self._features_large(down_16).view(-1)
234
+ features_small = self._features_small(down_small).view(-1)
235
+ features = torch.cat([features_large, features_small], dim=0)
236
+
237
+ # decoder
238
+
239
+ if image_type != ImageType.FAKE:
240
+ dec_large = self._decoder_large(down_16)
241
+ dec_small = self._decoder_small(down_small)
242
+ dec_piece = self._decoder_piece(crop_image_part(down_32, image_type))
243
+ return features, (dec_large, dec_small, dec_piece)
244
+
245
+ return features