ArantxaCasanova commited on
Commit
a00ee36
1 Parent(s): 03f1c62

First model version

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. BigGAN_PyTorch/BigGAN.py +711 -0
  2. BigGAN_PyTorch/BigGANdeep.py +734 -0
  3. BigGAN_PyTorch/LICENSE +21 -0
  4. BigGAN_PyTorch/README.md +144 -0
  5. BigGAN_PyTorch/TFHub/README.md +14 -0
  6. BigGAN_PyTorch/TFHub/biggan_v1.py +441 -0
  7. BigGAN_PyTorch/TFHub/converter.py +558 -0
  8. BigGAN_PyTorch/animal_hash.py +2652 -0
  9. BigGAN_PyTorch/config_files/COCO_Stuff/BigGAN/unconditional_biggan_res128.json +44 -0
  10. BigGAN_PyTorch/config_files/COCO_Stuff/BigGAN/unconditional_biggan_res256.json +44 -0
  11. BigGAN_PyTorch/config_files/COCO_Stuff/IC-GAN/icgan_res128_ddp.json +51 -0
  12. BigGAN_PyTorch/config_files/COCO_Stuff/IC-GAN/icgan_res256_ddp.json +51 -0
  13. BigGAN_PyTorch/config_files/ImageNet-LT/BigGAN/biggan_res128.json +48 -0
  14. BigGAN_PyTorch/config_files/ImageNet-LT/BigGAN/biggan_res256.json +48 -0
  15. BigGAN_PyTorch/config_files/ImageNet-LT/BigGAN/biggan_res64.json +48 -0
  16. BigGAN_PyTorch/config_files/ImageNet-LT/cc_IC-GAN/cc_icgan_res128.json +56 -0
  17. BigGAN_PyTorch/config_files/ImageNet-LT/cc_IC-GAN/cc_icgan_res256.json +56 -0
  18. BigGAN_PyTorch/config_files/ImageNet-LT/cc_IC-GAN/cc_icgan_res64.json +56 -0
  19. BigGAN_PyTorch/config_files/ImageNet/BigGAN/biggan_res128.json +40 -0
  20. BigGAN_PyTorch/config_files/ImageNet/BigGAN/biggan_res256_half_cap.json +40 -0
  21. BigGAN_PyTorch/config_files/ImageNet/BigGAN/biggan_res64.json +40 -0
  22. BigGAN_PyTorch/config_files/ImageNet/IC-GAN/icgan_res128.json +48 -0
  23. BigGAN_PyTorch/config_files/ImageNet/IC-GAN/icgan_res256.json +47 -0
  24. BigGAN_PyTorch/config_files/ImageNet/IC-GAN/icgan_res256_halfcap.json +47 -0
  25. BigGAN_PyTorch/config_files/ImageNet/IC-GAN/icgan_res64.json +48 -0
  26. BigGAN_PyTorch/config_files/ImageNet/cc_IC-GAN/cc_icgan_res128.json +48 -0
  27. BigGAN_PyTorch/config_files/ImageNet/cc_IC-GAN/cc_icgan_res256.json +47 -0
  28. BigGAN_PyTorch/config_files/ImageNet/cc_IC-GAN/cc_icgan_res256_halfcap.json +48 -0
  29. BigGAN_PyTorch/config_files/ImageNet/cc_IC-GAN/cc_icgan_res64.json +48 -0
  30. BigGAN_PyTorch/diffaugment_utils.py +119 -0
  31. BigGAN_PyTorch/imagenet_lt/ImageNet_LT_train.txt +0 -0
  32. BigGAN_PyTorch/imagenet_lt/ImageNet_LT_val.txt +0 -0
  33. BigGAN_PyTorch/imgs/D Singular Values.png +0 -0
  34. BigGAN_PyTorch/imgs/DeepSamples.png +0 -0
  35. BigGAN_PyTorch/imgs/DogBall.png +0 -0
  36. BigGAN_PyTorch/imgs/G Singular Values.png +0 -0
  37. BigGAN_PyTorch/imgs/IS_FID.png +0 -0
  38. BigGAN_PyTorch/imgs/Losses.png +0 -0
  39. BigGAN_PyTorch/imgs/header_image.jpg +0 -0
  40. BigGAN_PyTorch/imgs/interp_sample.jpg +0 -0
  41. BigGAN_PyTorch/layers.py +616 -0
  42. BigGAN_PyTorch/logs/BigGAN_ch96_bs256x8.jsonl +68 -0
  43. BigGAN_PyTorch/logs/compare_IS.m +97 -0
  44. BigGAN_PyTorch/logs/metalog.txt +3 -0
  45. BigGAN_PyTorch/logs/process_inception_log.m +27 -0
  46. BigGAN_PyTorch/logs/process_training.m +117 -0
  47. BigGAN_PyTorch/losses.py +43 -0
  48. BigGAN_PyTorch/make_hdf5.py +193 -0
  49. BigGAN_PyTorch/run.py +75 -0
  50. BigGAN_PyTorch/scripts/launch_BigGAN_bs256x8.sh +26 -0
BigGAN_PyTorch/BigGAN.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # All contributions by Andy Brock:
5
+ # Copyright (c) 2019 Andy Brock
6
+ #
7
+ # MIT License
8
+
9
+ import numpy as np
10
+ import math
11
+ import functools
12
+ import os
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ from torch.nn import init
17
+ import torch.optim as optim
18
+ import torch.nn.functional as F
19
+
20
+ # from torch.nn import Parameter as P
21
+ import sys
22
+
23
+ sys.path.insert(1, os.path.join(sys.path[0], ".."))
24
+ import BigGAN_PyTorch.layers as layers
25
+
26
+ # from sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d
27
+ from BigGAN_PyTorch.diffaugment_utils import DiffAugment
28
+
29
+ # Architectures for G
30
+ # Attention is passed in in the format '32_64' to mean applying an attention
31
+ # block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
32
+ def G_arch(ch=64, attention="64", ksize="333333", dilation="111111"):
33
+ arch = {}
34
+ arch[512] = {
35
+ "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2, 1]],
36
+ "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1, 1]],
37
+ "upsample": [True] * 7,
38
+ "resolution": [8, 16, 32, 64, 128, 256, 512],
39
+ "attention": {
40
+ 2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
41
+ for i in range(3, 10)
42
+ },
43
+ }
44
+ arch[256] = {
45
+ "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2]],
46
+ "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1]],
47
+ "upsample": [True] * 6,
48
+ "resolution": [8, 16, 32, 64, 128, 256],
49
+ "attention": {
50
+ 2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
51
+ for i in range(3, 9)
52
+ },
53
+ }
54
+ arch[128] = {
55
+ "in_channels": [ch * item for item in [16, 16, 8, 4, 2]],
56
+ "out_channels": [ch * item for item in [16, 8, 4, 2, 1]],
57
+ "upsample": [True] * 5,
58
+ "resolution": [8, 16, 32, 64, 128],
59
+ "attention": {
60
+ 2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
61
+ for i in range(3, 8)
62
+ },
63
+ }
64
+ arch[64] = {
65
+ "in_channels": [ch * item for item in [16, 16, 8, 4]],
66
+ "out_channels": [ch * item for item in [16, 8, 4, 2]],
67
+ "upsample": [True] * 4,
68
+ "resolution": [8, 16, 32, 64],
69
+ "attention": {
70
+ 2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
71
+ for i in range(3, 7)
72
+ },
73
+ }
74
+ arch[32] = {
75
+ "in_channels": [ch * item for item in [4, 4, 4]],
76
+ "out_channels": [ch * item for item in [4, 4, 4]],
77
+ "upsample": [True] * 3,
78
+ "resolution": [8, 16, 32],
79
+ "attention": {
80
+ 2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
81
+ for i in range(3, 6)
82
+ },
83
+ }
84
+
85
+ return arch
86
+
87
+
88
+ class Generator(nn.Module):
89
+ def __init__(
90
+ self,
91
+ G_ch=64,
92
+ dim_z=128,
93
+ bottom_width=4,
94
+ resolution=128,
95
+ G_kernel_size=3,
96
+ G_attn="64",
97
+ n_classes=1000,
98
+ num_G_SVs=1,
99
+ num_G_SV_itrs=1,
100
+ G_shared=True,
101
+ shared_dim=0,
102
+ hier=False,
103
+ cross_replica=False,
104
+ mybn=False,
105
+ G_activation=nn.ReLU(inplace=False),
106
+ G_lr=5e-5,
107
+ G_B1=0.0,
108
+ G_B2=0.999,
109
+ adam_eps=1e-8,
110
+ BN_eps=1e-5,
111
+ SN_eps=1e-12,
112
+ G_mixed_precision=False,
113
+ G_fp16=False,
114
+ G_init="ortho",
115
+ skip_init=False,
116
+ no_optim=False,
117
+ G_param="SN",
118
+ norm_style="bn",
119
+ class_cond=True,
120
+ embedded_optimizer=True,
121
+ instance_cond=False,
122
+ G_shared_feat=True,
123
+ shared_dim_feat=2048,
124
+ **kwargs
125
+ ):
126
+ super(Generator, self).__init__()
127
+ # Channel width mulitplier
128
+ self.ch = G_ch
129
+ # Dimensionality of the latent space
130
+ self.dim_z = dim_z
131
+ # The initial spatial dimensions
132
+ self.bottom_width = bottom_width
133
+ # Resolution of the output
134
+ self.resolution = resolution
135
+ # Kernel size?
136
+ self.kernel_size = G_kernel_size
137
+ # Attention?
138
+ self.attention = G_attn
139
+ # number of classes, for use in categorical conditional generation
140
+ self.n_classes = n_classes
141
+ # Use shared embeddings?
142
+ self.G_shared = G_shared
143
+ # Dimensionality of the shared embedding? Unused if not using G_shared
144
+ self.shared_dim = shared_dim if shared_dim > 0 else dim_z
145
+ # Hierarchical latent space?
146
+ self.hier = hier
147
+ # Cross replica batchnorm?
148
+ self.cross_replica = cross_replica
149
+ # Use my batchnorm?
150
+ self.mybn = mybn
151
+ # nonlinearity for residual blocks
152
+ self.activation = G_activation
153
+ # Initialization style
154
+ self.init = G_init
155
+ # Parameterization style
156
+ self.G_param = G_param
157
+ # Normalization style
158
+ self.norm_style = norm_style
159
+ # Epsilon for BatchNorm?
160
+ self.BN_eps = BN_eps
161
+ # Epsilon for Spectral Norm?
162
+ self.SN_eps = SN_eps
163
+ # fp16?
164
+ self.fp16 = G_fp16
165
+ # Use embeddings for instance features?
166
+ self.G_shared_feat = G_shared_feat
167
+ self.shared_dim_feat = shared_dim_feat
168
+ # Architecture dict
169
+ self.arch = G_arch(self.ch, self.attention)[resolution]
170
+
171
+ # If using hierarchical latents, adjust z
172
+ if self.hier:
173
+ # Number of places z slots into
174
+ self.num_slots = len(self.arch["in_channels"]) + 1
175
+ self.z_chunk_size = self.dim_z // self.num_slots
176
+ # Recalculate latent dimensionality for even splitting into chunks
177
+ self.dim_z = self.z_chunk_size * self.num_slots
178
+ else:
179
+ self.num_slots = 1
180
+ self.z_chunk_size = 0
181
+
182
+ # Which convs, batchnorms, and linear layers to use
183
+ if self.G_param == "SN":
184
+ self.which_conv = functools.partial(
185
+ layers.SNConv2d,
186
+ kernel_size=3,
187
+ padding=1,
188
+ num_svs=num_G_SVs,
189
+ num_itrs=num_G_SV_itrs,
190
+ eps=self.SN_eps,
191
+ )
192
+ self.which_linear = functools.partial(
193
+ layers.SNLinear,
194
+ num_svs=num_G_SVs,
195
+ num_itrs=num_G_SV_itrs,
196
+ eps=self.SN_eps,
197
+ )
198
+ else:
199
+ self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
200
+ self.which_linear = nn.Linear
201
+
202
+ # We use a non-spectral-normed embedding here regardless;
203
+ # For some reason applying SN to G's embedding seems to randomly cripple G
204
+ self.which_embedding = nn.Embedding
205
+ bn_linear = (
206
+ functools.partial(self.which_linear, bias=False)
207
+ if self.G_shared
208
+ else self.which_embedding
209
+ )
210
+ if not class_cond and not instance_cond:
211
+ input_sz_bn = self.n_classes
212
+ else:
213
+ input_sz_bn = self.z_chunk_size
214
+ if class_cond:
215
+ input_sz_bn += self.shared_dim
216
+ if instance_cond:
217
+ input_sz_bn += self.shared_dim_feat
218
+ self.which_bn = functools.partial(
219
+ layers.ccbn,
220
+ which_linear=bn_linear,
221
+ cross_replica=self.cross_replica,
222
+ mybn=self.mybn,
223
+ input_size=input_sz_bn,
224
+ norm_style=self.norm_style,
225
+ eps=self.BN_eps,
226
+ )
227
+
228
+ # Prepare model
229
+ # If not using shared embeddings, self.shared is just a passthrough
230
+ self.shared = (
231
+ self.which_embedding(n_classes, self.shared_dim)
232
+ if G_shared
233
+ else layers.identity()
234
+ )
235
+ self.shared_feat = (
236
+ self.which_linear(2048, self.shared_dim_feat)
237
+ if G_shared_feat
238
+ else layers.identity()
239
+ )
240
+ # First linear layer
241
+ self.linear = self.which_linear(
242
+ self.dim_z // self.num_slots,
243
+ self.arch["in_channels"][0] * (self.bottom_width ** 2),
244
+ )
245
+
246
+ # self.blocks is a doubly-nested list of modules, the outer loop intended
247
+ # to be over blocks at a given resolution (resblocks and/or self-attention)
248
+ # while the inner loop is over a given block
249
+ self.blocks = []
250
+ for index in range(len(self.arch["out_channels"])):
251
+ self.blocks += [
252
+ [
253
+ layers.GBlock(
254
+ in_channels=self.arch["in_channels"][index],
255
+ out_channels=self.arch["out_channels"][index],
256
+ which_conv=self.which_conv,
257
+ which_bn=self.which_bn,
258
+ activation=self.activation,
259
+ upsample=(
260
+ functools.partial(F.interpolate, scale_factor=2)
261
+ if self.arch["upsample"][index]
262
+ else None
263
+ ),
264
+ )
265
+ ]
266
+ ]
267
+
268
+ # If attention on this block, attach it to the end
269
+ if self.arch["attention"][self.arch["resolution"][index]]:
270
+ print(
271
+ "Adding attention layer in G at resolution %d"
272
+ % self.arch["resolution"][index]
273
+ )
274
+ self.blocks[-1] += [
275
+ layers.Attention(self.arch["out_channels"][index], self.which_conv)
276
+ ]
277
+
278
+ # Turn self.blocks into a ModuleList so that it's all properly registered.
279
+ self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
280
+
281
+ # output layer: batchnorm-relu-conv.
282
+ # Consider using a non-spectral conv here
283
+ self.output_layer = nn.Sequential(
284
+ layers.bn(
285
+ self.arch["out_channels"][-1],
286
+ cross_replica=self.cross_replica,
287
+ mybn=self.mybn,
288
+ ),
289
+ self.activation,
290
+ self.which_conv(self.arch["out_channels"][-1], 3),
291
+ )
292
+
293
+ # Initialize weights. Optionally skip init for testing.
294
+ if not skip_init:
295
+ self.init_weights()
296
+
297
+ # Set up optimizer
298
+ # If this is an EMA copy, no need for an optim, so just return now
299
+ if no_optim or not embedded_optimizer:
300
+ return
301
+ self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
302
+ if G_mixed_precision:
303
+ print("Using fp16 adam in G...")
304
+ import utils
305
+
306
+ self.optim = utils.Adam16(
307
+ params=self.parameters(),
308
+ lr=self.lr,
309
+ betas=(self.B1, self.B2),
310
+ weight_decay=0,
311
+ eps=self.adam_eps,
312
+ )
313
+ else:
314
+ self.optim = optim.Adam(
315
+ params=self.parameters(),
316
+ lr=self.lr,
317
+ betas=(self.B1, self.B2),
318
+ weight_decay=0,
319
+ eps=self.adam_eps,
320
+ )
321
+
322
+ # LR scheduling, left here for forward compatibility
323
+ # self.lr_sched = {'itr' : 0}# if self.progressive else {}
324
+ # self.j = 0
325
+
326
+ # Initialize
327
+ def init_weights(self):
328
+ self.param_count = 0
329
+ for module in self.modules():
330
+ if (
331
+ isinstance(module, nn.Conv2d)
332
+ or isinstance(module, nn.Linear)
333
+ or isinstance(module, nn.Embedding)
334
+ ):
335
+ if self.init == "ortho":
336
+ init.orthogonal_(module.weight)
337
+ elif self.init == "N02":
338
+ init.normal_(module.weight, 0, 0.02)
339
+ elif self.init in ["glorot", "xavier"]:
340
+ init.xavier_uniform_(module.weight)
341
+ else:
342
+ print("Init style not recognized...")
343
+ self.param_count += sum(
344
+ [p.data.nelement() for p in module.parameters()]
345
+ )
346
+ print("Param count for G" "s initialized parameters: %d" % self.param_count)
347
+
348
+ # Get conditionings
349
+
350
+ def get_condition_embeddings(self, cl=None, feat=None):
351
+ c_embed = []
352
+ if cl is not None:
353
+ c_embed.append(self.shared(cl))
354
+ if feat is not None:
355
+ c_embed.append(self.shared_feat(feat))
356
+ if len(c_embed) > 0:
357
+ c_embed = torch.cat(c_embed, dim=-1)
358
+ return c_embed
359
+
360
+ # Note on this forward function: we pass in a y vector which has
361
+ # already been passed through G.shared to enable easy class-wise
362
+ # interpolation later. If we passed in the one-hot and then ran it through
363
+ # G.shared in this forward function, it would be harder to handle.
364
+ def forward(self, z, label=None, feats=None):
365
+ y = self.get_condition_embeddings(label, feats)
366
+ # If hierarchical, concatenate zs and ys
367
+ if self.hier:
368
+ zs = torch.split(z, self.z_chunk_size, 1)
369
+ z = zs[0]
370
+ ys = [torch.cat([y, item], 1) for item in zs[1:]]
371
+ else:
372
+ ys = [y] * len(self.blocks)
373
+
374
+ # First linear layer
375
+ h = self.linear(z)
376
+ # Reshape
377
+ h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
378
+
379
+ # Loop over blocks
380
+ for index, blocklist in enumerate(self.blocks):
381
+ # Second inner loop in case block has multiple layers
382
+ for block in blocklist:
383
+ h = block(h, ys[index])
384
+
385
+ # Apply batchnorm-relu-conv-tanh at output
386
+ return torch.tanh(self.output_layer(h))
387
+
388
+
389
+ # Discriminator architecture, same paradigm as G's above
390
+ def D_arch(ch=64, attention="64", ksize="333333", dilation="111111"):
391
+ arch = {}
392
+ arch[256] = {
393
+ "in_channels": [3] + [ch * item for item in [1, 2, 4, 8, 8, 16]],
394
+ "out_channels": [item * ch for item in [1, 2, 4, 8, 8, 16, 16]],
395
+ "downsample": [True] * 6 + [False],
396
+ "resolution": [128, 64, 32, 16, 8, 4, 4],
397
+ "attention": {
398
+ 2 ** i: 2 ** i in [int(item) for item in attention.split("_")]
399
+ for i in range(2, 8)
400
+ },
401
+ }
402
+ arch[128] = {
403
+ "in_channels": [3] + [ch * item for item in [1, 2, 4, 8, 16]],
404
+ "out_channels": [item * ch for item in [1, 2, 4, 8, 16, 16]],
405
+ "downsample": [True] * 5 + [False],
406
+ "resolution": [64, 32, 16, 8, 4, 4],
407
+ "attention": {
408
+ 2 ** i: 2 ** i in [int(item) for item in attention.split("_")]
409
+ for i in range(2, 8)
410
+ },
411
+ }
412
+ arch[64] = {
413
+ "in_channels": [3] + [ch * item for item in [1, 2, 4, 8]],
414
+ "out_channels": [item * ch for item in [1, 2, 4, 8, 16]],
415
+ "downsample": [True] * 4 + [False],
416
+ "resolution": [32, 16, 8, 4, 4],
417
+ "attention": {
418
+ 2 ** i: 2 ** i in [int(item) for item in attention.split("_")]
419
+ for i in range(2, 7)
420
+ },
421
+ }
422
+ arch[32] = {
423
+ "in_channels": [3] + [item * ch for item in [4, 4, 4]],
424
+ "out_channels": [item * ch for item in [4, 4, 4, 4]],
425
+ "downsample": [True, True, False, False],
426
+ "resolution": [16, 16, 16, 16],
427
+ "attention": {
428
+ 2 ** i: 2 ** i in [int(item) for item in attention.split("_")]
429
+ for i in range(2, 6)
430
+ },
431
+ }
432
+ return arch
433
+
434
+
435
+ class Discriminator(nn.Module):
436
+ def __init__(
437
+ self,
438
+ D_ch=64,
439
+ D_wide=True,
440
+ resolution=128,
441
+ D_kernel_size=3,
442
+ D_attn="64",
443
+ n_classes=1000,
444
+ num_D_SVs=1,
445
+ num_D_SV_itrs=1,
446
+ D_activation=nn.ReLU(inplace=False),
447
+ D_lr=2e-4,
448
+ D_B1=0.0,
449
+ D_B2=0.999,
450
+ adam_eps=1e-8,
451
+ SN_eps=1e-12,
452
+ output_dim=1,
453
+ D_mixed_precision=False,
454
+ D_fp16=False,
455
+ D_init="ortho",
456
+ skip_init=False,
457
+ D_param="SN",
458
+ class_cond=True,
459
+ embedded_optimizer=True,
460
+ instance_cond=False,
461
+ instance_sz=2048,
462
+ **kwargs
463
+ ):
464
+ super(Discriminator, self).__init__()
465
+ # Width multiplier
466
+ self.ch = D_ch
467
+ # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
468
+ self.D_wide = D_wide
469
+ # Resolution
470
+ self.resolution = resolution
471
+ # Kernel size
472
+ self.kernel_size = D_kernel_size
473
+ # Attention?
474
+ self.attention = D_attn
475
+ # Number of classes
476
+ self.n_classes = n_classes
477
+ # Activation
478
+ self.activation = D_activation
479
+ # Initialization style
480
+ self.init = D_init
481
+ # Parameterization style
482
+ self.D_param = D_param
483
+ # Epsilon for Spectral Norm?
484
+ self.SN_eps = SN_eps
485
+ # Fp16?
486
+ self.fp16 = D_fp16
487
+ # Architecture
488
+ self.arch = D_arch(self.ch, self.attention)[resolution]
489
+
490
+ # Which convs, batchnorms, and linear layers to use
491
+ # No option to turn off SN in D right now
492
+ if self.D_param == "SN":
493
+ self.which_conv = functools.partial(
494
+ layers.SNConv2d,
495
+ kernel_size=3,
496
+ padding=1,
497
+ num_svs=num_D_SVs,
498
+ num_itrs=num_D_SV_itrs,
499
+ eps=self.SN_eps,
500
+ )
501
+ self.which_linear = functools.partial(
502
+ layers.SNLinear,
503
+ num_svs=num_D_SVs,
504
+ num_itrs=num_D_SV_itrs,
505
+ eps=self.SN_eps,
506
+ )
507
+ self.which_embedding = functools.partial(
508
+ layers.SNEmbedding,
509
+ num_svs=num_D_SVs,
510
+ num_itrs=num_D_SV_itrs,
511
+ eps=self.SN_eps,
512
+ )
513
+ # Prepare model
514
+ # self.blocks is a doubly-nested list of modules, the outer loop intended
515
+ # to be over blocks at a given resolution (resblocks and/or self-attention)
516
+ self.blocks = []
517
+ for index in range(len(self.arch["out_channels"])):
518
+ self.blocks += [
519
+ [
520
+ layers.DBlock(
521
+ in_channels=self.arch["in_channels"][index],
522
+ out_channels=self.arch["out_channels"][index],
523
+ which_conv=self.which_conv,
524
+ wide=self.D_wide,
525
+ activation=self.activation,
526
+ preactivation=(index > 0),
527
+ downsample=(
528
+ nn.AvgPool2d(2) if self.arch["downsample"][index] else None
529
+ ),
530
+ )
531
+ ]
532
+ ]
533
+ # If attention on this block, attach it to the end
534
+ if self.arch["attention"][self.arch["resolution"][index]]:
535
+ print(
536
+ "Adding attention layer in D at resolution %d"
537
+ % self.arch["resolution"][index]
538
+ )
539
+ self.blocks[-1] += [
540
+ layers.Attention(self.arch["out_channels"][index], self.which_conv)
541
+ ]
542
+ # Turn self.blocks into a ModuleList so that it's all properly registered.
543
+ self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
544
+ # Linear output layer. The output dimension is typically 1, but may be
545
+ # larger if we're e.g. turning this into a VAE with an inference output
546
+ self.linear = self.which_linear(self.arch["out_channels"][-1], output_dim)
547
+ # Embedding for projection discrimination
548
+ if class_cond and instance_cond:
549
+ self.linear_feat = self.which_linear(
550
+ instance_sz, self.arch["out_channels"][-1] // 2
551
+ )
552
+ self.embed = self.which_embedding(
553
+ self.n_classes, self.arch["out_channels"][-1] // 2
554
+ )
555
+ elif class_cond:
556
+ # Embedding for projection discrimination
557
+ self.embed = self.which_embedding(
558
+ self.n_classes, self.arch["out_channels"][-1]
559
+ )
560
+ elif instance_cond:
561
+ self.linear_feat = self.which_linear(
562
+ instance_sz, self.arch["out_channels"][-1]
563
+ )
564
+
565
+ # Initialize weights
566
+ if not skip_init:
567
+ self.init_weights()
568
+
569
+ # Set up optimizer
570
+ if embedded_optimizer:
571
+ self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
572
+ if D_mixed_precision:
573
+ print("Using fp16 adam in D...")
574
+ import utils
575
+
576
+ self.optim = utils.Adam16(
577
+ params=self.parameters(),
578
+ lr=self.lr,
579
+ betas=(self.B1, self.B2),
580
+ weight_decay=0,
581
+ eps=self.adam_eps,
582
+ )
583
+ else:
584
+ self.optim = optim.Adam(
585
+ params=self.parameters(),
586
+ lr=self.lr,
587
+ betas=(self.B1, self.B2),
588
+ weight_decay=0,
589
+ eps=self.adam_eps,
590
+ )
591
+ # LR scheduling, left here for forward compatibility
592
+ # self.lr_sched = {'itr' : 0}# if self.progressive else {}
593
+ # self.j = 0
594
+
595
+ # Initialize
596
+ def init_weights(self):
597
+ self.param_count = 0
598
+ for module in self.modules():
599
+ if (
600
+ isinstance(module, nn.Conv2d)
601
+ or isinstance(module, nn.Linear)
602
+ or isinstance(module, nn.Embedding)
603
+ ):
604
+ if self.init == "ortho":
605
+ init.orthogonal_(module.weight)
606
+ elif self.init == "N02":
607
+ init.normal_(module.weight, 0, 0.02)
608
+ elif self.init in ["glorot", "xavier"]:
609
+ init.xavier_uniform_(module.weight)
610
+ else:
611
+ print("Init style not recognized...")
612
+ self.param_count += sum(
613
+ [p.data.nelement() for p in module.parameters()]
614
+ )
615
+ print("Param count for D" "s initialized parameters: %d" % self.param_count)
616
+
617
+ def forward(self, x, y=None, feat=None):
618
+ # Stick x into h for cleaner for loops without flow control
619
+ h = x
620
+ # Loop over blocks
621
+ for index, blocklist in enumerate(self.blocks):
622
+ for block in blocklist:
623
+ h = block(h)
624
+ # Apply global sum pooling as in SN-GAN
625
+ h = torch.sum(self.activation(h), [2, 3])
626
+ # Get initial class-unconditional output
627
+ out = self.linear(h)
628
+ # Condition on both class and instance features
629
+ if y is not None and feat is not None:
630
+ out = out + torch.sum(
631
+ torch.cat([self.embed(y), self.linear_feat(feat)], dim=-1) * h,
632
+ 1,
633
+ keepdim=True,
634
+ )
635
+ # Condition on class only
636
+ elif y is not None:
637
+ # Get projection of final featureset onto class vectors and add to evidence
638
+ out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
639
+ # Condition on instance features only
640
+ elif feat is not None:
641
+ out = out + torch.sum(self.linear_feat(feat) * h, 1, keepdim=True)
642
+ return out
643
+
644
+
645
+ # Parallelized G_D to minimize cross-gpu communication
646
+ # Without this, Generator outputs would get all-gathered and then rebroadcast.
647
+ class G_D(nn.Module):
648
+ def __init__(self, G, D, optimizer_G=None, optimizer_D=None):
649
+ super(G_D, self).__init__()
650
+ self.G = G
651
+ self.D = D
652
+ self.optimizer_G = optimizer_G
653
+ self.optimizer_D = optimizer_D
654
+
655
+ def forward(
656
+ self,
657
+ z,
658
+ gy,
659
+ feats_g=None,
660
+ x=None,
661
+ dy=None,
662
+ feats=None,
663
+ train_G=False,
664
+ return_G_z=False,
665
+ split_D=False,
666
+ policy=False,
667
+ DA=False,
668
+ ):
669
+ # If training G, enable grad tape
670
+ with torch.set_grad_enabled(train_G):
671
+ # Get Generator output given noise
672
+ G_z = self.G(z, gy, feats_g)
673
+ # Cast as necessary
674
+ # if self.G.fp16 and not self.D.fp16:
675
+ # G_z = G_z.float()
676
+ # if self.D.fp16 and not self.G.fp16:
677
+ # G_z = G_z.half()
678
+ # Split_D means to run D once with real data and once with fake,
679
+ # rather than concatenating along the batch dimension.
680
+ if split_D:
681
+ D_fake = self.D(G_z, gy, feats_g)
682
+ if x is not None:
683
+ D_real = self.D(x, dy, feats)
684
+ return D_fake, D_real
685
+ else:
686
+ if return_G_z:
687
+ return D_fake, G_z
688
+ else:
689
+ return D_fake
690
+ # If real data is provided, concatenate it with the Generator's output
691
+ # along the batch dimension for improved efficiency.
692
+ else:
693
+ D_input = torch.cat([G_z, x], 0) if x is not None else G_z
694
+ D_class = torch.cat([gy, dy], 0) if dy is not None else gy
695
+ if feats_g is not None:
696
+ D_feats = (
697
+ torch.cat([feats_g, feats], 0) if feats is not None else feats_g
698
+ )
699
+ else:
700
+ D_feats = None
701
+ if DA:
702
+ D_input = DiffAugment(D_input, policy=policy)
703
+ # Get Discriminator output
704
+ D_out = self.D(D_input, D_class, D_feats)
705
+ if x is not None:
706
+ return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real
707
+ else:
708
+ if return_G_z:
709
+ return D_out, G_z
710
+ else:
711
+ return D_out
BigGAN_PyTorch/BigGANdeep.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+ # All contributions by Andy Brock:
8
+ # Copyright (c) 2019 Andy Brock
9
+ #
10
+ # MIT License
11
+ import numpy as np
12
+ import math
13
+ import functools
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.nn import init
18
+ import torch.optim as optim
19
+ import torch.nn.functional as F
20
+ from torch.nn import Parameter as P
21
+
22
+ import layers
23
+ from sync_batchnorm import SynchronizedBatchNorm2d as SyncBatchNorm2d
24
+
25
+ # BigGAN-deep: uses a different resblock and pattern
26
+
27
+
28
+ # Architectures for G
29
+ # Attention is passed in in the format '32_64' to mean applying an attention
30
+ # block at both resolution 32x32 and 64x64. Just '64' will apply at 64x64.
31
+
32
+ # Channel ratio is the ratio of
33
+ class GBlock(nn.Module):
34
+ def __init__(
35
+ self,
36
+ in_channels,
37
+ out_channels,
38
+ which_conv=nn.Conv2d,
39
+ which_bn=layers.bn,
40
+ activation=None,
41
+ upsample=None,
42
+ channel_ratio=4,
43
+ ):
44
+ super(GBlock, self).__init__()
45
+
46
+ self.in_channels, self.out_channels = in_channels, out_channels
47
+ self.hidden_channels = self.in_channels // channel_ratio
48
+ self.which_conv, self.which_bn = which_conv, which_bn
49
+ self.activation = activation
50
+ # Conv layers
51
+ self.conv1 = self.which_conv(
52
+ self.in_channels, self.hidden_channels, kernel_size=1, padding=0
53
+ )
54
+ self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels)
55
+ self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels)
56
+ self.conv4 = self.which_conv(
57
+ self.hidden_channels, self.out_channels, kernel_size=1, padding=0
58
+ )
59
+ # Batchnorm layers
60
+ self.bn1 = self.which_bn(self.in_channels)
61
+ self.bn2 = self.which_bn(self.hidden_channels)
62
+ self.bn3 = self.which_bn(self.hidden_channels)
63
+ self.bn4 = self.which_bn(self.hidden_channels)
64
+ # upsample layers
65
+ self.upsample = upsample
66
+
67
+ def forward(self, x, y):
68
+ # Project down to channel ratio
69
+ h = self.conv1(self.activation(self.bn1(x, y)))
70
+ # Apply next BN-ReLU
71
+ h = self.activation(self.bn2(h, y))
72
+ # Drop channels in x if necessary
73
+ if self.in_channels != self.out_channels:
74
+ x = x[:, : self.out_channels]
75
+ # Upsample both h and x at this point
76
+ if self.upsample:
77
+ h = self.upsample(h)
78
+ x = self.upsample(x)
79
+ # 3x3 convs
80
+ h = self.conv2(h)
81
+ h = self.conv3(self.activation(self.bn3(h, y)))
82
+ # Final 1x1 conv
83
+ h = self.conv4(self.activation(self.bn4(h, y)))
84
+ return h + x
85
+
86
+
87
+ def G_arch(ch=64, attention="64", ksize="333333", dilation="111111"):
88
+ arch = {}
89
+ arch[256] = {
90
+ "in_channels": [ch * item for item in [16, 16, 8, 8, 4, 2]],
91
+ "out_channels": [ch * item for item in [16, 8, 8, 4, 2, 1]],
92
+ "upsample": [True] * 6,
93
+ "resolution": [8, 16, 32, 64, 128, 256],
94
+ "attention": {
95
+ 2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
96
+ for i in range(3, 9)
97
+ },
98
+ }
99
+ arch[128] = {
100
+ "in_channels": [ch * item for item in [16, 16, 8, 4, 2]],
101
+ "out_channels": [ch * item for item in [16, 8, 4, 2, 1]],
102
+ "upsample": [True] * 5,
103
+ "resolution": [8, 16, 32, 64, 128],
104
+ "attention": {
105
+ 2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
106
+ for i in range(3, 8)
107
+ },
108
+ }
109
+ arch[64] = {
110
+ "in_channels": [ch * item for item in [16, 16, 8, 4]],
111
+ "out_channels": [ch * item for item in [16, 8, 4, 2]],
112
+ "upsample": [True] * 4,
113
+ "resolution": [8, 16, 32, 64],
114
+ "attention": {
115
+ 2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
116
+ for i in range(3, 7)
117
+ },
118
+ }
119
+ arch[32] = {
120
+ "in_channels": [ch * item for item in [4, 4, 4]],
121
+ "out_channels": [ch * item for item in [4, 4, 4]],
122
+ "upsample": [True] * 3,
123
+ "resolution": [8, 16, 32],
124
+ "attention": {
125
+ 2 ** i: (2 ** i in [int(item) for item in attention.split("_")])
126
+ for i in range(3, 6)
127
+ },
128
+ }
129
+
130
+ return arch
131
+
132
+
133
+ class Generator(nn.Module):
134
+ def __init__(
135
+ self,
136
+ G_ch=64,
137
+ G_depth=2,
138
+ dim_z=128,
139
+ bottom_width=4,
140
+ resolution=128,
141
+ G_kernel_size=3,
142
+ G_attn="64",
143
+ n_classes=1000,
144
+ num_G_SVs=1,
145
+ num_G_SV_itrs=1,
146
+ G_shared=True,
147
+ shared_dim=0,
148
+ hier=False,
149
+ cross_replica=False,
150
+ mybn=False,
151
+ G_activation=nn.ReLU(inplace=False),
152
+ G_lr=5e-5,
153
+ G_B1=0.0,
154
+ G_B2=0.999,
155
+ adam_eps=1e-8,
156
+ BN_eps=1e-5,
157
+ SN_eps=1e-12,
158
+ G_mixed_precision=False,
159
+ G_fp16=False,
160
+ G_init="ortho",
161
+ skip_init=False,
162
+ no_optim=False,
163
+ G_param="SN",
164
+ norm_style="bn",
165
+ **kwargs
166
+ ):
167
+ super(Generator, self).__init__()
168
+ # Channel width mulitplier
169
+ self.ch = G_ch
170
+ # Number of resblocks per stage
171
+ self.G_depth = G_depth
172
+ # Dimensionality of the latent space
173
+ self.dim_z = dim_z
174
+ # The initial spatial dimensions
175
+ self.bottom_width = bottom_width
176
+ # Resolution of the output
177
+ self.resolution = resolution
178
+ # Kernel size?
179
+ self.kernel_size = G_kernel_size
180
+ # Attention?
181
+ self.attention = G_attn
182
+ # number of classes, for use in categorical conditional generation
183
+ self.n_classes = n_classes
184
+ # Use shared embeddings?
185
+ self.G_shared = G_shared
186
+ # Dimensionality of the shared embedding? Unused if not using G_shared
187
+ self.shared_dim = shared_dim if shared_dim > 0 else dim_z
188
+ # Hierarchical latent space?
189
+ self.hier = hier
190
+ # Cross replica batchnorm?
191
+ self.cross_replica = cross_replica
192
+ # Use my batchnorm?
193
+ self.mybn = mybn
194
+ # nonlinearity for residual blocks
195
+ self.activation = G_activation
196
+ # Initialization style
197
+ self.init = G_init
198
+ # Parameterization style
199
+ self.G_param = G_param
200
+ # Normalization style
201
+ self.norm_style = norm_style
202
+ # Epsilon for BatchNorm?
203
+ self.BN_eps = BN_eps
204
+ # Epsilon for Spectral Norm?
205
+ self.SN_eps = SN_eps
206
+ # fp16?
207
+ self.fp16 = G_fp16
208
+ # Architecture dict
209
+ self.arch = G_arch(self.ch, self.attention)[resolution]
210
+
211
+ # Which convs, batchnorms, and linear layers to use
212
+ if self.G_param == "SN":
213
+ self.which_conv = functools.partial(
214
+ layers.SNConv2d,
215
+ kernel_size=3,
216
+ padding=1,
217
+ num_svs=num_G_SVs,
218
+ num_itrs=num_G_SV_itrs,
219
+ eps=self.SN_eps,
220
+ )
221
+ self.which_linear = functools.partial(
222
+ layers.SNLinear,
223
+ num_svs=num_G_SVs,
224
+ num_itrs=num_G_SV_itrs,
225
+ eps=self.SN_eps,
226
+ )
227
+ else:
228
+ self.which_conv = functools.partial(nn.Conv2d, kernel_size=3, padding=1)
229
+ self.which_linear = nn.Linear
230
+
231
+ # We use a non-spectral-normed embedding here regardless;
232
+ # For some reason applying SN to G's embedding seems to randomly cripple G
233
+ self.which_embedding = nn.Embedding
234
+ bn_linear = (
235
+ functools.partial(self.which_linear, bias=False)
236
+ if self.G_shared
237
+ else self.which_embedding
238
+ )
239
+ self.which_bn = functools.partial(
240
+ layers.ccbn,
241
+ which_linear=bn_linear,
242
+ cross_replica=self.cross_replica,
243
+ mybn=self.mybn,
244
+ input_size=(
245
+ self.shared_dim + self.dim_z if self.G_shared else self.n_classes
246
+ ),
247
+ norm_style=self.norm_style,
248
+ eps=self.BN_eps,
249
+ )
250
+
251
+ # Prepare model
252
+ # If not using shared embeddings, self.shared is just a passthrough
253
+ self.shared = (
254
+ self.which_embedding(n_classes, self.shared_dim)
255
+ if G_shared
256
+ else layers.identity()
257
+ )
258
+ # First linear layer
259
+ self.linear = self.which_linear(
260
+ self.dim_z + self.shared_dim,
261
+ self.arch["in_channels"][0] * (self.bottom_width ** 2),
262
+ )
263
+
264
+ # self.blocks is a doubly-nested list of modules, the outer loop intended
265
+ # to be over blocks at a given resolution (resblocks and/or self-attention)
266
+ # while the inner loop is over a given block
267
+ self.blocks = []
268
+ for index in range(len(self.arch["out_channels"])):
269
+ self.blocks += [
270
+ [
271
+ GBlock(
272
+ in_channels=self.arch["in_channels"][index],
273
+ out_channels=self.arch["in_channels"][index]
274
+ if g_index == 0
275
+ else self.arch["out_channels"][index],
276
+ which_conv=self.which_conv,
277
+ which_bn=self.which_bn,
278
+ activation=self.activation,
279
+ upsample=(
280
+ functools.partial(F.interpolate, scale_factor=2)
281
+ if self.arch["upsample"][index]
282
+ and g_index == (self.G_depth - 1)
283
+ else None
284
+ ),
285
+ )
286
+ ]
287
+ for g_index in range(self.G_depth)
288
+ ]
289
+
290
+ # If attention on this block, attach it to the end
291
+ if self.arch["attention"][self.arch["resolution"][index]]:
292
+ print(
293
+ "Adding attention layer in G at resolution %d"
294
+ % self.arch["resolution"][index]
295
+ )
296
+ self.blocks[-1] += [
297
+ layers.Attention(self.arch["out_channels"][index], self.which_conv)
298
+ ]
299
+
300
+ # Turn self.blocks into a ModuleList so that it's all properly registered.
301
+ self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
302
+
303
+ # output layer: batchnorm-relu-conv.
304
+ # Consider using a non-spectral conv here
305
+ self.output_layer = nn.Sequential(
306
+ layers.bn(
307
+ self.arch["out_channels"][-1],
308
+ cross_replica=self.cross_replica,
309
+ mybn=self.mybn,
310
+ ),
311
+ self.activation,
312
+ self.which_conv(self.arch["out_channels"][-1], 3),
313
+ )
314
+
315
+ # Initialize weights. Optionally skip init for testing.
316
+ if not skip_init:
317
+ self.init_weights()
318
+
319
+ # Set up optimizer
320
+ # If this is an EMA copy, no need for an optim, so just return now
321
+ if no_optim:
322
+ return
323
+ self.lr, self.B1, self.B2, self.adam_eps = G_lr, G_B1, G_B2, adam_eps
324
+ if G_mixed_precision:
325
+ print("Using fp16 adam in G...")
326
+ import utils
327
+
328
+ self.optim = utils.Adam16(
329
+ params=self.parameters(),
330
+ lr=self.lr,
331
+ betas=(self.B1, self.B2),
332
+ weight_decay=0,
333
+ eps=self.adam_eps,
334
+ )
335
+ else:
336
+ self.optim = optim.Adam(
337
+ params=self.parameters(),
338
+ lr=self.lr,
339
+ betas=(self.B1, self.B2),
340
+ weight_decay=0,
341
+ eps=self.adam_eps,
342
+ )
343
+
344
+ # LR scheduling, left here for forward compatibility
345
+ # self.lr_sched = {'itr' : 0}# if self.progressive else {}
346
+ # self.j = 0
347
+
348
+ # Initialize
349
+ def init_weights(self):
350
+ self.param_count = 0
351
+ for module in self.modules():
352
+ if (
353
+ isinstance(module, nn.Conv2d)
354
+ or isinstance(module, nn.Linear)
355
+ or isinstance(module, nn.Embedding)
356
+ ):
357
+ if self.init == "ortho":
358
+ init.orthogonal_(module.weight)
359
+ elif self.init == "N02":
360
+ init.normal_(module.weight, 0, 0.02)
361
+ elif self.init in ["glorot", "xavier"]:
362
+ init.xavier_uniform_(module.weight)
363
+ else:
364
+ print("Init style not recognized...")
365
+ self.param_count += sum(
366
+ [p.data.nelement() for p in module.parameters()]
367
+ )
368
+ print("Param count for G" "s initialized parameters: %d" % self.param_count)
369
+
370
+ # Note on this forward function: we pass in a y vector which has
371
+ # already been passed through G.shared to enable easy class-wise
372
+ # interpolation later. If we passed in the one-hot and then ran it through
373
+ # G.shared in this forward function, it would be harder to handle.
374
+ # NOTE: The z vs y dichotomy here is for compatibility with not-y
375
+ def forward(self, z, y):
376
+ # If hierarchical, concatenate zs and ys
377
+ if self.hier:
378
+ z = torch.cat([y, z], 1)
379
+ y = z
380
+ # First linear layer
381
+ h = self.linear(z)
382
+ # Reshape
383
+ h = h.view(h.size(0), -1, self.bottom_width, self.bottom_width)
384
+ # Loop over blocks
385
+ for index, blocklist in enumerate(self.blocks):
386
+ # Second inner loop in case block has multiple layers
387
+ for block in blocklist:
388
+ h = block(h, y)
389
+
390
+ # Apply batchnorm-relu-conv-tanh at output
391
+ return torch.tanh(self.output_layer(h))
392
+
393
+
394
+ class DBlock(nn.Module):
395
+ def __init__(
396
+ self,
397
+ in_channels,
398
+ out_channels,
399
+ which_conv=layers.SNConv2d,
400
+ wide=True,
401
+ preactivation=True,
402
+ activation=None,
403
+ downsample=None,
404
+ channel_ratio=4,
405
+ ):
406
+ super(DBlock, self).__init__()
407
+ self.in_channels, self.out_channels = in_channels, out_channels
408
+ # If using wide D (as in SA-GAN and BigGAN), change the channel pattern
409
+ self.hidden_channels = self.out_channels // channel_ratio
410
+ self.which_conv = which_conv
411
+ self.preactivation = preactivation
412
+ self.activation = activation
413
+ self.downsample = downsample
414
+
415
+ # Conv layers
416
+ self.conv1 = self.which_conv(
417
+ self.in_channels, self.hidden_channels, kernel_size=1, padding=0
418
+ )
419
+ self.conv2 = self.which_conv(self.hidden_channels, self.hidden_channels)
420
+ self.conv3 = self.which_conv(self.hidden_channels, self.hidden_channels)
421
+ self.conv4 = self.which_conv(
422
+ self.hidden_channels, self.out_channels, kernel_size=1, padding=0
423
+ )
424
+
425
+ self.learnable_sc = True if (in_channels != out_channels) else False
426
+ if self.learnable_sc:
427
+ self.conv_sc = self.which_conv(
428
+ in_channels, out_channels - in_channels, kernel_size=1, padding=0
429
+ )
430
+
431
+ def shortcut(self, x):
432
+ if self.downsample:
433
+ x = self.downsample(x)
434
+ if self.learnable_sc:
435
+ x = torch.cat([x, self.conv_sc(x)], 1)
436
+ return x
437
+
438
+ def forward(self, x):
439
+ # 1x1 bottleneck conv
440
+ h = self.conv1(F.relu(x))
441
+ # 3x3 convs
442
+ h = self.conv2(self.activation(h))
443
+ h = self.conv3(self.activation(h))
444
+ # relu before downsample
445
+ h = self.activation(h)
446
+ # downsample
447
+ if self.downsample:
448
+ h = self.downsample(h)
449
+ # final 1x1 conv
450
+ h = self.conv4(h)
451
+ return h + self.shortcut(x)
452
+
453
+
454
+ # Discriminator architecture, same paradigm as G's above
455
+ def D_arch(ch=64, attention="64", ksize="333333", dilation="111111"):
456
+ arch = {}
457
+ arch[256] = {
458
+ "in_channels": [item * ch for item in [1, 2, 4, 8, 8, 16]],
459
+ "out_channels": [item * ch for item in [2, 4, 8, 8, 16, 16]],
460
+ "downsample": [True] * 6 + [False],
461
+ "resolution": [128, 64, 32, 16, 8, 4, 4],
462
+ "attention": {
463
+ 2 ** i: 2 ** i in [int(item) for item in attention.split("_")]
464
+ for i in range(2, 8)
465
+ },
466
+ }
467
+ arch[128] = {
468
+ "in_channels": [item * ch for item in [1, 2, 4, 8, 16]],
469
+ "out_channels": [item * ch for item in [2, 4, 8, 16, 16]],
470
+ "downsample": [True] * 5 + [False],
471
+ "resolution": [64, 32, 16, 8, 4, 4],
472
+ "attention": {
473
+ 2 ** i: 2 ** i in [int(item) for item in attention.split("_")]
474
+ for i in range(2, 8)
475
+ },
476
+ }
477
+ arch[64] = {
478
+ "in_channels": [item * ch for item in [1, 2, 4, 8]],
479
+ "out_channels": [item * ch for item in [2, 4, 8, 16]],
480
+ "downsample": [True] * 4 + [False],
481
+ "resolution": [32, 16, 8, 4, 4],
482
+ "attention": {
483
+ 2 ** i: 2 ** i in [int(item) for item in attention.split("_")]
484
+ for i in range(2, 7)
485
+ },
486
+ }
487
+ arch[32] = {
488
+ "in_channels": [item * ch for item in [4, 4, 4]],
489
+ "out_channels": [item * ch for item in [4, 4, 4]],
490
+ "downsample": [True, True, False, False],
491
+ "resolution": [16, 16, 16, 16],
492
+ "attention": {
493
+ 2 ** i: 2 ** i in [int(item) for item in attention.split("_")]
494
+ for i in range(2, 6)
495
+ },
496
+ }
497
+ return arch
498
+
499
+
500
+ class Discriminator(nn.Module):
501
+ def __init__(
502
+ self,
503
+ D_ch=64,
504
+ D_wide=True,
505
+ D_depth=2,
506
+ resolution=128,
507
+ D_kernel_size=3,
508
+ D_attn="64",
509
+ n_classes=1000,
510
+ num_D_SVs=1,
511
+ num_D_SV_itrs=1,
512
+ D_activation=nn.ReLU(inplace=False),
513
+ D_lr=2e-4,
514
+ D_B1=0.0,
515
+ D_B2=0.999,
516
+ adam_eps=1e-8,
517
+ SN_eps=1e-12,
518
+ output_dim=1,
519
+ D_mixed_precision=False,
520
+ D_fp16=False,
521
+ D_init="ortho",
522
+ skip_init=False,
523
+ D_param="SN",
524
+ **kwargs
525
+ ):
526
+ super(Discriminator, self).__init__()
527
+ # Width multiplier
528
+ self.ch = D_ch
529
+ # Use Wide D as in BigGAN and SA-GAN or skinny D as in SN-GAN?
530
+ self.D_wide = D_wide
531
+ # How many resblocks per stage?
532
+ self.D_depth = D_depth
533
+ # Resolution
534
+ self.resolution = resolution
535
+ # Kernel size
536
+ self.kernel_size = D_kernel_size
537
+ # Attention?
538
+ self.attention = D_attn
539
+ # Number of classes
540
+ self.n_classes = n_classes
541
+ # Activation
542
+ self.activation = D_activation
543
+ # Initialization style
544
+ self.init = D_init
545
+ # Parameterization style
546
+ self.D_param = D_param
547
+ # Epsilon for Spectral Norm?
548
+ self.SN_eps = SN_eps
549
+ # Fp16?
550
+ self.fp16 = D_fp16
551
+ # Architecture
552
+ self.arch = D_arch(self.ch, self.attention)[resolution]
553
+
554
+ # Which convs, batchnorms, and linear layers to use
555
+ # No option to turn off SN in D right now
556
+ if self.D_param == "SN":
557
+ self.which_conv = functools.partial(
558
+ layers.SNConv2d,
559
+ kernel_size=3,
560
+ padding=1,
561
+ num_svs=num_D_SVs,
562
+ num_itrs=num_D_SV_itrs,
563
+ eps=self.SN_eps,
564
+ )
565
+ self.which_linear = functools.partial(
566
+ layers.SNLinear,
567
+ num_svs=num_D_SVs,
568
+ num_itrs=num_D_SV_itrs,
569
+ eps=self.SN_eps,
570
+ )
571
+ self.which_embedding = functools.partial(
572
+ layers.SNEmbedding,
573
+ num_svs=num_D_SVs,
574
+ num_itrs=num_D_SV_itrs,
575
+ eps=self.SN_eps,
576
+ )
577
+
578
+ # Prepare model
579
+ # Stem convolution
580
+ self.input_conv = self.which_conv(3, self.arch["in_channels"][0])
581
+ # self.blocks is a doubly-nested list of modules, the outer loop intended
582
+ # to be over blocks at a given resolution (resblocks and/or self-attention)
583
+ self.blocks = []
584
+ for index in range(len(self.arch["out_channels"])):
585
+ self.blocks += [
586
+ [
587
+ DBlock(
588
+ in_channels=self.arch["in_channels"][index]
589
+ if d_index == 0
590
+ else self.arch["out_channels"][index],
591
+ out_channels=self.arch["out_channels"][index],
592
+ which_conv=self.which_conv,
593
+ wide=self.D_wide,
594
+ activation=self.activation,
595
+ preactivation=True,
596
+ downsample=(
597
+ nn.AvgPool2d(2)
598
+ if self.arch["downsample"][index] and d_index == 0
599
+ else None
600
+ ),
601
+ )
602
+ for d_index in range(self.D_depth)
603
+ ]
604
+ ]
605
+ # If attention on this block, attach it to the end
606
+ if self.arch["attention"][self.arch["resolution"][index]]:
607
+ print(
608
+ "Adding attention layer in D at resolution %d"
609
+ % self.arch["resolution"][index]
610
+ )
611
+ self.blocks[-1] += [
612
+ layers.Attention(self.arch["out_channels"][index], self.which_conv)
613
+ ]
614
+ # Turn self.blocks into a ModuleList so that it's all properly registered.
615
+ self.blocks = nn.ModuleList([nn.ModuleList(block) for block in self.blocks])
616
+ # Linear output layer. The output dimension is typically 1, but may be
617
+ # larger if we're e.g. turning this into a VAE with an inference output
618
+ self.linear = self.which_linear(self.arch["out_channels"][-1], output_dim)
619
+ # Embedding for projection discrimination
620
+ self.embed = self.which_embedding(self.n_classes, self.arch["out_channels"][-1])
621
+
622
+ # Initialize weights
623
+ if not skip_init:
624
+ self.init_weights()
625
+
626
+ # Set up optimizer
627
+ self.lr, self.B1, self.B2, self.adam_eps = D_lr, D_B1, D_B2, adam_eps
628
+ if D_mixed_precision:
629
+ print("Using fp16 adam in D...")
630
+ import utils
631
+
632
+ self.optim = utils.Adam16(
633
+ params=self.parameters(),
634
+ lr=self.lr,
635
+ betas=(self.B1, self.B2),
636
+ weight_decay=0,
637
+ eps=self.adam_eps,
638
+ )
639
+ else:
640
+ self.optim = optim.Adam(
641
+ params=self.parameters(),
642
+ lr=self.lr,
643
+ betas=(self.B1, self.B2),
644
+ weight_decay=0,
645
+ eps=self.adam_eps,
646
+ )
647
+ # LR scheduling, left here for forward compatibility
648
+ # self.lr_sched = {'itr' : 0}# if self.progressive else {}
649
+ # self.j = 0
650
+
651
+ # Initialize
652
+ def init_weights(self):
653
+ self.param_count = 0
654
+ for module in self.modules():
655
+ if (
656
+ isinstance(module, nn.Conv2d)
657
+ or isinstance(module, nn.Linear)
658
+ or isinstance(module, nn.Embedding)
659
+ ):
660
+ if self.init == "ortho":
661
+ init.orthogonal_(module.weight)
662
+ elif self.init == "N02":
663
+ init.normal_(module.weight, 0, 0.02)
664
+ elif self.init in ["glorot", "xavier"]:
665
+ init.xavier_uniform_(module.weight)
666
+ else:
667
+ print("Init style not recognized...")
668
+ self.param_count += sum(
669
+ [p.data.nelement() for p in module.parameters()]
670
+ )
671
+ print("Param count for D" "s initialized parameters: %d" % self.param_count)
672
+
673
+ def forward(self, x, y=None):
674
+ # Run input conv
675
+ h = self.input_conv(x)
676
+ # Loop over blocks
677
+ for index, blocklist in enumerate(self.blocks):
678
+ for block in blocklist:
679
+ h = block(h)
680
+ # Apply global sum pooling as in SN-GAN
681
+ h = torch.sum(self.activation(h), [2, 3])
682
+ # Get initial class-unconditional output
683
+ out = self.linear(h)
684
+ # Get projection of final featureset onto class vectors and add to evidence
685
+ out = out + torch.sum(self.embed(y) * h, 1, keepdim=True)
686
+ return out
687
+
688
+
689
+ # Parallelized G_D to minimize cross-gpu communication
690
+ # Without this, Generator outputs would get all-gathered and then rebroadcast.
691
+ class G_D(nn.Module):
692
+ def __init__(self, G, D):
693
+ super(G_D, self).__init__()
694
+ self.G = G
695
+ self.D = D
696
+
697
+ def forward(
698
+ self, z, gy, x=None, dy=None, train_G=False, return_G_z=False, split_D=False
699
+ ):
700
+ # If training G, enable grad tape
701
+ with torch.set_grad_enabled(train_G):
702
+ # Get Generator output given noise
703
+ G_z = self.G(z, self.G.shared(gy))
704
+ # Cast as necessary
705
+ if self.G.fp16 and not self.D.fp16:
706
+ G_z = G_z.float()
707
+ if self.D.fp16 and not self.G.fp16:
708
+ G_z = G_z.half()
709
+ # Split_D means to run D once with real data and once with fake,
710
+ # rather than concatenating along the batch dimension.
711
+ if split_D:
712
+ D_fake = self.D(G_z, gy)
713
+ if x is not None:
714
+ D_real = self.D(x, dy)
715
+ return D_fake, D_real
716
+ else:
717
+ if return_G_z:
718
+ return D_fake, G_z
719
+ else:
720
+ return D_fake
721
+ # If real data is provided, concatenate it with the Generator's output
722
+ # along the batch dimension for improved efficiency.
723
+ else:
724
+ D_input = torch.cat([G_z, x], 0) if x is not None else G_z
725
+ D_class = torch.cat([gy, dy], 0) if dy is not None else gy
726
+ # Get Discriminator output
727
+ D_out = self.D(D_input, D_class)
728
+ if x is not None:
729
+ return torch.split(D_out, [G_z.shape[0], x.shape[0]]) # D_fake, D_real
730
+ else:
731
+ if return_G_z:
732
+ return D_out, G_z
733
+ else:
734
+ return D_out
BigGAN_PyTorch/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 Andy Brock
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
BigGAN_PyTorch/README.md ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BigGAN-PyTorch
2
+ The author's officially unofficial PyTorch BigGAN implementation.
3
+
4
+ ![Dogball? Dogball!](imgs/header_image.jpg?raw=true "Dogball? Dogball!")
5
+
6
+
7
+ This repo contains code for 4-8 GPU training of BigGANs from [Large Scale GAN Training for High Fidelity Natural Image Synthesis](https://arxiv.org/abs/1809.11096) by Andrew Brock, Jeff Donahue, and Karen Simonyan.
8
+
9
+ This code is by Andy Brock and Alex Andonian.
10
+
11
+ ## How To Use This Code
12
+ You will need:
13
+
14
+ - [PyTorch](https://PyTorch.org/), version 1.0.1
15
+ - tqdm, numpy, scipy, and h5py
16
+ - The ImageNet training set
17
+
18
+ First, you may optionally prepare a pre-processed HDF5 version of your target dataset for faster I/O. Following this (or not), you'll need the Inception moments needed to calculate FID. These can both be done by modifying and running
19
+
20
+ ```sh
21
+ sh scripts/utils/prepare_data.sh
22
+ ```
23
+
24
+ Which by default assumes your ImageNet training set is downloaded into the root folder `data` in this directory, and will prepare the cached HDF5 at 128x128 pixel resolution.
25
+
26
+ In the scripts folder, there are multiple bash scripts which will train BigGANs with different batch sizes. This code assumes you do not have access to a full TPU pod, and accordingly
27
+ spoofs mega-batches by using gradient accumulation (averaging grads over multiple minibatches, and only taking an optimizer step after N accumulations). By default, the `launch_BigGAN_bs256x8.sh` script trains a
28
+ full-sized BigGAN model with a batch size of 256 and 8 gradient accumulations, for a total batch size of 2048. On 8xV100 with full-precision training (no Tensor cores), this script takes 15 days to train to 150k iterations.
29
+
30
+ You will first need to figure out the maximum batch size your setup can support. The pre-trained models provided here were trained on 8xV100 (16GB VRAM each) which can support slightly more than the BS256 used by default.
31
+ Once you've determined this, you should modify the script so that the batch size times the number of gradient accumulations is equal to your desired total batch size (BigGAN defaults to 2048).
32
+
33
+ Note also that this script uses the `--load_in_mem` arg, which loads the entire (~64GB) I128.hdf5 file into RAM for faster data loading. If you don't have enough RAM to support this (probably 96GB+), remove this argument.
34
+
35
+
36
+ ## Metrics and Sampling
37
+ ![I believe I can fly!](imgs/interp_sample.jpg?raw=true "I believe I can fly!")
38
+
39
+ During training, this script will output logs with training metrics and test metrics, will save multiple copies (2 most recent and 5 highest-scoring) of the model weights/optimizer params, and will produce samples and interpolations every time it saves weights.
40
+ The logs folder contains scripts to process these logs and plot the results using MATLAB (sorry not sorry).
41
+
42
+ After training, one can use `sample.py` to produce additional samples and interpolations, test with different truncation values, batch sizes, number of standing stat accumulations, etc. See the `sample_BigGAN_bs256x8.sh` script for an example.
43
+
44
+ By default, everything is saved to weights/samples/logs/data folders which are assumed to be in the same folder as this repo.
45
+ You can point all of these to a different base folder using the `--base_root` argument, or pick specific locations for each of these with their respective arguments (e.g. `--logs_root`).
46
+
47
+ We include scripts to run BigGAN-deep, but we have not fully trained a model using them, so consider them untested. Additionally, we include scripts to run a model on CIFAR, and to run SA-GAN (with EMA) and SN-GAN on ImageNet. The SA-GAN code assumes you have 4xTitanX (or equivalent in terms of GPU RAM) and will run with a batch size of 128 and 2 gradient accumulations.
48
+
49
+ ## An Important Note on Inception Metrics
50
+ This repo uses the PyTorch in-built inception network to calculate IS and FID.
51
+ These scores are different from the scores you would get using the official TF inception code, and are only for monitoring purposes!
52
+ Run sample.py on your model, with the `--sample_npz` argument, then run inception_tf13 to calculate the actual TensorFlow IS. Note that you will need to have TensorFlow 1.3 or earlier installed, as TF1.4+ breaks the original IS code.
53
+
54
+ ## Pretrained models
55
+ ![PyTorch Inception Score and FID](imgs/IS_FID.png)
56
+ We include two pretrained model checkpoints (with G, D, the EMA copy of G, the optimizers, and the state dict):
57
+ - The main checkpoint is for a BigGAN trained on ImageNet at 128x128, using BS256 and 8 gradient accumulations, taken just before collapse, with a TF Inception Score of 97.35 +/- 1.79: [LINK](https://drive.google.com/open?id=1nAle7FCVFZdix2--ks0r5JBkFnKw8ctW)
58
+ - An earlier checkpoint of the first model (100k G iters), at high performance but well before collapse, which may be easier to fine-tune: [LINK](https://drive.google.com/open?id=1dmZrcVJUAWkPBGza_XgswSuT-UODXZcO)
59
+
60
+
61
+
62
+ Pretrained models for Places-365 coming soon.
63
+
64
+ This repo also contains scripts for porting the original TFHub BigGAN Generator weights to PyTorch. See the scripts in the TFHub folder for more details.
65
+
66
+ ## Fine-tuning, Using Your Own Dataset, or Making New Training Functions
67
+ ![That's deep, man](imgs/DeepSamples.png?raw=true "Deep Samples")
68
+
69
+ If you wish to resume interrupted training or fine-tune a pre-trained model, run the same launch script but with the `--resume` argument added.
70
+ Experiment names are automatically generated from the configuration, but can be overridden using the `--experiment_name` arg (for example, if you wish to fine-tune a model using modified optimizer settings).
71
+
72
+ To prep your own dataset, you will need to add it to datasets.py and modify the convenience dicts in utils.py (dset_dict, imsize_dict, root_dict, nclass_dict, classes_per_sheet_dict) to have the appropriate metadata for your dataset.
73
+ Repeat the process in prepare_data.sh (optionally produce an HDF5 preprocessed copy, and calculate the Inception Moments for FID).
74
+
75
+ By default, the training script will save the top 5 best checkpoints as measured by Inception Score.
76
+ For datasets other than ImageNet, Inception Score can be a very poor measure of quality, so you will likely want to use `--which_best FID` instead.
77
+
78
+ To use your own training function (e.g. train a BigVAE): either modify train_fns.GAN_training_function or add a new train fn and add it after the `if config['which_train_fn'] == 'GAN':` line in `train.py`.
79
+
80
+
81
+ ## Neat Stuff
82
+ - We include the full training and metrics logs [here](https://drive.google.com/open?id=1ZhY9Mg2b_S4QwxNmt57aXJ9FOC3ZN1qb) for reference. I've found that one of the hardest things about re-implementing a paper can be checking if the logs line up early in training,
83
+ especially if training takes multiple weeks. Hopefully these will be helpful for future work.
84
+ - We include an accelerated FID calculation--the original scipy version can require upwards of 10 minutes to calculate the matrix sqrt, this version uses an accelerated PyTorch version to calculate it in under a second.
85
+ - We include an accelerated, low-memory consumption ortho reg implementation.
86
+ - By default, we only compute the top singular value (the spectral norm), but this code supports computing more SVs through the `--num_G_SVs` argument.
87
+
88
+ ## Key Differences Between This Code And The Original BigGAN
89
+ - We use the optimizer settings from SA-GAN (G_lr=1e-4, D_lr=4e-4, num_D_steps=1, as opposed to BigGAN's G_lr=5e-5, D_lr=2e-4, num_D_steps=2).
90
+ While slightly less performant, this was the first corner we cut to bring training times down.
91
+ - By default, we do not use Cross-Replica BatchNorm (AKA Synced BatchNorm).
92
+ The two variants we tried (a custom, naive one and the one included in this repo) have slightly different gradients (albeit identical forward passes) from the built-in BatchNorm, which appear to be sufficient to cripple training.
93
+ - Gradient accumulation means that we update the SV estimates and the BN statistics 8 times more frequently. This means that the BN stats are much closer to standing stats, and that the singular value estimates tend to be more accurate.
94
+ Because of this, we measure metrics by default with G in test mode (using the BatchNorm running stat estimates instead of computing standing stats as in the paper). We do still support standing stats (see the sample.sh scripts).
95
+ This could also conceivably result in gradients from the earlier accumulations being stale, but in practice this does not appear to be a problem.
96
+ - The currently provided pretrained models were not trained with orthogonal regularization. Training without ortho reg seems to increase the probability that models will not be amenable to truncation,
97
+ but it looks like this particular model got a winning ticket. Regardless, we provide two highly optimized (fast and minimal memory consumption) ortho reg implementations which directly compute the ortho reg. gradients.
98
+
99
+ ## A Note On The Design Of This Repo
100
+ This code is designed from the ground up to serve as an extensible, hackable base for further research code.
101
+ We've put a lot of thought into making sure the abstractions are the *right* thickness for research--not so thick as to be impenetrable, but not so thin as to be useless.
102
+ The key idea is that if you want to experiment with a SOTA setup and make some modification (try out your own new loss function, architecture, self-attention block, etc) you should be able to easily do so just by dropping your code in one or two places, without having to worry about the rest of the codebase.
103
+ Things like the use of self.which_conv and functools.partial in the BigGAN.py model definition were put together with this in mind, as was the design of the Spectral Norm class inheritance.
104
+
105
+ With that said, this is a somewhat large codebase for a single project. While we tried to be thorough with the comments, if there's something you think could be more clear, better written, or better refactored, please feel free to raise an issue or a pull request.
106
+
107
+ ## Feature Requests
108
+ Want to work on or improve this code? There are a couple things this repo would benefit from, but which don't yet work.
109
+
110
+ - Synchronized BatchNorm (AKA Cross-Replica BatchNorm). We tried out two variants of this, but for some unknown reason it crippled training each time.
111
+ We have not tried the [apex](https://github.com/NVIDIA/apex) SyncBN as my school's servers are on ancient NVIDIA drivers that don't support it--apex would probably be a good place to start.
112
+ - Mixed precision training and making use of Tensor cores. This repo includes a naive mixed-precision Adam implementation which works early in training but leads to early collapse, and doesn't do anything to activate Tensor cores (it just reduces memory consumption).
113
+ As above, integrating [apex](https://github.com/NVIDIA/apex) into this code and employing its mixed-precision training techniques to take advantage of Tensor cores and reduce memory consumption could yield substantial speed gains.
114
+
115
+ ## Misc Notes
116
+ See [This directory](https://gist.github.com/yrevar/942d3a0ac09ec9e5eb3a) for ImageNet labels.
117
+
118
+ If you use this code, please cite
119
+ ```text
120
+ @inproceedings{
121
+ brock2018large,
122
+ title={Large Scale {GAN} Training for High Fidelity Natural Image Synthesis},
123
+ author={Andrew Brock and Jeff Donahue and Karen Simonyan},
124
+ booktitle={International Conference on Learning Representations},
125
+ year={2019},
126
+ url={https://openreview.net/forum?id=B1xsqj09Fm},
127
+ }
128
+ ```
129
+
130
+ ## Acknowledgments
131
+ Thanks to Google for the generous cloud credit donations.
132
+
133
+ [SyncBN](https://github.com/vacancy/Synchronized-BatchNorm-PyTorch) by Jiayuan Mao and Tete Xiao.
134
+
135
+ [Progress bar](https://github.com/Lasagne/Recipes/tree/master/papers/densenet) originally from Jan Schlüter.
136
+
137
+ Test metrics logger from [VoxNet.](https://github.com/dimatura/voxnet)
138
+
139
+ PyTorch [implementation of cov](https://discuss.PyTorch.org/t/covariance-and-gradient-support/16217/2) from Modar M. Alfadly.
140
+
141
+ PyTorch [fast Matrix Sqrt](https://github.com/msubhransu/matrix-sqrt) for FID from Tsung-Yu Lin and Subhransu Maji.
142
+
143
+ TensorFlow Inception Score code from [OpenAI's Improved-GAN.](https://github.com/openai/improved-gan)
144
+
BigGAN_PyTorch/TFHub/README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # BigGAN-PyTorch TFHub converter
2
+ This dir contains scripts for taking the [pre-trained generator weights from TFHub](https://tfhub.dev/s?q=biggan) and porting them to BigGAN-Pytorch.
3
+
4
+ In addition to the base libraries for BigGAN-PyTorch, to run this code you will need:
5
+
6
+ TensorFlow
7
+ TFHub
8
+ parse
9
+
10
+ Note that this code is only presently set up to run the ported models without truncation--you'll need to accumulate standing stats at each truncation level yourself if you wish to employ it.
11
+
12
+ To port the 128x128 model from tfhub, produce a pretrained weights .pth file, and generate samples using all your GPUs, run
13
+
14
+ `python converter.py -r 128 --generate_samples --parallel`
BigGAN_PyTorch/TFHub/biggan_v1.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # All contributions by Andy Brock:
5
+ # Copyright (c) 2019 Andy Brock
6
+ #
7
+ # MIT License
8
+ #
9
+ # BigGAN V1:
10
+ # This is now deprecated code used for porting the TFHub modules to pytorch,
11
+ # included here for reference only.
12
+ import numpy as np
13
+ import torch
14
+ from scipy.stats import truncnorm
15
+ from torch import nn
16
+ from torch.nn import Parameter
17
+ from torch.nn import functional as F
18
+
19
+
20
+ def l2normalize(v, eps=1e-4):
21
+ return v / (v.norm() + eps)
22
+
23
+
24
+ def truncated_z_sample(batch_size, z_dim, truncation=0.5, seed=None):
25
+ state = None if seed is None else np.random.RandomState(seed)
26
+ values = truncnorm.rvs(-2, 2, size=(batch_size, z_dim), random_state=state)
27
+ return truncation * values
28
+
29
+
30
+ def denorm(x):
31
+ out = (x + 1) / 2
32
+ return out.clamp_(0, 1)
33
+
34
+
35
+ class SpectralNorm(nn.Module):
36
+ def __init__(self, module, name="weight", power_iterations=1):
37
+ super(SpectralNorm, self).__init__()
38
+ self.module = module
39
+ self.name = name
40
+ self.power_iterations = power_iterations
41
+ if not self._made_params():
42
+ self._make_params()
43
+
44
+ def _update_u_v(self):
45
+ u = getattr(self.module, self.name + "_u")
46
+ v = getattr(self.module, self.name + "_v")
47
+ w = getattr(self.module, self.name + "_bar")
48
+
49
+ height = w.data.shape[0]
50
+ _w = w.view(height, -1)
51
+ for _ in range(self.power_iterations):
52
+ v = l2normalize(torch.matmul(_w.t(), u))
53
+ u = l2normalize(torch.matmul(_w, v))
54
+
55
+ sigma = u.dot((_w).mv(v))
56
+ setattr(self.module, self.name, w / sigma.expand_as(w))
57
+
58
+ def _made_params(self):
59
+ try:
60
+ getattr(self.module, self.name + "_u")
61
+ getattr(self.module, self.name + "_v")
62
+ getattr(self.module, self.name + "_bar")
63
+ return True
64
+ except AttributeError:
65
+ return False
66
+
67
+ def _make_params(self):
68
+ w = getattr(self.module, self.name)
69
+
70
+ height = w.data.shape[0]
71
+ width = w.view(height, -1).data.shape[1]
72
+
73
+ u = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
74
+ v = Parameter(w.data.new(height).normal_(0, 1), requires_grad=False)
75
+ u.data = l2normalize(u.data)
76
+ v.data = l2normalize(v.data)
77
+ w_bar = Parameter(w.data)
78
+
79
+ del self.module._parameters[self.name]
80
+ self.module.register_parameter(self.name + "_u", u)
81
+ self.module.register_parameter(self.name + "_v", v)
82
+ self.module.register_parameter(self.name + "_bar", w_bar)
83
+
84
+ def forward(self, *args):
85
+ self._update_u_v()
86
+ return self.module.forward(*args)
87
+
88
+
89
+ class SelfAttention(nn.Module):
90
+ """ Self Attention Layer"""
91
+
92
+ def __init__(self, in_dim, activation=F.relu):
93
+ super().__init__()
94
+ self.chanel_in = in_dim
95
+ self.activation = activation
96
+
97
+ self.theta = SpectralNorm(
98
+ nn.Conv2d(
99
+ in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1, bias=False
100
+ )
101
+ )
102
+ self.phi = SpectralNorm(
103
+ nn.Conv2d(
104
+ in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1, bias=False
105
+ )
106
+ )
107
+ self.pool = nn.MaxPool2d(2, 2)
108
+ self.g = SpectralNorm(
109
+ nn.Conv2d(
110
+ in_channels=in_dim, out_channels=in_dim // 2, kernel_size=1, bias=False
111
+ )
112
+ )
113
+ self.o_conv = SpectralNorm(
114
+ nn.Conv2d(
115
+ in_channels=in_dim // 2, out_channels=in_dim, kernel_size=1, bias=False
116
+ )
117
+ )
118
+ self.gamma = nn.Parameter(torch.zeros(1))
119
+
120
+ self.softmax = nn.Softmax(dim=-1)
121
+
122
+ def forward(self, x):
123
+ m_batchsize, C, width, height = x.size()
124
+ N = height * width
125
+
126
+ theta = self.theta(x)
127
+ phi = self.phi(x)
128
+ phi = self.pool(phi)
129
+ phi = phi.view(m_batchsize, -1, N // 4)
130
+ theta = theta.view(m_batchsize, -1, N)
131
+ theta = theta.permute(0, 2, 1)
132
+ attention = self.softmax(torch.bmm(theta, phi))
133
+ g = self.pool(self.g(x)).view(m_batchsize, -1, N // 4)
134
+ attn_g = torch.bmm(g, attention.permute(0, 2, 1)).view(
135
+ m_batchsize, -1, width, height
136
+ )
137
+ out = self.o_conv(attn_g)
138
+ return self.gamma * out + x
139
+
140
+
141
+ class ConditionalBatchNorm2d(nn.Module):
142
+ def __init__(self, num_features, num_classes, eps=1e-4, momentum=0.1):
143
+ super().__init__()
144
+ self.num_features = num_features
145
+ self.bn = nn.BatchNorm2d(num_features, affine=False, eps=eps, momentum=momentum)
146
+ self.gamma_embed = SpectralNorm(
147
+ nn.Linear(num_classes, num_features, bias=False)
148
+ )
149
+ self.beta_embed = SpectralNorm(nn.Linear(num_classes, num_features, bias=False))
150
+
151
+ def forward(self, x, y):
152
+ out = self.bn(x)
153
+ gamma = self.gamma_embed(y) + 1
154
+ beta = self.beta_embed(y)
155
+ out = gamma.view(-1, self.num_features, 1, 1) * out + beta.view(
156
+ -1, self.num_features, 1, 1
157
+ )
158
+ return out
159
+
160
+
161
+ class GBlock(nn.Module):
162
+ def __init__(
163
+ self,
164
+ in_channel,
165
+ out_channel,
166
+ kernel_size=[3, 3],
167
+ padding=1,
168
+ stride=1,
169
+ n_class=None,
170
+ bn=True,
171
+ activation=F.relu,
172
+ upsample=True,
173
+ downsample=False,
174
+ z_dim=148,
175
+ ):
176
+ super().__init__()
177
+
178
+ self.conv0 = SpectralNorm(
179
+ nn.Conv2d(
180
+ in_channel,
181
+ out_channel,
182
+ kernel_size,
183
+ stride,
184
+ padding,
185
+ bias=True if bn else True,
186
+ )
187
+ )
188
+ self.conv1 = SpectralNorm(
189
+ nn.Conv2d(
190
+ out_channel,
191
+ out_channel,
192
+ kernel_size,
193
+ stride,
194
+ padding,
195
+ bias=True if bn else True,
196
+ )
197
+ )
198
+
199
+ self.skip_proj = False
200
+ if in_channel != out_channel or upsample or downsample:
201
+ self.conv_sc = SpectralNorm(nn.Conv2d(in_channel, out_channel, 1, 1, 0))
202
+ self.skip_proj = True
203
+
204
+ self.upsample = upsample
205
+ self.downsample = downsample
206
+ self.activation = activation
207
+ self.bn = bn
208
+ if bn:
209
+ self.HyperBN = ConditionalBatchNorm2d(in_channel, z_dim)
210
+ self.HyperBN_1 = ConditionalBatchNorm2d(out_channel, z_dim)
211
+
212
+ def forward(self, input, condition=None):
213
+ out = input
214
+
215
+ if self.bn:
216
+ out = self.HyperBN(out, condition)
217
+ out = self.activation(out)
218
+ if self.upsample:
219
+ out = F.interpolate(out, scale_factor=2)
220
+ out = self.conv0(out)
221
+ if self.bn:
222
+ out = self.HyperBN_1(out, condition)
223
+ out = self.activation(out)
224
+ out = self.conv1(out)
225
+
226
+ if self.downsample:
227
+ out = F.avg_pool2d(out, 2)
228
+
229
+ if self.skip_proj:
230
+ skip = input
231
+ if self.upsample:
232
+ skip = F.interpolate(skip, scale_factor=2)
233
+ skip = self.conv_sc(skip)
234
+ if self.downsample:
235
+ skip = F.avg_pool2d(skip, 2)
236
+ else:
237
+ skip = input
238
+ return out + skip
239
+
240
+
241
+ class Generator128(nn.Module):
242
+ def __init__(self, code_dim=120, n_class=1000, chn=96, debug=False):
243
+ super().__init__()
244
+
245
+ self.linear = nn.Linear(n_class, 128, bias=False)
246
+
247
+ if debug:
248
+ chn = 8
249
+
250
+ self.first_view = 16 * chn
251
+
252
+ self.G_linear = SpectralNorm(nn.Linear(20, 4 * 4 * 16 * chn))
253
+
254
+ z_dim = code_dim + 28
255
+
256
+ self.GBlock = nn.ModuleList(
257
+ [
258
+ GBlock(16 * chn, 16 * chn, n_class=n_class, z_dim=z_dim),
259
+ GBlock(16 * chn, 8 * chn, n_class=n_class, z_dim=z_dim),
260
+ GBlock(8 * chn, 4 * chn, n_class=n_class, z_dim=z_dim),
261
+ GBlock(4 * chn, 2 * chn, n_class=n_class, z_dim=z_dim),
262
+ GBlock(2 * chn, 1 * chn, n_class=n_class, z_dim=z_dim),
263
+ ]
264
+ )
265
+
266
+ self.sa_id = 4
267
+ self.num_split = len(self.GBlock) + 1
268
+ self.attention = SelfAttention(2 * chn)
269
+ self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn, eps=1e-4)
270
+ self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1))
271
+
272
+ def forward(self, input, class_id):
273
+ codes = torch.chunk(input, self.num_split, 1)
274
+ class_emb = self.linear(class_id) # 128
275
+
276
+ out = self.G_linear(codes[0])
277
+ out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2)
278
+ for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)):
279
+ if i == self.sa_id:
280
+ out = self.attention(out)
281
+ condition = torch.cat([code, class_emb], 1)
282
+ out = GBlock(out, condition)
283
+
284
+ out = self.ScaledCrossReplicaBN(out)
285
+ out = F.relu(out)
286
+ out = self.colorize(out)
287
+ return torch.tanh(out)
288
+
289
+
290
+ class Generator256(nn.Module):
291
+ def __init__(self, code_dim=140, n_class=1000, chn=96, debug=False):
292
+ super().__init__()
293
+
294
+ self.linear = nn.Linear(n_class, 128, bias=False)
295
+
296
+ if debug:
297
+ chn = 8
298
+
299
+ self.first_view = 16 * chn
300
+
301
+ self.G_linear = SpectralNorm(nn.Linear(20, 4 * 4 * 16 * chn))
302
+
303
+ self.GBlock = nn.ModuleList(
304
+ [
305
+ GBlock(16 * chn, 16 * chn, n_class=n_class),
306
+ GBlock(16 * chn, 8 * chn, n_class=n_class),
307
+ GBlock(8 * chn, 8 * chn, n_class=n_class),
308
+ GBlock(8 * chn, 4 * chn, n_class=n_class),
309
+ GBlock(4 * chn, 2 * chn, n_class=n_class),
310
+ GBlock(2 * chn, 1 * chn, n_class=n_class),
311
+ ]
312
+ )
313
+
314
+ self.sa_id = 5
315
+ self.num_split = len(self.GBlock) + 1
316
+ self.attention = SelfAttention(2 * chn)
317
+ self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn, eps=1e-4)
318
+ self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1))
319
+
320
+ def forward(self, input, class_id):
321
+ codes = torch.chunk(input, self.num_split, 1)
322
+ class_emb = self.linear(class_id) # 128
323
+
324
+ out = self.G_linear(codes[0])
325
+ out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2)
326
+ for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)):
327
+ if i == self.sa_id:
328
+ out = self.attention(out)
329
+ condition = torch.cat([code, class_emb], 1)
330
+ out = GBlock(out, condition)
331
+
332
+ out = self.ScaledCrossReplicaBN(out)
333
+ out = F.relu(out)
334
+ out = self.colorize(out)
335
+ return torch.tanh(out)
336
+
337
+
338
+ class Generator512(nn.Module):
339
+ def __init__(self, code_dim=128, n_class=1000, chn=96, debug=False):
340
+ super().__init__()
341
+
342
+ self.linear = nn.Linear(n_class, 128, bias=False)
343
+
344
+ if debug:
345
+ chn = 8
346
+
347
+ self.first_view = 16 * chn
348
+
349
+ self.G_linear = SpectralNorm(nn.Linear(16, 4 * 4 * 16 * chn))
350
+
351
+ z_dim = code_dim + 16
352
+
353
+ self.GBlock = nn.ModuleList(
354
+ [
355
+ GBlock(16 * chn, 16 * chn, n_class=n_class, z_dim=z_dim),
356
+ GBlock(16 * chn, 8 * chn, n_class=n_class, z_dim=z_dim),
357
+ GBlock(8 * chn, 8 * chn, n_class=n_class, z_dim=z_dim),
358
+ GBlock(8 * chn, 4 * chn, n_class=n_class, z_dim=z_dim),
359
+ GBlock(4 * chn, 2 * chn, n_class=n_class, z_dim=z_dim),
360
+ GBlock(2 * chn, 1 * chn, n_class=n_class, z_dim=z_dim),
361
+ GBlock(1 * chn, 1 * chn, n_class=n_class, z_dim=z_dim),
362
+ ]
363
+ )
364
+
365
+ self.sa_id = 4
366
+ self.num_split = len(self.GBlock) + 1
367
+ self.attention = SelfAttention(4 * chn)
368
+ self.ScaledCrossReplicaBN = nn.BatchNorm2d(1 * chn)
369
+ self.colorize = SpectralNorm(nn.Conv2d(1 * chn, 3, [3, 3], padding=1))
370
+
371
+ def forward(self, input, class_id):
372
+ codes = torch.chunk(input, self.num_split, 1)
373
+ class_emb = self.linear(class_id) # 128
374
+
375
+ out = self.G_linear(codes[0])
376
+ out = out.view(-1, 4, 4, self.first_view).permute(0, 3, 1, 2)
377
+ for i, (code, GBlock) in enumerate(zip(codes[1:], self.GBlock)):
378
+ if i == self.sa_id:
379
+ out = self.attention(out)
380
+ condition = torch.cat([code, class_emb], 1)
381
+ out = GBlock(out, condition)
382
+
383
+ out = self.ScaledCrossReplicaBN(out)
384
+ out = F.relu(out)
385
+ out = self.colorize(out)
386
+ return torch.tanh(out)
387
+
388
+
389
+ class Discriminator(nn.Module):
390
+ def __init__(self, n_class=1000, chn=96, debug=False):
391
+ super().__init__()
392
+
393
+ def conv(in_channel, out_channel, downsample=True):
394
+ return GBlock(
395
+ in_channel, out_channel, bn=False, upsample=False, downsample=downsample
396
+ )
397
+
398
+ if debug:
399
+ chn = 8
400
+ self.debug = debug
401
+
402
+ self.pre_conv = nn.Sequential(
403
+ SpectralNorm(nn.Conv2d(3, 1 * chn, 3, padding=1)),
404
+ nn.ReLU(),
405
+ SpectralNorm(nn.Conv2d(1 * chn, 1 * chn, 3, padding=1)),
406
+ nn.AvgPool2d(2),
407
+ )
408
+ self.pre_skip = SpectralNorm(nn.Conv2d(3, 1 * chn, 1))
409
+
410
+ self.conv = nn.Sequential(
411
+ conv(1 * chn, 1 * chn, downsample=True),
412
+ conv(1 * chn, 2 * chn, downsample=True),
413
+ SelfAttention(2 * chn),
414
+ conv(2 * chn, 2 * chn, downsample=True),
415
+ conv(2 * chn, 4 * chn, downsample=True),
416
+ conv(4 * chn, 8 * chn, downsample=True),
417
+ conv(8 * chn, 8 * chn, downsample=True),
418
+ conv(8 * chn, 16 * chn, downsample=True),
419
+ conv(16 * chn, 16 * chn, downsample=False),
420
+ )
421
+
422
+ self.linear = SpectralNorm(nn.Linear(16 * chn, 1))
423
+
424
+ self.embed = nn.Embedding(n_class, 16 * chn)
425
+ self.embed.weight.data.uniform_(-0.1, 0.1)
426
+ self.embed = SpectralNorm(self.embed)
427
+
428
+ def forward(self, input, class_id):
429
+
430
+ out = self.pre_conv(input)
431
+ out += self.pre_skip(F.avg_pool2d(input, 2))
432
+ out = self.conv(out)
433
+ out = F.relu(out)
434
+ out = out.view(out.size(0), out.size(1), -1)
435
+ out = out.sum(2)
436
+ out_linear = self.linear(out).squeeze(1)
437
+ embed = self.embed(class_id)
438
+
439
+ prod = (out * embed).sum(1)
440
+
441
+ return out_linear + prod
BigGAN_PyTorch/TFHub/converter.py ADDED
@@ -0,0 +1,558 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # All contributions by Andy Brock:
5
+ # Copyright (c) 2019 Andy Brock
6
+ #
7
+ # MIT License
8
+
9
+ """Utilities for converting TFHub BigGAN generator weights to PyTorch.
10
+
11
+ Recommended usage:
12
+
13
+ To convert all BigGAN variants and generate test samples, use:
14
+
15
+ ```bash
16
+ CUDA_VISIBLE_DEVICES=0 python converter.py --generate_samples
17
+ ```
18
+
19
+ See `parse_args` for additional options.
20
+ """
21
+
22
+ import argparse
23
+ import os
24
+ import sys
25
+
26
+ import h5py
27
+ import torch
28
+ import torch.nn as nn
29
+ from torchvision.utils import save_image
30
+ import tensorflow as tf
31
+ import tensorflow_hub as hub
32
+ import parse
33
+
34
+ # import reference biggan from this folder
35
+ import biggan_v1 as biggan_for_conversion
36
+
37
+ # Import model from main folder
38
+ sys.path.append("..")
39
+ import BigGAN
40
+
41
+
42
+ DEVICE = "cuda"
43
+ HDF5_TMPL = "biggan-{}.h5"
44
+ PTH_TMPL = "biggan-{}.pth"
45
+ MODULE_PATH_TMPL = "https://tfhub.dev/deepmind/biggan-{}/2"
46
+ Z_DIMS = {128: 120, 256: 140, 512: 128}
47
+ RESOLUTIONS = list(Z_DIMS)
48
+
49
+
50
+ def dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=False):
51
+ """Loads TFHub weights and saves them to intermediate HDF5 file.
52
+
53
+ Args:
54
+ module_path ([Path-like]): Path to TFHub module.
55
+ hdf5_path ([Path-like]): Path to output HDF5 file.
56
+
57
+ Returns:
58
+ [h5py.File]: Loaded hdf5 file containing module weights.
59
+ """
60
+ if os.path.exists(hdf5_path) and (not redownload):
61
+ print("Loading BigGAN hdf5 file from:", hdf5_path)
62
+ return h5py.File(hdf5_path, "r")
63
+
64
+ print("Loading BigGAN module from:", module_path)
65
+ tf.reset_default_graph()
66
+ hub.Module(module_path)
67
+ print("Loaded BigGAN module from:", module_path)
68
+
69
+ initializer = tf.global_variables_initializer()
70
+ sess = tf.Session()
71
+ sess.run(initializer)
72
+
73
+ print("Saving BigGAN weights to :", hdf5_path)
74
+ h5f = h5py.File(hdf5_path, "w")
75
+ for var in tf.global_variables():
76
+ val = sess.run(var)
77
+ h5f.create_dataset(var.name, data=val)
78
+ print(f"Saving {var.name} with shape {val.shape}")
79
+ h5f.close()
80
+ return h5py.File(hdf5_path, "r")
81
+
82
+
83
+ class TFHub2Pytorch(object):
84
+
85
+ TF_ROOT = "module"
86
+
87
+ NUM_GBLOCK = {128: 5, 256: 6, 512: 7}
88
+
89
+ w = "w"
90
+ b = "b"
91
+ u = "u0"
92
+ v = "u1"
93
+ gamma = "gamma"
94
+ beta = "beta"
95
+
96
+ def __init__(
97
+ self, state_dict, tf_weights, resolution=256, load_ema=True, verbose=False
98
+ ):
99
+ self.state_dict = state_dict
100
+ self.tf_weights = tf_weights
101
+ self.resolution = resolution
102
+ self.verbose = verbose
103
+ if load_ema:
104
+ for name in ["w", "b", "gamma", "beta"]:
105
+ setattr(self, name, getattr(self, name) + "/ema_b999900")
106
+
107
+ def load(self):
108
+ self.load_generator()
109
+ return self.state_dict
110
+
111
+ def load_generator(self):
112
+ GENERATOR_ROOT = os.path.join(self.TF_ROOT, "Generator")
113
+
114
+ for i in range(self.NUM_GBLOCK[self.resolution]):
115
+ name_tf = os.path.join(GENERATOR_ROOT, "GBlock")
116
+ name_tf += f"_{i}" if i != 0 else ""
117
+ self.load_GBlock(f"GBlock.{i}.", name_tf)
118
+
119
+ self.load_attention("attention.", os.path.join(GENERATOR_ROOT, "attention"))
120
+ self.load_linear("linear", os.path.join(self.TF_ROOT, "linear"), bias=False)
121
+ self.load_snlinear("G_linear", os.path.join(GENERATOR_ROOT, "G_Z", "G_linear"))
122
+ self.load_colorize("colorize", os.path.join(GENERATOR_ROOT, "conv_2d"))
123
+ self.load_ScaledCrossReplicaBNs(
124
+ "ScaledCrossReplicaBN", os.path.join(GENERATOR_ROOT, "ScaledCrossReplicaBN")
125
+ )
126
+
127
+ def load_linear(self, name_pth, name_tf, bias=True):
128
+ self.state_dict[name_pth + ".weight"] = self.load_tf_tensor(
129
+ name_tf, self.w
130
+ ).permute(1, 0)
131
+ if bias:
132
+ self.state_dict[name_pth + ".bias"] = self.load_tf_tensor(name_tf, self.b)
133
+
134
+ def load_snlinear(self, name_pth, name_tf, bias=True):
135
+ self.state_dict[name_pth + ".module.weight_u"] = self.load_tf_tensor(
136
+ name_tf, self.u
137
+ ).squeeze()
138
+ self.state_dict[name_pth + ".module.weight_v"] = self.load_tf_tensor(
139
+ name_tf, self.v
140
+ ).squeeze()
141
+ self.state_dict[name_pth + ".module.weight_bar"] = self.load_tf_tensor(
142
+ name_tf, self.w
143
+ ).permute(1, 0)
144
+ if bias:
145
+ self.state_dict[name_pth + ".module.bias"] = self.load_tf_tensor(
146
+ name_tf, self.b
147
+ )
148
+
149
+ def load_colorize(self, name_pth, name_tf):
150
+ self.load_snconv(name_pth, name_tf)
151
+
152
+ def load_GBlock(self, name_pth, name_tf):
153
+ self.load_convs(name_pth, name_tf)
154
+ self.load_HyperBNs(name_pth, name_tf)
155
+
156
+ def load_convs(self, name_pth, name_tf):
157
+ self.load_snconv(name_pth + "conv0", os.path.join(name_tf, "conv0"))
158
+ self.load_snconv(name_pth + "conv1", os.path.join(name_tf, "conv1"))
159
+ self.load_snconv(name_pth + "conv_sc", os.path.join(name_tf, "conv_sc"))
160
+
161
+ def load_snconv(self, name_pth, name_tf, bias=True):
162
+ if self.verbose:
163
+ print(f"loading: {name_pth} from {name_tf}")
164
+ self.state_dict[name_pth + ".module.weight_u"] = self.load_tf_tensor(
165
+ name_tf, self.u
166
+ ).squeeze()
167
+ self.state_dict[name_pth + ".module.weight_v"] = self.load_tf_tensor(
168
+ name_tf, self.v
169
+ ).squeeze()
170
+ self.state_dict[name_pth + ".module.weight_bar"] = self.load_tf_tensor(
171
+ name_tf, self.w
172
+ ).permute(3, 2, 0, 1)
173
+ if bias:
174
+ self.state_dict[name_pth + ".module.bias"] = self.load_tf_tensor(
175
+ name_tf, self.b
176
+ ).squeeze()
177
+
178
+ def load_conv(self, name_pth, name_tf, bias=True):
179
+
180
+ self.state_dict[name_pth + ".weight_u"] = self.load_tf_tensor(
181
+ name_tf, self.u
182
+ ).squeeze()
183
+ self.state_dict[name_pth + ".weight_v"] = self.load_tf_tensor(
184
+ name_tf, self.v
185
+ ).squeeze()
186
+ self.state_dict[name_pth + ".weight_bar"] = self.load_tf_tensor(
187
+ name_tf, self.w
188
+ ).permute(3, 2, 0, 1)
189
+ if bias:
190
+ self.state_dict[name_pth + ".bias"] = self.load_tf_tensor(name_tf, self.b)
191
+
192
+ def load_HyperBNs(self, name_pth, name_tf):
193
+ self.load_HyperBN(name_pth + "HyperBN", os.path.join(name_tf, "HyperBN"))
194
+ self.load_HyperBN(name_pth + "HyperBN_1", os.path.join(name_tf, "HyperBN_1"))
195
+
196
+ def load_ScaledCrossReplicaBNs(self, name_pth, name_tf):
197
+ self.state_dict[name_pth + ".bias"] = self.load_tf_tensor(
198
+ name_tf, self.beta
199
+ ).squeeze()
200
+ self.state_dict[name_pth + ".weight"] = self.load_tf_tensor(
201
+ name_tf, self.gamma
202
+ ).squeeze()
203
+ self.state_dict[name_pth + ".running_mean"] = self.load_tf_tensor(
204
+ name_tf + "bn", "accumulated_mean"
205
+ )
206
+ self.state_dict[name_pth + ".running_var"] = self.load_tf_tensor(
207
+ name_tf + "bn", "accumulated_var"
208
+ )
209
+ self.state_dict[name_pth + ".num_batches_tracked"] = torch.tensor(
210
+ self.tf_weights[os.path.join(name_tf + "bn", "accumulation_counter:0")][()],
211
+ dtype=torch.float32,
212
+ )
213
+
214
+ def load_HyperBN(self, name_pth, name_tf):
215
+ if self.verbose:
216
+ print(f"loading: {name_pth} from {name_tf}")
217
+ beta = name_pth + ".beta_embed.module"
218
+ gamma = name_pth + ".gamma_embed.module"
219
+ self.state_dict[beta + ".weight_u"] = self.load_tf_tensor(
220
+ os.path.join(name_tf, "beta"), self.u
221
+ ).squeeze()
222
+ self.state_dict[gamma + ".weight_u"] = self.load_tf_tensor(
223
+ os.path.join(name_tf, "gamma"), self.u
224
+ ).squeeze()
225
+ self.state_dict[beta + ".weight_v"] = self.load_tf_tensor(
226
+ os.path.join(name_tf, "beta"), self.v
227
+ ).squeeze()
228
+ self.state_dict[gamma + ".weight_v"] = self.load_tf_tensor(
229
+ os.path.join(name_tf, "gamma"), self.v
230
+ ).squeeze()
231
+ self.state_dict[beta + ".weight_bar"] = self.load_tf_tensor(
232
+ os.path.join(name_tf, "beta"), self.w
233
+ ).permute(1, 0)
234
+ self.state_dict[gamma + ".weight_bar"] = self.load_tf_tensor(
235
+ os.path.join(name_tf, "gamma"), self.w
236
+ ).permute(1, 0)
237
+
238
+ cr_bn_name = name_tf.replace("HyperBN", "CrossReplicaBN")
239
+ self.state_dict[name_pth + ".bn.running_mean"] = self.load_tf_tensor(
240
+ cr_bn_name, "accumulated_mean"
241
+ )
242
+ self.state_dict[name_pth + ".bn.running_var"] = self.load_tf_tensor(
243
+ cr_bn_name, "accumulated_var"
244
+ )
245
+ self.state_dict[name_pth + ".bn.num_batches_tracked"] = torch.tensor(
246
+ self.tf_weights[os.path.join(cr_bn_name, "accumulation_counter:0")][()],
247
+ dtype=torch.float32,
248
+ )
249
+
250
+ def load_attention(self, name_pth, name_tf):
251
+
252
+ self.load_snconv(name_pth + "theta", os.path.join(name_tf, "theta"), bias=False)
253
+ self.load_snconv(name_pth + "phi", os.path.join(name_tf, "phi"), bias=False)
254
+ self.load_snconv(name_pth + "g", os.path.join(name_tf, "g"), bias=False)
255
+ self.load_snconv(
256
+ name_pth + "o_conv", os.path.join(name_tf, "o_conv"), bias=False
257
+ )
258
+ self.state_dict[name_pth + "gamma"] = self.load_tf_tensor(name_tf, self.gamma)
259
+
260
+ def load_tf_tensor(self, prefix, var, device="0"):
261
+ name = os.path.join(prefix, var) + f":{device}"
262
+ return torch.from_numpy(self.tf_weights[name][:])
263
+
264
+
265
+ # Convert from v1: This function maps
266
+ def convert_from_v1(hub_dict, resolution=128):
267
+ weightname_dict = {"weight_u": "u0", "weight_bar": "weight", "bias": "bias"}
268
+ convnum_dict = {"conv0": "conv1", "conv1": "conv2", "conv_sc": "conv_sc"}
269
+ attention_blocknum = {128: 3, 256: 4, 512: 3}[resolution]
270
+ hub2me = {
271
+ "linear.weight": "shared.weight", # This is actually the shared weight
272
+ # Linear stuff
273
+ "G_linear.module.weight_bar": "linear.weight",
274
+ "G_linear.module.bias": "linear.bias",
275
+ "G_linear.module.weight_u": "linear.u0",
276
+ # output layer stuff
277
+ "ScaledCrossReplicaBN.weight": "output_layer.0.gain",
278
+ "ScaledCrossReplicaBN.bias": "output_layer.0.bias",
279
+ "ScaledCrossReplicaBN.running_mean": "output_layer.0.stored_mean",
280
+ "ScaledCrossReplicaBN.running_var": "output_layer.0.stored_var",
281
+ "colorize.module.weight_bar": "output_layer.2.weight",
282
+ "colorize.module.bias": "output_layer.2.bias",
283
+ "colorize.module.weight_u": "output_layer.2.u0",
284
+ # Attention stuff
285
+ "attention.gamma": "blocks.%d.1.gamma" % attention_blocknum,
286
+ "attention.theta.module.weight_u": "blocks.%d.1.theta.u0" % attention_blocknum,
287
+ "attention.theta.module.weight_bar": "blocks.%d.1.theta.weight"
288
+ % attention_blocknum,
289
+ "attention.phi.module.weight_u": "blocks.%d.1.phi.u0" % attention_blocknum,
290
+ "attention.phi.module.weight_bar": "blocks.%d.1.phi.weight"
291
+ % attention_blocknum,
292
+ "attention.g.module.weight_u": "blocks.%d.1.g.u0" % attention_blocknum,
293
+ "attention.g.module.weight_bar": "blocks.%d.1.g.weight" % attention_blocknum,
294
+ "attention.o_conv.module.weight_u": "blocks.%d.1.o.u0" % attention_blocknum,
295
+ "attention.o_conv.module.weight_bar": "blocks.%d.1.o.weight"
296
+ % attention_blocknum,
297
+ }
298
+
299
+ # Loop over the hub dict and build the hub2me map
300
+ for name in hub_dict.keys():
301
+ if "GBlock" in name:
302
+ if "HyperBN" not in name: # it's a conv
303
+ out = parse.parse("GBlock.{:d}.{}.module.{}", name)
304
+ blocknum, convnum, weightname = out
305
+ if weightname not in weightname_dict:
306
+ continue # else hyperBN in
307
+ out_name = "blocks.%d.0.%s.%s" % (
308
+ blocknum,
309
+ convnum_dict[convnum],
310
+ weightname_dict[weightname],
311
+ ) # Increment conv number by 1
312
+ else: # hyperbn not conv
313
+ BNnum = 2 if "HyperBN_1" in name else 1
314
+ if "embed" in name:
315
+ out = parse.parse("GBlock.{:d}.{}.module.{}", name)
316
+ blocknum, gamma_or_beta, weightname = out
317
+ if weightname not in weightname_dict: # Ignore weight_v
318
+ continue
319
+ out_name = "blocks.%d.0.bn%d.%s.%s" % (
320
+ blocknum,
321
+ BNnum,
322
+ "gain" if "gamma" in gamma_or_beta else "bias",
323
+ weightname_dict[weightname],
324
+ )
325
+ else:
326
+ out = parse.parse("GBlock.{:d}.{}.bn.{}", name)
327
+ blocknum, dummy, mean_or_var = out
328
+ if "num_batches_tracked" in mean_or_var:
329
+ continue
330
+ out_name = "blocks.%d.0.bn%d.%s" % (
331
+ blocknum,
332
+ BNnum,
333
+ "stored_mean" if "mean" in mean_or_var else "stored_var",
334
+ )
335
+ hub2me[name] = out_name
336
+
337
+ # Invert the hub2me map
338
+ me2hub = {hub2me[item]: item for item in hub2me}
339
+ new_dict = {}
340
+ dimz_dict = {128: 20, 256: 20, 512: 16}
341
+ for item in me2hub:
342
+ # Swap input dim ordering on batchnorm bois to account for my arbitrary change of ordering when concatenating Ys and Zs
343
+ if (
344
+ ("bn" in item and "weight" in item)
345
+ and ("gain" in item or "bias" in item)
346
+ and ("output_layer" not in item)
347
+ ):
348
+ new_dict[item] = torch.cat(
349
+ [
350
+ hub_dict[me2hub[item]][:, -128:],
351
+ hub_dict[me2hub[item]][:, : dimz_dict[resolution]],
352
+ ],
353
+ 1,
354
+ )
355
+ # Reshape the first linear weight, bias, and u0
356
+ elif item == "linear.weight":
357
+ new_dict[item] = (
358
+ hub_dict[me2hub[item]]
359
+ .contiguous()
360
+ .view(4, 4, 96 * 16, -1)
361
+ .permute(2, 0, 1, 3)
362
+ .contiguous()
363
+ .view(-1, dimz_dict[resolution])
364
+ )
365
+ elif item == "linear.bias":
366
+ new_dict[item] = (
367
+ hub_dict[me2hub[item]]
368
+ .view(4, 4, 96 * 16)
369
+ .permute(2, 0, 1)
370
+ .contiguous()
371
+ .view(-1)
372
+ )
373
+ elif item == "linear.u0":
374
+ new_dict[item] = (
375
+ hub_dict[me2hub[item]]
376
+ .view(4, 4, 96 * 16)
377
+ .permute(2, 0, 1)
378
+ .contiguous()
379
+ .view(1, -1)
380
+ )
381
+ elif (
382
+ me2hub[item] == "linear.weight"
383
+ ): # THIS IS THE SHARED WEIGHT NOT THE FIRST LINEAR LAYER
384
+ # Transpose shared weight so that it's an embedding
385
+ new_dict[item] = hub_dict[me2hub[item]].t()
386
+ elif "weight_u" in me2hub[item]: # Unsqueeze u0s
387
+ new_dict[item] = hub_dict[me2hub[item]].unsqueeze(0)
388
+ else:
389
+ new_dict[item] = hub_dict[me2hub[item]]
390
+ return new_dict
391
+
392
+
393
+ def get_config(resolution):
394
+ attn_dict = {128: "64", 256: "128", 512: "64"}
395
+ dim_z_dict = {128: 120, 256: 140, 512: 128}
396
+ config = {
397
+ "G_param": "SN",
398
+ "D_param": "SN",
399
+ "G_ch": 96,
400
+ "D_ch": 96,
401
+ "D_wide": True,
402
+ "G_shared": True,
403
+ "shared_dim": 128,
404
+ "dim_z": dim_z_dict[resolution],
405
+ "hier": True,
406
+ "cross_replica": False,
407
+ "mybn": False,
408
+ "G_activation": nn.ReLU(inplace=True),
409
+ "G_attn": attn_dict[resolution],
410
+ "norm_style": "bn",
411
+ "G_init": "ortho",
412
+ "skip_init": True,
413
+ "no_optim": True,
414
+ "G_fp16": False,
415
+ "G_mixed_precision": False,
416
+ "accumulate_stats": False,
417
+ "num_standing_accumulations": 16,
418
+ "G_eval_mode": True,
419
+ "BN_eps": 1e-04,
420
+ "SN_eps": 1e-04,
421
+ "num_G_SVs": 1,
422
+ "num_G_SV_itrs": 1,
423
+ "resolution": resolution,
424
+ "n_classes": 1000,
425
+ }
426
+ return config
427
+
428
+
429
+ def convert_biggan(
430
+ resolution, weight_dir, redownload=False, no_ema=False, verbose=False
431
+ ):
432
+ module_path = MODULE_PATH_TMPL.format(resolution)
433
+ hdf5_path = os.path.join(weight_dir, HDF5_TMPL.format(resolution))
434
+ pth_path = os.path.join(weight_dir, PTH_TMPL.format(resolution))
435
+
436
+ tf_weights = dump_tfhub_to_hdf5(module_path, hdf5_path, redownload=redownload)
437
+ G_temp = getattr(biggan_for_conversion, f"Generator{resolution}")()
438
+ state_dict_temp = G_temp.state_dict()
439
+
440
+ converter = TFHub2Pytorch(
441
+ state_dict_temp,
442
+ tf_weights,
443
+ resolution=resolution,
444
+ load_ema=(not no_ema),
445
+ verbose=verbose,
446
+ )
447
+ state_dict_v1 = converter.load()
448
+ state_dict = convert_from_v1(state_dict_v1, resolution)
449
+ # Get the config, build the model
450
+ config = get_config(resolution)
451
+ G = BigGAN.Generator(**config)
452
+ G.load_state_dict(state_dict, strict=False) # Ignore missing sv0 entries
453
+ torch.save(state_dict, pth_path)
454
+
455
+ # output_location ='pretrained_weights/TFHub-PyTorch-128.pth'
456
+
457
+ return G
458
+
459
+
460
+ def generate_sample(G, z_dim, batch_size, filename, parallel=False):
461
+
462
+ G.eval()
463
+ G.to(DEVICE)
464
+ with torch.no_grad():
465
+ z = torch.randn(batch_size, G.dim_z).to(DEVICE)
466
+ y = torch.randint(
467
+ low=0,
468
+ high=1000,
469
+ size=(batch_size,),
470
+ device=DEVICE,
471
+ dtype=torch.int64,
472
+ requires_grad=False,
473
+ )
474
+ if parallel:
475
+ images = nn.parallel.data_parallel(G, (z, G.shared(y)))
476
+ else:
477
+ images = G(z, G.shared(y))
478
+ save_image(images, filename, scale_each=True, normalize=True)
479
+
480
+
481
+ def parse_args():
482
+ usage = "Parser for conversion script."
483
+ parser = argparse.ArgumentParser(description=usage)
484
+ parser.add_argument(
485
+ "--resolution",
486
+ "-r",
487
+ type=int,
488
+ default=None,
489
+ choices=[128, 256, 512],
490
+ help="Resolution of TFHub module to convert. Converts all resolutions if None.",
491
+ )
492
+ parser.add_argument(
493
+ "--redownload",
494
+ action="store_true",
495
+ default=False,
496
+ help="Redownload weights and overwrite current hdf5 file, if present.",
497
+ )
498
+ parser.add_argument("--weights_dir", type=str, default="pretrained_weights")
499
+ parser.add_argument("--samples_dir", type=str, default="pretrained_samples")
500
+ parser.add_argument(
501
+ "--no_ema", action="store_true", default=False, help="Do not load ema weights."
502
+ )
503
+ parser.add_argument(
504
+ "--verbose", action="store_true", default=False, help="Additionally logging."
505
+ )
506
+ parser.add_argument(
507
+ "--generate_samples",
508
+ action="store_true",
509
+ default=False,
510
+ help="Generate test sample with pretrained model.",
511
+ )
512
+ parser.add_argument(
513
+ "--batch_size", type=int, default=64, help="Batch size used for test sample."
514
+ )
515
+ parser.add_argument(
516
+ "--parallel", action="store_true", default=False, help="Parallelize G?"
517
+ )
518
+ args = parser.parse_args()
519
+ return args
520
+
521
+
522
+ if __name__ == "__main__":
523
+
524
+ args = parse_args()
525
+ os.makedirs(args.weights_dir, exist_ok=True)
526
+ os.makedirs(args.samples_dir, exist_ok=True)
527
+
528
+ if args.resolution is not None:
529
+ G = convert_biggan(
530
+ args.resolution,
531
+ args.weights_dir,
532
+ redownload=args.redownload,
533
+ no_ema=args.no_ema,
534
+ verbose=args.verbose,
535
+ )
536
+ if args.generate_samples:
537
+ filename = os.path.join(
538
+ args.samples_dir, f"biggan{args.resolution}_samples.jpg"
539
+ )
540
+ print("Generating samples...")
541
+ generate_sample(
542
+ G, Z_DIMS[args.resolution], args.batch_size, filename, args.parallel
543
+ )
544
+ else:
545
+ for res in RESOLUTIONS:
546
+ G = convert_biggan(
547
+ res,
548
+ args.weights_dir,
549
+ redownload=args.redownload,
550
+ no_ema=args.no_ema,
551
+ verbose=args.verbose,
552
+ )
553
+ if args.generate_samples:
554
+ filename = os.path.join(args.samples_dir, f"biggan{res}_samples.jpg")
555
+ print("Generating samples...")
556
+ generate_sample(
557
+ G, Z_DIMS[res], args.batch_size, filename, args.parallel
558
+ )
BigGAN_PyTorch/animal_hash.py ADDED
@@ -0,0 +1,2652 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # All contributions by Andy Brock:
5
+ # Copyright (c) 2019 Andy Brock
6
+ #
7
+ # MIT License
8
+
9
+ c = [
10
+ "Aardvark",
11
+ "Abyssinian",
12
+ "Affenpinscher",
13
+ "Akbash",
14
+ "Akita",
15
+ "Albatross",
16
+ "Alligator",
17
+ "Alpaca",
18
+ "Angelfish",
19
+ "Ant",
20
+ "Anteater",
21
+ "Antelope",
22
+ "Ape",
23
+ "Armadillo",
24
+ "Ass",
25
+ "Avocet",
26
+ "Axolotl",
27
+ "Baboon",
28
+ "Badger",
29
+ "Balinese",
30
+ "Bandicoot",
31
+ "Barb",
32
+ "Barnacle",
33
+ "Barracuda",
34
+ "Bat",
35
+ "Beagle",
36
+ "Bear",
37
+ "Beaver",
38
+ "Bee",
39
+ "Beetle",
40
+ "Binturong",
41
+ "Bird",
42
+ "Birman",
43
+ "Bison",
44
+ "Bloodhound",
45
+ "Boar",
46
+ "Bobcat",
47
+ "Bombay",
48
+ "Bongo",
49
+ "Bonobo",
50
+ "Booby",
51
+ "Budgerigar",
52
+ "Buffalo",
53
+ "Bulldog",
54
+ "Bullfrog",
55
+ "Burmese",
56
+ "Butterfly",
57
+ "Caiman",
58
+ "Camel",
59
+ "Capybara",
60
+ "Caracal",
61
+ "Caribou",
62
+ "Cassowary",
63
+ "Cat",
64
+ "Caterpillar",
65
+ "Catfish",
66
+ "Cattle",
67
+ "Centipede",
68
+ "Chameleon",
69
+ "Chamois",
70
+ "Cheetah",
71
+ "Chicken",
72
+ "Chihuahua",
73
+ "Chimpanzee",
74
+ "Chinchilla",
75
+ "Chinook",
76
+ "Chipmunk",
77
+ "Chough",
78
+ "Cichlid",
79
+ "Clam",
80
+ "Coati",
81
+ "Cobra",
82
+ "Cockroach",
83
+ "Cod",
84
+ "Collie",
85
+ "Coral",
86
+ "Cormorant",
87
+ "Cougar",
88
+ "Cow",
89
+ "Coyote",
90
+ "Crab",
91
+ "Crane",
92
+ "Crocodile",
93
+ "Crow",
94
+ "Curlew",
95
+ "Cuscus",
96
+ "Cuttlefish",
97
+ "Dachshund",
98
+ "Dalmatian",
99
+ "Deer",
100
+ "Dhole",
101
+ "Dingo",
102
+ "Dinosaur",
103
+ "Discus",
104
+ "Dodo",
105
+ "Dog",
106
+ "Dogball",
107
+ "Dogfish",
108
+ "Dolphin",
109
+ "Donkey",
110
+ "Dormouse",
111
+ "Dove",
112
+ "Dragonfly",
113
+ "Drever",
114
+ "Duck",
115
+ "Dugong",
116
+ "Dunker",
117
+ "Dunlin",
118
+ "Eagle",
119
+ "Earwig",
120
+ "Echidna",
121
+ "Eel",
122
+ "Eland",
123
+ "Elephant",
124
+ "ElephantSeal",
125
+ "Elk",
126
+ "Emu",
127
+ "Falcon",
128
+ "Ferret",
129
+ "Finch",
130
+ "Fish",
131
+ "Flamingo",
132
+ "Flounder",
133
+ "Fly",
134
+ "Fossa",
135
+ "Fox",
136
+ "Frigatebird",
137
+ "Frog",
138
+ "Galago",
139
+ "Gar",
140
+ "Gaur",
141
+ "Gazelle",
142
+ "Gecko",
143
+ "Gerbil",
144
+ "Gharial",
145
+ "GiantPanda",
146
+ "Gibbon",
147
+ "Giraffe",
148
+ "Gnat",
149
+ "Gnu",
150
+ "Goat",
151
+ "Goldfinch",
152
+ "Goldfish",
153
+ "Goose",
154
+ "Gopher",
155
+ "Gorilla",
156
+ "Goshawk",
157
+ "Grasshopper",
158
+ "Greyhound",
159
+ "Grouse",
160
+ "Guanaco",
161
+ "GuineaFowl",
162
+ "GuineaPig",
163
+ "Gull",
164
+ "Guppy",
165
+ "Hamster",
166
+ "Hare",
167
+ "Harrier",
168
+ "Havanese",
169
+ "Hawk",
170
+ "Hedgehog",
171
+ "Heron",
172
+ "Herring",
173
+ "Himalayan",
174
+ "Hippopotamus",
175
+ "Hornet",
176
+ "Horse",
177
+ "Human",
178
+ "Hummingbird",
179
+ "Hyena",
180
+ "Ibis",
181
+ "Iguana",
182
+ "Impala",
183
+ "Indri",
184
+ "Insect",
185
+ "Jackal",
186
+ "Jaguar",
187
+ "Javanese",
188
+ "Jay",
189
+ "Jellyfish",
190
+ "Kakapo",
191
+ "Kangaroo",
192
+ "Kingfisher",
193
+ "Kiwi",
194
+ "Koala",
195
+ "KomodoDragon",
196
+ "Kouprey",
197
+ "Kudu",
198
+ "Labradoodle",
199
+ "Ladybird",
200
+ "Lapwing",
201
+ "Lark",
202
+ "Lemming",
203
+ "Lemur",
204
+ "Leopard",
205
+ "Liger",
206
+ "Lion",
207
+ "Lionfish",
208
+ "Lizard",
209
+ "Llama",
210
+ "Lobster",
211
+ "Locust",
212
+ "Loris",
213
+ "Louse",
214
+ "Lynx",
215
+ "Lyrebird",
216
+ "Macaw",
217
+ "Magpie",
218
+ "Mallard",
219
+ "Maltese",
220
+ "Manatee",
221
+ "Mandrill",
222
+ "Markhor",
223
+ "Marten",
224
+ "Mastiff",
225
+ "Mayfly",
226
+ "Meerkat",
227
+ "Millipede",
228
+ "Mink",
229
+ "Mole",
230
+ "Molly",
231
+ "Mongoose",
232
+ "Mongrel",
233
+ "Monkey",
234
+ "Moorhen",
235
+ "Moose",
236
+ "Mosquito",
237
+ "Moth",
238
+ "Mouse",
239
+ "Mule",
240
+ "Narwhal",
241
+ "Neanderthal",
242
+ "Newfoundland",
243
+ "Newt",
244
+ "Nightingale",
245
+ "Numbat",
246
+ "Ocelot",
247
+ "Octopus",
248
+ "Okapi",
249
+ "Olm",
250
+ "Opossum",
251
+ "Orang-utan",
252
+ "Oryx",
253
+ "Ostrich",
254
+ "Otter",
255
+ "Owl",
256
+ "Ox",
257
+ "Oyster",
258
+ "Pademelon",
259
+ "Panther",
260
+ "Parrot",
261
+ "Partridge",
262
+ "Peacock",
263
+ "Peafowl",
264
+ "Pekingese",
265
+ "Pelican",
266
+ "Penguin",
267
+ "Persian",
268
+ "Pheasant",
269
+ "Pig",
270
+ "Pigeon",
271
+ "Pika",
272
+ "Pike",
273
+ "Piranha",
274
+ "Platypus",
275
+ "Pointer",
276
+ "Pony",
277
+ "Poodle",
278
+ "Porcupine",
279
+ "Porpoise",
280
+ "Possum",
281
+ "PrairieDog",
282
+ "Prawn",
283
+ "Puffin",
284
+ "Pug",
285
+ "Puma",
286
+ "Quail",
287
+ "Quelea",
288
+ "Quetzal",
289
+ "Quokka",
290
+ "Quoll",
291
+ "Rabbit",
292
+ "Raccoon",
293
+ "Ragdoll",
294
+ "Rail",
295
+ "Ram",
296
+ "Rat",
297
+ "Rattlesnake",
298
+ "Raven",
299
+ "RedDeer",
300
+ "RedPanda",
301
+ "Reindeer",
302
+ "Rhinoceros",
303
+ "Robin",
304
+ "Rook",
305
+ "Rottweiler",
306
+ "Ruff",
307
+ "Salamander",
308
+ "Salmon",
309
+ "SandDollar",
310
+ "Sandpiper",
311
+ "Saola",
312
+ "Sardine",
313
+ "Scorpion",
314
+ "SeaLion",
315
+ "SeaUrchin",
316
+ "Seahorse",
317
+ "Seal",
318
+ "Serval",
319
+ "Shark",
320
+ "Sheep",
321
+ "Shrew",
322
+ "Shrimp",
323
+ "Siamese",
324
+ "Siberian",
325
+ "Skunk",
326
+ "Sloth",
327
+ "Snail",
328
+ "Snake",
329
+ "Snowshoe",
330
+ "Somali",
331
+ "Sparrow",
332
+ "Spider",
333
+ "Sponge",
334
+ "Squid",
335
+ "Squirrel",
336
+ "Starfish",
337
+ "Starling",
338
+ "Stingray",
339
+ "Stinkbug",
340
+ "Stoat",
341
+ "Stork",
342
+ "Swallow",
343
+ "Swan",
344
+ "Tang",
345
+ "Tapir",
346
+ "Tarsier",
347
+ "Termite",
348
+ "Tetra",
349
+ "Tiffany",
350
+ "Tiger",
351
+ "Toad",
352
+ "Tortoise",
353
+ "Toucan",
354
+ "Tropicbird",
355
+ "Trout",
356
+ "Tuatara",
357
+ "Turkey",
358
+ "Turtle",
359
+ "Uakari",
360
+ "Uguisu",
361
+ "Umbrellabird",
362
+ "Viper",
363
+ "Vulture",
364
+ "Wallaby",
365
+ "Walrus",
366
+ "Warthog",
367
+ "Wasp",
368
+ "WaterBuffalo",
369
+ "Weasel",
370
+ "Whale",
371
+ "Whippet",
372
+ "Wildebeest",
373
+ "Wolf",
374
+ "Wolverine",
375
+ "Wombat",
376
+ "Woodcock",
377
+ "Woodlouse",
378
+ "Woodpecker",
379
+ "Worm",
380
+ "Wrasse",
381
+ "Wren",
382
+ "Yak",
383
+ "Zebra",
384
+ "Zebu",
385
+ "Zonkey",
386
+ ]
387
+ a = [
388
+ "able",
389
+ "above",
390
+ "absent",
391
+ "absolute",
392
+ "abstract",
393
+ "abundant",
394
+ "academic",
395
+ "acceptable",
396
+ "accepted",
397
+ "accessible",
398
+ "accurate",
399
+ "accused",
400
+ "active",
401
+ "actual",
402
+ "acute",
403
+ "added",
404
+ "additional",
405
+ "adequate",
406
+ "adjacent",
407
+ "administrative",
408
+ "adorable",
409
+ "advanced",
410
+ "adverse",
411
+ "advisory",
412
+ "aesthetic",
413
+ "afraid",
414
+ "african",
415
+ "aggregate",
416
+ "aggressive",
417
+ "agreeable",
418
+ "agreed",
419
+ "agricultural",
420
+ "alert",
421
+ "alive",
422
+ "alleged",
423
+ "allied",
424
+ "alone",
425
+ "alright",
426
+ "alternative",
427
+ "amateur",
428
+ "amazing",
429
+ "ambitious",
430
+ "american",
431
+ "amused",
432
+ "ancient",
433
+ "angry",
434
+ "annoyed",
435
+ "annual",
436
+ "anonymous",
437
+ "anxious",
438
+ "appalling",
439
+ "apparent",
440
+ "applicable",
441
+ "appropriate",
442
+ "arab",
443
+ "arbitrary",
444
+ "architectural",
445
+ "armed",
446
+ "arrogant",
447
+ "artificial",
448
+ "artistic",
449
+ "ashamed",
450
+ "asian",
451
+ "asleep",
452
+ "assistant",
453
+ "associated",
454
+ "atomic",
455
+ "attractive",
456
+ "australian",
457
+ "automatic",
458
+ "autonomous",
459
+ "available",
460
+ "average",
461
+ "awake",
462
+ "aware",
463
+ "awful",
464
+ "awkward",
465
+ "back",
466
+ "bad",
467
+ "balanced",
468
+ "bare",
469
+ "basic",
470
+ "beautiful",
471
+ "beneficial",
472
+ "better",
473
+ "bewildered",
474
+ "big",
475
+ "binding",
476
+ "biological",
477
+ "bitter",
478
+ "bizarre",
479
+ "black",
480
+ "blank",
481
+ "blind",
482
+ "blonde",
483
+ "bloody",
484
+ "blue",
485
+ "blushing",
486
+ "boiling",
487
+ "bold",
488
+ "bored",
489
+ "boring",
490
+ "bottom",
491
+ "brainy",
492
+ "brave",
493
+ "breakable",
494
+ "breezy",
495
+ "brief",
496
+ "bright",
497
+ "brilliant",
498
+ "british",
499
+ "broad",
500
+ "broken",
501
+ "brown",
502
+ "bumpy",
503
+ "burning",
504
+ "busy",
505
+ "calm",
506
+ "canadian",
507
+ "capable",
508
+ "capitalist",
509
+ "careful",
510
+ "casual",
511
+ "catholic",
512
+ "causal",
513
+ "cautious",
514
+ "central",
515
+ "certain",
516
+ "changing",
517
+ "characteristic",
518
+ "charming",
519
+ "cheap",
520
+ "cheerful",
521
+ "chemical",
522
+ "chief",
523
+ "chilly",
524
+ "chinese",
525
+ "chosen",
526
+ "christian",
527
+ "chronic",
528
+ "chubby",
529
+ "circular",
530
+ "civic",
531
+ "civil",
532
+ "civilian",
533
+ "classic",
534
+ "classical",
535
+ "clean",
536
+ "clear",
537
+ "clever",
538
+ "clinical",
539
+ "close",
540
+ "closed",
541
+ "cloudy",
542
+ "clumsy",
543
+ "coastal",
544
+ "cognitive",
545
+ "coherent",
546
+ "cold",
547
+ "collective",
548
+ "colonial",
549
+ "colorful",
550
+ "colossal",
551
+ "coloured",
552
+ "colourful",
553
+ "combative",
554
+ "combined",
555
+ "comfortable",
556
+ "coming",
557
+ "commercial",
558
+ "common",
559
+ "communist",
560
+ "compact",
561
+ "comparable",
562
+ "comparative",
563
+ "compatible",
564
+ "competent",
565
+ "competitive",
566
+ "complete",
567
+ "complex",
568
+ "complicated",
569
+ "comprehensive",
570
+ "compulsory",
571
+ "conceptual",
572
+ "concerned",
573
+ "concrete",
574
+ "condemned",
575
+ "confident",
576
+ "confidential",
577
+ "confused",
578
+ "conscious",
579
+ "conservation",
580
+ "conservative",
581
+ "considerable",
582
+ "consistent",
583
+ "constant",
584
+ "constitutional",
585
+ "contemporary",
586
+ "content",
587
+ "continental",
588
+ "continued",
589
+ "continuing",
590
+ "continuous",
591
+ "controlled",
592
+ "controversial",
593
+ "convenient",
594
+ "conventional",
595
+ "convinced",
596
+ "convincing",
597
+ "cooing",
598
+ "cool",
599
+ "cooperative",
600
+ "corporate",
601
+ "correct",
602
+ "corresponding",
603
+ "costly",
604
+ "courageous",
605
+ "crazy",
606
+ "creative",
607
+ "creepy",
608
+ "criminal",
609
+ "critical",
610
+ "crooked",
611
+ "crowded",
612
+ "crucial",
613
+ "crude",
614
+ "cruel",
615
+ "cuddly",
616
+ "cultural",
617
+ "curious",
618
+ "curly",
619
+ "current",
620
+ "curved",
621
+ "cute",
622
+ "daily",
623
+ "damaged",
624
+ "damp",
625
+ "dangerous",
626
+ "dark",
627
+ "dead",
628
+ "deaf",
629
+ "deafening",
630
+ "dear",
631
+ "decent",
632
+ "decisive",
633
+ "deep",
634
+ "defeated",
635
+ "defensive",
636
+ "defiant",
637
+ "definite",
638
+ "deliberate",
639
+ "delicate",
640
+ "delicious",
641
+ "delighted",
642
+ "delightful",
643
+ "democratic",
644
+ "dependent",
645
+ "depressed",
646
+ "desirable",
647
+ "desperate",
648
+ "detailed",
649
+ "determined",
650
+ "developed",
651
+ "developing",
652
+ "devoted",
653
+ "different",
654
+ "difficult",
655
+ "digital",
656
+ "diplomatic",
657
+ "direct",
658
+ "dirty",
659
+ "disabled",
660
+ "disappointed",
661
+ "disastrous",
662
+ "disciplinary",
663
+ "disgusted",
664
+ "distant",
665
+ "distinct",
666
+ "distinctive",
667
+ "distinguished",
668
+ "disturbed",
669
+ "disturbing",
670
+ "diverse",
671
+ "divine",
672
+ "dizzy",
673
+ "domestic",
674
+ "dominant",
675
+ "double",
676
+ "doubtful",
677
+ "drab",
678
+ "dramatic",
679
+ "dreadful",
680
+ "driving",
681
+ "drunk",
682
+ "dry",
683
+ "dual",
684
+ "due",
685
+ "dull",
686
+ "dusty",
687
+ "dutch",
688
+ "dying",
689
+ "dynamic",
690
+ "eager",
691
+ "early",
692
+ "eastern",
693
+ "easy",
694
+ "economic",
695
+ "educational",
696
+ "eerie",
697
+ "effective",
698
+ "efficient",
699
+ "elaborate",
700
+ "elated",
701
+ "elderly",
702
+ "eldest",
703
+ "electoral",
704
+ "electric",
705
+ "electrical",
706
+ "electronic",
707
+ "elegant",
708
+ "eligible",
709
+ "embarrassed",
710
+ "embarrassing",
711
+ "emotional",
712
+ "empirical",
713
+ "empty",
714
+ "enchanting",
715
+ "encouraging",
716
+ "endless",
717
+ "energetic",
718
+ "english",
719
+ "enormous",
720
+ "enthusiastic",
721
+ "entire",
722
+ "entitled",
723
+ "envious",
724
+ "environmental",
725
+ "equal",
726
+ "equivalent",
727
+ "essential",
728
+ "established",
729
+ "estimated",
730
+ "ethical",
731
+ "ethnic",
732
+ "european",
733
+ "eventual",
734
+ "everyday",
735
+ "evident",
736
+ "evil",
737
+ "evolutionary",
738
+ "exact",
739
+ "excellent",
740
+ "exceptional",
741
+ "excess",
742
+ "excessive",
743
+ "excited",
744
+ "exciting",
745
+ "exclusive",
746
+ "existing",
747
+ "exotic",
748
+ "expected",
749
+ "expensive",
750
+ "experienced",
751
+ "experimental",
752
+ "explicit",
753
+ "extended",
754
+ "extensive",
755
+ "external",
756
+ "extra",
757
+ "extraordinary",
758
+ "extreme",
759
+ "exuberant",
760
+ "faint",
761
+ "fair",
762
+ "faithful",
763
+ "familiar",
764
+ "famous",
765
+ "fancy",
766
+ "fantastic",
767
+ "far",
768
+ "fascinating",
769
+ "fashionable",
770
+ "fast",
771
+ "fat",
772
+ "fatal",
773
+ "favourable",
774
+ "favourite",
775
+ "federal",
776
+ "fellow",
777
+ "female",
778
+ "feminist",
779
+ "few",
780
+ "fierce",
781
+ "filthy",
782
+ "final",
783
+ "financial",
784
+ "fine",
785
+ "firm",
786
+ "fiscal",
787
+ "fit",
788
+ "fixed",
789
+ "flaky",
790
+ "flat",
791
+ "flexible",
792
+ "fluffy",
793
+ "fluttering",
794
+ "flying",
795
+ "following",
796
+ "fond",
797
+ "foolish",
798
+ "foreign",
799
+ "formal",
800
+ "formidable",
801
+ "forthcoming",
802
+ "fortunate",
803
+ "forward",
804
+ "fragile",
805
+ "frail",
806
+ "frantic",
807
+ "free",
808
+ "french",
809
+ "frequent",
810
+ "fresh",
811
+ "friendly",
812
+ "frightened",
813
+ "front",
814
+ "frozen",
815
+ "fucking",
816
+ "full",
817
+ "full-time",
818
+ "fun",
819
+ "functional",
820
+ "fundamental",
821
+ "funny",
822
+ "furious",
823
+ "future",
824
+ "fuzzy",
825
+ "gastric",
826
+ "gay",
827
+ "general",
828
+ "generous",
829
+ "genetic",
830
+ "gentle",
831
+ "genuine",
832
+ "geographical",
833
+ "german",
834
+ "giant",
835
+ "gigantic",
836
+ "given",
837
+ "glad",
838
+ "glamorous",
839
+ "gleaming",
840
+ "global",
841
+ "glorious",
842
+ "golden",
843
+ "good",
844
+ "gorgeous",
845
+ "gothic",
846
+ "governing",
847
+ "graceful",
848
+ "gradual",
849
+ "grand",
850
+ "grateful",
851
+ "greasy",
852
+ "great",
853
+ "greek",
854
+ "green",
855
+ "grey",
856
+ "grieving",
857
+ "grim",
858
+ "gross",
859
+ "grotesque",
860
+ "growing",
861
+ "grubby",
862
+ "grumpy",
863
+ "guilty",
864
+ "handicapped",
865
+ "handsome",
866
+ "happy",
867
+ "hard",
868
+ "harsh",
869
+ "head",
870
+ "healthy",
871
+ "heavy",
872
+ "helpful",
873
+ "helpless",
874
+ "hidden",
875
+ "high",
876
+ "high-pitched",
877
+ "hilarious",
878
+ "hissing",
879
+ "historic",
880
+ "historical",
881
+ "hollow",
882
+ "holy",
883
+ "homeless",
884
+ "homely",
885
+ "hon",
886
+ "honest",
887
+ "horizontal",
888
+ "horrible",
889
+ "hostile",
890
+ "hot",
891
+ "huge",
892
+ "human",
893
+ "hungry",
894
+ "hurt",
895
+ "hushed",
896
+ "husky",
897
+ "icy",
898
+ "ideal",
899
+ "identical",
900
+ "ideological",
901
+ "ill",
902
+ "illegal",
903
+ "imaginative",
904
+ "immediate",
905
+ "immense",
906
+ "imperial",
907
+ "implicit",
908
+ "important",
909
+ "impossible",
910
+ "impressed",
911
+ "impressive",
912
+ "improved",
913
+ "inadequate",
914
+ "inappropriate",
915
+ "inc",
916
+ "inclined",
917
+ "increased",
918
+ "increasing",
919
+ "incredible",
920
+ "independent",
921
+ "indian",
922
+ "indirect",
923
+ "individual",
924
+ "industrial",
925
+ "inevitable",
926
+ "influential",
927
+ "informal",
928
+ "inherent",
929
+ "initial",
930
+ "injured",
931
+ "inland",
932
+ "inner",
933
+ "innocent",
934
+ "innovative",
935
+ "inquisitive",
936
+ "instant",
937
+ "institutional",
938
+ "insufficient",
939
+ "intact",
940
+ "integral",
941
+ "integrated",
942
+ "intellectual",
943
+ "intelligent",
944
+ "intense",
945
+ "intensive",
946
+ "interested",
947
+ "interesting",
948
+ "interim",
949
+ "interior",
950
+ "intermediate",
951
+ "internal",
952
+ "international",
953
+ "intimate",
954
+ "invisible",
955
+ "involved",
956
+ "iraqi",
957
+ "irish",
958
+ "irrelevant",
959
+ "islamic",
960
+ "isolated",
961
+ "israeli",
962
+ "italian",
963
+ "itchy",
964
+ "japanese",
965
+ "jealous",
966
+ "jewish",
967
+ "jittery",
968
+ "joint",
969
+ "jolly",
970
+ "joyous",
971
+ "judicial",
972
+ "juicy",
973
+ "junior",
974
+ "just",
975
+ "keen",
976
+ "key",
977
+ "kind",
978
+ "known",
979
+ "korean",
980
+ "labour",
981
+ "large",
982
+ "large-scale",
983
+ "late",
984
+ "latin",
985
+ "lazy",
986
+ "leading",
987
+ "left",
988
+ "legal",
989
+ "legislative",
990
+ "legitimate",
991
+ "lengthy",
992
+ "lesser",
993
+ "level",
994
+ "lexical",
995
+ "liable",
996
+ "liberal",
997
+ "light",
998
+ "like",
999
+ "likely",
1000
+ "limited",
1001
+ "linear",
1002
+ "linguistic",
1003
+ "liquid",
1004
+ "literary",
1005
+ "little",
1006
+ "live",
1007
+ "lively",
1008
+ "living",
1009
+ "local",
1010
+ "logical",
1011
+ "lonely",
1012
+ "long",
1013
+ "long-term",
1014
+ "loose",
1015
+ "lost",
1016
+ "loud",
1017
+ "lovely",
1018
+ "low",
1019
+ "loyal",
1020
+ "ltd",
1021
+ "lucky",
1022
+ "mad",
1023
+ "magenta",
1024
+ "magic",
1025
+ "magnetic",
1026
+ "magnificent",
1027
+ "main",
1028
+ "major",
1029
+ "male",
1030
+ "mammoth",
1031
+ "managerial",
1032
+ "managing",
1033
+ "manual",
1034
+ "many",
1035
+ "marginal",
1036
+ "marine",
1037
+ "marked",
1038
+ "married",
1039
+ "marvellous",
1040
+ "marxist",
1041
+ "mass",
1042
+ "massive",
1043
+ "mathematical",
1044
+ "mature",
1045
+ "maximum",
1046
+ "mean",
1047
+ "meaningful",
1048
+ "mechanical",
1049
+ "medical",
1050
+ "medieval",
1051
+ "melodic",
1052
+ "melted",
1053
+ "mental",
1054
+ "mere",
1055
+ "metropolitan",
1056
+ "mid",
1057
+ "middle",
1058
+ "middle-class",
1059
+ "mighty",
1060
+ "mild",
1061
+ "military",
1062
+ "miniature",
1063
+ "minimal",
1064
+ "minimum",
1065
+ "ministerial",
1066
+ "minor",
1067
+ "miserable",
1068
+ "misleading",
1069
+ "missing",
1070
+ "misty",
1071
+ "mixed",
1072
+ "moaning",
1073
+ "mobile",
1074
+ "moderate",
1075
+ "modern",
1076
+ "modest",
1077
+ "molecular",
1078
+ "monetary",
1079
+ "monthly",
1080
+ "moral",
1081
+ "motionless",
1082
+ "muddy",
1083
+ "multiple",
1084
+ "mushy",
1085
+ "musical",
1086
+ "mute",
1087
+ "mutual",
1088
+ "mysterious",
1089
+ "naked",
1090
+ "narrow",
1091
+ "nasty",
1092
+ "national",
1093
+ "native",
1094
+ "natural",
1095
+ "naughty",
1096
+ "naval",
1097
+ "near",
1098
+ "nearby",
1099
+ "neat",
1100
+ "necessary",
1101
+ "negative",
1102
+ "neighbouring",
1103
+ "nervous",
1104
+ "net",
1105
+ "neutral",
1106
+ "new",
1107
+ "nice",
1108
+ "nineteenth-century",
1109
+ "noble",
1110
+ "noisy",
1111
+ "normal",
1112
+ "northern",
1113
+ "nosy",
1114
+ "notable",
1115
+ "novel",
1116
+ "nuclear",
1117
+ "numerous",
1118
+ "nursing",
1119
+ "nutritious",
1120
+ "nutty",
1121
+ "obedient",
1122
+ "objective",
1123
+ "obliged",
1124
+ "obnoxious",
1125
+ "obvious",
1126
+ "occasional",
1127
+ "occupational",
1128
+ "odd",
1129
+ "official",
1130
+ "ok",
1131
+ "okay",
1132
+ "old",
1133
+ "old-fashioned",
1134
+ "olympic",
1135
+ "only",
1136
+ "open",
1137
+ "operational",
1138
+ "opposite",
1139
+ "optimistic",
1140
+ "oral",
1141
+ "orange",
1142
+ "ordinary",
1143
+ "organic",
1144
+ "organisational",
1145
+ "original",
1146
+ "orthodox",
1147
+ "other",
1148
+ "outdoor",
1149
+ "outer",
1150
+ "outrageous",
1151
+ "outside",
1152
+ "outstanding",
1153
+ "overall",
1154
+ "overseas",
1155
+ "overwhelming",
1156
+ "painful",
1157
+ "pale",
1158
+ "palestinian",
1159
+ "panicky",
1160
+ "parallel",
1161
+ "parental",
1162
+ "parliamentary",
1163
+ "part-time",
1164
+ "partial",
1165
+ "particular",
1166
+ "passing",
1167
+ "passive",
1168
+ "past",
1169
+ "patient",
1170
+ "payable",
1171
+ "peaceful",
1172
+ "peculiar",
1173
+ "perfect",
1174
+ "permanent",
1175
+ "persistent",
1176
+ "personal",
1177
+ "petite",
1178
+ "philosophical",
1179
+ "physical",
1180
+ "pink",
1181
+ "plain",
1182
+ "planned",
1183
+ "plastic",
1184
+ "pleasant",
1185
+ "pleased",
1186
+ "poised",
1187
+ "polish",
1188
+ "polite",
1189
+ "political",
1190
+ "poor",
1191
+ "popular",
1192
+ "positive",
1193
+ "possible",
1194
+ "post-war",
1195
+ "potential",
1196
+ "powerful",
1197
+ "practical",
1198
+ "precious",
1199
+ "precise",
1200
+ "preferred",
1201
+ "pregnant",
1202
+ "preliminary",
1203
+ "premier",
1204
+ "prepared",
1205
+ "present",
1206
+ "presidential",
1207
+ "pretty",
1208
+ "previous",
1209
+ "prickly",
1210
+ "primary",
1211
+ "prime",
1212
+ "primitive",
1213
+ "principal",
1214
+ "printed",
1215
+ "prior",
1216
+ "private",
1217
+ "probable",
1218
+ "productive",
1219
+ "professional",
1220
+ "profitable",
1221
+ "profound",
1222
+ "progressive",
1223
+ "prominent",
1224
+ "promising",
1225
+ "proper",
1226
+ "proposed",
1227
+ "prospective",
1228
+ "protective",
1229
+ "protestant",
1230
+ "proud",
1231
+ "provincial",
1232
+ "psychiatric",
1233
+ "psychological",
1234
+ "public",
1235
+ "puny",
1236
+ "pure",
1237
+ "purple",
1238
+ "purring",
1239
+ "puzzled",
1240
+ "quaint",
1241
+ "qualified",
1242
+ "quick",
1243
+ "quickest",
1244
+ "quiet",
1245
+ "racial",
1246
+ "radical",
1247
+ "rainy",
1248
+ "random",
1249
+ "rapid",
1250
+ "rare",
1251
+ "raspy",
1252
+ "rational",
1253
+ "ratty",
1254
+ "raw",
1255
+ "ready",
1256
+ "real",
1257
+ "realistic",
1258
+ "rear",
1259
+ "reasonable",
1260
+ "recent",
1261
+ "red",
1262
+ "reduced",
1263
+ "redundant",
1264
+ "regional",
1265
+ "registered",
1266
+ "regular",
1267
+ "regulatory",
1268
+ "related",
1269
+ "relative",
1270
+ "relaxed",
1271
+ "relevant",
1272
+ "reliable",
1273
+ "relieved",
1274
+ "religious",
1275
+ "reluctant",
1276
+ "remaining",
1277
+ "remarkable",
1278
+ "remote",
1279
+ "renewed",
1280
+ "representative",
1281
+ "repulsive",
1282
+ "required",
1283
+ "resident",
1284
+ "residential",
1285
+ "resonant",
1286
+ "respectable",
1287
+ "respective",
1288
+ "responsible",
1289
+ "resulting",
1290
+ "retail",
1291
+ "retired",
1292
+ "revolutionary",
1293
+ "rich",
1294
+ "ridiculous",
1295
+ "right",
1296
+ "rigid",
1297
+ "ripe",
1298
+ "rising",
1299
+ "rival",
1300
+ "roasted",
1301
+ "robust",
1302
+ "rolling",
1303
+ "roman",
1304
+ "romantic",
1305
+ "rotten",
1306
+ "rough",
1307
+ "round",
1308
+ "royal",
1309
+ "rubber",
1310
+ "rude",
1311
+ "ruling",
1312
+ "running",
1313
+ "rural",
1314
+ "russian",
1315
+ "sacred",
1316
+ "sad",
1317
+ "safe",
1318
+ "salty",
1319
+ "satisfactory",
1320
+ "satisfied",
1321
+ "scared",
1322
+ "scary",
1323
+ "scattered",
1324
+ "scientific",
1325
+ "scornful",
1326
+ "scottish",
1327
+ "scrawny",
1328
+ "screeching",
1329
+ "secondary",
1330
+ "secret",
1331
+ "secure",
1332
+ "select",
1333
+ "selected",
1334
+ "selective",
1335
+ "selfish",
1336
+ "semantic",
1337
+ "senior",
1338
+ "sensible",
1339
+ "sensitive",
1340
+ "separate",
1341
+ "serious",
1342
+ "severe",
1343
+ "sexual",
1344
+ "shaggy",
1345
+ "shaky",
1346
+ "shallow",
1347
+ "shared",
1348
+ "sharp",
1349
+ "sheer",
1350
+ "shiny",
1351
+ "shivering",
1352
+ "shocked",
1353
+ "short",
1354
+ "short-term",
1355
+ "shrill",
1356
+ "shy",
1357
+ "sick",
1358
+ "significant",
1359
+ "silent",
1360
+ "silky",
1361
+ "silly",
1362
+ "similar",
1363
+ "simple",
1364
+ "single",
1365
+ "skilled",
1366
+ "skinny",
1367
+ "sleepy",
1368
+ "slight",
1369
+ "slim",
1370
+ "slimy",
1371
+ "slippery",
1372
+ "slow",
1373
+ "small",
1374
+ "smart",
1375
+ "smiling",
1376
+ "smoggy",
1377
+ "smooth",
1378
+ "so-called",
1379
+ "social",
1380
+ "socialist",
1381
+ "soft",
1382
+ "solar",
1383
+ "sole",
1384
+ "solid",
1385
+ "sophisticated",
1386
+ "sore",
1387
+ "sorry",
1388
+ "sound",
1389
+ "sour",
1390
+ "southern",
1391
+ "soviet",
1392
+ "spanish",
1393
+ "spare",
1394
+ "sparkling",
1395
+ "spatial",
1396
+ "special",
1397
+ "specific",
1398
+ "specified",
1399
+ "spectacular",
1400
+ "spicy",
1401
+ "spiritual",
1402
+ "splendid",
1403
+ "spontaneous",
1404
+ "sporting",
1405
+ "spotless",
1406
+ "spotty",
1407
+ "square",
1408
+ "squealing",
1409
+ "stable",
1410
+ "stale",
1411
+ "standard",
1412
+ "static",
1413
+ "statistical",
1414
+ "statutory",
1415
+ "steady",
1416
+ "steep",
1417
+ "sticky",
1418
+ "stiff",
1419
+ "still",
1420
+ "stingy",
1421
+ "stormy",
1422
+ "straight",
1423
+ "straightforward",
1424
+ "strange",
1425
+ "strategic",
1426
+ "strict",
1427
+ "striking",
1428
+ "striped",
1429
+ "strong",
1430
+ "structural",
1431
+ "stuck",
1432
+ "stupid",
1433
+ "subjective",
1434
+ "subsequent",
1435
+ "substantial",
1436
+ "subtle",
1437
+ "successful",
1438
+ "successive",
1439
+ "sudden",
1440
+ "sufficient",
1441
+ "suitable",
1442
+ "sunny",
1443
+ "super",
1444
+ "superb",
1445
+ "superior",
1446
+ "supporting",
1447
+ "supposed",
1448
+ "supreme",
1449
+ "sure",
1450
+ "surprised",
1451
+ "surprising",
1452
+ "surrounding",
1453
+ "surviving",
1454
+ "suspicious",
1455
+ "sweet",
1456
+ "swift",
1457
+ "swiss",
1458
+ "symbolic",
1459
+ "sympathetic",
1460
+ "systematic",
1461
+ "tall",
1462
+ "tame",
1463
+ "tan",
1464
+ "tart",
1465
+ "tasteless",
1466
+ "tasty",
1467
+ "technical",
1468
+ "technological",
1469
+ "teenage",
1470
+ "temporary",
1471
+ "tender",
1472
+ "tense",
1473
+ "terrible",
1474
+ "territorial",
1475
+ "testy",
1476
+ "then",
1477
+ "theoretical",
1478
+ "thick",
1479
+ "thin",
1480
+ "thirsty",
1481
+ "thorough",
1482
+ "thoughtful",
1483
+ "thoughtless",
1484
+ "thundering",
1485
+ "tight",
1486
+ "tiny",
1487
+ "tired",
1488
+ "top",
1489
+ "tory",
1490
+ "total",
1491
+ "tough",
1492
+ "toxic",
1493
+ "traditional",
1494
+ "tragic",
1495
+ "tremendous",
1496
+ "tricky",
1497
+ "tropical",
1498
+ "troubled",
1499
+ "turkish",
1500
+ "typical",
1501
+ "ugliest",
1502
+ "ugly",
1503
+ "ultimate",
1504
+ "unable",
1505
+ "unacceptable",
1506
+ "unaware",
1507
+ "uncertain",
1508
+ "unchanged",
1509
+ "uncomfortable",
1510
+ "unconscious",
1511
+ "underground",
1512
+ "underlying",
1513
+ "unemployed",
1514
+ "uneven",
1515
+ "unexpected",
1516
+ "unfair",
1517
+ "unfortunate",
1518
+ "unhappy",
1519
+ "uniform",
1520
+ "uninterested",
1521
+ "unique",
1522
+ "united",
1523
+ "universal",
1524
+ "unknown",
1525
+ "unlikely",
1526
+ "unnecessary",
1527
+ "unpleasant",
1528
+ "unsightly",
1529
+ "unusual",
1530
+ "unwilling",
1531
+ "upper",
1532
+ "upset",
1533
+ "uptight",
1534
+ "urban",
1535
+ "urgent",
1536
+ "used",
1537
+ "useful",
1538
+ "useless",
1539
+ "usual",
1540
+ "vague",
1541
+ "valid",
1542
+ "valuable",
1543
+ "variable",
1544
+ "varied",
1545
+ "various",
1546
+ "varying",
1547
+ "vast",
1548
+ "verbal",
1549
+ "vertical",
1550
+ "very",
1551
+ "victorian",
1552
+ "victorious",
1553
+ "video-taped",
1554
+ "violent",
1555
+ "visible",
1556
+ "visiting",
1557
+ "visual",
1558
+ "vital",
1559
+ "vivacious",
1560
+ "vivid",
1561
+ "vocational",
1562
+ "voiceless",
1563
+ "voluntary",
1564
+ "vulnerable",
1565
+ "wandering",
1566
+ "warm",
1567
+ "wasteful",
1568
+ "watery",
1569
+ "weak",
1570
+ "wealthy",
1571
+ "weary",
1572
+ "wee",
1573
+ "weekly",
1574
+ "weird",
1575
+ "welcome",
1576
+ "well",
1577
+ "well-known",
1578
+ "welsh",
1579
+ "western",
1580
+ "wet",
1581
+ "whispering",
1582
+ "white",
1583
+ "whole",
1584
+ "wicked",
1585
+ "wide",
1586
+ "wide-eyed",
1587
+ "widespread",
1588
+ "wild",
1589
+ "willing",
1590
+ "wise",
1591
+ "witty",
1592
+ "wonderful",
1593
+ "wooden",
1594
+ "working",
1595
+ "working-class",
1596
+ "worldwide",
1597
+ "worried",
1598
+ "worrying",
1599
+ "worthwhile",
1600
+ "worthy",
1601
+ "written",
1602
+ "wrong",
1603
+ "yellow",
1604
+ "young",
1605
+ "yummy",
1606
+ "zany",
1607
+ "zealous",
1608
+ ]
1609
+ b = [
1610
+ "abiding",
1611
+ "accelerating",
1612
+ "accepting",
1613
+ "accomplishing",
1614
+ "achieving",
1615
+ "acquiring",
1616
+ "acteding",
1617
+ "activating",
1618
+ "adapting",
1619
+ "adding",
1620
+ "addressing",
1621
+ "administering",
1622
+ "admiring",
1623
+ "admiting",
1624
+ "adopting",
1625
+ "advising",
1626
+ "affording",
1627
+ "agreeing",
1628
+ "alerting",
1629
+ "alighting",
1630
+ "allowing",
1631
+ "altereding",
1632
+ "amusing",
1633
+ "analyzing",
1634
+ "announcing",
1635
+ "annoying",
1636
+ "answering",
1637
+ "anticipating",
1638
+ "apologizing",
1639
+ "appearing",
1640
+ "applauding",
1641
+ "applieding",
1642
+ "appointing",
1643
+ "appraising",
1644
+ "appreciating",
1645
+ "approving",
1646
+ "arbitrating",
1647
+ "arguing",
1648
+ "arising",
1649
+ "arranging",
1650
+ "arresting",
1651
+ "arriving",
1652
+ "ascertaining",
1653
+ "asking",
1654
+ "assembling",
1655
+ "assessing",
1656
+ "assisting",
1657
+ "assuring",
1658
+ "attaching",
1659
+ "attacking",
1660
+ "attaining",
1661
+ "attempting",
1662
+ "attending",
1663
+ "attracting",
1664
+ "auditeding",
1665
+ "avoiding",
1666
+ "awaking",
1667
+ "backing",
1668
+ "baking",
1669
+ "balancing",
1670
+ "baning",
1671
+ "banging",
1672
+ "baring",
1673
+ "bating",
1674
+ "bathing",
1675
+ "battling",
1676
+ "bing",
1677
+ "beaming",
1678
+ "bearing",
1679
+ "beating",
1680
+ "becoming",
1681
+ "beging",
1682
+ "begining",
1683
+ "behaving",
1684
+ "beholding",
1685
+ "belonging",
1686
+ "bending",
1687
+ "beseting",
1688
+ "beting",
1689
+ "biding",
1690
+ "binding",
1691
+ "biting",
1692
+ "bleaching",
1693
+ "bleeding",
1694
+ "blessing",
1695
+ "blinding",
1696
+ "blinking",
1697
+ "bloting",
1698
+ "blowing",
1699
+ "blushing",
1700
+ "boasting",
1701
+ "boiling",
1702
+ "bolting",
1703
+ "bombing",
1704
+ "booking",
1705
+ "boring",
1706
+ "borrowing",
1707
+ "bouncing",
1708
+ "bowing",
1709
+ "boxing",
1710
+ "braking",
1711
+ "branching",
1712
+ "breaking",
1713
+ "breathing",
1714
+ "breeding",
1715
+ "briefing",
1716
+ "bringing",
1717
+ "broadcasting",
1718
+ "bruising",
1719
+ "brushing",
1720
+ "bubbling",
1721
+ "budgeting",
1722
+ "building",
1723
+ "bumping",
1724
+ "burning",
1725
+ "bursting",
1726
+ "burying",
1727
+ "busting",
1728
+ "buying",
1729
+ "buzing",
1730
+ "calculating",
1731
+ "calling",
1732
+ "camping",
1733
+ "caring",
1734
+ "carrying",
1735
+ "carving",
1736
+ "casting",
1737
+ "cataloging",
1738
+ "catching",
1739
+ "causing",
1740
+ "challenging",
1741
+ "changing",
1742
+ "charging",
1743
+ "charting",
1744
+ "chasing",
1745
+ "cheating",
1746
+ "checking",
1747
+ "cheering",
1748
+ "chewing",
1749
+ "choking",
1750
+ "choosing",
1751
+ "choping",
1752
+ "claiming",
1753
+ "claping",
1754
+ "clarifying",
1755
+ "classifying",
1756
+ "cleaning",
1757
+ "clearing",
1758
+ "clinging",
1759
+ "cliping",
1760
+ "closing",
1761
+ "clothing",
1762
+ "coaching",
1763
+ "coiling",
1764
+ "collecting",
1765
+ "coloring",
1766
+ "combing",
1767
+ "coming",
1768
+ "commanding",
1769
+ "communicating",
1770
+ "comparing",
1771
+ "competing",
1772
+ "compiling",
1773
+ "complaining",
1774
+ "completing",
1775
+ "composing",
1776
+ "computing",
1777
+ "conceiving",
1778
+ "concentrating",
1779
+ "conceptualizing",
1780
+ "concerning",
1781
+ "concluding",
1782
+ "conducting",
1783
+ "confessing",
1784
+ "confronting",
1785
+ "confusing",
1786
+ "connecting",
1787
+ "conserving",
1788
+ "considering",
1789
+ "consisting",
1790
+ "consolidating",
1791
+ "constructing",
1792
+ "consulting",
1793
+ "containing",
1794
+ "continuing",
1795
+ "contracting",
1796
+ "controling",
1797
+ "converting",
1798
+ "coordinating",
1799
+ "copying",
1800
+ "correcting",
1801
+ "correlating",
1802
+ "costing",
1803
+ "coughing",
1804
+ "counseling",
1805
+ "counting",
1806
+ "covering",
1807
+ "cracking",
1808
+ "crashing",
1809
+ "crawling",
1810
+ "creating",
1811
+ "creeping",
1812
+ "critiquing",
1813
+ "crossing",
1814
+ "crushing",
1815
+ "crying",
1816
+ "curing",
1817
+ "curling",
1818
+ "curving",
1819
+ "cuting",
1820
+ "cycling",
1821
+ "daming",
1822
+ "damaging",
1823
+ "dancing",
1824
+ "daring",
1825
+ "dealing",
1826
+ "decaying",
1827
+ "deceiving",
1828
+ "deciding",
1829
+ "decorating",
1830
+ "defining",
1831
+ "delaying",
1832
+ "delegating",
1833
+ "delighting",
1834
+ "delivering",
1835
+ "demonstrating",
1836
+ "depending",
1837
+ "describing",
1838
+ "deserting",
1839
+ "deserving",
1840
+ "designing",
1841
+ "destroying",
1842
+ "detailing",
1843
+ "detecting",
1844
+ "determining",
1845
+ "developing",
1846
+ "devising",
1847
+ "diagnosing",
1848
+ "diging",
1849
+ "directing",
1850
+ "disagreing",
1851
+ "disappearing",
1852
+ "disapproving",
1853
+ "disarming",
1854
+ "discovering",
1855
+ "disliking",
1856
+ "dispensing",
1857
+ "displaying",
1858
+ "disproving",
1859
+ "dissecting",
1860
+ "distributing",
1861
+ "diving",
1862
+ "diverting",
1863
+ "dividing",
1864
+ "doing",
1865
+ "doubling",
1866
+ "doubting",
1867
+ "drafting",
1868
+ "draging",
1869
+ "draining",
1870
+ "dramatizing",
1871
+ "drawing",
1872
+ "dreaming",
1873
+ "dressing",
1874
+ "drinking",
1875
+ "driping",
1876
+ "driving",
1877
+ "dropping",
1878
+ "drowning",
1879
+ "druming",
1880
+ "drying",
1881
+ "dusting",
1882
+ "dwelling",
1883
+ "earning",
1884
+ "eating",
1885
+ "editeding",
1886
+ "educating",
1887
+ "eliminating",
1888
+ "embarrassing",
1889
+ "employing",
1890
+ "emptying",
1891
+ "enacteding",
1892
+ "encouraging",
1893
+ "ending",
1894
+ "enduring",
1895
+ "enforcing",
1896
+ "engineering",
1897
+ "enhancing",
1898
+ "enjoying",
1899
+ "enlisting",
1900
+ "ensuring",
1901
+ "entering",
1902
+ "entertaining",
1903
+ "escaping",
1904
+ "establishing",
1905
+ "estimating",
1906
+ "evaluating",
1907
+ "examining",
1908
+ "exceeding",
1909
+ "exciting",
1910
+ "excusing",
1911
+ "executing",
1912
+ "exercising",
1913
+ "exhibiting",
1914
+ "existing",
1915
+ "expanding",
1916
+ "expecting",
1917
+ "expediting",
1918
+ "experimenting",
1919
+ "explaining",
1920
+ "exploding",
1921
+ "expressing",
1922
+ "extending",
1923
+ "extracting",
1924
+ "facing",
1925
+ "facilitating",
1926
+ "fading",
1927
+ "failing",
1928
+ "fancying",
1929
+ "fastening",
1930
+ "faxing",
1931
+ "fearing",
1932
+ "feeding",
1933
+ "feeling",
1934
+ "fencing",
1935
+ "fetching",
1936
+ "fighting",
1937
+ "filing",
1938
+ "filling",
1939
+ "filming",
1940
+ "finalizing",
1941
+ "financing",
1942
+ "finding",
1943
+ "firing",
1944
+ "fiting",
1945
+ "fixing",
1946
+ "flaping",
1947
+ "flashing",
1948
+ "fleing",
1949
+ "flinging",
1950
+ "floating",
1951
+ "flooding",
1952
+ "flowing",
1953
+ "flowering",
1954
+ "flying",
1955
+ "folding",
1956
+ "following",
1957
+ "fooling",
1958
+ "forbiding",
1959
+ "forcing",
1960
+ "forecasting",
1961
+ "foregoing",
1962
+ "foreseing",
1963
+ "foretelling",
1964
+ "forgeting",
1965
+ "forgiving",
1966
+ "forming",
1967
+ "formulating",
1968
+ "forsaking",
1969
+ "framing",
1970
+ "freezing",
1971
+ "frightening",
1972
+ "frying",
1973
+ "gathering",
1974
+ "gazing",
1975
+ "generating",
1976
+ "geting",
1977
+ "giving",
1978
+ "glowing",
1979
+ "gluing",
1980
+ "going",
1981
+ "governing",
1982
+ "grabing",
1983
+ "graduating",
1984
+ "grating",
1985
+ "greasing",
1986
+ "greeting",
1987
+ "grinning",
1988
+ "grinding",
1989
+ "griping",
1990
+ "groaning",
1991
+ "growing",
1992
+ "guaranteeing",
1993
+ "guarding",
1994
+ "guessing",
1995
+ "guiding",
1996
+ "hammering",
1997
+ "handing",
1998
+ "handling",
1999
+ "handwriting",
2000
+ "hanging",
2001
+ "happening",
2002
+ "harassing",
2003
+ "harming",
2004
+ "hating",
2005
+ "haunting",
2006
+ "heading",
2007
+ "healing",
2008
+ "heaping",
2009
+ "hearing",
2010
+ "heating",
2011
+ "helping",
2012
+ "hiding",
2013
+ "hitting",
2014
+ "holding",
2015
+ "hooking",
2016
+ "hoping",
2017
+ "hopping",
2018
+ "hovering",
2019
+ "hugging",
2020
+ "hmuming",
2021
+ "hunting",
2022
+ "hurrying",
2023
+ "hurting",
2024
+ "hypothesizing",
2025
+ "identifying",
2026
+ "ignoring",
2027
+ "illustrating",
2028
+ "imagining",
2029
+ "implementing",
2030
+ "impressing",
2031
+ "improving",
2032
+ "improvising",
2033
+ "including",
2034
+ "increasing",
2035
+ "inducing",
2036
+ "influencing",
2037
+ "informing",
2038
+ "initiating",
2039
+ "injecting",
2040
+ "injuring",
2041
+ "inlaying",
2042
+ "innovating",
2043
+ "inputing",
2044
+ "inspecting",
2045
+ "inspiring",
2046
+ "installing",
2047
+ "instituting",
2048
+ "instructing",
2049
+ "insuring",
2050
+ "integrating",
2051
+ "intending",
2052
+ "intensifying",
2053
+ "interesting",
2054
+ "interfering",
2055
+ "interlaying",
2056
+ "interpreting",
2057
+ "interrupting",
2058
+ "interviewing",
2059
+ "introducing",
2060
+ "inventing",
2061
+ "inventorying",
2062
+ "investigating",
2063
+ "inviting",
2064
+ "irritating",
2065
+ "itching",
2066
+ "jailing",
2067
+ "jamming",
2068
+ "jogging",
2069
+ "joining",
2070
+ "joking",
2071
+ "judging",
2072
+ "juggling",
2073
+ "jumping",
2074
+ "justifying",
2075
+ "keeping",
2076
+ "kepting",
2077
+ "kicking",
2078
+ "killing",
2079
+ "kissing",
2080
+ "kneeling",
2081
+ "kniting",
2082
+ "knocking",
2083
+ "knotting",
2084
+ "knowing",
2085
+ "labeling",
2086
+ "landing",
2087
+ "lasting",
2088
+ "laughing",
2089
+ "launching",
2090
+ "laying",
2091
+ "leading",
2092
+ "leaning",
2093
+ "leaping",
2094
+ "learning",
2095
+ "leaving",
2096
+ "lecturing",
2097
+ "leding",
2098
+ "lending",
2099
+ "leting",
2100
+ "leveling",
2101
+ "licensing",
2102
+ "licking",
2103
+ "lying",
2104
+ "lifteding",
2105
+ "lighting",
2106
+ "lightening",
2107
+ "liking",
2108
+ "listing",
2109
+ "listening",
2110
+ "living",
2111
+ "loading",
2112
+ "locating",
2113
+ "locking",
2114
+ "loging",
2115
+ "longing",
2116
+ "looking",
2117
+ "losing",
2118
+ "loving",
2119
+ "maintaining",
2120
+ "making",
2121
+ "maning",
2122
+ "managing",
2123
+ "manipulating",
2124
+ "manufacturing",
2125
+ "mapping",
2126
+ "marching",
2127
+ "marking",
2128
+ "marketing",
2129
+ "marrying",
2130
+ "matching",
2131
+ "mating",
2132
+ "mattering",
2133
+ "meaning",
2134
+ "measuring",
2135
+ "meddling",
2136
+ "mediating",
2137
+ "meeting",
2138
+ "melting",
2139
+ "melting",
2140
+ "memorizing",
2141
+ "mending",
2142
+ "mentoring",
2143
+ "milking",
2144
+ "mining",
2145
+ "misleading",
2146
+ "missing",
2147
+ "misspelling",
2148
+ "mistaking",
2149
+ "misunderstanding",
2150
+ "mixing",
2151
+ "moaning",
2152
+ "modeling",
2153
+ "modifying",
2154
+ "monitoring",
2155
+ "mooring",
2156
+ "motivating",
2157
+ "mourning",
2158
+ "moving",
2159
+ "mowing",
2160
+ "muddling",
2161
+ "muging",
2162
+ "multiplying",
2163
+ "murdering",
2164
+ "nailing",
2165
+ "naming",
2166
+ "navigating",
2167
+ "needing",
2168
+ "negotiating",
2169
+ "nesting",
2170
+ "noding",
2171
+ "nominating",
2172
+ "normalizing",
2173
+ "noting",
2174
+ "noticing",
2175
+ "numbering",
2176
+ "obeying",
2177
+ "objecting",
2178
+ "observing",
2179
+ "obtaining",
2180
+ "occuring",
2181
+ "offending",
2182
+ "offering",
2183
+ "officiating",
2184
+ "opening",
2185
+ "operating",
2186
+ "ordering",
2187
+ "organizing",
2188
+ "orienteding",
2189
+ "originating",
2190
+ "overcoming",
2191
+ "overdoing",
2192
+ "overdrawing",
2193
+ "overflowing",
2194
+ "overhearing",
2195
+ "overtaking",
2196
+ "overthrowing",
2197
+ "owing",
2198
+ "owning",
2199
+ "packing",
2200
+ "paddling",
2201
+ "painting",
2202
+ "parking",
2203
+ "parting",
2204
+ "participating",
2205
+ "passing",
2206
+ "pasting",
2207
+ "pating",
2208
+ "pausing",
2209
+ "paying",
2210
+ "pecking",
2211
+ "pedaling",
2212
+ "peeling",
2213
+ "peeping",
2214
+ "perceiving",
2215
+ "perfecting",
2216
+ "performing",
2217
+ "permiting",
2218
+ "persuading",
2219
+ "phoning",
2220
+ "photographing",
2221
+ "picking",
2222
+ "piloting",
2223
+ "pinching",
2224
+ "pining",
2225
+ "pinpointing",
2226
+ "pioneering",
2227
+ "placing",
2228
+ "planing",
2229
+ "planting",
2230
+ "playing",
2231
+ "pleading",
2232
+ "pleasing",
2233
+ "plugging",
2234
+ "pointing",
2235
+ "poking",
2236
+ "polishing",
2237
+ "poping",
2238
+ "possessing",
2239
+ "posting",
2240
+ "pouring",
2241
+ "practicing",
2242
+ "praiseding",
2243
+ "praying",
2244
+ "preaching",
2245
+ "preceding",
2246
+ "predicting",
2247
+ "prefering",
2248
+ "preparing",
2249
+ "prescribing",
2250
+ "presenting",
2251
+ "preserving",
2252
+ "preseting",
2253
+ "presiding",
2254
+ "pressing",
2255
+ "pretending",
2256
+ "preventing",
2257
+ "pricking",
2258
+ "printing",
2259
+ "processing",
2260
+ "procuring",
2261
+ "producing",
2262
+ "professing",
2263
+ "programing",
2264
+ "progressing",
2265
+ "projecting",
2266
+ "promising",
2267
+ "promoting",
2268
+ "proofreading",
2269
+ "proposing",
2270
+ "protecting",
2271
+ "proving",
2272
+ "providing",
2273
+ "publicizing",
2274
+ "pulling",
2275
+ "pumping",
2276
+ "punching",
2277
+ "puncturing",
2278
+ "punishing",
2279
+ "purchasing",
2280
+ "pushing",
2281
+ "puting",
2282
+ "qualifying",
2283
+ "questioning",
2284
+ "queuing",
2285
+ "quiting",
2286
+ "racing",
2287
+ "radiating",
2288
+ "raining",
2289
+ "raising",
2290
+ "ranking",
2291
+ "rating",
2292
+ "reaching",
2293
+ "reading",
2294
+ "realigning",
2295
+ "realizing",
2296
+ "reasoning",
2297
+ "receiving",
2298
+ "recognizing",
2299
+ "recommending",
2300
+ "reconciling",
2301
+ "recording",
2302
+ "recruiting",
2303
+ "reducing",
2304
+ "referring",
2305
+ "reflecting",
2306
+ "refusing",
2307
+ "regreting",
2308
+ "regulating",
2309
+ "rehabilitating",
2310
+ "reigning",
2311
+ "reinforcing",
2312
+ "rejecting",
2313
+ "rejoicing",
2314
+ "relating",
2315
+ "relaxing",
2316
+ "releasing",
2317
+ "relying",
2318
+ "remaining",
2319
+ "remembering",
2320
+ "reminding",
2321
+ "removing",
2322
+ "rendering",
2323
+ "reorganizing",
2324
+ "repairing",
2325
+ "repeating",
2326
+ "replacing",
2327
+ "replying",
2328
+ "reporting",
2329
+ "representing",
2330
+ "reproducing",
2331
+ "requesting",
2332
+ "rescuing",
2333
+ "researching",
2334
+ "resolving",
2335
+ "responding",
2336
+ "restoreding",
2337
+ "restructuring",
2338
+ "retiring",
2339
+ "retrieving",
2340
+ "returning",
2341
+ "reviewing",
2342
+ "revising",
2343
+ "rhyming",
2344
+ "riding",
2345
+ "riding",
2346
+ "ringing",
2347
+ "rinsing",
2348
+ "rising",
2349
+ "risking",
2350
+ "robing",
2351
+ "rocking",
2352
+ "rolling",
2353
+ "roting",
2354
+ "rubing",
2355
+ "ruining",
2356
+ "ruling",
2357
+ "runing",
2358
+ "rushing",
2359
+ "sacking",
2360
+ "sailing",
2361
+ "satisfying",
2362
+ "saving",
2363
+ "sawing",
2364
+ "saying",
2365
+ "scaring",
2366
+ "scattering",
2367
+ "scheduling",
2368
+ "scolding",
2369
+ "scorching",
2370
+ "scraping",
2371
+ "scratching",
2372
+ "screaming",
2373
+ "screwing",
2374
+ "scribbling",
2375
+ "scrubing",
2376
+ "sealing",
2377
+ "searching",
2378
+ "securing",
2379
+ "seing",
2380
+ "seeking",
2381
+ "selecting",
2382
+ "selling",
2383
+ "sending",
2384
+ "sensing",
2385
+ "separating",
2386
+ "serving",
2387
+ "servicing",
2388
+ "seting",
2389
+ "settling",
2390
+ "sewing",
2391
+ "shading",
2392
+ "shaking",
2393
+ "shaping",
2394
+ "sharing",
2395
+ "shaving",
2396
+ "shearing",
2397
+ "sheding",
2398
+ "sheltering",
2399
+ "shining",
2400
+ "shivering",
2401
+ "shocking",
2402
+ "shoing",
2403
+ "shooting",
2404
+ "shoping",
2405
+ "showing",
2406
+ "shrinking",
2407
+ "shruging",
2408
+ "shuting",
2409
+ "sighing",
2410
+ "signing",
2411
+ "signaling",
2412
+ "simplifying",
2413
+ "sining",
2414
+ "singing",
2415
+ "sinking",
2416
+ "siping",
2417
+ "siting",
2418
+ "sketching",
2419
+ "skiing",
2420
+ "skiping",
2421
+ "slaping",
2422
+ "slaying",
2423
+ "sleeping",
2424
+ "sliding",
2425
+ "slinging",
2426
+ "slinking",
2427
+ "sliping",
2428
+ "sliting",
2429
+ "slowing",
2430
+ "smashing",
2431
+ "smelling",
2432
+ "smiling",
2433
+ "smiting",
2434
+ "smoking",
2435
+ "snatching",
2436
+ "sneaking",
2437
+ "sneezing",
2438
+ "sniffing",
2439
+ "snoring",
2440
+ "snowing",
2441
+ "soaking",
2442
+ "solving",
2443
+ "soothing",
2444
+ "soothsaying",
2445
+ "sorting",
2446
+ "sounding",
2447
+ "sowing",
2448
+ "sparing",
2449
+ "sparking",
2450
+ "sparkling",
2451
+ "speaking",
2452
+ "specifying",
2453
+ "speeding",
2454
+ "spelling",
2455
+ "spending",
2456
+ "spilling",
2457
+ "spining",
2458
+ "spiting",
2459
+ "spliting",
2460
+ "spoiling",
2461
+ "spoting",
2462
+ "spraying",
2463
+ "spreading",
2464
+ "springing",
2465
+ "sprouting",
2466
+ "squashing",
2467
+ "squeaking",
2468
+ "squealing",
2469
+ "squeezing",
2470
+ "staining",
2471
+ "stamping",
2472
+ "standing",
2473
+ "staring",
2474
+ "starting",
2475
+ "staying",
2476
+ "stealing",
2477
+ "steering",
2478
+ "stepping",
2479
+ "sticking",
2480
+ "stimulating",
2481
+ "stinging",
2482
+ "stinking",
2483
+ "stirring",
2484
+ "stitching",
2485
+ "stoping",
2486
+ "storing",
2487
+ "straping",
2488
+ "streamlining",
2489
+ "strengthening",
2490
+ "stretching",
2491
+ "striding",
2492
+ "striking",
2493
+ "stringing",
2494
+ "stripping",
2495
+ "striving",
2496
+ "stroking",
2497
+ "structuring",
2498
+ "studying",
2499
+ "stuffing",
2500
+ "subleting",
2501
+ "subtracting",
2502
+ "succeeding",
2503
+ "sucking",
2504
+ "suffering",
2505
+ "suggesting",
2506
+ "suiting",
2507
+ "summarizing",
2508
+ "supervising",
2509
+ "supplying",
2510
+ "supporting",
2511
+ "supposing",
2512
+ "surprising",
2513
+ "surrounding",
2514
+ "suspecting",
2515
+ "suspending",
2516
+ "swearing",
2517
+ "sweating",
2518
+ "sweeping",
2519
+ "swelling",
2520
+ "swimming",
2521
+ "swinging",
2522
+ "switching",
2523
+ "symbolizing",
2524
+ "synthesizing",
2525
+ "systemizing",
2526
+ "tabulating",
2527
+ "taking",
2528
+ "talking",
2529
+ "taming",
2530
+ "taping",
2531
+ "targeting",
2532
+ "tasting",
2533
+ "teaching",
2534
+ "tearing",
2535
+ "teasing",
2536
+ "telephoning",
2537
+ "telling",
2538
+ "tempting",
2539
+ "terrifying",
2540
+ "testing",
2541
+ "thanking",
2542
+ "thawing",
2543
+ "thinking",
2544
+ "thriving",
2545
+ "throwing",
2546
+ "thrusting",
2547
+ "ticking",
2548
+ "tickling",
2549
+ "tying",
2550
+ "timing",
2551
+ "tiping",
2552
+ "tiring",
2553
+ "touching",
2554
+ "touring",
2555
+ "towing",
2556
+ "tracing",
2557
+ "trading",
2558
+ "training",
2559
+ "transcribing",
2560
+ "transfering",
2561
+ "transforming",
2562
+ "translating",
2563
+ "transporting",
2564
+ "traping",
2565
+ "traveling",
2566
+ "treading",
2567
+ "treating",
2568
+ "trembling",
2569
+ "tricking",
2570
+ "triping",
2571
+ "troting",
2572
+ "troubling",
2573
+ "troubleshooting",
2574
+ "trusting",
2575
+ "trying",
2576
+ "tuging",
2577
+ "tumbling",
2578
+ "turning",
2579
+ "tutoring",
2580
+ "twisting",
2581
+ "typing",
2582
+ "undergoing",
2583
+ "understanding",
2584
+ "undertaking",
2585
+ "undressing",
2586
+ "unfastening",
2587
+ "unifying",
2588
+ "uniting",
2589
+ "unlocking",
2590
+ "unpacking",
2591
+ "untidying",
2592
+ "updating",
2593
+ "upgrading",
2594
+ "upholding",
2595
+ "upseting",
2596
+ "using",
2597
+ "utilizing",
2598
+ "vanishing",
2599
+ "verbalizing",
2600
+ "verifying",
2601
+ "vexing",
2602
+ "visiting",
2603
+ "wailing",
2604
+ "waiting",
2605
+ "waking",
2606
+ "walking",
2607
+ "wandering",
2608
+ "wanting",
2609
+ "warming",
2610
+ "warning",
2611
+ "washing",
2612
+ "wasting",
2613
+ "watching",
2614
+ "watering",
2615
+ "waving",
2616
+ "wearing",
2617
+ "weaving",
2618
+ "wedding",
2619
+ "weeping",
2620
+ "weighing",
2621
+ "welcoming",
2622
+ "wending",
2623
+ "weting",
2624
+ "whining",
2625
+ "whiping",
2626
+ "whirling",
2627
+ "whispering",
2628
+ "whistling",
2629
+ "wining",
2630
+ "winding",
2631
+ "winking",
2632
+ "wiping",
2633
+ "wishing",
2634
+ "withdrawing",
2635
+ "withholding",
2636
+ "withstanding",
2637
+ "wobbling",
2638
+ "wondering",
2639
+ "working",
2640
+ "worrying",
2641
+ "wrapping",
2642
+ "wrecking",
2643
+ "wrestling",
2644
+ "wriggling",
2645
+ "wringing",
2646
+ "writing",
2647
+ "x-raying",
2648
+ "yawning",
2649
+ "yelling",
2650
+ "zipping",
2651
+ "zooming",
2652
+ ]
BigGAN_PyTorch/config_files/COCO_Stuff/BigGAN/unconditional_biggan_res128.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "unconditional_biggan_class_cond_res128_COCO",
3
+ "which_dataset": "coco",
4
+ "run_setup": "local_debug",
5
+ "deterministic_run": true,
6
+ "num_workers": 10,
7
+
8
+ "ddp_train": true,
9
+ "n_nodes": 1,
10
+ "n_gpus_per_node": 4,
11
+ "hflips": true,
12
+ "DA": true,
13
+ "DiffAugment": "translation",
14
+
15
+ "test_every": 1,
16
+ "save_every": 1,
17
+ "num_epochs": 3000,
18
+ "es_patience": 50,
19
+ "shuffle": true,
20
+
21
+ "G_eval_mode": true,
22
+ "ema": true,
23
+ "use_ema": true,
24
+ "num_G_accumulations": 1,
25
+ "num_D_accumulations": 1,
26
+ "num_D_steps": 2,
27
+
28
+ "constant_conditioning": true,
29
+ "class_cond": true,
30
+ "hier": true,
31
+ "resolution": 128,
32
+ "G_attn": "64",
33
+ "D_attn": "64",
34
+ "shared_dim": 128,
35
+ "G_shared": true,
36
+ "batch_size": 64,
37
+ "D_lr": 4e-4,
38
+ "G_lr": 1e-4,
39
+ "G_ch": 48,
40
+ "D_ch": 48,
41
+
42
+ "load_weights": ""
43
+
44
+ }
BigGAN_PyTorch/config_files/COCO_Stuff/BigGAN/unconditional_biggan_res256.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "unconditional_biggan_class_cond_res256_COCO",
3
+ "which_dataset": "coco",
4
+ "run_setup": "local_debug",
5
+ "deterministic_run": true,
6
+ "num_workers": 10,
7
+
8
+ "ddp_train": true,
9
+ "n_nodes": 2,
10
+ "n_gpus_per_node": 8,
11
+ "hflips": true,
12
+ "DA": true,
13
+ "DiffAugment": "translation",
14
+
15
+ "test_every": 1,
16
+ "save_every": 1,
17
+ "num_epochs": 3000,
18
+ "es_patience": 50,
19
+ "shuffle": true,
20
+
21
+ "G_eval_mode": true,
22
+ "ema": true,
23
+ "use_ema": true,
24
+ "num_G_accumulations": 1,
25
+ "num_D_accumulations": 1,
26
+ "num_D_steps": 2,
27
+
28
+ "constant_conditioning": true,
29
+ "class_cond": true,
30
+ "hier": true,
31
+ "resolution": 256,
32
+ "G_attn": "64",
33
+ "D_attn": "64",
34
+ "shared_dim": 128,
35
+ "G_shared": true,
36
+ "batch_size": 16,
37
+ "D_lr": 1e-4,
38
+ "G_lr": 1e-4,
39
+ "G_ch": 48,
40
+ "D_ch": 48,
41
+
42
+ "load_weights": ""
43
+
44
+ }
BigGAN_PyTorch/config_files/COCO_Stuff/IC-GAN/icgan_res128_ddp.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "icgan_res128_COCO",
3
+ "which_dataset": "coco",
4
+ "run_setup": "local_debug",
5
+ "deterministic_run": true,
6
+ "num_workers": 10,
7
+
8
+ "ddp_train": true,
9
+ "n_nodes": 1,
10
+ "n_gpus_per_node": 4,
11
+ "hflips": true,
12
+ "DA": true,
13
+ "DiffAugment": "translation",
14
+ "feature_augmentation": true,
15
+
16
+ "test_every": 5,
17
+ "save_every": 1,
18
+ "num_epochs": 3000,
19
+ "es_patience": 50,
20
+ "shuffle": true,
21
+
22
+ "G_eval_mode": true,
23
+ "ema": true,
24
+ "use_ema": true,
25
+ "num_G_accumulations": 1,
26
+ "num_D_accumulations": 1,
27
+ "num_D_steps": 1,
28
+
29
+ "class_cond": false,
30
+ "instance_cond": true,
31
+ "hier": true,
32
+ "resolution": 128,
33
+ "G_attn": "64",
34
+ "D_attn": "64",
35
+ "shared_dim": 128,
36
+ "shared_dim_feat": 512,
37
+ "G_shared": true,
38
+ "G_shared_feat": true,
39
+
40
+ "k_nn": 5,
41
+ "feature_extractor": "selfsupervised",
42
+
43
+ "batch_size": 64,
44
+ "D_lr": 4e-4,
45
+ "G_lr": 1e-4,
46
+ "G_ch": 48,
47
+ "D_ch": 48,
48
+
49
+ "load_weights": ""
50
+
51
+ }
BigGAN_PyTorch/config_files/COCO_Stuff/IC-GAN/icgan_res256_ddp.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "icgan_res256_COCO",
3
+ "which_dataset": "coco",
4
+ "run_setup": "local_debug",
5
+ "deterministic_run": true,
6
+ "num_workers": 10,
7
+
8
+ "ddp_train": true,
9
+ "n_nodes": 2,
10
+ "n_gpus_per_node": 8,
11
+ "hflips": true,
12
+ "DA": true,
13
+ "DiffAugment": "translation",
14
+ "feature_augmentation": true,
15
+
16
+ "test_every": 5,
17
+ "save_every": 1,
18
+ "num_epochs": 3000,
19
+ "es_patience": 50,
20
+ "shuffle": true,
21
+
22
+ "G_eval_mode": true,
23
+ "ema": true,
24
+ "use_ema": true,
25
+ "num_G_accumulations": 1,
26
+ "num_D_accumulations": 1,
27
+ "num_D_steps": 1,
28
+
29
+ "class_cond": false,
30
+ "instance_cond": true,
31
+ "hier": true,
32
+ "resolution": 256,
33
+ "G_attn": "64",
34
+ "D_attn": "64",
35
+ "shared_dim": 128,
36
+ "shared_dim_feat": 512,
37
+ "G_shared": true,
38
+ "G_shared_feat": true,
39
+
40
+ "k_nn": 5,
41
+ "feature_extractor": "selfsupervised",
42
+
43
+ "batch_size": 16,
44
+ "D_lr": 1e-4,
45
+ "G_lr": 1e-4,
46
+ "G_ch": 48,
47
+ "D_ch": 48,
48
+
49
+ "load_weights": ""
50
+
51
+ }
BigGAN_PyTorch/config_files/ImageNet-LT/BigGAN/biggan_res128.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "biggan_imagenet_lt_class_cond_res128",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 1,
9
+ "n_gpus_per_node": 2,
10
+ "hflips": true,
11
+ "DA": true,
12
+ "DiffAugment": "translation",
13
+
14
+ "test_every": 10,
15
+ "save_every": 1,
16
+ "num_epochs": 3000,
17
+ "es_patience": 50,
18
+ "shuffle": true,
19
+
20
+ "G_eval_mode": true,
21
+ "ema": true,
22
+ "use_ema": true,
23
+ "num_G_accumulations": 1,
24
+ "num_D_accumulations": 1,
25
+ "num_D_steps": 2,
26
+
27
+ "class_cond": true,
28
+ "hier": true,
29
+ "resolution": 128,
30
+ "G_attn": "64",
31
+ "D_attn": "64",
32
+ "shared_dim": 128,
33
+ "G_shared": true,
34
+ "batch_size": 64,
35
+ "D_lr": 1e-4,
36
+ "G_lr": 1e-4,
37
+ "G_ch": 64,
38
+ "D_ch": 64,
39
+
40
+ "longtail": true,
41
+ "longtail_gen": true,
42
+ "use_balanced_sampler": false,
43
+ "custom_distrib_gen": false,
44
+ "longtail_temperature": 1,
45
+
46
+ "load_weights": ""
47
+
48
+ }
BigGAN_PyTorch/config_files/ImageNet-LT/BigGAN/biggan_res256.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "biggan_imagenet_lt_class_cond_res256",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 1,
9
+ "n_gpus_per_node": 8,
10
+ "hflips": true,
11
+ "DA": true,
12
+ "DiffAugment": "translation",
13
+
14
+ "test_every": 10,
15
+ "save_every": 1,
16
+ "num_epochs": 3000,
17
+ "es_patience": 50,
18
+ "shuffle": true,
19
+
20
+ "G_eval_mode": true,
21
+ "ema": true,
22
+ "use_ema": true,
23
+ "num_G_accumulations": 1,
24
+ "num_D_accumulations": 1,
25
+ "num_D_steps": 2,
26
+
27
+ "class_cond": true,
28
+ "hier": true,
29
+ "resolution": 256,
30
+ "G_attn": "64",
31
+ "D_attn": "64",
32
+ "shared_dim": 128,
33
+ "G_shared": true,
34
+ "batch_size": 16,
35
+ "D_lr": 1e-4,
36
+ "G_lr": 1e-4,
37
+ "G_ch": 64,
38
+ "D_ch": 64,
39
+
40
+ "longtail": true,
41
+ "longtail_gen": true,
42
+ "use_balanced_sampler": false,
43
+ "custom_distrib_gen": false,
44
+ "longtail_temperature": 1,
45
+
46
+ "load_weights": ""
47
+
48
+ }
BigGAN_PyTorch/config_files/ImageNet-LT/BigGAN/biggan_res64.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "biggan_imagenet_lt_class_cond_res64",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 1,
9
+ "n_gpus_per_node": 1,
10
+ "hflips": true,
11
+ "DA": true,
12
+ "DiffAugment": "translation",
13
+
14
+ "test_every": 1,
15
+ "save_every": 1,
16
+ "num_epochs": 3000,
17
+ "es_patience": 50,
18
+ "shuffle": true,
19
+
20
+ "G_eval_mode": true,
21
+ "ema": true,
22
+ "use_ema": true,
23
+ "num_G_accumulations": 1,
24
+ "num_D_accumulations": 1,
25
+ "num_D_steps": 1,
26
+
27
+ "class_cond": true,
28
+ "hier": true,
29
+ "resolution": 64,
30
+ "G_attn": "32",
31
+ "D_attn": "32",
32
+ "shared_dim": 128,
33
+ "G_shared": true,
34
+ "batch_size": 128,
35
+ "D_lr": 1e-3,
36
+ "G_lr": 1e-5,
37
+ "G_ch": 64,
38
+ "D_ch": 64,
39
+
40
+ "longtail": true,
41
+ "longtail_gen": true,
42
+ "use_balanced_sampler": false,
43
+ "custom_distrib_gen": false,
44
+ "longtail_temperature": 1,
45
+
46
+ "load_weights": ""
47
+
48
+ }
BigGAN_PyTorch/config_files/ImageNet-LT/cc_IC-GAN/cc_icgan_res128.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "cc_icgan_biggan_imagenet_res128",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 1,
9
+ "n_gpus_per_node": 2,
10
+ "hflips": true,
11
+ "DA": true,
12
+ "DiffAugment": "translation",
13
+
14
+ "test_every": 10,
15
+ "save_every": 1,
16
+ "num_epochs": 3000,
17
+ "es_patience": 50,
18
+ "shuffle": true,
19
+
20
+ "G_eval_mode": true,
21
+ "ema": true,
22
+ "use_ema": true,
23
+ "num_G_accumulations": 1,
24
+ "num_D_accumulations": 1,
25
+ "num_D_steps": 2,
26
+
27
+ "class_cond": true,
28
+ "instance_cond": true,
29
+ "which_knn_balance": "instance_balance",
30
+ "hier": true,
31
+ "resolution": 128,
32
+ "G_attn": "64",
33
+ "D_attn": "64",
34
+ "shared_dim": 128,
35
+ "shared_dim_feat": 512,
36
+ "G_shared": true,
37
+ "G_shared_feat": true,
38
+
39
+ "k_nn": 5,
40
+ "feature_extractor": "classification",
41
+
42
+ "batch_size": 64,
43
+ "D_lr": 1e-4,
44
+ "G_lr": 1e-4,
45
+ "G_ch": 64,
46
+ "D_ch": 64,
47
+
48
+ "longtail": true,
49
+ "longtail_gen": true,
50
+ "use_balanced_sampler": false,
51
+ "custom_distrib_gen": false,
52
+ "longtail_temperature": 1,
53
+
54
+ "load_weights": ""
55
+
56
+ }
BigGAN_PyTorch/config_files/ImageNet-LT/cc_IC-GAN/cc_icgan_res256.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "cc_icgan_biggan_imagenet_res256",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 1,
9
+ "n_gpus_per_node": 8,
10
+ "hflips": true,
11
+ "DA": true,
12
+ "DiffAugment": "translation",
13
+
14
+ "test_every": 10,
15
+ "save_every": 1,
16
+ "num_epochs": 3000,
17
+ "es_patience": 50,
18
+ "shuffle": true,
19
+
20
+ "G_eval_mode": true,
21
+ "ema": true,
22
+ "use_ema": true,
23
+ "num_G_accumulations": 1,
24
+ "num_D_accumulations": 1,
25
+ "num_D_steps": 2,
26
+
27
+ "class_cond": true,
28
+ "instance_cond": true,
29
+ "which_knn_balance": "instance_balance",
30
+ "hier": true,
31
+ "resolution": 256,
32
+ "G_attn": "64",
33
+ "D_attn": "64",
34
+ "shared_dim": 128,
35
+ "shared_dim_feat": 512,
36
+ "G_shared": true,
37
+ "G_shared_feat": true,
38
+
39
+ "k_nn": 5,
40
+ "feature_extractor": "classification",
41
+
42
+ "batch_size": 16,
43
+ "D_lr": 1e-4,
44
+ "G_lr": 1e-4,
45
+ "G_ch": 64,
46
+ "D_ch": 64,
47
+
48
+ "longtail": true,
49
+ "longtail_gen": true,
50
+ "use_balanced_sampler": false,
51
+ "custom_distrib_gen": false,
52
+ "longtail_temperature": 1,
53
+
54
+ "load_weights": ""
55
+
56
+ }
BigGAN_PyTorch/config_files/ImageNet-LT/cc_IC-GAN/cc_icgan_res64.json ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "cc_icgan_biggan_imagenet_res64",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 1,
9
+ "n_gpus_per_node": 1,
10
+ "hflips": true,
11
+ "DA": true,
12
+ "DiffAugment": "translation",
13
+
14
+ "test_every": 1,
15
+ "save_every": 1,
16
+ "num_epochs": 3000,
17
+ "es_patience": 50,
18
+ "shuffle": true,
19
+
20
+ "G_eval_mode": true,
21
+ "ema": true,
22
+ "use_ema": true,
23
+ "num_G_accumulations": 1,
24
+ "num_D_accumulations": 1,
25
+ "num_D_steps": 1,
26
+
27
+ "class_cond": true,
28
+ "instance_cond": true,
29
+ "which_knn_balance": "instance_balance",
30
+ "hier": true,
31
+ "resolution": 64,
32
+ "G_attn": "32",
33
+ "D_attn": "32",
34
+ "shared_dim": 128,
35
+ "shared_dim_feat": 512,
36
+ "G_shared": true,
37
+ "G_shared_feat": true,
38
+
39
+ "k_nn": 5,
40
+ "feature_extractor": "classification",
41
+
42
+ "batch_size": 128,
43
+ "D_lr": 1e-3,
44
+ "G_lr": 1e-5,
45
+ "G_ch": 64,
46
+ "D_ch": 64,
47
+
48
+ "longtail": true,
49
+ "longtail_gen": true,
50
+ "use_balanced_sampler": false,
51
+ "custom_distrib_gen": false,
52
+ "longtail_temperature": 1,
53
+
54
+ "load_weights": ""
55
+
56
+ }
BigGAN_PyTorch/config_files/ImageNet/BigGAN/biggan_res128.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "biggan_imagenet_res128",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 4,
9
+ "n_gpus_per_node": 8,
10
+ "hflips": true,
11
+
12
+ "test_every": 5,
13
+ "save_every": 2,
14
+ "num_epochs": 3000,
15
+ "es_patience": 50,
16
+ "shuffle": true,
17
+
18
+ "G_eval_mode": true,
19
+ "ema": true,
20
+ "use_ema": true,
21
+ "num_G_accumulations": 1,
22
+ "num_D_accumulations": 1,
23
+ "num_D_steps": 1,
24
+
25
+ "class_cond": true,
26
+ "hier": true,
27
+ "resolution": 128,
28
+ "G_attn": "64",
29
+ "D_attn": "64",
30
+ "shared_dim": 128,
31
+ "G_shared": true,
32
+ "batch_size": 64,
33
+ "D_lr": 4e-4,
34
+ "G_lr": 1e-4,
35
+ "G_ch": 96,
36
+ "D_ch": 96,
37
+
38
+ "load_weights": ""
39
+
40
+ }
BigGAN_PyTorch/config_files/ImageNet/BigGAN/biggan_res256_half_cap.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "biggan_class_cond_res256_half_cap_noflips",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 4,
9
+ "n_gpus_per_node": 8,
10
+ "hflips": false,
11
+
12
+ "test_every": 5,
13
+ "save_every": 1,
14
+ "num_epochs": 3000,
15
+ "es_patience": 50,
16
+ "shuffle": true,
17
+
18
+ "G_eval_mode": true,
19
+ "ema": true,
20
+ "use_ema": true,
21
+ "num_G_accumulations": 4,
22
+ "num_D_accumulations": 4,
23
+ "num_D_steps": 1,
24
+
25
+ "class_cond": true,
26
+ "hier": true,
27
+ "resolution": 256,
28
+ "G_attn": "64",
29
+ "D_attn": "64",
30
+ "shared_dim": 128,
31
+ "G_shared": true,
32
+ "batch_size": 16,
33
+ "D_lr": 4e-4,
34
+ "G_lr": 1e-4,
35
+ "G_ch": 64,
36
+ "D_ch": 64,
37
+
38
+ "load_weights": ""
39
+
40
+ }
BigGAN_PyTorch/config_files/ImageNet/BigGAN/biggan_res64.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "biggan_imagenet_res64",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 1,
9
+ "n_gpus_per_node": 1,
10
+ "hflips": true,
11
+
12
+ "test_every": 1,
13
+ "save_every": 1,
14
+ "num_epochs": 3000,
15
+ "es_patience": 50,
16
+ "shuffle": true,
17
+
18
+ "G_eval_mode": true,
19
+ "ema": true,
20
+ "use_ema": true,
21
+ "num_G_accumulations": 1,
22
+ "num_D_accumulations": 1,
23
+ "num_D_steps": 1,
24
+
25
+ "class_cond": true,
26
+ "hier": true,
27
+ "resolution": 64,
28
+ "G_attn": "32",
29
+ "D_attn": "32",
30
+ "shared_dim": 128,
31
+ "G_shared": true,
32
+ "batch_size": 256,
33
+ "D_lr": 1e-4,
34
+ "G_lr": 1e-4,
35
+ "G_ch": 64,
36
+ "D_ch": 64,
37
+
38
+ "load_weights": ""
39
+
40
+ }
BigGAN_PyTorch/config_files/ImageNet/IC-GAN/icgan_res128.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "icgan_biggan_imagenet_res128",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 4,
9
+ "n_gpus_per_node": 8,
10
+ "hflips": true,
11
+ "feature_augmentation": true,
12
+
13
+ "test_every": 5,
14
+ "save_every": 1,
15
+ "num_epochs": 3000,
16
+ "es_patience": 50,
17
+ "shuffle": true,
18
+
19
+ "G_eval_mode": true,
20
+ "ema": true,
21
+ "use_ema": true,
22
+ "num_G_accumulations": 1,
23
+ "num_D_accumulations": 1,
24
+ "num_D_steps": 1,
25
+
26
+ "class_cond": false,
27
+ "instance_cond": true,
28
+ "hier": true,
29
+ "resolution": 128,
30
+ "G_attn": "64",
31
+ "D_attn": "64",
32
+ "shared_dim": 128,
33
+ "shared_dim_feat": 512,
34
+ "G_shared": true,
35
+ "G_shared_feat": true,
36
+
37
+ "k_nn": 50,
38
+ "feature_extractor": "selfsupervised",
39
+
40
+ "batch_size": 64,
41
+ "D_lr": 1e-4,
42
+ "G_lr": 4e-5,
43
+ "G_ch": 96,
44
+ "D_ch": 96,
45
+
46
+ "load_weights": ""
47
+
48
+ }
BigGAN_PyTorch/config_files/ImageNet/IC-GAN/icgan_res256.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "icgan_biggan_imagenet_res256",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 4,
9
+ "n_gpus_per_node": 8,
10
+ "hflips": true,
11
+ "feature_augmentation": false,
12
+
13
+ "test_every": 5,
14
+ "save_every": 1,
15
+ "num_epochs": 3000,
16
+ "es_patience": 50,
17
+ "shuffle": true,
18
+
19
+ "G_eval_mode": true,
20
+ "ema": true,
21
+ "use_ema": true,
22
+ "num_G_accumulations": 4,
23
+ "num_D_accumulations": 4,
24
+ "num_D_steps": 1,
25
+
26
+ "class_cond": false,
27
+ "instance_cond": true,
28
+ "hier": true,
29
+ "resolution": 256,
30
+ "G_attn": "64",
31
+ "D_attn": "64",
32
+ "shared_dim": 128,
33
+ "shared_dim_feat": 512,
34
+ "G_shared": true,
35
+ "G_shared_feat": true,
36
+
37
+ "k_nn": 50,
38
+ "feature_extractor": "selfsupervised",
39
+
40
+ "batch_size": 16,
41
+ "D_lr": 1e-4,
42
+ "G_lr": 4e-5,
43
+ "G_ch": 96,
44
+ "D_ch": 96,
45
+
46
+ "load_weights": ""
47
+ }
BigGAN_PyTorch/config_files/ImageNet/IC-GAN/icgan_res256_halfcap.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "icgan_biggan_imagenet_res256_halfcap",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 4,
9
+ "n_gpus_per_node": 8,
10
+ "hflips": true,
11
+ "feature_augmentation": true,
12
+
13
+ "test_every": 5,
14
+ "save_every": 1,
15
+ "num_epochs": 3000,
16
+ "es_patience": 50,
17
+ "shuffle": true,
18
+
19
+ "G_eval_mode": true,
20
+ "ema": true,
21
+ "use_ema": true,
22
+ "num_G_accumulations": 4,
23
+ "num_D_accumulations": 4,
24
+ "num_D_steps": 2,
25
+
26
+ "class_cond": false,
27
+ "instance_cond": true,
28
+ "hier": true,
29
+ "resolution": 256,
30
+ "G_attn": "64",
31
+ "D_attn": "64",
32
+ "shared_dim": 128,
33
+ "shared_dim_feat": 512,
34
+ "G_shared": true,
35
+ "G_shared_feat": true,
36
+
37
+ "k_nn": 50,
38
+ "feature_extractor": "selfsupervised",
39
+
40
+ "batch_size": 16,
41
+ "D_lr": 1e-4,
42
+ "G_lr": 1e-4,
43
+ "G_ch": 64,
44
+ "D_ch": 64,
45
+
46
+ "load_weights": ""
47
+ }
BigGAN_PyTorch/config_files/ImageNet/IC-GAN/icgan_res64.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "icgan_biggan_imagenet_res64",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 1,
9
+ "n_gpus_per_node": 1,
10
+ "hflips": true,
11
+ "feature_augmentation": true,
12
+
13
+ "test_every": 1,
14
+ "save_every": 1,
15
+ "num_epochs": 3000,
16
+ "es_patience": 50,
17
+ "shuffle": true,
18
+
19
+ "G_eval_mode": true,
20
+ "ema": true,
21
+ "use_ema": true,
22
+ "num_G_accumulations": 1,
23
+ "num_D_accumulations": 1,
24
+ "num_D_steps": 1,
25
+
26
+ "class_cond": false,
27
+ "instance_cond": true,
28
+ "hier": true,
29
+ "resolution": 64,
30
+ "G_attn": "32",
31
+ "D_attn": "32",
32
+ "shared_dim": 128,
33
+ "shared_dim_feat": 512,
34
+ "G_shared": true,
35
+ "G_shared_feat": true,
36
+
37
+ "k_nn": 50,
38
+ "feature_extractor": "selfsupervised",
39
+
40
+ "batch_size": 256,
41
+ "D_lr": 1e-4,
42
+ "G_lr": 1e-4,
43
+ "G_ch": 64,
44
+ "D_ch": 64,
45
+
46
+ "load_weights": ""
47
+
48
+ }
BigGAN_PyTorch/config_files/ImageNet/cc_IC-GAN/cc_icgan_res128.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "cc_icgan_biggan_imagenet_res128",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 4,
9
+ "n_gpus_per_node": 8,
10
+ "hflips": true,
11
+ "feature_augmentation": true,
12
+
13
+ "test_every": 5,
14
+ "save_every": 1,
15
+ "num_epochs": 3000,
16
+ "es_patience": 50,
17
+ "shuffle": true,
18
+
19
+ "G_eval_mode": true,
20
+ "ema": true,
21
+ "use_ema": true,
22
+ "num_G_accumulations": 1,
23
+ "num_D_accumulations": 1,
24
+ "num_D_steps": 1,
25
+
26
+ "class_cond": true,
27
+ "instance_cond": true,
28
+ "hier": true,
29
+ "resolution": 128,
30
+ "G_attn": "64",
31
+ "D_attn": "64",
32
+ "shared_dim": 128,
33
+ "shared_dim_feat": 512,
34
+ "G_shared": true,
35
+ "G_shared_feat": true,
36
+
37
+ "k_nn": 50,
38
+ "feature_extractor": "classification",
39
+
40
+ "batch_size": 64,
41
+ "D_lr": 1e-4,
42
+ "G_lr": 4e-5,
43
+ "G_ch": 96,
44
+ "D_ch": 96,
45
+
46
+ "load_weights": ""
47
+
48
+ }
BigGAN_PyTorch/config_files/ImageNet/cc_IC-GAN/cc_icgan_res256.json ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "cc_icgan_biggan_imagenet_res256",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 4,
9
+ "n_gpus_per_node": 8,
10
+ "hflips": true,
11
+ "feature_augmentation": false,
12
+
13
+ "test_every": 5,
14
+ "save_every": 1,
15
+ "num_epochs": 3000,
16
+ "es_patience": 50,
17
+ "shuffle": true,
18
+
19
+ "G_eval_mode": true,
20
+ "ema": true,
21
+ "use_ema": true,
22
+ "num_G_accumulations": 4,
23
+ "num_D_accumulations": 4,
24
+ "num_D_steps": 1,
25
+
26
+ "class_cond": true,
27
+ "instance_cond": true,
28
+ "hier": true,
29
+ "resolution": 256,
30
+ "G_attn": "64",
31
+ "D_attn": "64",
32
+ "shared_dim": 128,
33
+ "shared_dim_feat": 512,
34
+ "G_shared": true,
35
+ "G_shared_feat": true,
36
+
37
+ "k_nn": 50,
38
+ "feature_extractor": "classification",
39
+
40
+ "batch_size": 16,
41
+ "D_lr": 1e-4,
42
+ "G_lr": 4e-5,
43
+ "G_ch": 96,
44
+ "D_ch": 96,
45
+
46
+ "load_weights": ""
47
+ }
BigGAN_PyTorch/config_files/ImageNet/cc_IC-GAN/cc_icgan_res256_halfcap.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "cc_icgan_biggan_imagenet_res256_halfcap",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 4,
9
+ "n_gpus_per_node": 8,
10
+ "hflips": true,
11
+ "feature_augmentation": true,
12
+
13
+ "test_every": 5,
14
+ "save_every": 1,
15
+ "num_epochs": 3000,
16
+ "es_patience": 50,
17
+ "shuffle": true,
18
+
19
+ "G_eval_mode": true,
20
+ "ema": true,
21
+ "use_ema": true,
22
+ "num_G_accumulations": 4,
23
+ "num_D_accumulations": 4,
24
+ "num_D_steps": 2,
25
+
26
+ "class_cond": true,
27
+ "instance_cond": true,
28
+ "hier": true,
29
+ "resolution": 256,
30
+ "G_attn": "64",
31
+ "D_attn": "64",
32
+ "shared_dim": 128,
33
+ "shared_dim_feat": 512,
34
+ "G_shared": true,
35
+ "G_shared_feat": true,
36
+
37
+ "k_nn": 50,
38
+ "feature_extractor": "classification",
39
+
40
+ "batch_size": 16,
41
+ "D_lr": 1e-4,
42
+ "G_lr": 1e-4,
43
+ "G_ch": 64,
44
+ "D_ch": 64,
45
+
46
+ "load_weights": ""
47
+
48
+ }
BigGAN_PyTorch/config_files/ImageNet/cc_IC-GAN/cc_icgan_res64.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "experiment_name": "cc_icgan_biggan_imagenet_res64",
3
+ "run_setup": "local_debug",
4
+ "deterministic_run": true,
5
+ "num_workers": 10,
6
+
7
+ "ddp_train": true,
8
+ "n_nodes": 1,
9
+ "n_gpus_per_node": 1,
10
+ "hflips": true,
11
+ "feature_augmentation": true,
12
+
13
+ "test_every": 1,
14
+ "save_every": 1,
15
+ "num_epochs": 3000,
16
+ "es_patience": 50,
17
+ "shuffle": true,
18
+
19
+ "G_eval_mode": true,
20
+ "ema": true,
21
+ "use_ema": true,
22
+ "num_G_accumulations": 1,
23
+ "num_D_accumulations": 1,
24
+ "num_D_steps": 1,
25
+
26
+ "class_cond": true,
27
+ "instance_cond": true,
28
+ "hier": true,
29
+ "resolution": 64,
30
+ "G_attn": "32",
31
+ "D_attn": "32",
32
+ "shared_dim": 128,
33
+ "shared_dim_feat": 512,
34
+ "G_shared": true,
35
+ "G_shared_feat": true,
36
+
37
+ "k_nn": 50,
38
+ "feature_extractor": "classification",
39
+
40
+ "batch_size": 256,
41
+ "D_lr": 1e-4,
42
+ "G_lr": 1e-4,
43
+ "G_ch": 64,
44
+ "D_ch": 64,
45
+
46
+ "load_weights": ""
47
+
48
+ }
BigGAN_PyTorch/diffaugment_utils.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # Copyright (c) 2020, Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han
5
+ # All rights reserved.
6
+ #
7
+ # Redistribution and use in source and binary forms, with or without
8
+ # modification, are permitted provided that the following conditions are met:
9
+
10
+ # * Redistributions of source code must retain the above copyright notice, this
11
+ # list of conditions and the following disclaimer.
12
+ #
13
+ # * Redistributions in binary form must reproduce the above copyright notice,
14
+ # this list of conditions and the following disclaimer in the documentation
15
+ # and/or other materials provided with the distribution.
16
+ #
17
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
18
+ # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
19
+ # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
20
+ # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
21
+ # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
22
+ # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
23
+ # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
24
+ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
25
+ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26
+ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27
+
28
+ import torch
29
+ import torch.nn.functional as F
30
+
31
+
32
+ def DiffAugment(x, policy="", channels_first=True):
33
+ if policy:
34
+ if not channels_first:
35
+ x = x.permute(0, 3, 1, 2)
36
+ for p in policy.split(","):
37
+ for f in AUGMENT_FNS[p]:
38
+ x = f(x)
39
+ if not channels_first:
40
+ x = x.permute(0, 2, 3, 1)
41
+ x = x.contiguous()
42
+ return x
43
+
44
+
45
+ def rand_brightness(x):
46
+ x = x + (torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) - 0.5)
47
+ return x
48
+
49
+
50
+ def rand_saturation(x):
51
+ x_mean = x.mean(dim=1, keepdim=True)
52
+ x = (x - x_mean) * (
53
+ torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) * 2
54
+ ) + x_mean
55
+ return x
56
+
57
+
58
+ def rand_contrast(x):
59
+ x_mean = x.mean(dim=[1, 2, 3], keepdim=True)
60
+ x = (x - x_mean) * (
61
+ torch.rand(x.size(0), 1, 1, 1, dtype=x.dtype, device=x.device) + 0.5
62
+ ) + x_mean
63
+ return x
64
+
65
+
66
+ def rand_translation(x, ratio=0.125):
67
+ shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
68
+ translation_x = torch.randint(
69
+ -shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device
70
+ )
71
+ translation_y = torch.randint(
72
+ -shift_y, shift_y + 1, size=[x.size(0), 1, 1], device=x.device
73
+ )
74
+ grid_batch, grid_x, grid_y = torch.meshgrid(
75
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
76
+ torch.arange(x.size(2), dtype=torch.long, device=x.device),
77
+ torch.arange(x.size(3), dtype=torch.long, device=x.device),
78
+ )
79
+ grid_x = torch.clamp(grid_x + translation_x + 1, 0, x.size(2) + 1)
80
+ grid_y = torch.clamp(grid_y + translation_y + 1, 0, x.size(3) + 1)
81
+ x_pad = F.pad(x, [1, 1, 1, 1, 0, 0, 0, 0])
82
+ x = (
83
+ x_pad.permute(0, 2, 3, 1)
84
+ .contiguous()[grid_batch, grid_x, grid_y]
85
+ .permute(0, 3, 1, 2)
86
+ )
87
+ return x
88
+
89
+
90
+ def rand_cutout(x, ratio=0.5):
91
+ cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
92
+ offset_x = torch.randint(
93
+ 0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device
94
+ )
95
+ offset_y = torch.randint(
96
+ 0, x.size(3) + (1 - cutout_size[1] % 2), size=[x.size(0), 1, 1], device=x.device
97
+ )
98
+ grid_batch, grid_x, grid_y = torch.meshgrid(
99
+ torch.arange(x.size(0), dtype=torch.long, device=x.device),
100
+ torch.arange(cutout_size[0], dtype=torch.long, device=x.device),
101
+ torch.arange(cutout_size[1], dtype=torch.long, device=x.device),
102
+ )
103
+ grid_x = torch.clamp(
104
+ grid_x + offset_x - cutout_size[0] // 2, min=0, max=x.size(2) - 1
105
+ )
106
+ grid_y = torch.clamp(
107
+ grid_y + offset_y - cutout_size[1] // 2, min=0, max=x.size(3) - 1
108
+ )
109
+ mask = torch.ones(x.size(0), x.size(2), x.size(3), dtype=x.dtype, device=x.device)
110
+ mask[grid_batch, grid_x, grid_y] = 0
111
+ x = x * mask.unsqueeze(1)
112
+ return x
113
+
114
+
115
+ AUGMENT_FNS = {
116
+ "color": [rand_brightness, rand_saturation, rand_contrast],
117
+ "translation": [rand_translation],
118
+ "cutout": [rand_cutout],
119
+ }
BigGAN_PyTorch/imagenet_lt/ImageNet_LT_train.txt ADDED
The diff for this file is too large to render. See raw diff
BigGAN_PyTorch/imagenet_lt/ImageNet_LT_val.txt ADDED
The diff for this file is too large to render. See raw diff
BigGAN_PyTorch/imgs/D Singular Values.png ADDED
BigGAN_PyTorch/imgs/DeepSamples.png ADDED
BigGAN_PyTorch/imgs/DogBall.png ADDED
BigGAN_PyTorch/imgs/G Singular Values.png ADDED
BigGAN_PyTorch/imgs/IS_FID.png ADDED
BigGAN_PyTorch/imgs/Losses.png ADDED
BigGAN_PyTorch/imgs/header_image.jpg ADDED
BigGAN_PyTorch/imgs/interp_sample.jpg ADDED
BigGAN_PyTorch/layers.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # All contributions by Andy Brock:
5
+ # Copyright (c) 2019 Andy Brock
6
+ #
7
+ # MIT License
8
+ """ Layers
9
+ This file contains various layers for the BigGAN models.
10
+ """
11
+ import os
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ from torch.nn import init
16
+ import torch.optim as optim
17
+ import torch.nn.functional as F
18
+ from torch.nn import Parameter as P
19
+
20
+ import sys
21
+
22
+ sys.path.insert(1, os.path.join(sys.path[0], ".."))
23
+ from BigGAN_PyTorch.sync_batchnorm import SynchronizedBatchNorm2d as SyncBN2d
24
+
25
+
26
+ # Projection of x onto y
27
+ def proj(x, y):
28
+ return torch.mm(y, x.t()) * y / torch.mm(y, y.t())
29
+
30
+
31
+ # Orthogonalize x wrt list of vectors ys
32
+ def gram_schmidt(x, ys):
33
+ for y in ys:
34
+ x = x - proj(x, y)
35
+ return x
36
+
37
+
38
+ # Apply num_itrs steps of the power method to estimate top N singular values.
39
+ def power_iteration(W, u_, update=True, eps=1e-12):
40
+ # Lists holding singular vectors and values
41
+ us, vs, svs = [], [], []
42
+ for i, u in enumerate(u_):
43
+ # Run one step of the power iteration
44
+ with torch.no_grad():
45
+ v = torch.matmul(u, W)
46
+ # Run Gram-Schmidt to subtract components of all other singular vectors
47
+ v = F.normalize(gram_schmidt(v, vs), eps=eps)
48
+ # Add to the list
49
+ vs += [v]
50
+ # Update the other singular vector
51
+ u = torch.matmul(v, W.t())
52
+ # Run Gram-Schmidt to subtract components of all other singular vectors
53
+ u = F.normalize(gram_schmidt(u, us), eps=eps)
54
+ # Add to the list
55
+ us += [u]
56
+ if update:
57
+ u_[i][:] = u
58
+ # Compute this singular value and add it to the list
59
+ svs += [torch.squeeze(torch.matmul(torch.matmul(v, W.t()), u.t()))]
60
+ # svs += [torch.sum(F.linear(u, W.transpose(0, 1)) * v)]
61
+ return svs, us, vs
62
+
63
+
64
+ # Convenience passthrough function
65
+ class identity(nn.Module):
66
+ def forward(self, input):
67
+ return input
68
+
69
+
70
+ # Spectral normalization base class
71
+ class SN(object):
72
+ def __init__(self, num_svs, num_itrs, num_outputs, transpose=False, eps=1e-12):
73
+ # Number of power iterations per step
74
+ self.num_itrs = num_itrs
75
+ # Number of singular values
76
+ self.num_svs = num_svs
77
+ # Transposed?
78
+ self.transpose = transpose
79
+ # Epsilon value for avoiding divide-by-0
80
+ self.eps = eps
81
+ # Register a singular vector for each sv
82
+ for i in range(self.num_svs):
83
+ self.register_buffer("u%d" % i, torch.randn(1, num_outputs))
84
+ self.register_buffer("sv%d" % i, torch.ones(1))
85
+
86
+ # Singular vectors (u side)
87
+ @property
88
+ def u(self):
89
+ return [getattr(self, "u%d" % i) for i in range(self.num_svs)]
90
+
91
+ # Singular values;
92
+ # note that these buffers are just for logging and are not used in training.
93
+ @property
94
+ def sv(self):
95
+ return [getattr(self, "sv%d" % i) for i in range(self.num_svs)]
96
+
97
+ # Compute the spectrally-normalized weight
98
+ def W_(self):
99
+ W_mat = self.weight.view(self.weight.size(0), -1)
100
+ if self.transpose:
101
+ W_mat = W_mat.t()
102
+ # Apply num_itrs power iterations
103
+ for _ in range(self.num_itrs):
104
+ svs, us, vs = power_iteration(
105
+ W_mat, self.u, update=self.training, eps=self.eps
106
+ )
107
+ # Update the svs
108
+ if self.training:
109
+ with torch.no_grad(): # Make sure to do this in a no_grad() context or you'll get memory leaks!
110
+ for i, sv in enumerate(svs):
111
+ self.sv[i][:] = sv
112
+ return self.weight / svs[0]
113
+
114
+
115
+ # 2D Conv layer with spectral norm
116
+ class SNConv2d(nn.Conv2d, SN):
117
+ def __init__(
118
+ self,
119
+ in_channels,
120
+ out_channels,
121
+ kernel_size,
122
+ stride=1,
123
+ padding=0,
124
+ dilation=1,
125
+ groups=1,
126
+ bias=True,
127
+ num_svs=1,
128
+ num_itrs=1,
129
+ eps=1e-12,
130
+ ):
131
+ nn.Conv2d.__init__(
132
+ self,
133
+ in_channels,
134
+ out_channels,
135
+ kernel_size,
136
+ stride,
137
+ padding,
138
+ dilation,
139
+ groups,
140
+ bias,
141
+ )
142
+ SN.__init__(self, num_svs, num_itrs, out_channels, eps=eps)
143
+
144
+ def forward(self, x):
145
+ return F.conv2d(
146
+ x,
147
+ self.W_(),
148
+ self.bias,
149
+ self.stride,
150
+ self.padding,
151
+ self.dilation,
152
+ self.groups,
153
+ )
154
+
155
+
156
+ # Linear layer with spectral norm
157
+ class SNLinear(nn.Linear, SN):
158
+ def __init__(
159
+ self, in_features, out_features, bias=True, num_svs=1, num_itrs=1, eps=1e-12
160
+ ):
161
+ nn.Linear.__init__(self, in_features, out_features, bias)
162
+ SN.__init__(self, num_svs, num_itrs, out_features, eps=eps)
163
+
164
+ def forward(self, x):
165
+ return F.linear(x, self.W_(), self.bias)
166
+
167
+
168
+ # Embedding layer with spectral norm
169
+ # We use num_embeddings as the dim instead of embedding_dim here
170
+ # for convenience sake
171
+ class SNEmbedding(nn.Embedding, SN):
172
+ def __init__(
173
+ self,
174
+ num_embeddings,
175
+ embedding_dim,
176
+ padding_idx=None,
177
+ max_norm=None,
178
+ norm_type=2,
179
+ scale_grad_by_freq=False,
180
+ sparse=False,
181
+ _weight=None,
182
+ num_svs=1,
183
+ num_itrs=1,
184
+ eps=1e-12,
185
+ ):
186
+ nn.Embedding.__init__(
187
+ self,
188
+ num_embeddings,
189
+ embedding_dim,
190
+ padding_idx,
191
+ max_norm,
192
+ norm_type,
193
+ scale_grad_by_freq,
194
+ sparse,
195
+ _weight,
196
+ )
197
+ SN.__init__(self, num_svs, num_itrs, num_embeddings, eps=eps)
198
+
199
+ def forward(self, x):
200
+ return F.embedding(x, self.W_())
201
+
202
+
203
+ # A non-local block as used in SA-GAN
204
+ # Note that the implementation as described in the paper is largely incorrect;
205
+ # refer to the released code for the actual implementation.
206
+ class Attention(nn.Module):
207
+ def __init__(self, ch, which_conv=SNConv2d, name="attention"):
208
+ super(Attention, self).__init__()
209
+ # Channel multiplier
210
+ self.ch = ch
211
+ self.which_conv = which_conv
212
+ self.theta = self.which_conv(
213
+ self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
214
+ )
215
+ self.phi = self.which_conv(
216
+ self.ch, self.ch // 8, kernel_size=1, padding=0, bias=False
217
+ )
218
+ self.g = self.which_conv(
219
+ self.ch, self.ch // 2, kernel_size=1, padding=0, bias=False
220
+ )
221
+ self.o = self.which_conv(
222
+ self.ch // 2, self.ch, kernel_size=1, padding=0, bias=False
223
+ )
224
+ # Learnable gain parameter
225
+ self.gamma = P(torch.tensor(0.0), requires_grad=True)
226
+
227
+ def forward(self, x, y=None):
228
+ # Apply convs
229
+ theta = self.theta(x)
230
+ phi = F.max_pool2d(self.phi(x), [2, 2])
231
+ g = F.max_pool2d(self.g(x), [2, 2])
232
+ # Perform reshapes
233
+ theta = theta.view(-1, self.ch // 8, x.shape[2] * x.shape[3])
234
+ phi = phi.view(-1, self.ch // 8, x.shape[2] * x.shape[3] // 4)
235
+ g = g.view(-1, self.ch // 2, x.shape[2] * x.shape[3] // 4)
236
+ # Matmul and softmax to get attention maps
237
+ beta = F.softmax(torch.bmm(theta.transpose(1, 2), phi), -1)
238
+ # Attention map times g path
239
+ o = self.o(
240
+ torch.bmm(g, beta.transpose(1, 2)).view(
241
+ -1, self.ch // 2, x.shape[2], x.shape[3]
242
+ )
243
+ )
244
+ return self.gamma * o + x
245
+
246
+
247
+ # Fused batchnorm op
248
+ def fused_bn(x, mean, var, gain=None, bias=None, eps=1e-5):
249
+ # Apply scale and shift--if gain and bias are provided, fuse them here
250
+ # Prepare scale
251
+ scale = torch.rsqrt(var + eps)
252
+ # If a gain is provided, use it
253
+ if gain is not None:
254
+ scale = scale * gain
255
+ # Prepare shift
256
+ shift = mean * scale
257
+ # If bias is provided, use it
258
+ if bias is not None:
259
+ shift = shift - bias
260
+ return x * scale - shift
261
+ # return ((x - mean) / ((var + eps) ** 0.5)) * gain + bias # The unfused way.
262
+
263
+
264
+ # Manual BN
265
+ # Calculate means and variances using mean-of-squares minus mean-squared
266
+ def manual_bn(x, gain=None, bias=None, return_mean_var=False, eps=1e-5):
267
+ # Cast x to float32 if necessary
268
+ float_x = x.float()
269
+ # Calculate expected value of x (m) and expected value of x**2 (m2)
270
+ # Mean of x
271
+ m = torch.mean(float_x, [0, 2, 3], keepdim=True)
272
+ # Mean of x squared
273
+ m2 = torch.mean(float_x ** 2, [0, 2, 3], keepdim=True)
274
+ # Calculate variance as mean of squared minus mean squared.
275
+ var = m2 - m ** 2
276
+ # Cast back to float 16 if necessary
277
+ var = var.type(x.type())
278
+ m = m.type(x.type())
279
+ # Return mean and variance for updating stored mean/var if requested
280
+ if return_mean_var:
281
+ return fused_bn(x, m, var, gain, bias, eps), m.squeeze(), var.squeeze()
282
+ else:
283
+ return fused_bn(x, m, var, gain, bias, eps)
284
+
285
+
286
+ # My batchnorm, supports standing stats
287
+ class myBN(nn.Module):
288
+ def __init__(self, num_channels, eps=1e-5, momentum=0.1):
289
+ super(myBN, self).__init__()
290
+ # momentum for updating running stats
291
+ self.momentum = momentum
292
+ # epsilon to avoid dividing by 0
293
+ self.eps = eps
294
+ # Momentum
295
+ self.momentum = momentum
296
+ # Register buffers
297
+ self.register_buffer("stored_mean", torch.zeros(num_channels))
298
+ self.register_buffer("stored_var", torch.ones(num_channels))
299
+ self.register_buffer("accumulation_counter", torch.zeros(1))
300
+ # Accumulate running means and vars
301
+ self.accumulate_standing = False
302
+
303
+ # reset standing stats
304
+ def reset_stats(self):
305
+ self.stored_mean[:] = 0
306
+ self.stored_var[:] = 0
307
+ self.accumulation_counter[:] = 0
308
+
309
+ def forward(self, x, gain, bias):
310
+ if self.training:
311
+ out, mean, var = manual_bn(
312
+ x, gain, bias, return_mean_var=True, eps=self.eps
313
+ )
314
+ # If accumulating standing stats, increment them
315
+ if self.accumulate_standing:
316
+ self.stored_mean[:] = self.stored_mean + mean.data
317
+ self.stored_var[:] = self.stored_var + var.data
318
+ self.accumulation_counter += 1.0
319
+ # If not accumulating standing stats, take running averages
320
+ else:
321
+ self.stored_mean[:] = (
322
+ self.stored_mean * (1 - self.momentum) + mean * self.momentum
323
+ )
324
+ self.stored_var[:] = (
325
+ self.stored_var * (1 - self.momentum) + var * self.momentum
326
+ )
327
+ return out
328
+ # If not in training mode, use the stored statistics
329
+ else:
330
+ mean = self.stored_mean.view(1, -1, 1, 1)
331
+ var = self.stored_var.view(1, -1, 1, 1)
332
+ # If using standing stats, divide them by the accumulation counter
333
+ if self.accumulate_standing:
334
+ mean = mean / self.accumulation_counter
335
+ var = var / self.accumulation_counter
336
+ return fused_bn(x, mean, var, gain, bias, self.eps)
337
+
338
+
339
+ # Simple function to handle groupnorm norm stylization
340
+ def groupnorm(x, norm_style):
341
+ # If number of channels specified in norm_style:
342
+ if "ch" in norm_style:
343
+ ch = int(norm_style.split("_")[-1])
344
+ groups = max(int(x.shape[1]) // ch, 1)
345
+ # If number of groups specified in norm style
346
+ elif "grp" in norm_style:
347
+ groups = int(norm_style.split("_")[-1])
348
+ # If neither, default to groups = 16
349
+ else:
350
+ groups = 16
351
+ return F.group_norm(x, groups)
352
+
353
+
354
+ # Class-conditional bn
355
+ # output size is the number of channels, input size is for the linear layers
356
+ # Andy's Note: this class feels messy but I'm not really sure how to clean it up
357
+ # Suggestions welcome! (By which I mean, refactor this and make a pull request
358
+ # if you want to make this more readable/usable).
359
+ class ccbn(nn.Module):
360
+ def __init__(
361
+ self,
362
+ output_size,
363
+ input_size,
364
+ which_linear,
365
+ eps=1e-5,
366
+ momentum=0.1,
367
+ cross_replica=False,
368
+ mybn=False,
369
+ norm_style="bn",
370
+ ):
371
+ super(ccbn, self).__init__()
372
+ self.output_size, self.input_size = output_size, input_size
373
+ # Prepare gain and bias layers
374
+ self.gain = which_linear(input_size, output_size)
375
+ self.bias = which_linear(input_size, output_size)
376
+ # epsilon to avoid dividing by 0
377
+ self.eps = eps
378
+ # Momentum
379
+ self.momentum = momentum
380
+ # Use cross-replica batchnorm?
381
+ self.cross_replica = cross_replica
382
+ # Use my batchnorm?
383
+ self.mybn = mybn
384
+ # Norm style?
385
+ self.norm_style = norm_style
386
+
387
+ if self.cross_replica:
388
+ # self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
389
+ self.bn = nn.BatchNorm2d(
390
+ output_size, eps=self.eps, momentum=self.momentum, affine=False
391
+ )
392
+ elif self.mybn:
393
+ self.bn = myBN(output_size, self.eps, self.momentum)
394
+ elif self.norm_style in ["bn", "in"]:
395
+ self.register_buffer("stored_mean", torch.zeros(output_size))
396
+ self.register_buffer("stored_var", torch.ones(output_size))
397
+
398
+ def forward(self, x, y):
399
+ # Calculate class-conditional gains and biases
400
+ gain = (1 + self.gain(y)).view(y.size(0), -1, 1, 1)
401
+ bias = self.bias(y).view(y.size(0), -1, 1, 1)
402
+
403
+ # If using my batchnorm
404
+ if self.cross_replica:
405
+ out = self.bn(x)
406
+ out = out * gain + bias
407
+ return out
408
+ elif self.mybn:
409
+ return self.bn(x, gain=gain, bias=bias)
410
+ else:
411
+ if self.norm_style == "bn":
412
+ out = F.batch_norm(
413
+ x,
414
+ self.stored_mean,
415
+ self.stored_var,
416
+ None,
417
+ None,
418
+ self.training,
419
+ 0.1,
420
+ self.eps,
421
+ )
422
+ elif self.norm_style == "in":
423
+ out = F.instance_norm(
424
+ x,
425
+ self.stored_mean,
426
+ self.stored_var,
427
+ None,
428
+ None,
429
+ self.training,
430
+ 0.1,
431
+ self.eps,
432
+ )
433
+ elif self.norm_style == "gn":
434
+ out = groupnorm(x, self.normstyle)
435
+ elif self.norm_style == "nonorm":
436
+ out = x
437
+ return out * gain + bias
438
+
439
+ def extra_repr(self):
440
+ s = "out: {output_size}, in: {input_size},"
441
+ s += " cross_replica={cross_replica}"
442
+ return s.format(**self.__dict__)
443
+
444
+
445
+ # Normal, non-class-conditional BN
446
+ class bn(nn.Module):
447
+ def __init__(
448
+ self,
449
+ output_size,
450
+ eps=1e-5,
451
+ momentum=0.1,
452
+ cross_replica=False,
453
+ mybn=False,
454
+ **kwargs
455
+ ):
456
+ super(bn, self).__init__()
457
+ self.output_size = output_size
458
+
459
+ # epsilon to avoid dividing by 0
460
+ self.eps = eps
461
+ # Momentum
462
+ self.momentum = momentum
463
+ # Use cross-replica batchnorm?
464
+ self.cross_replica = cross_replica
465
+ # Use my batchnorm?
466
+ self.mybn = mybn
467
+
468
+ if self.cross_replica:
469
+ # self.bn = SyncBN2d(output_size, eps=self.eps, momentum=self.momentum, affine=False)
470
+ self.bn = nn.BatchNorm2d(
471
+ output_size, eps=self.eps, momentum=self.momentum, affine=True
472
+ )
473
+ elif mybn:
474
+ # Prepare gain and bias layers
475
+ self.bn = myBN(output_size, self.eps, self.momentum)
476
+ # Register buffers if neither of the above
477
+ else:
478
+ self.register_buffer("stored_mean", torch.zeros(output_size))
479
+ self.register_buffer("stored_var", torch.ones(output_size))
480
+
481
+ if not self.cross_replica:
482
+ self.gain = P(torch.ones(output_size), requires_grad=True)
483
+ self.bias = P(torch.zeros(output_size), requires_grad=True)
484
+
485
+ def forward(self, x, y=None):
486
+ if self.cross_replica:
487
+ out = self.bn(x)
488
+ return out
489
+ elif self.mybn:
490
+ gain = self.gain.view(1, -1, 1, 1)
491
+ bias = self.bias.view(1, -1, 1, 1)
492
+ return self.bn(x, gain=gain, bias=bias)
493
+ else:
494
+ return F.batch_norm(
495
+ x,
496
+ self.stored_mean,
497
+ self.stored_var,
498
+ self.gain,
499
+ self.bias,
500
+ self.training,
501
+ self.momentum,
502
+ self.eps,
503
+ )
504
+
505
+
506
+ # Generator blocks
507
+ # Note that this class assumes the kernel size and padding (and any other
508
+ # settings) have been selected in the main generator module and passed in
509
+ # through the which_conv arg. Similar rules apply with which_bn (the input
510
+ # size [which is actually the number of channels of the conditional info] must
511
+ # be preselected)
512
+ class GBlock(nn.Module):
513
+ def __init__(
514
+ self,
515
+ in_channels,
516
+ out_channels,
517
+ which_conv=nn.Conv2d,
518
+ which_bn=bn,
519
+ activation=None,
520
+ upsample=None,
521
+ ):
522
+ super(GBlock, self).__init__()
523
+
524
+ self.in_channels, self.out_channels = in_channels, out_channels
525
+ self.which_conv, self.which_bn = which_conv, which_bn
526
+ self.activation = activation
527
+ self.upsample = upsample
528
+ # Conv layers
529
+ self.conv1 = self.which_conv(self.in_channels, self.out_channels)
530
+ self.conv2 = self.which_conv(self.out_channels, self.out_channels)
531
+ self.learnable_sc = in_channels != out_channels or upsample
532
+ if self.learnable_sc:
533
+ self.conv_sc = self.which_conv(
534
+ in_channels, out_channels, kernel_size=1, padding=0
535
+ )
536
+ # Batchnorm layers
537
+ self.bn1 = self.which_bn(in_channels)
538
+ self.bn2 = self.which_bn(out_channels)
539
+ # upsample layers
540
+ self.upsample = upsample
541
+
542
+ def forward(self, x, y):
543
+ h = self.activation(self.bn1(x, y))
544
+ if self.upsample:
545
+ h = self.upsample(h)
546
+ x = self.upsample(x)
547
+ h = self.conv1(h)
548
+ h = self.activation(self.bn2(h, y))
549
+ h = self.conv2(h)
550
+ if self.learnable_sc:
551
+ x = self.conv_sc(x)
552
+ return h + x
553
+
554
+
555
+ # Residual block for the discriminator
556
+ class DBlock(nn.Module):
557
+ def __init__(
558
+ self,
559
+ in_channels,
560
+ out_channels,
561
+ which_conv=SNConv2d,
562
+ wide=True,
563
+ preactivation=False,
564
+ activation=None,
565
+ downsample=None,
566
+ ):
567
+ super(DBlock, self).__init__()
568
+ self.in_channels, self.out_channels = in_channels, out_channels
569
+ # If using wide D (as in SA-GAN and BigGAN), change the channel pattern
570
+ self.hidden_channels = self.out_channels if wide else self.in_channels
571
+ self.which_conv = which_conv
572
+ self.preactivation = preactivation
573
+ self.activation = activation
574
+ self.downsample = downsample
575
+
576
+ # Conv layers
577
+ self.conv1 = self.which_conv(self.in_channels, self.hidden_channels)
578
+ self.conv2 = self.which_conv(self.hidden_channels, self.out_channels)
579
+ self.learnable_sc = (
580
+ True if (in_channels != out_channels) or downsample else False
581
+ )
582
+ if self.learnable_sc:
583
+ self.conv_sc = self.which_conv(
584
+ in_channels, out_channels, kernel_size=1, padding=0
585
+ )
586
+
587
+ def shortcut(self, x):
588
+ if self.preactivation:
589
+ if self.learnable_sc:
590
+ x = self.conv_sc(x)
591
+ if self.downsample:
592
+ x = self.downsample(x)
593
+ else:
594
+ if self.downsample:
595
+ x = self.downsample(x)
596
+ if self.learnable_sc:
597
+ x = self.conv_sc(x)
598
+ return x
599
+
600
+ def forward(self, x):
601
+ if self.preactivation:
602
+ # h = self.activation(x) # NOT TODAY SATAN
603
+ # Andy's note: This line *must* be an out-of-place ReLU or it
604
+ # will negatively affect the shortcut connection.
605
+ h = F.relu(x)
606
+ else:
607
+ h = x
608
+ h = self.conv1(h)
609
+ h = self.conv2(self.activation(h))
610
+ if self.downsample:
611
+ h = self.downsample(h)
612
+
613
+ return h + self.shortcut(x)
614
+
615
+
616
+ # dogball
BigGAN_PyTorch/logs/BigGAN_ch96_bs256x8.jsonl ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {"itr": 2000, "IS_mean": 2.806771755218506, "IS_std": 0.019480662420392036, "FID": 173.76484159711126, "_stamp": 1551403232.0425167}
2
+ {"itr": 4000, "IS_mean": 4.962374687194824, "IS_std": 0.07276841998100281, "FID": 113.86730514283107, "_stamp": 1551422228.743057}
3
+ {"itr": 6000, "IS_mean": 6.939817905426025, "IS_std": 0.11417163163423538, "FID": 101.63548498447199, "_stamp": 1551457139.3400874}
4
+ {"itr": 8000, "IS_mean": 8.142985343933105, "IS_std": 0.11931543797254562, "FID": 92.0014385772705, "_stamp": 1551476217.2409613}
5
+ {"itr": 10000, "IS_mean": 10.355518341064453, "IS_std": 0.09094739705324173, "FID": 83.58068997965364, "_stamp": 1551494854.2419689}
6
+ {"itr": 12000, "IS_mean": 11.288347244262695, "IS_std": 0.14952820539474487, "FID": 80.98066299357106, "_stamp": 1551513232.5049698}
7
+ {"itr": 14000, "IS_mean": 11.755794525146484, "IS_std": 0.17969024181365967, "FID": 76.80603924280956, "_stamp": 1551531425.150371}
8
+ {"itr": 18000, "IS_mean": 13.65534496307373, "IS_std": 0.11151058971881866, "FID": 65.95736694335938, "_stamp": 1551588271.9177916}
9
+ {"itr": 20000, "IS_mean": 14.817827224731445, "IS_std": 0.23588882386684418, "FID": 61.32061767578125, "_stamp": 1551606713.6567464}
10
+ {"itr": 22000, "IS_mean": 17.16551399230957, "IS_std": 0.19506946206092834, "FID": 53.387969970703125, "_stamp": 1551624876.6513028}
11
+ {"itr": 24000, "IS_mean": 19.60654067993164, "IS_std": 0.5591856837272644, "FID": 46.5386962890625, "_stamp": 1551642822.6126688}
12
+ {"itr": 26000, "IS_mean": 21.74416732788086, "IS_std": 0.2850531041622162, "FID": 41.595001220703125, "_stamp": 1551663522.6019194}
13
+ {"itr": 28000, "IS_mean": 23.923612594604492, "IS_std": 0.41587772965431213, "FID": 37.894744873046875, "_stamp": 1551681794.6567173}
14
+ {"itr": 30000, "IS_mean": 25.569377899169922, "IS_std": 0.3333457112312317, "FID": 35.49310302734375, "_stamp": 1551699773.7080302}
15
+ {"itr": 32000, "IS_mean": 26.867944717407227, "IS_std": 0.5968036651611328, "FID": 33.4849853515625, "_stamp": 1551717623.887933}
16
+ {"itr": 34000, "IS_mean": 28.719074249267578, "IS_std": 0.5698027014732361, "FID": 31.375518798828125, "_stamp": 1551735411.1578612}
17
+ {"itr": 36000, "IS_mean": 30.587574005126953, "IS_std": 0.5044271349906921, "FID": 29.432281494140625, "_stamp": 1551783380.6357439}
18
+ {"itr": 38000, "IS_mean": 32.08299255371094, "IS_std": 0.49342143535614014, "FID": 28.099456787109375, "_stamp": 1551801179.6495197}
19
+ {"itr": 40000, "IS_mean": 34.24657440185547, "IS_std": 0.7709177732467651, "FID": 26.53802490234375, "_stamp": 1551818775.171794}
20
+ {"itr": 42000, "IS_mean": 35.891212463378906, "IS_std": 0.7036871314048767, "FID": 25.03021240234375, "_stamp": 1551836329.6873965}
21
+ {"itr": 44000, "IS_mean": 38.184898376464844, "IS_std": 0.32996198534965515, "FID": 23.4940185546875, "_stamp": 1551897864.911537}
22
+ {"itr": 46000, "IS_mean": 40.239479064941406, "IS_std": 0.7761151194572449, "FID": 22.53167724609375, "_stamp": 1551915406.4840703}
23
+ {"itr": 48000, "IS_mean": 41.46656036376953, "IS_std": 1.1031498908996582, "FID": 21.5338134765625, "_stamp": 1551932899.6074848}
24
+ {"itr": 50000, "IS_mean": 43.31670379638672, "IS_std": 0.7796809077262878, "FID": 20.53253173828125, "_stamp": 1551950390.345334}
25
+ {"itr": 52000, "IS_mean": 45.1517333984375, "IS_std": 1.2925242185592651, "FID": 19.656646728515625, "_stamp": 1551967838.1501615}
26
+ {"itr": 54000, "IS_mean": 47.638771057128906, "IS_std": 1.0689665079116821, "FID": 18.898162841796875, "_stamp": 1552044534.5349634}
27
+ {"itr": 56000, "IS_mean": 48.87520217895508, "IS_std": 1.1317559480667114, "FID": 18.1248779296875, "_stamp": 1552061763.3080354}
28
+ {"itr": 58000, "IS_mean": 49.40987014770508, "IS_std": 1.1866596937179565, "FID": 17.751922607421875, "_stamp": 1552078939.9828825}
29
+ {"itr": 60000, "IS_mean": 51.051334381103516, "IS_std": 1.2281248569488525, "FID": 17.19964599609375, "_stamp": 1552096167.889482}
30
+ {"itr": 62000, "IS_mean": 52.0235481262207, "IS_std": 0.5391153693199158, "FID": 16.62115478515625, "_stamp": 1552113417.9520617}
31
+ {"itr": 64000, "IS_mean": 53.868492126464844, "IS_std": 1.327082633972168, "FID": 16.237335205078125, "_stamp": 1552142961.09602}
32
+ {"itr": 66000, "IS_mean": 54.978721618652344, "IS_std": 0.9502049088478088, "FID": 15.81170654296875, "_stamp": 1552162403.2232807}
33
+ {"itr": 68000, "IS_mean": 55.73248291015625, "IS_std": 1.0323851108551025, "FID": 15.545623779296875, "_stamp": 1552181112.676657}
34
+ {"itr": 70000, "IS_mean": 56.78422927856445, "IS_std": 1.211003303527832, "FID": 15.28369140625, "_stamp": 1552199498.887533}
35
+ {"itr": 72000, "IS_mean": 57.972999572753906, "IS_std": 0.8668608665466309, "FID": 14.86395263671875, "_stamp": 1552217782.2738616}
36
+ {"itr": 74000, "IS_mean": 58.845054626464844, "IS_std": 1.4297977685928345, "FID": 14.620635986328125, "_stamp": 1552251085.1781816}
37
+ {"itr": 76000, "IS_mean": 59.60982131958008, "IS_std": 0.9095696210861206, "FID": 14.360198974609375, "_stamp": 1552270214.9345307}
38
+ {"itr": 78000, "IS_mean": 60.71195602416992, "IS_std": 0.960899829864502, "FID": 14.07183837890625, "_stamp": 1552288697.1580262}
39
+ {"itr": 80000, "IS_mean": 61.772125244140625, "IS_std": 0.6913255453109741, "FID": 13.781585693359375, "_stamp": 1552307170.0280282}
40
+ {"itr": 82000, "IS_mean": 62.98079299926758, "IS_std": 1.4735801219940186, "FID": 13.55389404296875, "_stamp": 1552325252.8553352}
41
+ {"itr": 84000, "IS_mean": 64.95240783691406, "IS_std": 0.9018951654434204, "FID": 13.231689453125, "_stamp": 1552344135.3111835}
42
+ {"itr": 86000, "IS_mean": 65.13968658447266, "IS_std": 0.8772205114364624, "FID": 13.176849365234375, "_stamp": 1552362429.6782444}
43
+ {"itr": 88000, "IS_mean": 65.84476470947266, "IS_std": 1.167534351348877, "FID": 12.87078857421875, "_stamp": 1552380560.7988124}
44
+ {"itr": 90000, "IS_mean": 67.41099548339844, "IS_std": 1.6899267435073853, "FID": 12.586517333984375, "_stamp": 1552398550.2060475}
45
+ {"itr": 92000, "IS_mean": 68.63685607910156, "IS_std": 1.9431978464126587, "FID": 12.49505615234375, "_stamp": 1552430781.6406457}
46
+ {"itr": 94000, "IS_mean": 70.09907531738281, "IS_std": 1.0715738534927368, "FID": 12.047607421875, "_stamp": 1552449001.1950285}
47
+ {"itr": 96000, "IS_mean": 70.34623718261719, "IS_std": 1.7962944507598877, "FID": 11.896697998046875, "_stamp": 1552466989.3587568}
48
+ {"itr": 98000, "IS_mean": 71.08210754394531, "IS_std": 1.458209753036499, "FID": 11.73046875, "_stamp": 1552484800.7138846}
49
+ {"itr": 100000, "IS_mean": 72.24256896972656, "IS_std": 1.3259714841842651, "FID": 11.7386474609375, "_stamp": 1552502538.0269725}
50
+ {"itr": 102000, "IS_mean": 73.19488525390625, "IS_std": 1.3439149856567383, "FID": 11.50494384765625, "_stamp": 1552523284.4514356}
51
+ {"itr": 104000, "IS_mean": 73.38243103027344, "IS_std": 1.4162707328796387, "FID": 11.374542236328125, "_stamp": 1552541012.0651608}
52
+ {"itr": 106000, "IS_mean": 74.95563507080078, "IS_std": 1.089124083518982, "FID": 11.10479736328125, "_stamp": 1552558577.7458107}
53
+ {"itr": 108000, "IS_mean": 76.42997741699219, "IS_std": 1.9282453060150146, "FID": 10.998870849609375, "_stamp": 1552576111.9480467}
54
+ {"itr": 110000, "IS_mean": 76.89225769042969, "IS_std": 1.4771150350570679, "FID": 10.847015380859375, "_stamp": 1552593659.445132}
55
+ {"itr": 112000, "IS_mean": 78.04684448242188, "IS_std": 1.4850096702575684, "FID": 10.772552490234375, "_stamp": 1552616479.5201895}
56
+ {"itr": 114000, "IS_mean": 79.67677307128906, "IS_std": 2.0147368907928467, "FID": 10.528045654296875, "_stamp": 1552633850.9315467}
57
+ {"itr": 116000, "IS_mean": 79.8828125, "IS_std": 0.978247344493866, "FID": 10.626068115234375, "_stamp": 1552651198.9012825}
58
+ {"itr": 118000, "IS_mean": 79.95381164550781, "IS_std": 1.8608143329620361, "FID": 10.46771240234375, "_stamp": 1552668560.4420238}
59
+ {"itr": 120000, "IS_mean": 82.37217712402344, "IS_std": 1.8909310102462769, "FID": 10.259033203125, "_stamp": 1552749673.4319007}
60
+ {"itr": 122000, "IS_mean": 83.49666595458984, "IS_std": 2.38446044921875, "FID": 9.996185302734375, "_stamp": 1552766698.2706933}
61
+ {"itr": 124000, "IS_mean": 83.05189514160156, "IS_std": 1.8844469785690308, "FID": 10.164398193359375, "_stamp": 1552783762.891172}
62
+ {"itr": 126000, "IS_mean": 84.27763366699219, "IS_std": 0.9329544901847839, "FID": 10.03509521484375, "_stamp": 1552800953.5724175}
63
+ {"itr": 128000, "IS_mean": 85.84852600097656, "IS_std": 2.2698562145233154, "FID": 9.91644287109375, "_stamp": 1552818112.227726}
64
+ {"itr": 130000, "IS_mean": 87.356689453125, "IS_std": 2.0958640575408936, "FID": 9.771148681640625, "_stamp": 1552837539.995247}
65
+ {"itr": 132000, "IS_mean": 88.72562408447266, "IS_std": 1.7551432847976685, "FID": 9.8258056640625, "_stamp": 1552859685.9305944}
66
+ {"itr": 134000, "IS_mean": 88.0631103515625, "IS_std": 1.8199039697647095, "FID": 9.957183837890625, "_stamp": 1552880037.5408435}
67
+ {"itr": 136000, "IS_mean": 91.50938415527344, "IS_std": 1.9926033020019531, "FID": 9.876556396484375, "_stamp": 1552899854.652669}
68
+ {"itr": 138000, "IS_mean": 93.09217834472656, "IS_std": 2.3062736988067627, "FID": 9.908477783203125, "_stamp": 1552921580.958927}
BigGAN_PyTorch/logs/compare_IS.m ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ % Copyright (c) Facebook, Inc. and its affiliates.
2
+ % All rights reserved.
3
+ %
4
+ % All contributions by Andy Brock:
5
+ % Copyright (c) 2019 Andy Brock
6
+ %
7
+ % MIT License
8
+
9
+ clc
10
+ clear all
11
+ close all
12
+ fclose all;
13
+
14
+
15
+
16
+ %% Get All logs and sort them
17
+ s = {};
18
+ d = dir();
19
+ j = 1;
20
+ for i = 1:length(d)
21
+ if any(strfind(d(i).name,'.jsonl'))
22
+ s = [s; d(i).name];
23
+ end
24
+ end
25
+
26
+
27
+ j = 1;
28
+ for i = 1:length(s)
29
+ fname = s{i,1};
30
+ % Check if the Inception metrics log exists, and if so, plot it
31
+ [itr, IS, FID, t] = process_inception_log(fname(1:end - 10), 'log.jsonl');
32
+ s{i,2} = itr;
33
+ s{i,3} = IS;
34
+ s{i,4} = FID;
35
+ s{i,5} = max(IS);
36
+ s{i,6} = min(FID);
37
+ s{i,7} = t;
38
+ end
39
+ % Sort by Inception Score?
40
+ [IS_sorted, IS_index] = sort(cell2mat(s(:,5)));
41
+ % Cutoff inception scores below a certain value?
42
+ threshold = 22;
43
+ IS_index = IS_index(IS_sorted > threshold);
44
+
45
+ % Sort by FID?
46
+ [FID_sorted, FID_index] = sort(cell2mat(s(:,6)));
47
+ % Cutoff also based on IS?
48
+ % threshold = 0;
49
+ FID_index = FID_index(IS_sorted > threshold);
50
+
51
+
52
+
53
+ %% Plot things?
54
+ cc = hsv(length(IS_index));
55
+ legend1 = {};
56
+ legend2 = {};
57
+ make_axis=true;%false % Turn this on to see the axis out to 1e6 iterations
58
+ for i=1:length(IS_index)
59
+ legend1 = [legend1; s{IS_index(i), 1}];
60
+ figure(1)
61
+ plot(s{IS_index(i),2}, s{IS_index(i),3}, 'color', cc(i,:),'linewidth',2)
62
+ hold on;
63
+ xlabel('itr'); ylabel('IS');
64
+ grid on;
65
+ if make_axis
66
+ axis([0,1e6,0,80]); % 50% grid on;
67
+ end
68
+ legend(legend1,'Interpreter','none')
69
+ %pause(1) % Turn this on to animate stuff
70
+ legend2 = [legend2; s{IS_index(i), 1}];
71
+ figure(2)
72
+ plot(s{IS_index(i),2}, s{IS_index(i),4}, 'color', cc(i,:),'linewidth',2)
73
+ hold on;
74
+ xlabel('itr'); ylabel('FID');
75
+ j = j + 1;
76
+ grid on;
77
+ if make_axis
78
+ axis([0,1e6,0,50]);% grid on;
79
+ end
80
+ legend(legend2, 'Interpreter','none')
81
+
82
+ end
83
+
84
+ %% Quick script to plot IS versus timesteps
85
+ if 0
86
+ figure(3);
87
+ this_index=4;
88
+ subplot(2,1,1);
89
+ %plot(s{this_index, 2}(2:end), s{this_index, 7}(2:end) - s{this_index, 7}(1:end-1), 'r*');
90
+ % xlabel('Iteration');ylabel('\Delta T')
91
+ plot(s{this_index, 2}, s{this_index, 7}, 'r*');
92
+ xlabel('Iteration');ylabel('T')
93
+ subplot(2,1,2);
94
+ plot(s{this_index, 2}, s{this_index, 3}, 'r', 'linewidth',2);
95
+ xlabel('Iteration'), ylabel('Inception score')
96
+ title(s{this_index,1})
97
+ end
BigGAN_PyTorch/logs/metalog.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
1
+ datetime: 2019-03-18 13:27:59.181225
2
+ config: {'dataset': 'I128_hdf5', 'augment': False, 'num_workers': 8, 'pin_memory': True, 'shuffle': True, 'load_in_mem': True, 'use_multiepoch_sampler': True, 'model': 'model', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 96, 'D_ch': 96, 'G_depth': 1, 'D_depth': 1, 'D_wide': True, 'G_shared': True, 'shared_dim': 128, 'dim_z': 120, 'z_var': 1.0, 'hier': True, 'cross_replica': False, 'mybn': False, 'G_nl': 'inplace_relu', 'D_nl': 'inplace_relu', 'G_attn': '64', 'D_attn': '64', 'norm_style': 'bn', 'seed': 0, 'G_init': 'ortho', 'D_init': 'ortho', 'skip_init': True, 'G_lr': 0.0001, 'D_lr': 0.0004, 'G_B1': 0.0, 'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 256, 'G_batch_size': 0, 'num_G_accumulations': 8, 'num_D_steps': 1, 'num_D_accumulations': 8, 'split_D': False, 'num_epochs': 400, 'parallel': True, 'G_fp16': False, 'D_fp16': False, 'D_mixed_precision': False, 'G_mixed_precision': False, 'accumulate_stats': False, 'num_standing_accumulations': 16, 'G_eval_mode': True, 'save_every': 500, 'num_save_copies': 2, 'num_best_copies': 5, 'which_best': 'IS', 'no_fid': False, 'test_every': 2000, 'num_inception_images': 50000, 'hashname': False, 'base_root': '', 'dataset_root': 'data', 'weights_root': 'weights', 'logs_root': 'logs', 'samples_root': 'samples', 'pbar': 'mine', 'name_suffix': '', 'experiment_name': 'Jade_BigGAN_B1_bs256x8_fp32', 'config_from_name': False, 'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 20000, 'adam_eps': 1e-06, 'BN_eps': 1e-05, 'SN_eps': 1e-06, 'num_G_SVs': 1, 'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0, 'D_ortho': 0.0, 'toggle_grads': True, 'which_train_fn': 'GAN', 'load_weights': '', 'resume': True, 'logstyle': '%3.3e', 'log_G_spectra': False, 'log_D_spectra': False, 'sv_log_interval': 10, 'resolution': 128, 'n_classes': 1000, 'G_activation': ReLU(inplace), 'D_activation': ReLU(inplace)}
3
+ state: {'itr': 137500, 'epoch': 2, 'save_num': 0, 'save_best_num': 1, 'best_IS': 91.509384, 'best_FID': tensor(9.7711, 'config': {'dataset': 'I128_hdf5', 'augment': False, 'num_workers': 8, 'pin_memory': True, 'shuffle': True, 'load_in_mem': True, 'use_multiepoch_sampler': True, 'model': 'model', 'G_param': 'SN', 'D_param': 'SN', 'G_ch': 96, 'D_ch': 96, 'D_wide': True, 'G_shared': True, 'shared_dim': 128, 'dim_z': 120, 'hier': True, 'cross_replica': False, 'mybn': False, 'G_nl': 'inplace_relu', 'D_nl': 'inplace_relu', 'G_attn': '64', 'D_attn': '64', 'norm_style': 'bn', 'seed': 0, 'G_init': 'ortho', 'D_init': 'ortho', 'skip_init': False, 'G_lr': 0.0001, 'D_lr': 0.0004, 'G_B1': 0.0, 'D_B1': 0.0, 'G_B2': 0.999, 'D_B2': 0.999, 'batch_size': 256, 'G_batch_size': 0, 'num_G_accumulations': 8, 'num_D_steps': 1, 'num_D_accumulations': 8, 'split_D': False, 'num_epochs': 100, 'parallel': True, 'G_fp16': False, 'D_fp16': False, 'D_mixed_precision': False, 'G_mixed_precision': False, 'accumulate_stats': False, 'num_standing_accumulations': 16, 'BN_sync': False, 'G_eval_mode': True, 'save_every': 500, 'num_save_copies': 2, 'num_best_copies': 5, 'which_best': 'IS', 'no_fid': False, 'test_every': 2000, 'num_inception_images': 50000, 'hashname': False, 'base_root': '', 'dataset_root': 'data', 'weights_root': 'weights', 'logs_root': 'logs', 'samples_root': 'samples', 'pbar': 'mine', 'name_suffix': '', 'experiment_name': 'Jade_BigGAN_B1_bs256x8_fp32', 'ema': True, 'ema_decay': 0.9999, 'use_ema': True, 'ema_start': 20000, 'adam_eps': 1e-06, 'BN_eps': 1e-05, 'SN_eps': 1e-06, 'num_G_SVs': 1, 'num_D_SVs': 1, 'num_G_SV_itrs': 1, 'num_D_SV_itrs': 1, 'G_ortho': 0.0, 'D_ortho': 0.0, 'toggle_grads': True, 'which_train_fn': 'GAN', 'load_weights': '', 'resume': False, 'logstyle': '%3.3e', 'log_G_spectra': False, 'log_D_spectra': False, 'sv_log_interval': 10, 'resolution': 128, 'n_classes': 1000, 'G_activation': ReLU(inplace), 'D_activation': ReLU(inplace)}}
BigGAN_PyTorch/logs/process_inception_log.m ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ % Copyright (c) Facebook, Inc. and its affiliates.
2
+ % All rights reserved.
3
+ %
4
+ % All contributions by Andy Brock:
5
+ % Copyright (c) 2019 Andy Brock
6
+ %
7
+ % MIT License
8
+ %
9
+ function [itr, IS, FID, t] = process_inception_log(fname, which_log)
10
+ f = sprintf('%s_%s',fname, which_log);%'G_loss.log');
11
+ fid = fopen(f,'r');
12
+ itr = [];
13
+ IS = [];
14
+ FID = [];
15
+ t = [];
16
+ i = 1;
17
+ while ~feof(fid);
18
+ s = fgets(fid);
19
+ parsed = sscanf(s,'{"itr": %d, "IS_mean": %f, "IS_std": %f, "FID": %f, "_stamp": %f}');
20
+ itr(i) = parsed(1);
21
+ IS(i) = parsed(2);
22
+ FID(i) = parsed(4);
23
+ t(i) = parsed(5);
24
+ i = i + 1;
25
+ end
26
+ fclose(fid);
27
+ end
BigGAN_PyTorch/logs/process_training.m ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ % Copyright (c) Facebook, Inc. and its affiliates.
2
+ % All rights reserved.
3
+ %
4
+ % All contributions by Andy Brock:
5
+ % Copyright (c) 2019 Andy Brock
6
+ %
7
+ % MIT License
8
+ %
9
+ clc
10
+ clear all
11
+ close all
12
+ fclose all;
13
+
14
+
15
+
16
+ %% Get all training logs for a given run
17
+ target_dir = '.';
18
+ s = {};
19
+ nm = {};
20
+ d = dir(target_dir);
21
+ j = 1;
22
+ for i = 1:length(d)
23
+ if any(strfind(d(i).name,'.log'))
24
+ s = [s; sprintf('%s\\%s', target_dir, d(i).name)];
25
+ nm = [nm; d(i).name];
26
+ end
27
+ end
28
+ %% Loop over training logs and acquire data
29
+ D_count = 0;
30
+ G_count = 0;
31
+ for i = 1:length(s)
32
+ fname = s{i,1};
33
+ fid = fopen(s{i,1},'r');
34
+ % Prepare bookkeeping for sv0
35
+ if any(strfind(s{i,1},'sv'))
36
+ if any(strfind(s{i,1},'G_'))
37
+ G_count = G_count +1;
38
+ else
39
+ D_count = D_count + 1;
40
+ end
41
+ end
42
+ itr = [];
43
+ val = [];
44
+ j = 1;
45
+ while ~feof(fid);
46
+ line = fgets(fid);
47
+ parsed = sscanf(line, '%d: %e');
48
+ itr(j) = parsed(1);
49
+ val(j) = parsed(2);
50
+ j = j + 1;
51
+ end
52
+ s{i,2} = itr;
53
+ s{i,3} = val;
54
+ fclose(fid);
55
+ end
56
+
57
+ %% Plot SVs and losses
58
+ close all;
59
+ Gcc = hsv(G_count);
60
+ Dcc = hsv(D_count);
61
+ gi = 1;
62
+ di = 1;
63
+ li = 1;
64
+ legendG = {};
65
+ legendD = {};
66
+ legendL = {};
67
+ thresh=2; % wavelet denoising threshold
68
+ losses = {};
69
+ for i=1:length(s)
70
+ if any(strfind(s{i,1},'D_loss_real.log')) || any(strfind(s{i,1},'D_loss_fake.log')) || any(strfind(s{i,1},'G_loss.log'))
71
+ % Select colors
72
+ if any(strfind(s{i,1},'D_loss_real.log'))
73
+ color1 = [0.7,0.7,1.0];
74
+ color2 = [0, 0, 1];
75
+ dlr = {s{i,2}, s{i,3}, wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1, color2};
76
+ losses = [losses; dlr];
77
+ elseif any(strfind(s{i,1},'D_loss_fake.log'))
78
+ color1 = [0.7,1.0,0.7];
79
+ color2 = [0, 1, 0];
80
+ dlf = {s{i,2},s{i,3} wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1, color2};
81
+ losses = [losses; dlf];
82
+ else % g loss
83
+ color1 = [1.0, 0.7,0.7];
84
+ color2 = [1, 0, 0];
85
+ gl = {s{i,2},s{i,3}, wden(s{i,3},'sqtwolog','s','mln', thresh, 'sym4'), color1 color2};
86
+ losses = [losses; gl];
87
+ end
88
+ figure(1); hold on;
89
+ % Plot the unsmoothed losses; we'll plot the smoothed losses later
90
+ plot(s{i,2},s{i,3},'color', color1, 'HandleVisibility','off');
91
+ legendL = [legendL; nm{i}];
92
+ continue
93
+ end
94
+ if any(strfind(s{i,1},'G_'))
95
+ legendG = [legendG; nm{i}];
96
+ figure(2); hold on;
97
+ plot(s{i,2},s{i,3},'color',Gcc(gi,:),'linewidth',2);
98
+ gi = gi+1;
99
+ elseif any(strfind(s{i,1},'D_'))
100
+ legendD = [legendD; nm{i}];
101
+ figure(3); hold on;
102
+ plot(s{i,2},s{i,3},'color',Dcc(di,:),'linewidth',2);
103
+ di = di+1;
104
+ else
105
+ s{i,1} % Debug print to show the name of the log that was not processed.
106
+ end
107
+ end
108
+ figure(1);
109
+ % Plot the smoothed losses last
110
+ for i = 1:3
111
+ % plot(losses{i,1}, losses{i,2},'color', losses{i,4}, 'HandleVisibility','off');
112
+ plot(losses{i,1},losses{i,3},'color',losses{i,5});
113
+ end
114
+ legend(legendL, 'Interpreter', 'none'); title('Losses'); xlabel('Generator itr'); ylabel('loss'); axis([0, max(s{end,2}), -1, 4]);
115
+
116
+ figure(2); legend(legendG,'Interpreter','none'); title('Singular Values in G'); xlabel('Generator itr'); ylabel('SV0');
117
+ figure(3); legend(legendD, 'Interpreter', 'none'); title('Singular Values in D'); xlabel('Generator itr'); ylabel('SV0');
BigGAN_PyTorch/losses.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # All contributions by Andy Brock:
5
+ # Copyright (c) 2019 Andy Brock
6
+ #
7
+ # MIT License
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ # DCGAN loss
12
+ def loss_dcgan_dis(dis_fake, dis_real):
13
+ L1 = torch.mean(F.softplus(-dis_real))
14
+ L2 = torch.mean(F.softplus(dis_fake))
15
+ return L1, L2
16
+
17
+
18
+ def loss_dcgan_gen(dis_fake):
19
+ loss = torch.mean(F.softplus(-dis_fake))
20
+ return loss
21
+
22
+
23
+ # Hinge Loss
24
+ def loss_hinge_dis(dis_fake, dis_real):
25
+ loss_real = torch.mean(F.relu(1.0 - dis_real))
26
+ loss_fake = torch.mean(F.relu(1.0 + dis_fake))
27
+ return loss_real, loss_fake
28
+
29
+
30
+ # def loss_hinge_dis(dis_fake, dis_real): # This version returns a single loss
31
+ # loss = torch.mean(F.relu(1. - dis_real))
32
+ # loss += torch.mean(F.relu(1. + dis_fake))
33
+ # return loss
34
+
35
+
36
+ def loss_hinge_gen(dis_fake):
37
+ loss = -torch.mean(dis_fake)
38
+ return loss
39
+
40
+
41
+ # Default to hinge loss
42
+ generator_loss = loss_hinge_gen
43
+ discriminator_loss = loss_hinge_dis
BigGAN_PyTorch/make_hdf5.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # All contributions by Andy Brock:
5
+ # Copyright (c) 2019 Andy Brock
6
+ #
7
+ # MIT License
8
+ """ Convert dataset to HDF5
9
+ This script preprocesses a dataset and saves it (images and labels) to
10
+ an HDF5 file for improved I/O. """
11
+ import os
12
+ import sys
13
+ from argparse import ArgumentParser
14
+ from tqdm import tqdm, trange
15
+ import h5py as h5
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torchvision.datasets as dset
20
+ import torchvision.transforms as transforms
21
+ from torchvision.utils import save_image
22
+ import torchvision.transforms as transforms
23
+ from torch.utils.data import DataLoader
24
+
25
+ import utils
26
+
27
+
28
+ def prepare_parser():
29
+ usage = "Parser for ImageNet HDF5 scripts."
30
+ parser = ArgumentParser(description=usage)
31
+ parser.add_argument(
32
+ "--resolution",
33
+ type=int,
34
+ default=128,
35
+ help="Which Dataset resolution to train on, out of 64, 128, 256, 512 (default: %(default)s)",
36
+ )
37
+ parser.add_argument(
38
+ "--split",
39
+ type=str,
40
+ default="train",
41
+ help="Which Dataset to convert: train, val (default: %(default)s)",
42
+ )
43
+ parser.add_argument(
44
+ "--data_root",
45
+ type=str,
46
+ default="data",
47
+ help="Default location where data is stored (default: %(default)s)",
48
+ )
49
+ parser.add_argument(
50
+ "--out_path",
51
+ type=str,
52
+ default="data",
53
+ help="Default location where data in hdf5 format will be stored (default: %(default)s)",
54
+ )
55
+ parser.add_argument(
56
+ "--longtail",
57
+ action="store_true",
58
+ default=False,
59
+ help="Use long-tail version of the dataset",
60
+ )
61
+ parser.add_argument(
62
+ "--batch_size",
63
+ type=int,
64
+ default=256,
65
+ help="Default overall batchsize (default: %(default)s)",
66
+ )
67
+ parser.add_argument(
68
+ "--num_workers",
69
+ type=int,
70
+ default=16,
71
+ help="Number of dataloader workers (default: %(default)s)",
72
+ )
73
+ parser.add_argument(
74
+ "--chunk_size",
75
+ type=int,
76
+ default=500,
77
+ help="Default overall batchsize (default: %(default)s)",
78
+ )
79
+ parser.add_argument(
80
+ "--compression",
81
+ action="store_true",
82
+ default=False,
83
+ help="Use LZF compression? (default: %(default)s)",
84
+ )
85
+ return parser
86
+
87
+
88
+ def run(config):
89
+ # Get image size
90
+
91
+ # Update compression entry
92
+ config["compression"] = (
93
+ "lzf" if config["compression"] else None
94
+ ) # No compression; can also use 'lzf'
95
+
96
+ # Get dataset
97
+ kwargs = {
98
+ "num_workers": config["num_workers"],
99
+ "pin_memory": False,
100
+ "drop_last": False,
101
+ }
102
+ dataset = utils.get_dataset_images(
103
+ config["resolution"],
104
+ data_path=os.path.join(config["data_root"], config["split"]),
105
+ longtail=config["longtail"],
106
+ )
107
+ train_loader = utils.get_dataloader(
108
+ dataset, config["batch_size"], shuffle=False, **kwargs
109
+ )
110
+
111
+ # HDF5 supports chunking and compression. You may want to experiment
112
+ # with different chunk sizes to see how it runs on your machines.
113
+ # Chunk Size/compression Read speed @ 256x256 Read speed @ 128x128 Filesize @ 128x128 Time to write @128x128
114
+ # 1 / None 20/s
115
+ # 500 / None ramps up to 77/s 102/s 61GB 23min
116
+ # 500 / LZF 8/s 56GB 23min
117
+ # 1000 / None 78/s
118
+ # 5000 / None 81/s
119
+ # auto:(125,1,16,32) / None 11/s 61GB
120
+
121
+ print(
122
+ "Starting to load dataset into an HDF5 file with chunk size %i and compression %s..."
123
+ % (config["chunk_size"], config["compression"])
124
+ )
125
+ # Loop over train loader
126
+ for i, (x, y) in enumerate(tqdm(train_loader)):
127
+ # Stick X into the range [0, 255] since it's coming from the train loader
128
+ x = (255 * ((x + 1) / 2.0)).byte().numpy()
129
+ # Numpyify y
130
+ y = y.numpy()
131
+ # If we're on the first batch, prepare the hdf5
132
+ if i == 0:
133
+ with h5.File(
134
+ config["out_path"]
135
+ + "/ILSVRC%i%s_xy.hdf5"
136
+ % (config["resolution"], "" if not config["longtail"] else "longtail"),
137
+ "w",
138
+ ) as f:
139
+ print("Producing dataset of len %d" % len(train_loader.dataset))
140
+ imgs_dset = f.create_dataset(
141
+ "imgs",
142
+ x.shape,
143
+ dtype="uint8",
144
+ maxshape=(
145
+ len(train_loader.dataset),
146
+ 3,
147
+ config["resolution"],
148
+ config["resolution"],
149
+ ),
150
+ chunks=(
151
+ config["chunk_size"],
152
+ 3,
153
+ config["resolution"],
154
+ config["resolution"],
155
+ ),
156
+ compression=config["compression"],
157
+ )
158
+ print("Image chunks chosen as " + str(imgs_dset.chunks))
159
+ imgs_dset[...] = x
160
+ labels_dset = f.create_dataset(
161
+ "labels",
162
+ y.shape,
163
+ dtype="int64",
164
+ maxshape=(len(train_loader.dataset),),
165
+ chunks=(config["chunk_size"],),
166
+ compression=config["compression"],
167
+ )
168
+ print("Label chunks chosen as " + str(labels_dset.chunks))
169
+ labels_dset[...] = y
170
+ # Else append to the hdf5
171
+ else:
172
+ with h5.File(
173
+ config["out_path"]
174
+ + "/ILSVRC%i%s_xy.hdf5"
175
+ % (config["resolution"], "" if not config["longtail"] else "longtail"),
176
+ "a",
177
+ ) as f:
178
+ f["imgs"].resize(f["imgs"].shape[0] + x.shape[0], axis=0)
179
+ f["imgs"][-x.shape[0] :] = x
180
+ f["labels"].resize(f["labels"].shape[0] + y.shape[0], axis=0)
181
+ f["labels"][-y.shape[0] :] = y
182
+
183
+
184
+ def main():
185
+ # parse command line and run
186
+ parser = prepare_parser()
187
+ config = vars(parser.parse_args())
188
+ print(config)
189
+ run(config)
190
+
191
+
192
+ if __name__ == "__main__":
193
+ main()
BigGAN_PyTorch/run.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import utils
8
+ from trainer import run
9
+ from submitit.helpers import Checkpointable
10
+
11
+ LOCAL = False
12
+ try:
13
+ import submitit
14
+ except:
15
+ print(
16
+ "No submitit package found! Defaulting to executing the script in the local machine"
17
+ )
18
+ LOCAL = True
19
+ import json
20
+
21
+
22
+ class Trainer(Checkpointable):
23
+ def __call__(self, config):
24
+ if config["run_setup"] == "local_debug" or LOCAL:
25
+ run(config, "local_debug")
26
+ else:
27
+ run(config, "slurm", master_node=submitit.JobEnvironment().hostnames[0])
28
+
29
+
30
+ if __name__ == "__main__":
31
+ parser = utils.prepare_parser()
32
+ config = vars(parser.parse_args())
33
+
34
+ if config["json_config"] != "":
35
+ data = json.load(open(config["json_config"]))
36
+ for key in data.keys():
37
+ config[key] = data[key]
38
+ else:
39
+ print("Not using JSON configuration file!")
40
+ config["G_batch_size"] = config["batch_size"]
41
+ config["batch_size"] = (
42
+ config["batch_size"] * config["num_D_accumulations"] * config["num_D_steps"]
43
+ )
44
+
45
+ trainer = Trainer()
46
+ if config["run_setup"] == "local_debug" or LOCAL:
47
+ trainer(config)
48
+ else:
49
+ print(
50
+ "Using ",
51
+ config["n_nodes"],
52
+ " nodes and ",
53
+ config["n_gpus_per_node"],
54
+ " GPUs per node.",
55
+ )
56
+ executor = submitit.SlurmExecutor(
57
+ folder=config["slurm_logdir"], max_num_timeout=60
58
+ )
59
+ executor.update_parameters(
60
+ gpus_per_node=config["n_gpus_per_node"],
61
+ partition=config["partition"],
62
+ constraint="volta32gb",
63
+ nodes=config["n_nodes"],
64
+ ntasks_per_node=config["n_gpus_per_node"],
65
+ cpus_per_task=8,
66
+ mem=256000,
67
+ time=3200,
68
+ job_name=config["experiment_name"],
69
+ exclusive=True if config["n_gpus_per_node"] == 8 else False,
70
+ )
71
+
72
+ executor.submit(trainer, config)
73
+ import time
74
+
75
+ time.sleep(1)
BigGAN_PyTorch/scripts/launch_BigGAN_bs256x8.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #
3
+ # Copyright (c) Facebook, Inc. and its affiliates.
4
+ # All rights reserved.
5
+ #
6
+ # All contributions by Andy Brock:
7
+ # Copyright (c) 2019 Andy Brock
8
+ #
9
+ # MIT License
10
+ #
11
+ python train.py \
12
+ --dataset I128_hdf5 --parallel --shuffle --num_workers 8 --batch_size 256 --load_in_mem \
13
+ --num_G_accumulations 8 --num_D_accumulations 8 \
14
+ --num_D_steps 1 --G_lr 1e-4 --D_lr 4e-4 --D_B2 0.999 --G_B2 0.999 \
15
+ --G_attn 64 --D_attn 64 \
16
+ --G_nl inplace_relu --D_nl inplace_relu \
17
+ --SN_eps 1e-6 --BN_eps 1e-5 --adam_eps 1e-6 \
18
+ --G_ortho 0.0 \
19
+ --G_shared \
20
+ --G_init ortho --D_init ortho \
21
+ --hier --dim_z 120 --shared_dim 128 \
22
+ --G_eval_mode \
23
+ --G_ch 96 --D_ch 96 \
24
+ --ema --use_ema --ema_start 20000 \
25
+ --test_every 2000 --save_every 1000 --num_best_copies 5 --num_save_copies 2 --seed 0 \
26
+ --use_multiepoch_sampler \