michaelriedl commited on
Commit
8f71eda
1 Parent(s): 4fe7b39

Initial dump

Browse files
Files changed (5) hide show
  1. LightweightGANConfig.py +31 -0
  2. LightweightGANModel.py +29 -0
  3. config.json +24 -0
  4. deploy.py +385 -0
  5. pytorch_model.bin +3 -0
LightweightGANConfig.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class LightweightGANConfig(PretrainedConfig):
5
+ model_type = "lightweight-gan"
6
+
7
+ def __init__(
8
+ self,
9
+ image_size=64,
10
+ latent_dim=256,
11
+ fmap_max=512,
12
+ fmap_inverse_coef=12,
13
+ transparent=False,
14
+ greyscale=False,
15
+ attn_res_layers=[32],
16
+ freq_chan_attn=False,
17
+ syncbatchnorm=False,
18
+ antialias=False,
19
+ **kwargs,
20
+ ):
21
+ self.image_size = image_size
22
+ self.latent_dim = latent_dim
23
+ self.fmap_max = fmap_max
24
+ self.fmap_inverse_coef = fmap_inverse_coef
25
+ self.transparent = transparent
26
+ self.greyscale = greyscale
27
+ self.attn_res_layers = attn_res_layers
28
+ self.freq_chan_attn = freq_chan_attn
29
+ self.syncbatchnorm = syncbatchnorm
30
+ self.antialias = antialias
31
+ super().__init__(**kwargs)
LightweightGANModel.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import PreTrainedModel
3
+ from .LightweightGANConfig import LightweightGANConfig
4
+ from .deploy import Generator
5
+
6
+
7
+ class LightweightGANModel(PreTrainedModel):
8
+ config_class = LightweightGANConfig
9
+
10
+ def __init__(self, config):
11
+ super().__init__(config)
12
+ self.model = Generator(
13
+ image_size=config.image_size,
14
+ latent_dim=config.latent_dim,
15
+ fmap_max=config.fmap_max,
16
+ fmap_inverse_coef=config.fmap_inverse_coef,
17
+ transparent=config.transparent,
18
+ greyscale=config.greyscale,
19
+ attn_res_layers=config.attn_res_layers,
20
+ freq_chan_attn=config.freq_chan_attn,
21
+ syncbatchnorm=config.syncbatchnorm,
22
+ antialias=config.antialias,
23
+ )
24
+
25
+ def forward(self, tensor):
26
+ return self.model(tensor)
27
+
28
+ def load_params(self, pt_file):
29
+ self.model.load_state_dict(torch.load(pt_file))
config.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "antialias": false,
3
+ "architectures": [
4
+ "LightweightGANModel"
5
+ ],
6
+ "attn_res_layers": [
7
+ 32
8
+ ],
9
+ "auto_map": {
10
+ "AutoConfig": "LightweightGANConfig.LightweightGANConfig",
11
+ "AutoModel": "LightweightGANModel.LightweightGANModel"
12
+ },
13
+ "fmap_inverse_coef": 12,
14
+ "fmap_max": 512,
15
+ "freq_chan_attn": false,
16
+ "greyscale": false,
17
+ "image_size": 256,
18
+ "latent_dim": 256,
19
+ "model_type": "lightweight-gan",
20
+ "syncbatchnorm": false,
21
+ "torch_dtype": "float32",
22
+ "transformers_version": "4.31.0",
23
+ "transparent": true
24
+ }
deploy.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from math import log2
5
+ from torch import nn, einsum
6
+ from kornia.filters import filter2d
7
+ from einops import reduce, rearrange, repeat
8
+
9
+
10
+ def exists(val):
11
+ return val is not None
12
+
13
+
14
+ def is_power_of_two(val):
15
+ return log2(val).is_integer()
16
+
17
+
18
+ def default(val, d):
19
+ return val if exists(val) else d
20
+
21
+
22
+ def get_1d_dct(i, freq, L):
23
+ result = math.cos(math.pi * freq * (i + 0.5) / L) / math.sqrt(L)
24
+ return result * (1 if freq == 0 else math.sqrt(2))
25
+
26
+
27
+ def get_dct_weights(width, channel, fidx_u, fidx_v):
28
+ dct_weights = torch.zeros(1, channel, width, width)
29
+ c_part = channel // len(fidx_u)
30
+
31
+ for i, (u_x, v_y) in enumerate(zip(fidx_u, fidx_v)):
32
+ for x in range(width):
33
+ for y in range(width):
34
+ coor_value = get_1d_dct(x, u_x, width) * get_1d_dct(y, v_y, width)
35
+ dct_weights[:, i * c_part : (i + 1) * c_part, x, y] = coor_value
36
+
37
+ return dct_weights
38
+
39
+
40
+ class Blur(nn.Module):
41
+ def __init__(self):
42
+ super().__init__()
43
+ f = torch.Tensor([1, 2, 1])
44
+ self.register_buffer("f", f)
45
+
46
+ def forward(self, x):
47
+ f = self.f
48
+ f = f[None, None, :] * f[None, :, None]
49
+ return filter2d(x, f, normalized=True)
50
+
51
+
52
+ class ChanNorm(nn.Module):
53
+ def __init__(self, dim, eps=1e-5):
54
+ super().__init__()
55
+ self.eps = eps
56
+ self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
57
+ self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
58
+
59
+ def forward(self, x):
60
+ var = torch.var(x, dim=1, unbiased=False, keepdim=True)
61
+ mean = torch.mean(x, dim=1, keepdim=True)
62
+ return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
63
+
64
+
65
+ def Conv2dSame(dim_in, dim_out, kernel_size, bias=True):
66
+ pad_left = kernel_size // 2
67
+ pad_right = (pad_left - 1) if (kernel_size % 2) == 0 else pad_left
68
+
69
+ return nn.Sequential(
70
+ nn.ZeroPad2d((pad_left, pad_right, pad_left, pad_right)),
71
+ nn.Conv2d(dim_in, dim_out, kernel_size, bias=bias),
72
+ )
73
+
74
+
75
+ class DepthWiseConv2d(nn.Module):
76
+ def __init__(self, dim_in, dim_out, kernel_size, padding=0, stride=1, bias=True):
77
+ super().__init__()
78
+ self.net = nn.Sequential(
79
+ nn.Conv2d(
80
+ dim_in,
81
+ dim_in,
82
+ kernel_size=kernel_size,
83
+ padding=padding,
84
+ groups=dim_in,
85
+ stride=stride,
86
+ bias=bias,
87
+ ),
88
+ nn.Conv2d(dim_in, dim_out, kernel_size=1, bias=bias),
89
+ )
90
+
91
+ def forward(self, x):
92
+ return self.net(x)
93
+
94
+
95
+ class FCANet(nn.Module):
96
+ def __init__(self, *, chan_in, chan_out, reduction=4, width):
97
+ super().__init__()
98
+
99
+ freq_w, freq_h = ([0] * 8), list(
100
+ range(8)
101
+ ) # in paper, it seems 16 frequencies was ideal
102
+ dct_weights = get_dct_weights(
103
+ width, chan_in, [*freq_w, *freq_h], [*freq_h, *freq_w]
104
+ )
105
+ self.register_buffer("dct_weights", dct_weights)
106
+
107
+ chan_intermediate = max(3, chan_out // reduction)
108
+
109
+ self.net = nn.Sequential(
110
+ nn.Conv2d(chan_in, chan_intermediate, 1),
111
+ nn.LeakyReLU(0.1),
112
+ nn.Conv2d(chan_intermediate, chan_out, 1),
113
+ nn.Sigmoid(),
114
+ )
115
+
116
+ def forward(self, x):
117
+ x = reduce(
118
+ x * self.dct_weights, "b c (h h1) (w w1) -> b c h1 w1", "sum", h1=1, w1=1
119
+ )
120
+ return self.net(x)
121
+
122
+
123
+ class Generator(nn.Module):
124
+ def __init__(
125
+ self,
126
+ *,
127
+ image_size,
128
+ latent_dim=256,
129
+ fmap_max=512,
130
+ fmap_inverse_coef=12,
131
+ transparent=False,
132
+ greyscale=False,
133
+ attn_res_layers=[],
134
+ freq_chan_attn=False,
135
+ syncbatchnorm=False,
136
+ antialias=False,
137
+ ):
138
+ super().__init__()
139
+ resolution = log2(image_size)
140
+ assert is_power_of_two(image_size), "image size must be a power of 2"
141
+
142
+ # Set the normalization and blur
143
+ norm_class = nn.SyncBatchNorm if syncbatchnorm else nn.BatchNorm2d
144
+ Blur = nn.Identity if not antialias else Blur
145
+
146
+ if transparent:
147
+ init_channel = 4
148
+ elif greyscale:
149
+ init_channel = 1
150
+ else:
151
+ init_channel = 3
152
+
153
+ self.latent_dim = latent_dim
154
+
155
+ fmap_max = default(fmap_max, latent_dim)
156
+
157
+ self.initial_conv = nn.Sequential(
158
+ nn.ConvTranspose2d(latent_dim, latent_dim * 2, 4),
159
+ norm_class(latent_dim * 2),
160
+ nn.GLU(dim=1),
161
+ )
162
+
163
+ num_layers = int(resolution) - 2
164
+ features = list(
165
+ map(lambda n: (n, 2 ** (fmap_inverse_coef - n)), range(2, num_layers + 2))
166
+ )
167
+ features = list(map(lambda n: (n[0], min(n[1], fmap_max)), features))
168
+ features = list(map(lambda n: 3 if n[0] >= 8 else n[1], features))
169
+ features = [latent_dim, *features]
170
+
171
+ in_out_features = list(zip(features[:-1], features[1:]))
172
+
173
+ self.res_layers = range(2, num_layers + 2)
174
+ self.layers = nn.ModuleList([])
175
+ self.res_to_feature_map = dict(zip(self.res_layers, in_out_features))
176
+
177
+ self.sle_map = ((3, 7), (4, 8), (5, 9), (6, 10))
178
+ self.sle_map = list(
179
+ filter(lambda t: t[0] <= resolution and t[1] <= resolution, self.sle_map)
180
+ )
181
+ self.sle_map = dict(self.sle_map)
182
+
183
+ self.num_layers_spatial_res = 1
184
+
185
+ for res, (chan_in, chan_out) in zip(self.res_layers, in_out_features):
186
+ image_width = 2**res
187
+
188
+ attn = None
189
+ if image_width in attn_res_layers:
190
+ attn = PreNorm(chan_in, LinearAttention(chan_in))
191
+
192
+ sle = None
193
+ if res in self.sle_map:
194
+ residual_layer = self.sle_map[res]
195
+ sle_chan_out = self.res_to_feature_map[residual_layer - 1][-1]
196
+
197
+ if freq_chan_attn:
198
+ sle = FCANet(
199
+ chan_in=chan_out, chan_out=sle_chan_out, width=2 ** (res + 1)
200
+ )
201
+ else:
202
+ sle = GlobalContext(chan_in=chan_out, chan_out=sle_chan_out)
203
+
204
+ layer = nn.ModuleList(
205
+ [
206
+ nn.Sequential(
207
+ PixelShuffleUpsample(chan_in),
208
+ Blur(),
209
+ Conv2dSame(chan_in, chan_out * 2, 4),
210
+ Noise(),
211
+ norm_class(chan_out * 2),
212
+ nn.GLU(dim=1),
213
+ ),
214
+ sle,
215
+ attn,
216
+ ]
217
+ )
218
+ self.layers.append(layer)
219
+
220
+ self.out_conv = nn.Conv2d(features[-1], init_channel, 3, padding=1)
221
+
222
+ def forward(self, x):
223
+ x = rearrange(x, "b c -> b c () ()")
224
+ x = self.initial_conv(x)
225
+ x = F.normalize(x, dim=1)
226
+
227
+ residuals = dict()
228
+
229
+ for res, (up, sle, attn) in zip(self.res_layers, self.layers):
230
+ if exists(attn):
231
+ x = attn(x) + x
232
+
233
+ x = up(x)
234
+
235
+ if exists(sle):
236
+ out_res = self.sle_map[res]
237
+ residual = sle(x)
238
+ residuals[out_res] = residual
239
+
240
+ next_res = res + 1
241
+ if next_res in residuals:
242
+ x = x * residuals[next_res]
243
+
244
+ return self.out_conv(x)
245
+
246
+
247
+ class GlobalContext(nn.Module):
248
+ def __init__(self, *, chan_in, chan_out):
249
+ super().__init__()
250
+ self.to_k = nn.Conv2d(chan_in, 1, 1)
251
+ chan_intermediate = max(3, chan_out // 2)
252
+
253
+ self.net = nn.Sequential(
254
+ nn.Conv2d(chan_in, chan_intermediate, 1),
255
+ nn.LeakyReLU(0.1),
256
+ nn.Conv2d(chan_intermediate, chan_out, 1),
257
+ nn.Sigmoid(),
258
+ )
259
+
260
+ def forward(self, x):
261
+ context = self.to_k(x)
262
+ context = context.flatten(2).softmax(dim=-1)
263
+ out = einsum("b i n, b c n -> b c i", context, x.flatten(2))
264
+ out = out.unsqueeze(-1)
265
+ return self.net(out)
266
+
267
+
268
+ class LinearAttention(nn.Module):
269
+ def __init__(self, dim, dim_head=64, heads=8, kernel_size=3):
270
+ super().__init__()
271
+ self.scale = dim_head**-0.5
272
+ self.heads = heads
273
+ self.dim_head = dim_head
274
+ inner_dim = dim_head * heads
275
+
276
+ self.kernel_size = kernel_size
277
+ self.nonlin = nn.GELU()
278
+
279
+ self.to_lin_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
280
+ self.to_lin_kv = DepthWiseConv2d(dim, inner_dim * 2, 3, padding=1, bias=False)
281
+
282
+ self.to_q = nn.Conv2d(dim, inner_dim, 1, bias=False)
283
+ self.to_kv = nn.Conv2d(dim, inner_dim * 2, 1, bias=False)
284
+
285
+ self.to_out = nn.Conv2d(inner_dim * 2, dim, 1)
286
+
287
+ def forward(self, fmap):
288
+ h, x, y = self.heads, *fmap.shape[-2:]
289
+
290
+ # linear attention
291
+
292
+ lin_q, lin_k, lin_v = (
293
+ self.to_lin_q(fmap),
294
+ *self.to_lin_kv(fmap).chunk(2, dim=1),
295
+ )
296
+ lin_q, lin_k, lin_v = map(
297
+ lambda t: rearrange(t, "b (h c) x y -> (b h) (x y) c", h=h),
298
+ (lin_q, lin_k, lin_v),
299
+ )
300
+
301
+ lin_q = lin_q.softmax(dim=-1)
302
+ lin_k = lin_k.softmax(dim=-2)
303
+
304
+ lin_q = lin_q * self.scale
305
+
306
+ context = einsum("b n d, b n e -> b d e", lin_k, lin_v)
307
+ lin_out = einsum("b n d, b d e -> b n e", lin_q, context)
308
+ lin_out = rearrange(lin_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)
309
+
310
+ # conv-like full attention
311
+
312
+ q, k, v = (self.to_q(fmap), *self.to_kv(fmap).chunk(2, dim=1))
313
+ q, k, v = map(
314
+ lambda t: rearrange(t, "b (h c) x y -> (b h) c x y", h=h), (q, k, v)
315
+ )
316
+
317
+ k = F.unfold(k, kernel_size=self.kernel_size, padding=self.kernel_size // 2)
318
+ v = F.unfold(v, kernel_size=self.kernel_size, padding=self.kernel_size // 2)
319
+
320
+ k, v = map(
321
+ lambda t: rearrange(t, "b (d j) n -> b n j d", d=self.dim_head), (k, v)
322
+ )
323
+
324
+ q = rearrange(q, "b c ... -> b (...) c") * self.scale
325
+
326
+ sim = einsum("b i d, b i j d -> b i j", q, k)
327
+ sim = sim - sim.amax(dim=-1, keepdim=True).detach()
328
+
329
+ attn = sim.softmax(dim=-1)
330
+
331
+ full_out = einsum("b i j, b i j d -> b i d", attn, v)
332
+ full_out = rearrange(full_out, "(b h) (x y) d -> b (h d) x y", h=h, x=x, y=y)
333
+
334
+ # add outputs of linear attention + conv like full attention
335
+
336
+ lin_out = self.nonlin(lin_out)
337
+ out = torch.cat((lin_out, full_out), dim=1)
338
+ return self.to_out(out)
339
+
340
+
341
+ class Noise(nn.Module):
342
+ def __init__(self):
343
+ super().__init__()
344
+ self.weight = nn.Parameter(torch.zeros(1))
345
+
346
+ def forward(self, x, noise=None):
347
+ b, _, h, w, device = *x.shape, x.device
348
+
349
+ if not exists(noise):
350
+ noise = torch.randn(b, 1, h, w, device=device)
351
+
352
+ return x + self.weight * noise
353
+
354
+
355
+ class PixelShuffleUpsample(nn.Module):
356
+ def __init__(self, dim, dim_out=None):
357
+ super().__init__()
358
+ dim_out = default(dim_out, dim)
359
+ conv = nn.Conv2d(dim, dim_out * 4, 1)
360
+
361
+ self.net = nn.Sequential(conv, nn.SiLU(), nn.PixelShuffle(2))
362
+
363
+ self.init_conv_(conv)
364
+
365
+ def init_conv_(self, conv):
366
+ o, i, h, w = conv.weight.shape
367
+ conv_weight = torch.empty(o // 4, i, h, w)
368
+ nn.init.kaiming_uniform_(conv_weight)
369
+ conv_weight = repeat(conv_weight, "o ... -> (o 4) ...")
370
+
371
+ conv.weight.data.copy_(conv_weight)
372
+ nn.init.zeros_(conv.bias.data)
373
+
374
+ def forward(self, x):
375
+ return self.net(x)
376
+
377
+
378
+ class PreNorm(nn.Module):
379
+ def __init__(self, dim, fn):
380
+ super().__init__()
381
+ self.fn = fn
382
+ self.norm = ChanNorm(dim)
383
+
384
+ def forward(self, x):
385
+ return self.fn(self.norm(x))
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4cb74d9d864a2aa6256ee59c7c1e8efb6ac2f0c73d61ac987e97a28f212ff09e
3
+ size 96248639