james-oldfield commited on
Commit
2a76164
1 Parent(s): 917c2bc

Upload 194 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. models/stylegan2_ffhq1024.pth +3 -0
  2. networks/__pycache__/load_generator.cpython-38.pyc +0 -0
  3. networks/biggan/__init__.py +6 -0
  4. networks/biggan/__pycache__/__init__.cpython-38.pyc +0 -0
  5. networks/biggan/__pycache__/config.cpython-38.pyc +0 -0
  6. networks/biggan/__pycache__/file_utils.cpython-38.pyc +0 -0
  7. networks/biggan/__pycache__/model.cpython-38.pyc +0 -0
  8. networks/biggan/__pycache__/utils.cpython-38.pyc +0 -0
  9. networks/biggan/config.py +70 -0
  10. networks/biggan/convert.sh +21 -0
  11. networks/biggan/convert_tf_to_pytorch.py +312 -0
  12. networks/biggan/download_tf.sh +21 -0
  13. networks/biggan/file_utils.py +232 -0
  14. networks/biggan/model.py +352 -0
  15. networks/biggan/utils.py +216 -0
  16. networks/genforce/.gitignore +29 -0
  17. networks/genforce/LICENSE +18 -0
  18. networks/genforce/MODEL_ZOO.md +131 -0
  19. networks/genforce/README.md +169 -0
  20. networks/genforce/__init__.py +0 -0
  21. networks/genforce/__pycache__/__init__.cpython-38.pyc +0 -0
  22. networks/genforce/configs/stylegan_demo.py +61 -0
  23. networks/genforce/configs/stylegan_ffhq1024.py +63 -0
  24. networks/genforce/configs/stylegan_ffhq1024_val.py +29 -0
  25. networks/genforce/configs/stylegan_ffhq256.py +63 -0
  26. networks/genforce/configs/stylegan_ffhq256_encoder_y.py +73 -0
  27. networks/genforce/configs/stylegan_ffhq256_val.py +29 -0
  28. networks/genforce/convert_model.py +77 -0
  29. networks/genforce/datasets/README.md +24 -0
  30. networks/genforce/datasets/__init__.py +7 -0
  31. networks/genforce/datasets/dataloaders.py +128 -0
  32. networks/genforce/datasets/datasets.py +239 -0
  33. networks/genforce/datasets/distributed_sampler.py +144 -0
  34. networks/genforce/datasets/libturbojpeg.so.0 +0 -0
  35. networks/genforce/datasets/transforms.py +201 -0
  36. networks/genforce/metrics/README.md +18 -0
  37. networks/genforce/metrics/__init__.py +0 -0
  38. networks/genforce/metrics/fid.py +59 -0
  39. networks/genforce/metrics/inception.py +520 -0
  40. networks/genforce/models/__init__.py +131 -0
  41. networks/genforce/models/__pycache__/__init__.cpython-38.pyc +0 -0
  42. networks/genforce/models/__pycache__/encoder.cpython-38.pyc +0 -0
  43. networks/genforce/models/__pycache__/model_zoo.cpython-38.pyc +0 -0
  44. networks/genforce/models/__pycache__/perceptual_model.cpython-38.pyc +0 -0
  45. networks/genforce/models/__pycache__/pggan_discriminator.cpython-38.pyc +0 -0
  46. networks/genforce/models/__pycache__/pggan_generator.cpython-38.pyc +0 -0
  47. networks/genforce/models/__pycache__/stylegan2_discriminator.cpython-38.pyc +0 -0
  48. networks/genforce/models/__pycache__/stylegan2_generator.cpython-38.pyc +0 -0
  49. networks/genforce/models/__pycache__/stylegan_discriminator.cpython-38.pyc +0 -0
  50. networks/genforce/models/__pycache__/stylegan_generator.cpython-38.pyc +0 -0
models/stylegan2_ffhq1024.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d68c291b390869bee4f83f144d0480bd4c4cb7aab0fee0a9ee551073ea2d2163
3
+ size 381464183
networks/__pycache__/load_generator.cpython-38.pyc ADDED
Binary file (1.43 kB). View file
 
networks/biggan/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .config import BigGANConfig
2
+ from .model import BigGAN
3
+ from .file_utils import PYTORCH_PRETRAINED_BIGGAN_CACHE, cached_path
4
+ from .utils import (truncated_noise_sample, save_as_images,
5
+ convert_to_images, display_in_terminal,
6
+ one_hot_from_int, one_hot_from_names)
networks/biggan/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (518 Bytes). View file
 
networks/biggan/__pycache__/config.cpython-38.pyc ADDED
Binary file (2.54 kB). View file
 
networks/biggan/__pycache__/file_utils.cpython-38.pyc ADDED
Binary file (6.22 kB). View file
 
networks/biggan/__pycache__/model.cpython-38.pyc ADDED
Binary file (10.9 kB). View file
 
networks/biggan/__pycache__/utils.cpython-38.pyc ADDED
Binary file (21.5 kB). View file
 
networks/biggan/config.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """
3
+ BigGAN config.
4
+ """
5
+ from __future__ import (absolute_import, division, print_function, unicode_literals)
6
+
7
+ import copy
8
+ import json
9
+
10
+ class BigGANConfig(object):
11
+ """ Configuration class to store the configuration of a `BigGAN`.
12
+ Defaults are for the 128x128 model.
13
+ layers tuple are (up-sample in the layer ?, input channels, output channels)
14
+ """
15
+ def __init__(self,
16
+ output_dim=128,
17
+ z_dim=128,
18
+ class_embed_dim=128,
19
+ channel_width=128,
20
+ num_classes=1000,
21
+ layers=[(False, 16, 16),
22
+ (True, 16, 16),
23
+ (False, 16, 16),
24
+ (True, 16, 8),
25
+ (False, 8, 8),
26
+ (True, 8, 4),
27
+ (False, 4, 4),
28
+ (True, 4, 2),
29
+ (False, 2, 2),
30
+ (True, 2, 1)],
31
+ attention_layer_position=8,
32
+ eps=1e-4,
33
+ n_stats=51):
34
+ """Constructs BigGANConfig. """
35
+ self.output_dim = output_dim
36
+ self.z_dim = z_dim
37
+ self.class_embed_dim = class_embed_dim
38
+ self.channel_width = channel_width
39
+ self.num_classes = num_classes
40
+ self.layers = layers
41
+ self.attention_layer_position = attention_layer_position
42
+ self.eps = eps
43
+ self.n_stats = n_stats
44
+
45
+ @classmethod
46
+ def from_dict(cls, json_object):
47
+ """Constructs a `BigGANConfig` from a Python dictionary of parameters."""
48
+ config = BigGANConfig()
49
+ for key, value in json_object.items():
50
+ config.__dict__[key] = value
51
+ return config
52
+
53
+ @classmethod
54
+ def from_json_file(cls, json_file):
55
+ """Constructs a `BigGANConfig` from a json file of parameters."""
56
+ with open(json_file, "r", encoding='utf-8') as reader:
57
+ text = reader.read()
58
+ return cls.from_dict(json.loads(text))
59
+
60
+ def __repr__(self):
61
+ return str(self.to_json_string())
62
+
63
+ def to_dict(self):
64
+ """Serializes this instance to a Python dictionary."""
65
+ output = copy.deepcopy(self.__dict__)
66
+ return output
67
+
68
+ def to_json_string(self):
69
+ """Serializes this instance to a JSON string."""
70
+ return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
networks/biggan/convert.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-present, Thomas Wolf, Huggingface Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ set -e
9
+ set -x
10
+
11
+ models="128 256 512"
12
+
13
+ mkdir -p models/model_128
14
+ mkdir -p models/model_256
15
+ mkdir -p models/model_512
16
+
17
+ # Convert TF Hub models.
18
+ for model in $models
19
+ do
20
+ pytorch_pretrained_biggan --model_type $model --tf_model_path models/model_$model --pt_save_path models/model_$model
21
+ done
networks/biggan/convert_tf_to_pytorch.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """
3
+ Convert a TF Hub model for BigGAN in a PT one.
4
+ """
5
+ from __future__ import (absolute_import, division, print_function, unicode_literals)
6
+
7
+ from itertools import chain
8
+
9
+ import os
10
+ import argparse
11
+ import logging
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from torch.nn.functional import normalize
17
+
18
+ from .model import BigGAN, WEIGHTS_NAME, CONFIG_NAME
19
+ from .config import BigGANConfig
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def extract_batch_norm_stats(tf_model_path, batch_norm_stats_path=None):
25
+ try:
26
+ import numpy as np
27
+ import tensorflow as tf
28
+ import tensorflow_hub as hub
29
+ except ImportError:
30
+ raise ImportError("Loading a TensorFlow models in PyTorch, requires TensorFlow and TF Hub to be installed. "
31
+ "Please see https://www.tensorflow.org/install/ for installation instructions for TensorFlow. "
32
+ "And see https://github.com/tensorflow/hub for installing Hub. "
33
+ "Probably pip install tensorflow tensorflow-hub")
34
+ tf.reset_default_graph()
35
+ logger.info('Loading BigGAN module from: {}'.format(tf_model_path))
36
+ module = hub.Module(tf_model_path)
37
+ inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
38
+ for k, v in module.get_input_info_dict().items()}
39
+ output = module(inputs)
40
+
41
+ initializer = tf.global_variables_initializer()
42
+ sess = tf.Session()
43
+ stacks = sum(((i*10 + 1, i*10 + 3, i*10 + 6, i*10 + 8) for i in range(50)), ())
44
+ numpy_stacks = []
45
+ for i in stacks:
46
+ logger.info("Retrieving module_apply_default/stack_{}".format(i))
47
+ try:
48
+ stack_var = tf.get_default_graph().get_tensor_by_name("module_apply_default/stack_%d:0" % i)
49
+ except KeyError:
50
+ break # We have all the stats
51
+ numpy_stacks.append(sess.run(stack_var))
52
+
53
+ if batch_norm_stats_path is not None:
54
+ torch.save(numpy_stacks, batch_norm_stats_path)
55
+ else:
56
+ return numpy_stacks
57
+
58
+
59
+ def build_tf_to_pytorch_map(model, config):
60
+ """ Build a map from TF variables to PyTorch modules. """
61
+ tf_to_pt_map = {}
62
+
63
+ # Embeddings and GenZ
64
+ tf_to_pt_map.update({'linear/w/ema_0.9999': model.embeddings.weight,
65
+ 'Generator/GenZ/G_linear/b/ema_0.9999': model.generator.gen_z.bias,
66
+ 'Generator/GenZ/G_linear/w/ema_0.9999': model.generator.gen_z.weight_orig,
67
+ 'Generator/GenZ/G_linear/u0': model.generator.gen_z.weight_u})
68
+
69
+ # GBlock blocks
70
+ model_layer_idx = 0
71
+ for i, (up, in_channels, out_channels) in enumerate(config.layers):
72
+ if i == config.attention_layer_position:
73
+ model_layer_idx += 1
74
+ layer_str = "Generator/GBlock_%d/" % i if i > 0 else "Generator/GBlock/"
75
+ layer_pnt = model.generator.layers[model_layer_idx]
76
+ for i in range(4): # Batchnorms
77
+ batch_str = layer_str + ("BatchNorm_%d/" % i if i > 0 else "BatchNorm/")
78
+ batch_pnt = getattr(layer_pnt, 'bn_%d' % i)
79
+ for name in ('offset', 'scale'):
80
+ sub_module_str = batch_str + name + "/"
81
+ sub_module_pnt = getattr(batch_pnt, name)
82
+ tf_to_pt_map.update({sub_module_str + "w/ema_0.9999": sub_module_pnt.weight_orig,
83
+ sub_module_str + "u0": sub_module_pnt.weight_u})
84
+ for i in range(4): # Convolutions
85
+ conv_str = layer_str + "conv%d/" % i
86
+ conv_pnt = getattr(layer_pnt, 'conv_%d' % i)
87
+ tf_to_pt_map.update({conv_str + "b/ema_0.9999": conv_pnt.bias,
88
+ conv_str + "w/ema_0.9999": conv_pnt.weight_orig,
89
+ conv_str + "u0": conv_pnt.weight_u})
90
+ model_layer_idx += 1
91
+
92
+ # Attention block
93
+ layer_str = "Generator/attention/"
94
+ layer_pnt = model.generator.layers[config.attention_layer_position]
95
+ tf_to_pt_map.update({layer_str + "gamma/ema_0.9999": layer_pnt.gamma})
96
+ for pt_name, tf_name in zip(['snconv1x1_g', 'snconv1x1_o_conv', 'snconv1x1_phi', 'snconv1x1_theta'],
97
+ ['g/', 'o_conv/', 'phi/', 'theta/']):
98
+ sub_module_str = layer_str + tf_name
99
+ sub_module_pnt = getattr(layer_pnt, pt_name)
100
+ tf_to_pt_map.update({sub_module_str + "w/ema_0.9999": sub_module_pnt.weight_orig,
101
+ sub_module_str + "u0": sub_module_pnt.weight_u})
102
+
103
+ # final batch norm and conv to rgb
104
+ layer_str = "Generator/BatchNorm/"
105
+ layer_pnt = model.generator.bn
106
+ tf_to_pt_map.update({layer_str + "offset/ema_0.9999": layer_pnt.bias,
107
+ layer_str + "scale/ema_0.9999": layer_pnt.weight})
108
+ layer_str = "Generator/conv_to_rgb/"
109
+ layer_pnt = model.generator.conv_to_rgb
110
+ tf_to_pt_map.update({layer_str + "b/ema_0.9999": layer_pnt.bias,
111
+ layer_str + "w/ema_0.9999": layer_pnt.weight_orig,
112
+ layer_str + "u0": layer_pnt.weight_u})
113
+ return tf_to_pt_map
114
+
115
+
116
+ def load_tf_weights_in_biggan(model, config, tf_model_path, batch_norm_stats_path=None):
117
+ """ Load tf checkpoints and standing statistics in a pytorch model
118
+ """
119
+ try:
120
+ import numpy as np
121
+ import tensorflow as tf
122
+ except ImportError:
123
+ raise ImportError("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
124
+ "https://www.tensorflow.org/install/ for installation instructions.")
125
+ # Load weights from TF model
126
+ checkpoint_path = tf_model_path + "/variables/variables"
127
+ init_vars = tf.train.list_variables(checkpoint_path)
128
+ from pprint import pprint
129
+ pprint(init_vars)
130
+
131
+ # Extract batch norm statistics from model if needed
132
+ if batch_norm_stats_path:
133
+ stats = torch.load(batch_norm_stats_path)
134
+ else:
135
+ logger.info("Extracting batch norm stats")
136
+ stats = extract_batch_norm_stats(tf_model_path)
137
+
138
+ # Build TF to PyTorch weights loading map
139
+ tf_to_pt_map = build_tf_to_pytorch_map(model, config)
140
+
141
+ tf_weights = {}
142
+ for name in tf_to_pt_map.keys():
143
+ array = tf.train.load_variable(checkpoint_path, name)
144
+ tf_weights[name] = array
145
+ # logger.info("Loading TF weight {} with shape {}".format(name, array.shape))
146
+
147
+ # Load parameters
148
+ with torch.no_grad():
149
+ pt_params_pnt = set()
150
+ for name, pointer in tf_to_pt_map.items():
151
+ array = tf_weights[name]
152
+ if pointer.dim() == 1:
153
+ if pointer.dim() < array.ndim:
154
+ array = np.squeeze(array)
155
+ elif pointer.dim() == 2: # Weights
156
+ array = np.transpose(array)
157
+ elif pointer.dim() == 4: # Convolutions
158
+ array = np.transpose(array, (3, 2, 0, 1))
159
+ else:
160
+ raise "Wrong dimensions to adjust: " + str((pointer.shape, array.shape))
161
+ if pointer.shape != array.shape:
162
+ raise ValueError("Wrong dimensions: " + str((pointer.shape, array.shape)))
163
+ logger.info("Initialize PyTorch weight {} with shape {}".format(name, pointer.shape))
164
+ pointer.data = torch.from_numpy(array) if isinstance(array, np.ndarray) else torch.tensor(array)
165
+ tf_weights.pop(name, None)
166
+ pt_params_pnt.add(pointer.data_ptr())
167
+
168
+ # Prepare SpectralNorm buffers by running one step of Spectral Norm (no need to train the model):
169
+ for module in model.modules():
170
+ for n, buffer in module.named_buffers():
171
+ if n == 'weight_v':
172
+ weight_mat = module.weight_orig
173
+ weight_mat = weight_mat.reshape(weight_mat.size(0), -1)
174
+ u = module.weight_u
175
+
176
+ v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=config.eps)
177
+ buffer.data = v
178
+ pt_params_pnt.add(buffer.data_ptr())
179
+
180
+ u = normalize(torch.mv(weight_mat, v), dim=0, eps=config.eps)
181
+ module.weight_u.data = u
182
+ pt_params_pnt.add(module.weight_u.data_ptr())
183
+
184
+ # Load batch norm statistics
185
+ index = 0
186
+ for layer in model.generator.layers:
187
+ if not hasattr(layer, 'bn_0'):
188
+ continue
189
+ for i in range(4): # Batchnorms
190
+ bn_pointer = getattr(layer, 'bn_%d' % i)
191
+ pointer = bn_pointer.running_means
192
+ if pointer.shape != stats[index].shape:
193
+ raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape))
194
+ pointer.data = torch.from_numpy(stats[index])
195
+ pt_params_pnt.add(pointer.data_ptr())
196
+
197
+ pointer = bn_pointer.running_vars
198
+ if pointer.shape != stats[index+1].shape:
199
+ raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape))
200
+ pointer.data = torch.from_numpy(stats[index+1])
201
+ pt_params_pnt.add(pointer.data_ptr())
202
+
203
+ index += 2
204
+
205
+ bn_pointer = model.generator.bn
206
+ pointer = bn_pointer.running_means
207
+ if pointer.shape != stats[index].shape:
208
+ raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape))
209
+ pointer.data = torch.from_numpy(stats[index])
210
+ pt_params_pnt.add(pointer.data_ptr())
211
+
212
+ pointer = bn_pointer.running_vars
213
+ if pointer.shape != stats[index+1].shape:
214
+ raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape))
215
+ pointer.data = torch.from_numpy(stats[index+1])
216
+ pt_params_pnt.add(pointer.data_ptr())
217
+
218
+ remaining_params = list(n for n, t in chain(model.named_parameters(), model.named_buffers()) \
219
+ if t.data_ptr() not in pt_params_pnt)
220
+
221
+ logger.info("TF Weights not copied to PyTorch model: {} -".format(', '.join(tf_weights.keys())))
222
+ logger.info("Remanining parameters/buffers from PyTorch model: {} -".format(', '.join(remaining_params)))
223
+
224
+ return model
225
+
226
+
227
+ BigGAN128 = BigGANConfig(output_dim=128, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000,
228
+ layers=[(False, 16, 16),
229
+ (True, 16, 16),
230
+ (False, 16, 16),
231
+ (True, 16, 8),
232
+ (False, 8, 8),
233
+ (True, 8, 4),
234
+ (False, 4, 4),
235
+ (True, 4, 2),
236
+ (False, 2, 2),
237
+ (True, 2, 1)],
238
+ attention_layer_position=8, eps=1e-4, n_stats=51)
239
+
240
+ BigGAN256 = BigGANConfig(output_dim=256, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000,
241
+ layers=[(False, 16, 16),
242
+ (True, 16, 16),
243
+ (False, 16, 16),
244
+ (True, 16, 8),
245
+ (False, 8, 8),
246
+ (True, 8, 8),
247
+ (False, 8, 8),
248
+ (True, 8, 4),
249
+ (False, 4, 4),
250
+ (True, 4, 2),
251
+ (False, 2, 2),
252
+ (True, 2, 1)],
253
+ attention_layer_position=8, eps=1e-4, n_stats=51)
254
+
255
+ BigGAN512 = BigGANConfig(output_dim=512, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000,
256
+ layers=[(False, 16, 16),
257
+ (True, 16, 16),
258
+ (False, 16, 16),
259
+ (True, 16, 8),
260
+ (False, 8, 8),
261
+ (True, 8, 8),
262
+ (False, 8, 8),
263
+ (True, 8, 4),
264
+ (False, 4, 4),
265
+ (True, 4, 2),
266
+ (False, 2, 2),
267
+ (True, 2, 1),
268
+ (False, 1, 1),
269
+ (True, 1, 1)],
270
+ attention_layer_position=8, eps=1e-4, n_stats=51)
271
+
272
+
273
+ def main():
274
+ parser = argparse.ArgumentParser(description="Convert a BigGAN TF Hub model in a PyTorch model")
275
+ parser.add_argument("--model_type", type=str, default="", required=True,
276
+ help="BigGAN model type (128, 256, 512)")
277
+ parser.add_argument("--tf_model_path", type=str, default="", required=True,
278
+ help="Path of the downloaded TF Hub model")
279
+ parser.add_argument("--pt_save_path", type=str, default="",
280
+ help="Folder to save the PyTorch model (default: Folder of the TF Hub model)")
281
+ parser.add_argument("--batch_norm_stats_path", type=str, default="",
282
+ help="Path of previously extracted batch norm statistics")
283
+ args = parser.parse_args()
284
+
285
+ logging.basicConfig(level=logging.INFO)
286
+
287
+ if not args.pt_save_path:
288
+ args.pt_save_path = args.tf_model_path
289
+
290
+ if args.model_type == "128":
291
+ config = BigGAN128
292
+ elif args.model_type == "256":
293
+ config = BigGAN256
294
+ elif args.model_type == "512":
295
+ config = BigGAN512
296
+ else:
297
+ raise ValueError("model_type should be one of 128, 256 or 512")
298
+
299
+ model = BigGAN(config)
300
+ model = load_tf_weights_in_biggan(model, config, args.tf_model_path, args.batch_norm_stats_path)
301
+
302
+ model_save_path = os.path.join(args.pt_save_path, WEIGHTS_NAME)
303
+ config_save_path = os.path.join(args.pt_save_path, CONFIG_NAME)
304
+
305
+ logger.info("Save model dump to {}".format(model_save_path))
306
+ torch.save(model.state_dict(), model_save_path)
307
+ logger.info("Save configuration file to {}".format(config_save_path))
308
+ with open(config_save_path, "w", encoding="utf-8") as f:
309
+ f.write(config.to_json_string())
310
+
311
+ if __name__ == "__main__":
312
+ main()
networks/biggan/download_tf.sh ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2019-present, Thomas Wolf, Huggingface Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ set -e
9
+ set -x
10
+
11
+ models="128 256 512"
12
+
13
+ mkdir -p models/model_128
14
+ mkdir -p models/model_256
15
+ mkdir -p models/model_512
16
+
17
+ # Download TF Hub models.
18
+ for model in $models
19
+ do
20
+ curl -L "https://tfhub.dev/deepmind/biggan-deep-$model/1?tf-hub-format=compressed" | tar -zxvC models/model_$model
21
+ done
networks/biggan/file_utils.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Utilities for working with the local dataset cache.
3
+ This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
4
+ Copyright by the AllenNLP authors.
5
+ """
6
+ from __future__ import (absolute_import, division, print_function, unicode_literals)
7
+
8
+ import json
9
+ import logging
10
+ import os
11
+ import shutil
12
+ import tempfile
13
+ from functools import wraps
14
+ from hashlib import sha256
15
+ import sys
16
+ from io import open
17
+
18
+ import requests
19
+ from tqdm import tqdm
20
+
21
+ try:
22
+ from urllib.parse import urlparse
23
+ except ImportError:
24
+ from urlparse import urlparse
25
+
26
+ try:
27
+ from pathlib import Path
28
+ PYTORCH_PRETRAINED_BIGGAN_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
29
+ Path.home() / '.pytorch_pretrained_biggan'))
30
+ except (AttributeError, ImportError):
31
+ PYTORCH_PRETRAINED_BIGGAN_CACHE = os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
32
+ os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_biggan'))
33
+
34
+ logger = logging.getLogger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ def url_to_filename(url, etag=None):
38
+ """
39
+ Convert `url` into a hashed filename in a repeatable way.
40
+ If `etag` is specified, append its hash to the url's, delimited
41
+ by a period.
42
+ """
43
+ url_bytes = url.encode('utf-8')
44
+ url_hash = sha256(url_bytes)
45
+ filename = url_hash.hexdigest()
46
+
47
+ if etag:
48
+ etag_bytes = etag.encode('utf-8')
49
+ etag_hash = sha256(etag_bytes)
50
+ filename += '.' + etag_hash.hexdigest()
51
+
52
+ return filename
53
+
54
+
55
+ def filename_to_url(filename, cache_dir=None):
56
+ """
57
+ Return the url and etag (which may be ``None``) stored for `filename`.
58
+ Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
59
+ """
60
+ if cache_dir is None:
61
+ cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
62
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
63
+ cache_dir = str(cache_dir)
64
+
65
+ cache_path = os.path.join(cache_dir, filename)
66
+ if not os.path.exists(cache_path):
67
+ raise EnvironmentError("file {} not found".format(cache_path))
68
+
69
+ meta_path = cache_path + '.json'
70
+ if not os.path.exists(meta_path):
71
+ raise EnvironmentError("file {} not found".format(meta_path))
72
+
73
+ with open(meta_path, encoding="utf-8") as meta_file:
74
+ metadata = json.load(meta_file)
75
+ url = metadata['url']
76
+ etag = metadata['etag']
77
+
78
+ return url, etag
79
+
80
+
81
+ def cached_path(url_or_filename, cache_dir=None):
82
+ """
83
+ Given something that might be a URL (or might be a local path),
84
+ determine which. If it's a URL, download the file and cache it, and
85
+ return the path to the cached file. If it's already a local path,
86
+ make sure the file exists and then return the path.
87
+ """
88
+ if cache_dir is None:
89
+ cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
90
+ if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
91
+ url_or_filename = str(url_or_filename)
92
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
93
+ cache_dir = str(cache_dir)
94
+
95
+ parsed = urlparse(url_or_filename)
96
+
97
+ if parsed.scheme in ('http', 'https', 's3'):
98
+ # URL, so get it from the cache (downloading if necessary)
99
+ return get_from_cache(url_or_filename, cache_dir)
100
+ elif os.path.exists(url_or_filename):
101
+ # File, and it exists.
102
+ return url_or_filename
103
+ elif parsed.scheme == '':
104
+ # File, but it doesn't exist.
105
+ raise EnvironmentError("file {} not found".format(url_or_filename))
106
+ else:
107
+ # Something unknown
108
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
109
+
110
+
111
+ def split_s3_path(url):
112
+ """Split a full s3 path into the bucket name and path."""
113
+ parsed = urlparse(url)
114
+ if not parsed.netloc or not parsed.path:
115
+ raise ValueError("bad s3 path {}".format(url))
116
+ bucket_name = parsed.netloc
117
+ s3_path = parsed.path
118
+ # Remove '/' at beginning of path.
119
+ if s3_path.startswith("/"):
120
+ s3_path = s3_path[1:]
121
+ return bucket_name, s3_path
122
+
123
+
124
+ def s3_request(func):
125
+ """
126
+ Wrapper function for s3 requests in order to create more helpful error
127
+ messages.
128
+ """
129
+
130
+ @wraps(func)
131
+ def wrapper(url, *args, **kwargs):
132
+ try:
133
+ return func(url, *args, **kwargs)
134
+ except ClientError as exc:
135
+ if int(exc.response["Error"]["Code"]) == 404:
136
+ raise EnvironmentError("file {} not found".format(url))
137
+ else:
138
+ raise
139
+
140
+ return wrapper
141
+
142
+
143
+ def http_get(url, temp_file):
144
+ req = requests.get(url, stream=True)
145
+ content_length = req.headers.get('Content-Length')
146
+ total = int(content_length) if content_length is not None else None
147
+ progress = tqdm(unit="B", total=total)
148
+ for chunk in req.iter_content(chunk_size=1024):
149
+ if chunk: # filter out keep-alive new chunks
150
+ progress.update(len(chunk))
151
+ temp_file.write(chunk)
152
+ progress.close()
153
+
154
+
155
+ def get_from_cache(url, cache_dir=None):
156
+ """
157
+ Given a URL, look for the corresponding dataset in the local cache.
158
+ If it's not there, download it. Then return the path to the cached file.
159
+ """
160
+ if cache_dir is None:
161
+ cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
162
+ if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
163
+ cache_dir = str(cache_dir)
164
+
165
+ if not os.path.exists(cache_dir):
166
+ os.makedirs(cache_dir)
167
+
168
+ # Get eTag to add to filename, if it exists.
169
+ if url.startswith("s3://"):
170
+ print('Not supported due to colab demo. Sorry!')
171
+ raise
172
+ else:
173
+ response = requests.head(url, allow_redirects=True)
174
+ if response.status_code != 200:
175
+ raise IOError("HEAD request failed for url {} with status code {}"
176
+ .format(url, response.status_code))
177
+ etag = response.headers.get("ETag")
178
+
179
+ filename = url_to_filename(url, etag)
180
+
181
+ # get cache path to put the file
182
+ cache_path = os.path.join(cache_dir, filename)
183
+
184
+ if not os.path.exists(cache_path):
185
+ # Download to temporary file, then copy to cache dir once finished.
186
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
187
+ with tempfile.NamedTemporaryFile() as temp_file:
188
+ logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
189
+
190
+ # GET file object
191
+ if url.startswith("s3://"):
192
+ print('Not supported due to colab demo. Sorry!')
193
+ raise
194
+ else:
195
+ http_get(url, temp_file)
196
+
197
+ # we are copying the file before closing it, so flush to avoid truncation
198
+ temp_file.flush()
199
+ # shutil.copyfileobj() starts at the current position, so go to the start
200
+ temp_file.seek(0)
201
+
202
+ logger.info("copying %s to cache at %s", temp_file.name, cache_path)
203
+ with open(cache_path, 'wb') as cache_file:
204
+ shutil.copyfileobj(temp_file, cache_file)
205
+
206
+ logger.info("creating metadata file for %s", cache_path)
207
+ meta = {'url': url, 'etag': etag}
208
+ meta_path = cache_path + '.json'
209
+ with open(meta_path, 'w', encoding="utf-8") as meta_file:
210
+ json.dump(meta, meta_file)
211
+
212
+ logger.info("removing temp file %s", temp_file.name)
213
+
214
+ return cache_path
215
+
216
+
217
+ def read_set_from_file(filename):
218
+ '''
219
+ Extract a de-duped collection (set) of text from a file.
220
+ Expected file format is one item per line.
221
+ '''
222
+ collection = set()
223
+ with open(filename, 'r', encoding='utf-8') as file_:
224
+ for line in file_:
225
+ collection.add(line.rstrip())
226
+ return collection
227
+
228
+
229
+ def get_file_extension(path, dot=True, lower=True):
230
+ ext = os.path.splitext(path)[1]
231
+ ext = ext if dot else ext[1:]
232
+ return ext.lower() if lower else ext
networks/biggan/model.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """ BigGAN PyTorch model.
3
+ From "Large Scale GAN Training for High Fidelity Natural Image Synthesis"
4
+ By Andrew Brocky, Jeff Donahuey and Karen Simonyan.
5
+ https://openreview.net/forum?id=B1xsqj09Fm
6
+
7
+ PyTorch version implemented from the computational graph of the TF Hub module for BigGAN.
8
+ Some part of the code are adapted from https://github.com/brain-research/self-attention-gan
9
+
10
+ This version only comprises the generator (since the discriminator's weights are not released).
11
+ This version only comprises the "deep" version of BigGAN (see publication).
12
+ """
13
+ from __future__ import (absolute_import, division, print_function, unicode_literals)
14
+
15
+ import os
16
+ import logging
17
+ import math
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+
24
+ from .config import BigGANConfig
25
+ from .file_utils import cached_path
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+ PRETRAINED_MODEL_ARCHIVE_MAP = {
30
+ 'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-pytorch_model.bin",
31
+ 'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-pytorch_model.bin",
32
+ 'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-pytorch_model.bin",
33
+ }
34
+
35
+ PRETRAINED_CONFIG_ARCHIVE_MAP = {
36
+ 'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-config.json",
37
+ 'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-config.json",
38
+ 'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-config.json",
39
+ }
40
+
41
+ WEIGHTS_NAME = 'pytorch_model.bin'
42
+ CONFIG_NAME = 'config.json'
43
+
44
+
45
+ def snconv2d(eps=1e-12, **kwargs):
46
+ return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps)
47
+
48
+ def snlinear(eps=1e-12, **kwargs):
49
+ return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps)
50
+
51
+ def sn_embedding(eps=1e-12, **kwargs):
52
+ return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps)
53
+
54
+ class SelfAttn(nn.Module):
55
+ """ Self attention Layer"""
56
+ def __init__(self, in_channels, eps=1e-12):
57
+ super(SelfAttn, self).__init__()
58
+ self.in_channels = in_channels
59
+ self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
60
+ kernel_size=1, bias=False, eps=eps)
61
+ self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
62
+ kernel_size=1, bias=False, eps=eps)
63
+ self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2,
64
+ kernel_size=1, bias=False, eps=eps)
65
+ self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels,
66
+ kernel_size=1, bias=False, eps=eps)
67
+ self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
68
+ self.softmax = nn.Softmax(dim=-1)
69
+ self.gamma = nn.Parameter(torch.zeros(1))
70
+
71
+ def forward(self, x):
72
+ _, ch, h, w = x.size()
73
+ # Theta path
74
+ theta = self.snconv1x1_theta(x)
75
+ theta = theta.view(-1, ch//8, h*w)
76
+ # Phi path
77
+ phi = self.snconv1x1_phi(x)
78
+ phi = self.maxpool(phi)
79
+ phi = phi.view(-1, ch//8, h*w//4)
80
+ # Attn map
81
+ attn = torch.bmm(theta.permute(0, 2, 1), phi)
82
+ attn = self.softmax(attn)
83
+ # g path
84
+ g = self.snconv1x1_g(x)
85
+ g = self.maxpool(g)
86
+ g = g.view(-1, ch//2, h*w//4)
87
+ # Attn_g - o_conv
88
+ attn_g = torch.bmm(g, attn.permute(0, 2, 1))
89
+ attn_g = attn_g.view(-1, ch//2, h, w)
90
+ attn_g = self.snconv1x1_o_conv(attn_g)
91
+ # Out
92
+ out = x + self.gamma*attn_g
93
+ return out
94
+
95
+
96
+ class BigGANBatchNorm(nn.Module):
97
+ """ This is a batch norm module that can handle conditional input and can be provided with pre-computed
98
+ activation means and variances for various truncation parameters.
99
+
100
+ We cannot just rely on torch.batch_norm since it cannot handle
101
+ batched weights (pytorch 1.0.1). We computate batch_norm our-self without updating running means and variances.
102
+ If you want to train this model you should add running means and variance computation logic.
103
+ """
104
+ def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, conditional=True):
105
+ super(BigGANBatchNorm, self).__init__()
106
+ self.num_features = num_features
107
+ self.eps = eps
108
+ self.conditional = conditional
109
+
110
+ # We use pre-computed statistics for n_stats values of truncation between 0 and 1
111
+ self.register_buffer('running_means', torch.zeros(n_stats, num_features))
112
+ self.register_buffer('running_vars', torch.ones(n_stats, num_features))
113
+ self.step_size = 1.0 / (n_stats - 1)
114
+
115
+ if conditional:
116
+ assert condition_vector_dim is not None
117
+ self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
118
+ self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
119
+ else:
120
+ self.weight = torch.nn.Parameter(torch.Tensor(num_features))
121
+ self.bias = torch.nn.Parameter(torch.Tensor(num_features))
122
+
123
+ def forward(self, x, truncation, condition_vector=None):
124
+ # Retreive pre-computed statistics associated to this truncation
125
+ coef, start_idx = math.modf(truncation / self.step_size)
126
+ start_idx = int(start_idx)
127
+ if coef != 0.0: # Interpolate
128
+ running_mean = self.running_means[start_idx] * coef + self.running_means[start_idx + 1] * (1 - coef)
129
+ running_var = self.running_vars[start_idx] * coef + self.running_vars[start_idx + 1] * (1 - coef)
130
+ else:
131
+ running_mean = self.running_means[start_idx]
132
+ running_var = self.running_vars[start_idx]
133
+
134
+ if self.conditional:
135
+ running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
136
+ running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
137
+
138
+ weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1)
139
+ bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1)
140
+
141
+ out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias
142
+ else:
143
+ out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias,
144
+ training=False, momentum=0.0, eps=self.eps)
145
+
146
+ return out
147
+
148
+
149
+ class GenBlock(nn.Module):
150
+ def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, up_sample=False,
151
+ n_stats=51, eps=1e-12):
152
+ super(GenBlock, self).__init__()
153
+ self.up_sample = up_sample
154
+ self.drop_channels = (in_size != out_size)
155
+ middle_size = in_size // reduction_factor
156
+
157
+ self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
158
+ self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, eps=eps)
159
+
160
+ self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
161
+ self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
162
+
163
+ self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
164
+ self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
165
+
166
+ self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
167
+ self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, eps=eps)
168
+
169
+ self.relu = nn.ReLU()
170
+
171
+ def forward(self, x, cond_vector, truncation):
172
+ x0 = x
173
+
174
+ x = self.bn_0(x, truncation, cond_vector)
175
+ x = self.relu(x)
176
+ x = self.conv_0(x)
177
+
178
+ x = self.bn_1(x, truncation, cond_vector)
179
+ x = self.relu(x)
180
+ if self.up_sample:
181
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
182
+ x = self.conv_1(x)
183
+
184
+ x = self.bn_2(x, truncation, cond_vector)
185
+ x = self.relu(x)
186
+ x = self.conv_2(x)
187
+
188
+ x = self.bn_3(x, truncation, cond_vector)
189
+ x = self.relu(x)
190
+ x = self.conv_3(x)
191
+
192
+ if self.drop_channels:
193
+ new_channels = x0.shape[1] // 2
194
+ x0 = x0[:, :new_channels, ...]
195
+ if self.up_sample:
196
+ x0 = F.interpolate(x0, scale_factor=2, mode='nearest')
197
+
198
+ out = x + x0
199
+ return out
200
+
201
+ class Generator(nn.Module):
202
+ def __init__(self, config):
203
+ super(Generator, self).__init__()
204
+ self.config = config
205
+ ch = config.channel_width
206
+ condition_vector_dim = config.z_dim * 2
207
+
208
+ self.gen_z = snlinear(in_features=condition_vector_dim,
209
+ out_features=4 * 4 * 16 * ch, eps=config.eps)
210
+
211
+ layers = []
212
+ for i, layer in enumerate(config.layers):
213
+ if i == config.attention_layer_position:
214
+ layers.append(SelfAttn(ch*layer[1], eps=config.eps))
215
+ layers.append(GenBlock(ch*layer[1],
216
+ ch*layer[2],
217
+ condition_vector_dim,
218
+ up_sample=layer[0],
219
+ n_stats=config.n_stats,
220
+ eps=config.eps))
221
+ self.layers = nn.ModuleList(layers)
222
+
223
+ self.bn = BigGANBatchNorm(ch, n_stats=config.n_stats, eps=config.eps, conditional=False)
224
+ self.relu = nn.ReLU()
225
+ self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, eps=config.eps)
226
+ self.tanh = nn.Tanh()
227
+
228
+ def forward(self, cond_vector, truncation, z=None, start=0, stop=None):
229
+ # We use this conversion step to be able to use TF weights:
230
+ # TF convention on shape is [batch, height, width, channels]
231
+ # PT convention on shape is [batch, channels, height, width]
232
+ if start == 0 and z is None:
233
+ z = self.gen_z(cond_vector)
234
+ z = z.view(-1, 4, 4, 16 * self.config.channel_width)
235
+ z = z.permute(0, 3, 1, 2).contiguous()
236
+
237
+ if stop is None: stop = len(self.layers)
238
+
239
+ # for i, layer in enumerate(self.layers):
240
+ for i in range(start, stop):
241
+ if isinstance(self.layers[i], GenBlock):
242
+ z = self.layers[i](z, cond_vector, truncation)
243
+ else:
244
+ z = self.layers[i](z)
245
+
246
+ if stop == len(self.layers):
247
+ z = self.bn(z, truncation)
248
+ z = self.relu(z)
249
+ z = self.conv_to_rgb(z)
250
+ z = z[:, :3, ...]
251
+ z = self.tanh(z)
252
+
253
+ # for i, layer in enumerate(self.layers):
254
+ # if isinstance(layer, GenBlock):
255
+ # z = layer(z, cond_vector, truncation)
256
+ # else:
257
+ # z = layer(z)
258
+
259
+ # z = self.bn(z, truncation)
260
+ # z = self.relu(z)
261
+ # z = self.conv_to_rgb(z)
262
+ # z = z[:, :3, ...]
263
+ # z = self.tanh(z)
264
+ return z
265
+
266
+ class BigGAN(nn.Module):
267
+ """BigGAN Generator."""
268
+
269
+ @classmethod
270
+ def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
271
+ if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
272
+ model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
273
+ config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
274
+ else:
275
+ model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
276
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
277
+
278
+ try:
279
+ resolved_model_file = cached_path(model_file, cache_dir=cache_dir)
280
+ resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
281
+ except EnvironmentError:
282
+ logger.error("Wrong model name, should be a valid path to a folder containing "
283
+ "a {} file and a {} file or a model name in {}".format(
284
+ WEIGHTS_NAME, CONFIG_NAME, PRETRAINED_MODEL_ARCHIVE_MAP.keys()))
285
+ raise
286
+
287
+ logger.info("loading model {} from cache at {}".format(pretrained_model_name_or_path, resolved_model_file))
288
+
289
+ # Load config
290
+ config = BigGANConfig.from_json_file(resolved_config_file)
291
+ logger.info("Model config {}".format(config))
292
+
293
+ # Instantiate model.
294
+ model = cls(config, *inputs, **kwargs)
295
+ state_dict = torch.load(resolved_model_file, map_location='cpu' if not torch.cuda.is_available() else None)
296
+ model.load_state_dict(state_dict, strict=False)
297
+ return model
298
+
299
+ def __init__(self, config):
300
+ super(BigGAN, self).__init__()
301
+ self.config = config
302
+ self.embeddings = nn.Linear(config.num_classes, config.z_dim, bias=False)
303
+ self.generator = Generator(config)
304
+ # self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
305
+ # print(f'device: {self.device}')
306
+ # self.generator.to(self.device)
307
+
308
+ def forward(self, z, class_label, truncation, cond_vector=None, start=0, stop=None):
309
+ assert 0 < truncation <= 1
310
+
311
+ results = {}
312
+ if start == 0 and cond_vector is None:
313
+ embed = self.embeddings(class_label)
314
+ cond_vector = torch.cat((z, embed), dim=1)
315
+ results['cond_vector'] = cond_vector
316
+
317
+ results['z'] = self.generator(cond_vector, truncation, z=None if start == 0 else z, start=start, stop=stop)
318
+ return results
319
+
320
+
321
+ if __name__ == "__main__":
322
+ import PIL
323
+ from .utils import truncated_noise_sample, save_as_images, one_hot_from_names
324
+ from .convert_tf_to_pytorch import load_tf_weights_in_biggan
325
+
326
+ load_cache = False
327
+ cache_path = './saved_model.pt'
328
+ config = BigGANConfig()
329
+ model = BigGAN(config)
330
+ if not load_cache:
331
+ model = load_tf_weights_in_biggan(model, config, './models/model_128/', './models/model_128/batchnorms_stats.bin')
332
+ torch.save(model.state_dict(), cache_path)
333
+ else:
334
+ model.load_state_dict(torch.load(cache_path))
335
+
336
+ model.eval()
337
+
338
+ truncation = 0.4
339
+ noise = truncated_noise_sample(batch_size=2, truncation=truncation)
340
+ label = one_hot_from_names('diver', batch_size=2)
341
+
342
+ # Tests
343
+ # noise = np.zeros((1, 128))
344
+ # label = [983]
345
+
346
+ noise = torch.tensor(noise, dtype=torch.float)
347
+ label = torch.tensor(label, dtype=torch.float)
348
+ with torch.no_grad():
349
+ outputs = model(noise, label, truncation)
350
+ print(outputs.shape)
351
+
352
+ save_as_images(outputs)
networks/biggan/utils.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ """ BigGAN utilities to prepare truncated noise samples and convert/save/display output images.
3
+ Also comprise ImageNet utilities to prepare one hot input vectors for ImageNet classes.
4
+ We use Wordnet so you can just input a name in a string and automatically get a corresponding
5
+ imagenet class if it exists (or a hypo/hypernym exists in imagenet).
6
+ """
7
+ from __future__ import absolute_import, division, print_function, unicode_literals
8
+
9
+ import json
10
+ import logging
11
+ from io import BytesIO
12
+
13
+ import numpy as np
14
+ from scipy.stats import truncnorm
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ NUM_CLASSES = 1000
19
+
20
+
21
+ def truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None):
22
+ """ Create a truncated noise vector.
23
+ Params:
24
+ batch_size: batch size.
25
+ dim_z: dimension of z
26
+ truncation: truncation value to use
27
+ seed: seed for the random generator
28
+ Output:
29
+ array of shape (batch_size, dim_z)
30
+ """
31
+ state = None if seed is None else np.random.RandomState(seed)
32
+ values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32)
33
+ return truncation * values
34
+
35
+
36
+ def convert_to_images(obj):
37
+ """ Convert an output tensor from BigGAN in a list of images.
38
+ Params:
39
+ obj: tensor or numpy array of shape (batch_size, channels, height, width)
40
+ Output:
41
+ list of Pillow Images of size (height, width)
42
+ """
43
+ try:
44
+ import PIL
45
+ except ImportError:
46
+ raise ImportError("Please install Pillow to use images: pip install Pillow")
47
+
48
+ if not isinstance(obj, np.ndarray):
49
+ obj = obj.detach().numpy()
50
+
51
+ obj = obj.transpose((0, 2, 3, 1))
52
+ obj = np.clip(((obj + 1) / 2.0) * 256, 0, 255)
53
+
54
+ img = []
55
+ for i, out in enumerate(obj):
56
+ out_array = np.asarray(np.uint8(out), dtype=np.uint8)
57
+ img.append(PIL.Image.fromarray(out_array))
58
+ return img
59
+
60
+
61
+ def save_as_images(obj, file_name='output'):
62
+ """ Convert and save an output tensor from BigGAN in a list of saved images.
63
+ Params:
64
+ obj: tensor or numpy array of shape (batch_size, channels, height, width)
65
+ file_name: path and beggingin of filename to save.
66
+ Images will be saved as `file_name_{image_number}.png`
67
+ """
68
+ img = convert_to_images(obj)
69
+
70
+ for i, out in enumerate(img):
71
+ current_file_name = file_name + '_%d.png' % i
72
+ logger.info("Saving image to {}".format(current_file_name))
73
+ out.save(current_file_name, 'png')
74
+
75
+
76
+ def display_in_terminal(obj):
77
+ """ Convert and display an output tensor from BigGAN in the terminal.
78
+ This function use `libsixel` and will only work in a libsixel-compatible terminal.
79
+ Please refer to https://github.com/saitoha/libsixel for more details.
80
+
81
+ Params:
82
+ obj: tensor or numpy array of shape (batch_size, channels, height, width)
83
+ file_name: path and beggingin of filename to save.
84
+ Images will be saved as `file_name_{image_number}.png`
85
+ """
86
+ try:
87
+ import PIL
88
+ from libsixel import (sixel_output_new, sixel_dither_new, sixel_dither_initialize,
89
+ sixel_dither_set_palette, sixel_dither_set_pixelformat,
90
+ sixel_dither_get, sixel_encode, sixel_dither_unref,
91
+ sixel_output_unref, SIXEL_PIXELFORMAT_RGBA8888,
92
+ SIXEL_PIXELFORMAT_RGB888, SIXEL_PIXELFORMAT_PAL8,
93
+ SIXEL_PIXELFORMAT_G8, SIXEL_PIXELFORMAT_G1)
94
+ except ImportError:
95
+ raise ImportError("Display in Terminal requires Pillow, libsixel "
96
+ "and a libsixel compatible terminal. "
97
+ "Please read info at https://github.com/saitoha/libsixel "
98
+ "and install with pip install Pillow libsixel-python")
99
+
100
+ s = BytesIO()
101
+
102
+ images = convert_to_images(obj)
103
+ widths, heights = zip(*(i.size for i in images))
104
+
105
+ output_width = sum(widths)
106
+ output_height = max(heights)
107
+
108
+ output_image = PIL.Image.new('RGB', (output_width, output_height))
109
+
110
+ x_offset = 0
111
+ for im in images:
112
+ output_image.paste(im, (x_offset,0))
113
+ x_offset += im.size[0]
114
+
115
+ try:
116
+ data = output_image.tobytes()
117
+ except NotImplementedError:
118
+ data = output_image.tostring()
119
+ output = sixel_output_new(lambda data, s: s.write(data), s)
120
+
121
+ try:
122
+ if output_image.mode == 'RGBA':
123
+ dither = sixel_dither_new(256)
124
+ sixel_dither_initialize(dither, data, output_width, output_height, SIXEL_PIXELFORMAT_RGBA8888)
125
+ elif output_image.mode == 'RGB':
126
+ dither = sixel_dither_new(256)
127
+ sixel_dither_initialize(dither, data, output_width, output_height, SIXEL_PIXELFORMAT_RGB888)
128
+ elif output_image.mode == 'P':
129
+ palette = output_image.getpalette()
130
+ dither = sixel_dither_new(256)
131
+ sixel_dither_set_palette(dither, palette)
132
+ sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_PAL8)
133
+ elif output_image.mode == 'L':
134
+ dither = sixel_dither_get(SIXEL_BUILTIN_G8)
135
+ sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_G8)
136
+ elif output_image.mode == '1':
137
+ dither = sixel_dither_get(SIXEL_BUILTIN_G1)
138
+ sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_G1)
139
+ else:
140
+ raise RuntimeError('unexpected output_image mode')
141
+ try:
142
+ sixel_encode(data, output_width, output_height, 1, dither, output)
143
+ print(s.getvalue().decode('ascii'))
144
+ finally:
145
+ sixel_dither_unref(dither)
146
+ finally:
147
+ sixel_output_unref(output)
148
+
149
+
150
+ def one_hot_from_int(int_or_list, batch_size=1):
151
+ """ Create a one-hot vector from a class index or a list of class indices.
152
+ Params:
153
+ int_or_list: int, or list of int, of the imagenet classes (between 0 and 999)
154
+ batch_size: batch size.
155
+ If int_or_list is an int create a batch of identical classes.
156
+ If int_or_list is a list, we should have `len(int_or_list) == batch_size`
157
+ Output:
158
+ array of shape (batch_size, 1000)
159
+ """
160
+ if isinstance(int_or_list, int):
161
+ int_or_list = [int_or_list]
162
+
163
+ if len(int_or_list) == 1 and batch_size > 1:
164
+ int_or_list = [int_or_list[0]] * batch_size
165
+
166
+ assert batch_size == len(int_or_list)
167
+
168
+ array = np.zeros((batch_size, NUM_CLASSES), dtype=np.float32)
169
+ for i, j in enumerate(int_or_list):
170
+ array[i, j] = 1.0
171
+ return array
172
+
173
+
174
+ def one_hot_from_names(class_name_or_list, batch_size=1):
175
+ """ Create a one-hot vector from the name of an imagenet class ('tennis ball', 'daisy', ...).
176
+ We use NLTK's wordnet search to try to find the relevant synset of ImageNet and take the first one.
177
+ If we can't find it direcly, we look at the hyponyms and hypernyms of the class name.
178
+
179
+ Params:
180
+ class_name_or_list: string containing the name of an imagenet object or a list of such strings (for a batch).
181
+ Output:
182
+ array of shape (batch_size, 1000)
183
+ """
184
+ try:
185
+ from nltk.corpus import wordnet as wn
186
+ except ImportError:
187
+ raise ImportError("You need to install nltk to use this function")
188
+
189
+ if not isinstance(class_name_or_list, (list, tuple)):
190
+ class_name_or_list = [class_name_or_list]
191
+ else:
192
+ batch_size = max(batch_size, len(class_name_or_list))
193
+
194
+ classes = []
195
+ for class_name in class_name_or_list:
196
+ class_name = class_name.replace(" ", "_")
197
+
198
+ original_synsets = wn.synsets(class_name)
199
+ original_synsets = list(filter(lambda s: s.pos() == 'n', original_synsets)) # keep only names
200
+ if not original_synsets:
201
+ return None
202
+
203
+ possible_synsets = list(filter(lambda s: s.offset() in IMAGENET, original_synsets))
204
+ if possible_synsets:
205
+ classes.append(IMAGENET[possible_synsets[0].offset()])
206
+ else:
207
+ # try hypernyms and hyponyms
208
+ possible_synsets = sum([s.hypernyms() + s.hyponyms() for s in original_synsets], [])
209
+ possible_synsets = list(filter(lambda s: s.offset() in IMAGENET, possible_synsets))
210
+ if possible_synsets:
211
+ classes.append(IMAGENET[possible_synsets[0].offset()])
212
+
213
+ return one_hot_from_int(classes, batch_size=batch_size)
214
+
215
+
216
+ IMAGENET = {1440764: 0, 1443537: 1, 1484850: 2, 1491361: 3, 1494475: 4, 1496331: 5, 1498041: 6, 1514668: 7, 1514859: 8, 1518878: 9, 1530575: 10, 1531178: 11, 1532829: 12, 1534433: 13, 1537544: 14, 1558993: 15, 1560419: 16, 1580077: 17, 1582220: 18, 1592084: 19, 1601694: 20, 1608432: 21, 1614925: 22, 1616318: 23, 1622779: 24, 1629819: 25, 1630670: 26, 1631663: 27, 1632458: 28, 1632777: 29, 1641577: 30, 1644373: 31, 1644900: 32, 1664065: 33, 1665541: 34, 1667114: 35, 1667778: 36, 1669191: 37, 1675722: 38, 1677366: 39, 1682714: 40, 1685808: 41, 1687978: 42, 1688243: 43, 1689811: 44, 1692333: 45, 1693334: 46, 1694178: 47, 1695060: 48, 1697457: 49, 1698640: 50, 1704323: 51, 1728572: 52, 1728920: 53, 1729322: 54, 1729977: 55, 1734418: 56, 1735189: 57, 1737021: 58, 1739381: 59, 1740131: 60, 1742172: 61, 1744401: 62, 1748264: 63, 1749939: 64, 1751748: 65, 1753488: 66, 1755581: 67, 1756291: 68, 1768244: 69, 1770081: 70, 1770393: 71, 1773157: 72, 1773549: 73, 1773797: 74, 1774384: 75, 1774750: 76, 1775062: 77, 1776313: 78, 1784675: 79, 1795545: 80, 1796340: 81, 1797886: 82, 1798484: 83, 1806143: 84, 1806567: 85, 1807496: 86, 1817953: 87, 1818515: 88, 1819313: 89, 1820546: 90, 1824575: 91, 1828970: 92, 1829413: 93, 1833805: 94, 1843065: 95, 1843383: 96, 1847000: 97, 1855032: 98, 1855672: 99, 1860187: 100, 1871265: 101, 1872401: 102, 1873310: 103, 1877812: 104, 1882714: 105, 1883070: 106, 1910747: 107, 1914609: 108, 1917289: 109, 1924916: 110, 1930112: 111, 1943899: 112, 1944390: 113, 1945685: 114, 1950731: 115, 1955084: 116, 1968897: 117, 1978287: 118, 1978455: 119, 1980166: 120, 1981276: 121, 1983481: 122, 1984695: 123, 1985128: 124, 1986214: 125, 1990800: 126, 2002556: 127, 2002724: 128, 2006656: 129, 2007558: 130, 2009229: 131, 2009912: 132, 2011460: 133, 2012849: 134, 2013706: 135, 2017213: 136, 2018207: 137, 2018795: 138, 2025239: 139, 2027492: 140, 2028035: 141, 2033041: 142, 2037110: 143, 2051845: 144, 2056570: 145, 2058221: 146, 2066245: 147, 2071294: 148, 2074367: 149, 2077923: 150, 2085620: 151, 2085782: 152, 2085936: 153, 2086079: 154, 2086240: 155, 2086646: 156, 2086910: 157, 2087046: 158, 2087394: 159, 2088094: 160, 2088238: 161, 2088364: 162, 2088466: 163, 2088632: 164, 2089078: 165, 2089867: 166, 2089973: 167, 2090379: 168, 2090622: 169, 2090721: 170, 2091032: 171, 2091134: 172, 2091244: 173, 2091467: 174, 2091635: 175, 2091831: 176, 2092002: 177, 2092339: 178, 2093256: 179, 2093428: 180, 2093647: 181, 2093754: 182, 2093859: 183, 2093991: 184, 2094114: 185, 2094258: 186, 2094433: 187, 2095314: 188, 2095570: 189, 2095889: 190, 2096051: 191, 2096177: 192, 2096294: 193, 2096437: 194, 2096585: 195, 2097047: 196, 2097130: 197, 2097209: 198, 2097298: 199, 2097474: 200, 2097658: 201, 2098105: 202, 2098286: 203, 2098413: 204, 2099267: 205, 2099429: 206, 2099601: 207, 2099712: 208, 2099849: 209, 2100236: 210, 2100583: 211, 2100735: 212, 2100877: 213, 2101006: 214, 2101388: 215, 2101556: 216, 2102040: 217, 2102177: 218, 2102318: 219, 2102480: 220, 2102973: 221, 2104029: 222, 2104365: 223, 2105056: 224, 2105162: 225, 2105251: 226, 2105412: 227, 2105505: 228, 2105641: 229, 2105855: 230, 2106030: 231, 2106166: 232, 2106382: 233, 2106550: 234, 2106662: 235, 2107142: 236, 2107312: 237, 2107574: 238, 2107683: 239, 2107908: 240, 2108000: 241, 2108089: 242, 2108422: 243, 2108551: 244, 2108915: 245, 2109047: 246, 2109525: 247, 2109961: 248, 2110063: 249, 2110185: 250, 2110341: 251, 2110627: 252, 2110806: 253, 2110958: 254, 2111129: 255, 2111277: 256, 2111500: 257, 2111889: 258, 2112018: 259, 2112137: 260, 2112350: 261, 2112706: 262, 2113023: 263, 2113186: 264, 2113624: 265, 2113712: 266, 2113799: 267, 2113978: 268, 2114367: 269, 2114548: 270, 2114712: 271, 2114855: 272, 2115641: 273, 2115913: 274, 2116738: 275, 2117135: 276, 2119022: 277, 2119789: 278, 2120079: 279, 2120505: 280, 2123045: 281, 2123159: 282, 2123394: 283, 2123597: 284, 2124075: 285, 2125311: 286, 2127052: 287, 2128385: 288, 2128757: 289, 2128925: 290, 2129165: 291, 2129604: 292, 2130308: 293, 2132136: 294, 2133161: 295, 2134084: 296, 2134418: 297, 2137549: 298, 2138441: 299, 2165105: 300, 2165456: 301, 2167151: 302, 2168699: 303, 2169497: 304, 2172182: 305, 2174001: 306, 2177972: 307, 2190166: 308, 2206856: 309, 2219486: 310, 2226429: 311, 2229544: 312, 2231487: 313, 2233338: 314, 2236044: 315, 2256656: 316, 2259212: 317, 2264363: 318, 2268443: 319, 2268853: 320, 2276258: 321, 2277742: 322, 2279972: 323, 2280649: 324, 2281406: 325, 2281787: 326, 2317335: 327, 2319095: 328, 2321529: 329, 2325366: 330, 2326432: 331, 2328150: 332, 2342885: 333, 2346627: 334, 2356798: 335, 2361337: 336, 2363005: 337, 2364673: 338, 2389026: 339, 2391049: 340, 2395406: 341, 2396427: 342, 2397096: 343, 2398521: 344, 2403003: 345, 2408429: 346, 2410509: 347, 2412080: 348, 2415577: 349, 2417914: 350, 2422106: 351, 2422699: 352, 2423022: 353, 2437312: 354, 2437616: 355, 2441942: 356, 2442845: 357, 2443114: 358, 2443484: 359, 2444819: 360, 2445715: 361, 2447366: 362, 2454379: 363, 2457408: 364, 2480495: 365, 2480855: 366, 2481823: 367, 2483362: 368, 2483708: 369, 2484975: 370, 2486261: 371, 2486410: 372, 2487347: 373, 2488291: 374, 2488702: 375, 2489166: 376, 2490219: 377, 2492035: 378, 2492660: 379, 2493509: 380, 2493793: 381, 2494079: 382, 2497673: 383, 2500267: 384, 2504013: 385, 2504458: 386, 2509815: 387, 2510455: 388, 2514041: 389, 2526121: 390, 2536864: 391, 2606052: 392, 2607072: 393, 2640242: 394, 2641379: 395, 2643566: 396, 2655020: 397, 2666196: 398, 2667093: 399, 2669723: 400, 2672831: 401, 2676566: 402, 2687172: 403, 2690373: 404, 2692877: 405, 2699494: 406, 2701002: 407, 2704792: 408, 2708093: 409, 2727426: 410, 2730930: 411, 2747177: 412, 2749479: 413, 2769748: 414, 2776631: 415, 2777292: 416, 2782093: 417, 2783161: 418, 2786058: 419, 2787622: 420, 2788148: 421, 2790996: 422, 2791124: 423, 2791270: 424, 2793495: 425, 2794156: 426, 2795169: 427, 2797295: 428, 2799071: 429, 2802426: 430, 2804414: 431, 2804610: 432, 2807133: 433, 2808304: 434, 2808440: 435, 2814533: 436, 2814860: 437, 2815834: 438, 2817516: 439, 2823428: 440, 2823750: 441, 2825657: 442, 2834397: 443, 2835271: 444, 2837789: 445, 2840245: 446, 2841315: 447, 2843684: 448, 2859443: 449, 2860847: 450, 2865351: 451, 2869837: 452, 2870880: 453, 2871525: 454, 2877765: 455, 2879718: 456, 2883205: 457, 2892201: 458, 2892767: 459, 2894605: 460, 2895154: 461, 2906734: 462, 2909870: 463, 2910353: 464, 2916936: 465, 2917067: 466, 2927161: 467, 2930766: 468, 2939185: 469, 2948072: 470, 2950826: 471, 2951358: 472, 2951585: 473, 2963159: 474, 2965783: 475, 2966193: 476, 2966687: 477, 2971356: 478, 2974003: 479, 2977058: 480, 2978881: 481, 2979186: 482, 2980441: 483, 2981792: 484, 2988304: 485, 2992211: 486, 2992529: 487, 2999410: 488, 3000134: 489, 3000247: 490, 3000684: 491, 3014705: 492, 3016953: 493, 3017168: 494, 3018349: 495, 3026506: 496, 3028079: 497, 3032252: 498, 3041632: 499, 3042490: 500, 3045698: 501, 3047690: 502, 3062245: 503, 3063599: 504, 3063689: 505, 3065424: 506, 3075370: 507, 3085013: 508, 3089624: 509, 3095699: 510, 3100240: 511, 3109150: 512, 3110669: 513, 3124043: 514, 3124170: 515, 3125729: 516, 3126707: 517, 3127747: 518, 3127925: 519, 3131574: 520, 3133878: 521, 3134739: 522, 3141823: 523, 3146219: 524, 3160309: 525, 3179701: 526, 3180011: 527, 3187595: 528, 3188531: 529, 3196217: 530, 3197337: 531, 3201208: 532, 3207743: 533, 3207941: 534, 3208938: 535, 3216828: 536, 3218198: 537, 3220513: 538, 3223299: 539, 3240683: 540, 3249569: 541, 3250847: 542, 3255030: 543, 3259280: 544, 3271574: 545, 3272010: 546, 3272562: 547, 3290653: 548, 3291819: 549, 3297495: 550, 3314780: 551, 3325584: 552, 3337140: 553, 3344393: 554, 3345487: 555, 3347037: 556, 3355925: 557, 3372029: 558, 3376595: 559, 3379051: 560, 3384352: 561, 3388043: 562, 3388183: 563, 3388549: 564, 3393912: 565, 3394916: 566, 3400231: 567, 3404251: 568, 3417042: 569, 3424325: 570, 3425413: 571, 3443371: 572, 3444034: 573, 3445777: 574, 3445924: 575, 3447447: 576, 3447721: 577, 3450230: 578, 3452741: 579, 3457902: 580, 3459775: 581, 3461385: 582, 3467068: 583, 3476684: 584, 3476991: 585, 3478589: 586, 3481172: 587, 3482405: 588, 3483316: 589, 3485407: 590, 3485794: 591, 3492542: 592, 3494278: 593, 3495258: 594, 3496892: 595, 3498962: 596, 3527444: 597, 3529860: 598, 3530642: 599, 3532672: 600, 3534580: 601, 3535780: 602, 3538406: 603, 3544143: 604, 3584254: 605, 3584829: 606, 3590841: 607, 3594734: 608, 3594945: 609, 3595614: 610, 3598930: 611, 3599486: 612, 3602883: 613, 3617480: 614, 3623198: 615, 3627232: 616, 3630383: 617, 3633091: 618, 3637318: 619, 3642806: 620, 3649909: 621, 3657121: 622, 3658185: 623, 3661043: 624, 3662601: 625, 3666591: 626, 3670208: 627, 3673027: 628, 3676483: 629, 3680355: 630, 3690938: 631, 3691459: 632, 3692522: 633, 3697007: 634, 3706229: 635, 3709823: 636, 3710193: 637, 3710637: 638, 3710721: 639, 3717622: 640, 3720891: 641, 3721384: 642, 3724870: 643, 3729826: 644, 3733131: 645, 3733281: 646, 3733805: 647, 3742115: 648, 3743016: 649, 3759954: 650, 3761084: 651, 3763968: 652, 3764736: 653, 3769881: 654, 3770439: 655, 3770679: 656, 3773504: 657, 3775071: 658, 3775546: 659, 3776460: 660, 3777568: 661, 3777754: 662, 3781244: 663, 3782006: 664, 3785016: 665, 3786901: 666, 3787032: 667, 3788195: 668, 3788365: 669, 3791053: 670, 3792782: 671, 3792972: 672, 3793489: 673, 3794056: 674, 3796401: 675, 3803284: 676, 3804744: 677, 3814639: 678, 3814906: 679, 3825788: 680, 3832673: 681, 3837869: 682, 3838899: 683, 3840681: 684, 3841143: 685, 3843555: 686, 3854065: 687, 3857828: 688, 3866082: 689, 3868242: 690, 3868863: 691, 3871628: 692, 3873416: 693, 3874293: 694, 3874599: 695, 3876231: 696, 3877472: 697, 3877845: 698, 3884397: 699, 3887697: 700, 3888257: 701, 3888605: 702, 3891251: 703, 3891332: 704, 3895866: 705, 3899768: 706, 3902125: 707, 3903868: 708, 3908618: 709, 3908714: 710, 3916031: 711, 3920288: 712, 3924679: 713, 3929660: 714, 3929855: 715, 3930313: 716, 3930630: 717, 3933933: 718, 3935335: 719, 3937543: 720, 3938244: 721, 3942813: 722, 3944341: 723, 3947888: 724, 3950228: 725, 3954731: 726, 3956157: 727, 3958227: 728, 3961711: 729, 3967562: 730, 3970156: 731, 3976467: 732, 3976657: 733, 3977966: 734, 3980874: 735, 3982430: 736, 3983396: 737, 3991062: 738, 3992509: 739, 3995372: 740, 3998194: 741, 4004767: 742, 4005630: 743, 4008634: 744, 4009552: 745, 4019541: 746, 4023962: 747, 4026417: 748, 4033901: 749, 4033995: 750, 4037443: 751, 4039381: 752, 4040759: 753, 4041544: 754, 4044716: 755, 4049303: 756, 4065272: 757, 4067472: 758, 4069434: 759, 4070727: 760, 4074963: 761, 4081281: 762, 4086273: 763, 4090263: 764, 4099969: 765, 4111531: 766, 4116512: 767, 4118538: 768, 4118776: 769, 4120489: 770, 4125021: 771, 4127249: 772, 4131690: 773, 4133789: 774, 4136333: 775, 4141076: 776, 4141327: 777, 4141975: 778, 4146614: 779, 4147183: 780, 4149813: 781, 4152593: 782, 4153751: 783, 4154565: 784, 4162706: 785, 4179913: 786, 4192698: 787, 4200800: 788, 4201297: 789, 4204238: 790, 4204347: 791, 4208210: 792, 4209133: 793, 4209239: 794, 4228054: 795, 4229816: 796, 4235860: 797, 4238763: 798, 4239074: 799, 4243546: 800, 4251144: 801, 4252077: 802, 4252225: 803, 4254120: 804, 4254680: 805, 4254777: 806, 4258138: 807, 4259630: 808, 4263257: 809, 4264628: 810, 4265275: 811, 4266014: 812, 4270147: 813, 4273569: 814, 4275548: 815, 4277352: 816, 4285008: 817, 4286575: 818, 4296562: 819, 4310018: 820, 4311004: 821, 4311174: 822, 4317175: 823, 4325704: 824, 4326547: 825, 4328186: 826, 4330267: 827, 4332243: 828, 4335435: 829, 4336792: 830, 4344873: 831, 4346328: 832, 4347754: 833, 4350905: 834, 4355338: 835, 4355933: 836, 4356056: 837, 4357314: 838, 4366367: 839, 4367480: 840, 4370456: 841, 4371430: 842, 4371774: 843, 4372370: 844, 4376876: 845, 4380533: 846, 4389033: 847, 4392985: 848, 4398044: 849, 4399382: 850, 4404412: 851, 4409515: 852, 4417672: 853, 4418357: 854, 4423845: 855, 4428191: 856, 4429376: 857, 4435653: 858, 4442312: 859, 4443257: 860, 4447861: 861, 4456115: 862, 4458633: 863, 4461696: 864, 4462240: 865, 4465501: 866, 4467665: 867, 4476259: 868, 4479046: 869, 4482393: 870, 4483307: 871, 4485082: 872, 4486054: 873, 4487081: 874, 4487394: 875, 4493381: 876, 4501370: 877, 4505470: 878, 4507155: 879, 4509417: 880, 4515003: 881, 4517823: 882, 4522168: 883, 4523525: 884, 4525038: 885, 4525305: 886, 4532106: 887, 4532670: 888, 4536866: 889, 4540053: 890, 4542943: 891, 4548280: 892, 4548362: 893, 4550184: 894, 4552348: 895, 4553703: 896, 4554684: 897, 4557648: 898, 4560804: 899, 4562935: 900, 4579145: 901, 4579432: 902, 4584207: 903, 4589890: 904, 4590129: 905, 4591157: 906, 4591713: 907, 4592741: 908, 4596742: 909, 4597913: 910, 4599235: 911, 4604644: 912, 4606251: 913, 4612504: 914, 4613696: 915, 6359193: 916, 6596364: 917, 6785654: 918, 6794110: 919, 6874185: 920, 7248320: 921, 7565083: 922, 7579787: 923, 7583066: 924, 7584110: 925, 7590611: 926, 7613480: 927, 7614500: 928, 7615774: 929, 7684084: 930, 7693725: 931, 7695742: 932, 7697313: 933, 7697537: 934, 7711569: 935, 7714571: 936, 7714990: 937, 7715103: 938, 7716358: 939, 7716906: 940, 7717410: 941, 7717556: 942, 7718472: 943, 7718747: 944, 7720875: 945, 7730033: 946, 7734744: 947, 7742313: 948, 7745940: 949, 7747607: 950, 7749582: 951, 7753113: 952, 7753275: 953, 7753592: 954, 7754684: 955, 7760859: 956, 7768694: 957, 7802026: 958, 7831146: 959, 7836838: 960, 7860988: 961, 7871810: 962, 7873807: 963, 7875152: 964, 7880968: 965, 7892512: 966, 7920052: 967, 7930864: 968, 7932039: 969, 9193705: 970, 9229709: 971, 9246464: 972, 9256479: 973, 9288635: 974, 9332890: 975, 9399592: 976, 9421951: 977, 9428293: 978, 9468604: 979, 9472597: 980, 9835506: 981, 10148035: 982, 10565667: 983, 11879895: 984, 11939491: 985, 12057211: 986, 12144580: 987, 12267677: 988, 12620546: 989, 12768682: 990, 12985857: 991, 12998815: 992, 13037406: 993, 13040303: 994, 13044778: 995, 13052670: 996, 13054560: 997, 13133613: 998, 15075141: 999}
networks/genforce/.gitignore ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.py[cod]
3
+
4
+ /.vscode/
5
+ /.idea/
6
+ *.sw[pon]
7
+
8
+ /data/
9
+ /work_dirs/
10
+ *.jpg
11
+ *.png
12
+ *.jpeg
13
+ *.gif
14
+ *.avi
15
+ *.mp4
16
+
17
+ *.npy
18
+ *.txt
19
+ *.json
20
+ *.log
21
+ *.html
22
+ *.tar
23
+ *.zip
24
+ events.*
25
+
26
+ *.pth
27
+ *.pkl
28
+ *.h5
29
+ *.dat
networks/genforce/LICENSE ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2020 GenForce
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy of
4
+ this software and associated documentation files (the "Software"), to deal in
5
+ the Software without restriction, including without limitation the rights to
6
+ use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
7
+ of the Software, and to permit persons to whom the Software is furnished to do
8
+ so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in all
11
+ copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15
+ FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16
+ COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17
+ IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18
+ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
networks/genforce/MODEL_ZOO.md ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Zoo
2
+
3
+ ## Pre-trained Models
4
+
5
+ First of all, we thank the following repositories for their work on high-quality image synthesis
6
+
7
+ - [PGGAN](https://github.com/tkarras/progressive_growing_of_gans)
8
+ - [StyleGAN](https://github.com/NVlabs/stylegan)
9
+ - [StyleGAN2](https://github.com/NVlabs/stylegan2)
10
+
11
+ Please download the models you need and save them to `checkpoints/`.
12
+
13
+ | PGGAN Official | | | |
14
+ | :-- | :-- | :-- | :-- |
15
+ | *Face*
16
+ | [celebahq-1024x1024](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EW_3jQ6E7xlKvCSHYrbmkQQBAB8tgIv5W5evdT6-GuXiWw?e=gRifVa&download=1)
17
+ | *Indoor Scene*
18
+ | [bedroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EUZQWGz2GT5Bh_GJLalP63IBvCsXDTOxDFIC_ZBsmoEacA?e=VNXiDb&download=1) | [livingroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Efzh6qQv6QtCm0YN1lulH-YByqdE3AqlI-E6US_hXMuiig?e=ppdyB2&download=1) | [diningroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EcLb3_hGUkdClompZo27xk0BNmotgbFqdIeu-ZOGJsBMRg?e=xjYpN3&download=1) | [kitchen-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ESCyg6hpNn1LlHVX_un1wLsBZAORUNkW9MO2kU1X5kafAQ?e=09TbGC&download=1)
19
+ | *Outdoor Scene*
20
+ | [churchoutdoor-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EQ8cKujs2TVGjCL_j6bsnk8BqD9REF2ME2lBnpbTPsqIvA?e=zH55fT&download=1) | [tower-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EeyBJvgRVGJClKr1KKYDF_cBT1FDepRU1-GLqYNh8W9-fQ?e=nrpa5N&download=1) | [bridge-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EZ2QScfPy19PiDERLJQ3gPMBP4WmvZHwhNFLzfaP2YD8hQ?e=bef1U9&download=1)
21
+ | *Other Scene*
22
+ | [restaurant-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ERvJ4pz8jgtMrcuJXUfcOQEBDugZ099_TetCQs-9-ILCVg?e=qYsVdQ&download=1) | [classroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EUU9SCOPUxhMoUS4Ceo9kl0BQkVK7d69lA-JeOP-zOWvXw?e=YIB4no&download=1) | [conferenceroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EX8AF0_6NoJAl5vKFewHWnsBk0r4PK4WsqsMrJyj84TrqQ?e=oNQIZS&download=1)
23
+ | *Animal*
24
+ | [person-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EWu4SqR42YpCoqsVJOcM2cMBcdfXA0j5wZ2hno9X0R9ydQ?e=KuDRns&download=1) | [cat-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EQdveyUNOMtAue52n6BxoHoB6Yup5-PTvBDmyfUn7Un4Hw?e=7acGbT&download=1) | [dog-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ESaKyXA5fGlOvXJYDDFbT2kB9c0HlXh9n_wnyhiP05nhow?e=d4aKDV&download=1) | [bird-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Ef2p4Pd3AKVCmSm00YikCIABhylh2dLPaFjPfPVn3RiTXA?e=9bRitp&download=1)
25
+ | [horse-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EXwCPdv6XqJFtuvFFoswRScBmLJbhKzaC5D_iovl1GFOTw?e=WDdD77&download=1) | [sheep-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ER6J5EKjAUNFtm9VwLf-uUsBZ5dnqxeKsPxY9ijiPtMhcQ?e=OKtfva&download=1) | [cow-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ERZLxw7N7xJPm72FyePTbpcByzrr0pH-Fg7qyLt5tYGXwQ?e=ovIPCl&download=1)
26
+ | *Transportation*
27
+ | [car-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EfGc2we47aFDtAY1548pRvsByIju-uXRbkZEFpJotuPKZw?e=DQqVj8&download=1) | [bicycle-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Ed1dN_FgwmdBgeNWhaRUry8BgwT88-n2ppicSDPx-f7f_Q?e=bxTxnf&download=1) | [motorbike-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EV3yQdeJXIdPjZbMO0mp2-MBJbKuuBdypzBL4gnedO57Dw?e=tXdvtD&download=1) | [bus-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Ed7-OYLnq0RCqRlM8qK8wZ8B87dz_NUxIKBrvyFUwRCEbg?e=VP5bmX&download=1)
28
+ | [train-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EedE2cozKOVAkhvbdLd4SfwBknFW8vWZnKiqgeIBbAvCCA?e=BrLpTl&download=1) | [boat-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Eb39waqQFr9Bp4wO0rC5NHwB0Vz2NGCuqbRPucguBIkDrg?e=lddSyL&download=1) | [airplane-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Ee6FzIx3KjNDhxrS5mDvpCEB3iQ7TgErmKhbwbV-eF07iw?e=xflPXa&download=1)
29
+ | *Furniture*
30
+ | [bottle-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EWhoy2AFCTZGtEG1UoayWjcB9Kdc_wreJ8p4RlBB93nbNg?e=DMZceU&download=1) | [chair-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EbQRTfwdostBhXG30Uacn7ABsEUFa-tEW3oxiM5zDYQbRw?e=FkB7T0&download=1) | [pottedplant-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EWg7hnoGATBOuJvXWr4m7CQBJL9o7nqnD6nOMRhtH2SKXg?e=Zi3hjD&download=1) | [tvmonitor-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EVXwttoJVtBMuhHNDdK3cMwBdMiZARJV38PMTsL6whnFlA?e=RbG0ru&download=1)
31
+ | [diningtable-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EXVzBkbmTCVImMtuHLCTBeMBXZmv0RWyx5KXQQAe7-7D5w?e=6RYSnm&download=1) | [sofa-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EaADQYDXwY9NrzbiUFcRYRgBOu1GdJMG8YgNZZmbNjbn-Q?e=DqKrXG&download=1)
32
+
33
+ | StyleGAN Official | | | |
34
+ | :-- | :--: | :--: | :--: |
35
+ | Model (Dataset) | Training Samples | Training Duration (K Images) | FID
36
+ | [ffhq-1024x1024](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EdfMxgb0hU9BoXwiR3dqYDEBowCSEF1IcsW3n4kwfoZ9OQ?e=VwIV58&download=1) | 70,000 | 25,000 | 4.40 |
37
+ | [celebahq-1024x1024](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EcCdXHddE7FOvyfmqeOyc9ABqVuWh8PQYFnV6JM1CXvFig?e=1nUYZ5&download=1) | 30,000 | 25,000 | 5.06 |
38
+ | [bedroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Ea6RBPddjcRNoFMXm8AyEBcBUHdlRNtjtclNKFe89amjBw?e=Og8Vff&download=1) | 3,033,042 | 70,000 | 2.65 |
39
+ | [cat-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EVjX8u9HuehLip3z0hRfIHcB7QtoFkTB7NiRDb8nrKOl2w?e=lHcp1B&download=1) | 1,657,266 | 70,000 | 8.53 |
40
+ | [car-512x384](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EcRJNNzzUzJGjI2X53S9HjkBhXkKT5JRd6Q3IIhCY1AyRw?e=FvMRNj&download=1) | 5,520,756 | 46,000 | 3.27 |
41
+
42
+ | StyleGAN Ours | | | |
43
+ | :-- | :--: | :--: | :--: |
44
+ | Model (Dataset) | Training Samples | Training Duration (K Images) | FID
45
+ | *Face ("partial" means faces are not fully aligned to center)*
46
+ | [celeba_partial-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ET2etKNzMS9JmHj5j60fqMcBRJfQfYNvqUrujaIXxCvKDQ?e=QReLE6&download=1) | 103,706 | 50,000 | 7.03 |
47
+ | [ffhq-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ES-NAUCC2qdHg87BftvlBiQBVpbJ8-005Q4TNr5KrOxQEw?e=00AnWt&download=1) | 70,000 | 25,000 | 5.70 |
48
+ | [ffhq-512x512](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EZYrrwOiEgVOg-PfGv7QTegBzFQ9yq2v7o1WxNq5JJ9KNA?e=SZU8PI&download=1) | 70,000 | 25,000 | 5.15 |
49
+ | *LSUN Indoor Scene*
50
+ | [livingroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EfFCYLHjqbFDmjOvCCFJgDcBZ1QYgETfZJxp4ZTHjLxZBg?e=InVd0n&download=1) | 1,315,802 | 30,000 | 5.16 |
51
+ | [diningroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ERsUza_hSFRIm4iZCag7P0kBQ9EIdfQKByw4QYt_ay97lg?e=Cimh7S&download=1) | 657,571 | 25,000 | 4.13 |
52
+ | [kitchen-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ERcYvoingQNKix35lUs0vUkBQQkAZMp1rtDxjwNlOJAoaA?e=a1Tcwr&download=1) | 1,000,000 | 30,000 | 5.06 |
53
+ | *LSUN Indoor Scene Mixture*
54
+ | [apartment-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EfurPNSB2BRFtXdqGkmDD6YBwyKN8YK2v7nKwnJQdsbf6A?e=w3oYa4&download=1) | 4 * 200,000 | 60,000 | 4.18 |
55
+ | *LSUN Outdoor Scene*
56
+ | [church-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ETMgG1_d06tAlbUkJD1qA9IBaLZ9zJKPkG2kO-4jxhVV5w?e=Dbkb7o&download=1) | 126,227 | 30,000 | 4.82 |
57
+ | [tower-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Ebm9QMgqB2VDqyIE5rFhreEBgZ_RyKcRf8bQ333K453u3w?e=if8sDj&download=1) | 708,264 | 30,000 | 5.99 |
58
+ | [bridge-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Ed9QM6OP9sVHnazSp4cqPSEBb-ALfBPXRxP1hD7FsTYh8w?e=3vv06p&download=1) | 818,687 | 25,000 | 6.42 |
59
+ | *LSUN Other Scene*
60
+ | [restaurant-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ESDhYr01WtlEvBNFrVpFezcB2l9lF1rBYuHFoeNpBr5B7A?e=uFWFNh&download=1) | 626,331 | 50,000 | 4.03 |
61
+ | [classroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EbWnI3oto9NPk-lxwZlWqPQB2atWpGiTWMIT59MzF9ij9Q?e=KvcNBg&download=1) | 168,103 | 50,000 | 10.10 |
62
+ | [conferenceroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Eb1gVi3pGa9PgJ4XYYu_6yABQZ0ZcGDak4FEHaTHaeYFzw?e=0BeE8t&download=1) | 229,069 | 50,000 | 6.20 |
63
+
64
+ | StyleGAN Third-Party | |
65
+ | :-- | :--: |
66
+ | Model (Dataset) | Source |
67
+ | [animeface-512x512](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EWDWflY6lBpGgX0CGQpd2Z4B5wTEVamTOA9JRYne7zdCvA?e=tOzgYA&download=1) | [link](https://www.gwern.net/Faces#portrait-results)
68
+ | [animeportrait-512x512](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EXBvhTBi-v5NsnQtrxhFEKsBin4xg-Dud9Jr62AEwFTIxg?e=bMGK7r&download=1) | [link](https://www.gwern.net/Faces#portrait-results)
69
+ | [artface-512x512](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Eca0OiGqhyZMmoPbKahSBWQBWvcAH4q2CE3zdZJflp2jkQ?e=h4rWAm&download=1) | [link](https://github.com/ak9250/stylegan-art)
70
+
71
+ | StyleGAN2 Official | | | |
72
+ | :-- | :--: | :--: | :--: |
73
+ | Model (Dataset) | Training Samples | Training Duration (K Images) | FID
74
+ | [ffhq-1024x1024](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EX0DNWiBvl5FuOQTF4oMPBYBNSalcxTK0AbLwBn9Y3vfgg?e=Q0sZit&download=1) | 70,000 | 25,000 | 2.84 |
75
+ | [church-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EQzDtJUdQ4ROunMGn2sZouEBmNeFX4QWvxjermVE5cZvNA?e=tQ7r9r&download=1) | 126,227 | 48,000 | 3.86 |
76
+ | [cat-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EUKXeBwUUbZJr6kup7PW4ekBx2-vmTp8FjcGb10v8bgJxQ?e=nkerMF&download=1) | 1,657,266 | 88,000 | 6.93 |
77
+ | [horse-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EconoT6tb69OuAIqfXRtGlsBZz4vBx01UmmFO-JAS356Jg?e=bcSCC4&download=1) | 2,000,340 | 100,000 | 3.43 |
78
+ | [car-512x384](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EYSnUsxU8KJFuMHhZm-JLWoB0nHxdlbrLHNZ_Qkoe3b9LA?e=Ycjp5A&download=1) | 5,520,756 | 57,000 | 2.32 |
79
+
80
+ ## Training Datasets
81
+
82
+ - [MNIST](http://yann.lecun.com/exdb/mnist/) (60,000 training samples and 10,000 test samples on 10 digital numbers)
83
+ - [SVHN](http://ufldl.stanford.edu/housenumbers/) (73,257 training samples, 26,032 testing samples, and 531,131 additional samples on 10 digital numbers)
84
+ - [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) (50,000 training samples and 10,000 test samples on 10 classes)
85
+ - [CIFAR100](https://www.cs.toronto.edu/~kriz/cifar.html) (50,000 training samples and 10,000 test samples on 100 classes)
86
+ - [ImageNet](http://www.image-net.org/) (1,281,167 training samples, 50,000 validation samples, and 100,100 testing samples on 1000 classes)
87
+ - [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) (202,599 samples from 10,177 identities, with 5 landmarks and 40 binary facial attributes)
88
+ - [CelebA-HQ](https://github.com/tkarras/progressive_growing_of_gans) (30,000 samples)
89
+ - [FF-HQ](https://github.com/NVlabs/ffhq-dataset) (70,000 samples)
90
+ - [LSUN](https://github.com/fyu/lsun) (see statistical information below)
91
+ - [Places](http://places2.csail.mit.edu/) (around 1.8M training samples covering 365 classes)
92
+ - [Cityscapes](https://www.cityscapes-dataset.com/) (2,975 training samples, 19998 extra training samples (one broken), 500 validation samples, and 1,525 test samples)
93
+ - [Streetscapes](http://streetscore.media.mit.edu/data.html)
94
+
95
+ Statistical information of [LSUN](https://github.com/fyu/lsun) dataset is summarized as follows:
96
+
97
+ | LSUN Datasets Stats | | |
98
+ | :-- | :--: | :--: |
99
+ | Name | Number of Samples | Size |
100
+ | *Scenes*
101
+ | bedroom (train) | 3,033,042 | 43G |
102
+ | bridge (train) | 818,687 | 15G |
103
+ | churchoutdoor (train) | 126,227 | 2G |
104
+ | classroom (train) | 168,103 | 3G |
105
+ | conferenceroom (train) | 229,069 | 4G |
106
+ | diningroom (train) | 657,571 | 11G |
107
+ | kitchen (train) | 2,212,277 | 33G |
108
+ | livingroom (train) | 1,315,802 | 21G |
109
+ | restaurant (train) | 626,331 | 13G |
110
+ | tower (train) | 708,264 | 11G |
111
+ | *Objects*
112
+ | airplane | 1,530,696 | 34G |
113
+ | bicycle | 3,347,211 | 129G |
114
+ | bird | 2,310,362 | 65G |
115
+ | boat | 2,651,165 | 86G |
116
+ | bottle | 3,202,760 | 64G |
117
+ | bus | 695,891 | 24G |
118
+ | car | 5,520,756 | 173G |
119
+ | cat | 1,657,266 | 42G |
120
+ | chair | 5,037,807 | 116G |
121
+ | cow | 377,379 | 15G |
122
+ | diningtable | 1,537,123 | 48G |
123
+ | dog | 5,054,817 | 145G |
124
+ | horse | 2,000,340 | 69G |
125
+ | motorbike | 1,194,101 | 42G |
126
+ | person | 18,890,816 | 477G |
127
+ | pottedplant | 1,104,859 | 43G |
128
+ | sheep | 418,983 | 18G |
129
+ | sofa | 2,365,870 | 56G |
130
+ | train | 1,148,020 | 43G |
131
+ | tvmonitor | 2,463,284 | 46G |
networks/genforce/README.md ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GenForce Lib for Generative Modeling
2
+
3
+ An efficient PyTorch library for deep generative modeling. May the Generative Force (GenForce) be with You.
4
+
5
+ ![image](./teaser.gif)
6
+
7
+ ## Updates
8
+
9
+ - **Encoder Training:** We support training encoders on top of pre-trained GANs for GAN inversion.
10
+ - **Model Converters:** You can easily migrate your already started projects to this repository. Please check [here](./converters/README.md) for more details.
11
+
12
+ ## Highlights
13
+
14
+ - **Distributed** training framework.
15
+ - **Fast** training speed.
16
+ - **Modular** design for prototyping new models.
17
+ - **Model zoo** containing a rich set of pretrained GAN models, with [Colab live demo](https://colab.research.google.com/github/genforce/genforce/blob/master/docs/synthesize_demo.ipynb) to play.
18
+
19
+ ## Installation
20
+
21
+ 1. Create a virtual environment via `conda`.
22
+
23
+ ```shell
24
+ conda create -n genforce python=3.7
25
+ conda activate genforce
26
+ ```
27
+
28
+ 2. Install `cuda` and `cudnn`. (We use `CUDA 10.0` in case you would like to use `TensorFlow 1.15` for model conversion.)
29
+
30
+ ```shell
31
+ conda install cudatoolkit=10.0 cudnn=7.6.5
32
+ ```
33
+
34
+ 3. Install `torch` and `torchvision`.
35
+
36
+ ```shell
37
+ pip install torch==1.7 torchvision==0.8
38
+ ```
39
+
40
+ 4. Install requirements
41
+
42
+ ```shell
43
+ pip install -r requirements.txt
44
+ ```
45
+
46
+ ## Quick Demo
47
+
48
+ We provide a quick training demo, `scripts/stylegan_training_demo.py`, which allows to train StyleGAN on a toy dataset (500 animeface images with 64 x 64 resolution). Try it via
49
+
50
+ ```shell
51
+ ./scripts/stylegan_training_demo.sh
52
+ ```
53
+
54
+ We also provide an inference demo, `synthesize.py`, which allows to synthesize images with pre-trained models. Generated images can be found at `work_dirs/synthesis_results/`. Try it via
55
+
56
+ ```shell
57
+ python synthesize.py stylegan_ffhq1024
58
+ ```
59
+
60
+ You can also play the demo at [Colab](https://colab.research.google.com/github/genforce/genforce/blob/master/docs/synthesize_demo.ipynb).
61
+
62
+ ## Play with GANs
63
+
64
+ ### Test
65
+
66
+ Pre-trained models can be found at [model zoo](MODEL_ZOO.md).
67
+
68
+ - On local machine:
69
+
70
+ ```shell
71
+ GPUS=8
72
+ CONFIG=configs/stylegan_ffhq256_val.py
73
+ WORK_DIR=work_dirs/stylegan_ffhq256_val
74
+ CHECKPOINT=checkpoints/stylegan_ffhq256.pth
75
+ ./scripts/dist_test.sh ${GPUS} ${CONFIG} ${WORK_DIR} ${CHECKPOINT}
76
+ ```
77
+
78
+ - Using `slurm`:
79
+
80
+ ```shell
81
+ CONFIG=configs/stylegan_ffhq256_val.py
82
+ WORK_DIR=work_dirs/stylegan_ffhq256_val
83
+ CHECKPOINT=checkpoints/stylegan_ffhq256.pth
84
+ GPUS=8 ./scripts/slurm_test.sh ${PARTITION} ${JOB_NAME} \
85
+ ${CONFIG} ${WORK_DIR} ${CHECKPOINT}
86
+ ```
87
+
88
+ ### Train
89
+
90
+ All log files in the training process, such as log message, checkpoints, synthesis snapshots, etc, will be saved to the work directory.
91
+
92
+ - On local machine:
93
+
94
+ ```shell
95
+ GPUS=8
96
+ CONFIG=configs/stylegan_ffhq256.py
97
+ WORK_DIR=work_dirs/stylegan_ffhq256_train
98
+ ./scripts/dist_train.sh ${GPUS} ${CONFIG} ${WORK_DIR} \
99
+ [--options additional_arguments]
100
+ ```
101
+
102
+ - Using `slurm`:
103
+
104
+ ```shell
105
+ CONFIG=configs/stylegan_ffhq256.py
106
+ WORK_DIR=work_dirs/stylegan_ffhq256_train
107
+ GPUS=8 ./scripts/slurm_train.sh ${PARTITION} ${JOB_NAME} \
108
+ ${CONFIG} ${WORK_DIR} \
109
+ [--options additional_arguments]
110
+ ```
111
+
112
+ ## Play with Encoders for GAN Inversion
113
+
114
+ ### Train
115
+
116
+ - On local machine:
117
+
118
+ ```shell
119
+ GPUS=8
120
+ CONFIG=configs/stylegan_ffhq256_encoder_y.py
121
+ WORK_DIR=work_dirs/stylegan_ffhq256_encoder_y
122
+ ./scripts/dist_train.sh ${GPUS} ${CONFIG} ${WORK_DIR} \
123
+ [--options additional_arguments]
124
+ ```
125
+
126
+
127
+ - Using `slurm`:
128
+
129
+ ```shell
130
+ CONFIG=configs/stylegan_ffhq256_encoder_y.py
131
+ WORK_DIR=work_dirs/stylegan_ffhq256_encoder_y
132
+ GPUS=8 ./scripts/slurm_train.sh ${PARTITION} ${JOB_NAME} \
133
+ ${CONFIG} ${WORK_DIR} \
134
+ [--options additional_arguments]
135
+ ```
136
+ ## Contributors
137
+
138
+ | Member | Module |
139
+ | :-- | :-- |
140
+ |[Yujun Shen](http://shenyujun.github.io/) | models and running controllers
141
+ |[Yinghao Xu](https://justimyhxu.github.io/) | runner and loss functions
142
+ |[Ceyuan Yang](http://ceyuan.me/) | data loader
143
+ |[Jiapeng Zhu](https://zhujiapeng.github.io/) | evaluation metrics
144
+ |[Bolei Zhou](http://bzhou.ie.cuhk.edu.hk/) | cheerleader
145
+
146
+ **NOTE:** The above form only lists the person in charge for each module. We help each other a lot and develop as a **TEAM**.
147
+
148
+ *We welcome external contributors to join us for improving this library.*
149
+
150
+ ## License
151
+
152
+ The project is under the [MIT License](./LICENSE).
153
+
154
+ ## Acknowledgement
155
+
156
+ We thank [PGGAN](https://github.com/tkarras/progressive_growing_of_gans), [StyleGAN](https://github.com/NVlabs/stylegan), [StyleGAN2](https://github.com/NVlabs/stylegan2), [StyleGAN2-ADA](https://github.com/NVlabs/stylegan2-ada) for their work on high-quality image synthesis. We thank [IDInvert](https://github.com/genforce/idinvert) and [GHFeat](https://github.com/genforce/ghfeat) for their contribution to GAN inversion. We also thank [MMCV](https://github.com/open-mmlab/mmcv) for the inspiration on the design of controllers.
157
+
158
+ ## BibTex
159
+
160
+ We open source this library to the community to facilitate the research of generative modeling. If you do like our work and use the codebase or models for your research, please cite our work as follows.
161
+
162
+ ```bibtex
163
+ @misc{genforce2020,
164
+ title = {GenForce},
165
+ author = {Shen, Yujun and Xu, Yinghao and Yang, Ceyuan and Zhu, Jiapeng and Zhou, Bolei},
166
+ howpublished = {\url{https://github.com/genforce/genforce}},
167
+ year = {2020}
168
+ }
169
+ ```
networks/genforce/__init__.py ADDED
File without changes
networks/genforce/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (162 Bytes). View file
 
networks/genforce/configs/stylegan_demo.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Configuration for StyleGAN training demo.
3
+
4
+ All settings are particularly used for one replica (GPU), such as `batch_size`
5
+ and `num_workers`.
6
+ """
7
+
8
+ runner_type = 'StyleGANRunner'
9
+ gan_type = 'stylegan'
10
+ resolution = 64
11
+ batch_size = 4
12
+ val_batch_size = 32
13
+ total_img = 100_000
14
+
15
+ # Training dataset is repeated at the beginning to avoid loading dataset
16
+ # repeatedly at the end of each epoch. This can save some I/O time.
17
+ data = dict(
18
+ num_workers=4,
19
+ repeat=500,
20
+ train=dict(root_dir='data/demo.zip', data_format='zip',
21
+ resolution=resolution, mirror=0.5),
22
+ val=dict(root_dir='data/demo.zip', data_format='zip',
23
+ resolution=resolution),
24
+ )
25
+
26
+ controllers = dict(
27
+ RunningLogger=dict(every_n_iters=10),
28
+ ProgressScheduler=dict(
29
+ every_n_iters=1, init_res=8, minibatch_repeats=4,
30
+ lod_training_img=5_000, lod_transition_img=5_000,
31
+ batch_size_schedule=dict(res4=64, res8=32, res16=16, res32=8),
32
+ ),
33
+ Snapshoter=dict(every_n_iters=500, first_iter=True, num=200),
34
+ FIDEvaluator=dict(every_n_iters=5000, first_iter=True, num=50000),
35
+ Checkpointer=dict(every_n_iters=5000, first_iter=True),
36
+ )
37
+
38
+ modules = dict(
39
+ discriminator=dict(
40
+ model=dict(gan_type=gan_type, resolution=resolution),
41
+ lr=dict(lr_type='FIXED'),
42
+ opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)),
43
+ kwargs_train=dict(),
44
+ kwargs_val=dict(),
45
+ ),
46
+ generator=dict(
47
+ model=dict(gan_type=gan_type, resolution=resolution),
48
+ lr=dict(lr_type='FIXED'),
49
+ opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)),
50
+ kwargs_train=dict(w_moving_decay=0.995, style_mixing_prob=0.9,
51
+ trunc_psi=1.0, trunc_layers=0, randomize_noise=True),
52
+ kwargs_val=dict(trunc_psi=1.0, trunc_layers=0, randomize_noise=False),
53
+ g_smooth_img=10000,
54
+ )
55
+ )
56
+
57
+ loss = dict(
58
+ type='LogisticGANLoss',
59
+ d_loss_kwargs=dict(r1_gamma=10.0),
60
+ g_loss_kwargs=dict(),
61
+ )
networks/genforce/configs/stylegan_ffhq1024.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Configuration for training StyleGAN on FF-HQ (1024) dataset.
3
+
4
+ All settings are particularly used for one replica (GPU), such as `batch_size`
5
+ and `num_workers`.
6
+ """
7
+
8
+ runner_type = 'StyleGANRunner'
9
+ gan_type = 'stylegan'
10
+ resolution = 1024
11
+ batch_size = 4
12
+ val_batch_size = 16
13
+ total_img = 25000_000
14
+
15
+ # Training dataset is repeated at the beginning to avoid loading dataset
16
+ # repeatedly at the end of each epoch. This can save some I/O time.
17
+ data = dict(
18
+ num_workers=4,
19
+ repeat=500,
20
+ # train=dict(root_dir='data/ffhq', resolution=resolution, mirror=0.5),
21
+ # val=dict(root_dir='data/ffhq', resolution=resolution),
22
+ train=dict(root_dir='data/ffhq.zip', data_format='zip',
23
+ resolution=resolution, mirror=0.5),
24
+ val=dict(root_dir='data/ffhq.zip', data_format='zip',
25
+ resolution=resolution),
26
+ )
27
+
28
+ controllers = dict(
29
+ RunningLogger=dict(every_n_iters=10),
30
+ ProgressScheduler=dict(
31
+ every_n_iters=1, init_res=8, minibatch_repeats=4,
32
+ lod_training_img=600_000, lod_transition_img=600_000,
33
+ batch_size_schedule=dict(res4=64, res8=32, res16=16, res32=8),
34
+ ),
35
+ Snapshoter=dict(every_n_iters=500, first_iter=True, num=200),
36
+ FIDEvaluator=dict(every_n_iters=5000, first_iter=True, num=50000),
37
+ Checkpointer=dict(every_n_iters=5000, first_iter=True),
38
+ )
39
+
40
+ modules = dict(
41
+ discriminator=dict(
42
+ model=dict(gan_type=gan_type, resolution=resolution),
43
+ lr=dict(lr_type='FIXED'),
44
+ opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)),
45
+ kwargs_train=dict(),
46
+ kwargs_val=dict(),
47
+ ),
48
+ generator=dict(
49
+ model=dict(gan_type=gan_type, resolution=resolution),
50
+ lr=dict(lr_type='FIXED'),
51
+ opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)),
52
+ kwargs_train=dict(w_moving_decay=0.995, style_mixing_prob=0.9,
53
+ trunc_psi=1.0, trunc_layers=0, randomize_noise=True),
54
+ kwargs_val=dict(trunc_psi=1.0, trunc_layers=0, randomize_noise=False),
55
+ g_smooth_img=10_000,
56
+ )
57
+ )
58
+
59
+ loss = dict(
60
+ type='LogisticGANLoss',
61
+ d_loss_kwargs=dict(r1_gamma=10.0),
62
+ g_loss_kwargs=dict(),
63
+ )
networks/genforce/configs/stylegan_ffhq1024_val.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Configuration for testing StyleGAN on FF-HQ (1024) dataset.
3
+
4
+ All settings are particularly used for one replica (GPU), such as `batch_size`
5
+ and `num_workers`.
6
+ """
7
+
8
+ runner_type = 'StyleGANRunner'
9
+ gan_type = 'stylegan'
10
+ resolution = 1024
11
+ batch_size = 16
12
+
13
+ data = dict(
14
+ num_workers=4,
15
+ # val=dict(root_dir='data/ffhq', resolution=resolution),
16
+ val=dict(root_dir='data/ffhq.zip', data_format='zip',
17
+ resolution=resolution),
18
+ )
19
+
20
+ modules = dict(
21
+ discriminator=dict(
22
+ model=dict(gan_type=gan_type, resolution=resolution),
23
+ kwargs_val=dict(),
24
+ ),
25
+ generator=dict(
26
+ model=dict(gan_type=gan_type, resolution=resolution),
27
+ kwargs_val=dict(trunc_psi=0.7, trunc_layers=8, randomize_noise=False),
28
+ )
29
+ )
networks/genforce/configs/stylegan_ffhq256.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Configuration for training StyleGAN on FF-HQ (256) dataset.
3
+
4
+ All settings are particularly used for one replica (GPU), such as `batch_size`
5
+ and `num_workers`.
6
+ """
7
+
8
+ runner_type = 'StyleGANRunner'
9
+ gan_type = 'stylegan'
10
+ resolution = 256
11
+ batch_size = 4
12
+ val_batch_size = 64
13
+ total_img = 25000_000
14
+
15
+ # Training dataset is repeated at the beginning to avoid loading dataset
16
+ # repeatedly at the end of each epoch. This can save some I/O time.
17
+ data = dict(
18
+ num_workers=4,
19
+ repeat=500,
20
+ # train=dict(root_dir='data/ffhq', resolution=resolution, mirror=0.5),
21
+ # val=dict(root_dir='data/ffhq', resolution=resolution),
22
+ train=dict(root_dir='data/ffhq.zip', data_format='zip',
23
+ resolution=resolution, mirror=0.5),
24
+ val=dict(root_dir='data/ffhq.zip', data_format='zip',
25
+ resolution=resolution),
26
+ )
27
+
28
+ controllers = dict(
29
+ RunningLogger=dict(every_n_iters=10),
30
+ ProgressScheduler=dict(
31
+ every_n_iters=1, init_res=8, minibatch_repeats=4,
32
+ lod_training_img=600_000, lod_transition_img=600_000,
33
+ batch_size_schedule=dict(res4=64, res8=32, res16=16, res32=8),
34
+ ),
35
+ Snapshoter=dict(every_n_iters=500, first_iter=True, num=200),
36
+ FIDEvaluator=dict(every_n_iters=5000, first_iter=True, num=50000),
37
+ Checkpointer=dict(every_n_iters=5000, first_iter=True),
38
+ )
39
+
40
+ modules = dict(
41
+ discriminator=dict(
42
+ model=dict(gan_type=gan_type, resolution=resolution),
43
+ lr=dict(lr_type='FIXED'),
44
+ opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)),
45
+ kwargs_train=dict(),
46
+ kwargs_val=dict(),
47
+ ),
48
+ generator=dict(
49
+ model=dict(gan_type=gan_type, resolution=resolution),
50
+ lr=dict(lr_type='FIXED'),
51
+ opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)),
52
+ kwargs_train=dict(w_moving_decay=0.995, style_mixing_prob=0.9,
53
+ trunc_psi=1.0, trunc_layers=0, randomize_noise=True),
54
+ kwargs_val=dict(trunc_psi=1.0, trunc_layers=0, randomize_noise=False),
55
+ g_smooth_img=10_000,
56
+ )
57
+ )
58
+
59
+ loss = dict(
60
+ type='LogisticGANLoss',
61
+ d_loss_kwargs=dict(r1_gamma=10.0),
62
+ g_loss_kwargs=dict(),
63
+ )
networks/genforce/configs/stylegan_ffhq256_encoder_y.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Configuration for training StyleGAN Encoder on FF-HQ (256) dataset.
3
+
4
+ All settings are particularly used for one replica (GPU), such as `batch_size`
5
+ and `num_workers`.
6
+ """
7
+
8
+ gan_model_path = 'checkpoints/stylegan_ffhq256.pth'
9
+ perceptual_model_path = 'checkpoints/vgg16.pth'
10
+
11
+ runner_type = 'EncoderRunner'
12
+ gan_type = 'stylegan'
13
+ resolution = 256
14
+ batch_size = 12
15
+ val_batch_size = 25
16
+ total_img = 14000_000
17
+ space_of_latent = 'y'
18
+
19
+ # Training dataset is repeated at the beginning to avoid loading dataset
20
+ # repeatedly at the end of each epoch. This can save some I/O time.
21
+ data = dict(
22
+ num_workers=4,
23
+ repeat=500,
24
+ # train=dict(root_dir='data/ffhq', resolution=resolution, mirror=0.5),
25
+ # val=dict(root_dir='data/ffhq', resolution=resolution),
26
+ train=dict(root_dir='data/', data_format='list',
27
+ image_list_path='data/ffhq/ffhq_train_list.txt',
28
+ resolution=resolution, mirror=0.5),
29
+ val=dict(root_dir='data/', data_format='list',
30
+ image_list_path='./data/ffhq/ffhq_val_list.txt',
31
+ resolution=resolution),
32
+ )
33
+
34
+ controllers = dict(
35
+ RunningLogger=dict(every_n_iters=50),
36
+ Snapshoter=dict(every_n_iters=10000, first_iter=True, num=200),
37
+ Checkpointer=dict(every_n_iters=10000, first_iter=False),
38
+ )
39
+
40
+ modules = dict(
41
+ discriminator=dict(
42
+ model=dict(gan_type=gan_type, resolution=resolution),
43
+ lr=dict(lr_type='ExpSTEP', decay_factor=0.8, decay_step=36458 // 2),
44
+ opt=dict(opt_type='Adam', base_lr=1e-4, betas=(0.9, 0.99)),
45
+ kwargs_train=dict(),
46
+ kwargs_val=dict(),
47
+ ),
48
+ generator=dict(
49
+ model=dict(gan_type=gan_type, resolution=resolution, repeat_w=True),
50
+ kwargs_val=dict(randomize_noise=False),
51
+ ),
52
+ encoder=dict(
53
+ model=dict(gan_type=gan_type, resolution=resolution, network_depth=18,
54
+ latent_dim = [1024] * 8 + [512, 512, 256, 256, 128, 128],
55
+ num_latents_per_head=[4, 4, 6],
56
+ use_fpn=True,
57
+ fpn_channels=512,
58
+ use_sam=True,
59
+ sam_channels=512),
60
+ lr=dict(lr_type='ExpSTEP', decay_factor=0.8, decay_step=36458 // 2),
61
+ opt=dict(opt_type='Adam', base_lr=1e-4, betas=(0.9, 0.99)),
62
+ kwargs_train=dict(),
63
+ kwargs_val=dict(),
64
+ ),
65
+ )
66
+
67
+ loss = dict(
68
+ type='EncoderLoss',
69
+ d_loss_kwargs=dict(r1_gamma=10.0),
70
+ e_loss_kwargs=dict(adv_lw=0.08, perceptual_lw=5e-5),
71
+ perceptual_kwargs=dict(output_layer_idx=23,
72
+ pretrained_weight_path=perceptual_model_path),
73
+ )
networks/genforce/configs/stylegan_ffhq256_val.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Configuration for testing StyleGAN on FF-HQ (256) dataset.
3
+
4
+ All settings are particularly used for one replica (GPU), such as `batch_size`
5
+ and `num_workers`.
6
+ """
7
+
8
+ runner_type = 'StyleGANRunner'
9
+ gan_type = 'stylegan'
10
+ resolution = 256
11
+ batch_size = 64
12
+
13
+ data = dict(
14
+ num_workers=4,
15
+ # val=dict(root_dir='data/ffhq', resolution=resolution),
16
+ val=dict(root_dir='data/ffhq.zip', data_format='zip',
17
+ resolution=resolution),
18
+ )
19
+
20
+ modules = dict(
21
+ discriminator=dict(
22
+ model=dict(gan_type=gan_type, resolution=resolution),
23
+ kwargs_val=dict(),
24
+ ),
25
+ generator=dict(
26
+ model=dict(gan_type=gan_type, resolution=resolution),
27
+ kwargs_val=dict(trunc_psi=0.7, trunc_layers=8, randomize_noise=False),
28
+ )
29
+ )
networks/genforce/convert_model.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Script to convert officially released models to match this repository."""
2
+
3
+ import argparse
4
+
5
+ from converters import convert_pggan_weight
6
+ from converters import convert_stylegan_weight
7
+ from converters import convert_stylegan2_weight
8
+ from converters import convert_stylegan2ada_tf_weight
9
+ from converters import convert_stylegan2ada_pth_weight
10
+
11
+
12
+ def parse_args():
13
+ """Parses arguments."""
14
+ parser = argparse.ArgumentParser(description='Convert pre-trained models.')
15
+ parser.add_argument('model_type', type=str,
16
+ choices=['pggan', 'stylegan', 'stylegan2',
17
+ 'stylegan2ada_tf', 'stylegan2ada_pth'],
18
+ help='Type of the model to convert')
19
+ parser.add_argument('--source_model_path', type=str, required=True,
20
+ help='Path to load the model for conversion.')
21
+ parser.add_argument('--target_model_path', type=str, default=None,
22
+ help='Path to save the converted model. If not '
23
+ 'specified, the model will be saved to the same '
24
+ 'directory of the source model.')
25
+ parser.add_argument('--test_num', type=int, default=10,
26
+ help='Number of test samples used to check the '
27
+ 'precision of the converted model. (default: 10)')
28
+ parser.add_argument('--save_test_image', action='store_true',
29
+ help='Whether to save the test image. (default: False)')
30
+ parser.add_argument('--verbose_log', action='store_true',
31
+ help='Whether to print verbose log. (default: False)')
32
+ return parser.parse_args()
33
+
34
+
35
+ def main():
36
+ """Main function."""
37
+ args = parse_args()
38
+ if args.target_model_path is None:
39
+ args.target_model_path = args.source_model_path.replace('.pkl', '.pth')
40
+
41
+ if args.model_type == 'pggan':
42
+ convert_pggan_weight(tf_weight_path=args.source_model_path,
43
+ pth_weight_path=args.target_model_path,
44
+ test_num=args.test_num,
45
+ save_test_image=args.save_test_image,
46
+ verbose=args.verbose_log)
47
+ elif args.model_type == 'stylegan':
48
+ convert_stylegan_weight(tf_weight_path=args.source_model_path,
49
+ pth_weight_path=args.target_model_path,
50
+ test_num=args.test_num,
51
+ save_test_image=args.save_test_image,
52
+ verbose=args.verbose_log)
53
+ elif args.model_type == 'stylegan2':
54
+ convert_stylegan2_weight(tf_weight_path=args.source_model_path,
55
+ pth_weight_path=args.target_model_path,
56
+ test_num=args.test_num,
57
+ save_test_image=args.save_test_image,
58
+ verbose=args.verbose_log)
59
+ elif args.model_type == 'stylegan2ada_tf':
60
+ convert_stylegan2ada_tf_weight(tf_weight_path=args.source_model_path,
61
+ pth_weight_path=args.target_model_path,
62
+ test_num=args.test_num,
63
+ save_test_image=args.save_test_image,
64
+ verbose=args.verbose_log)
65
+ elif args.model_type == 'stylegan2ada_pth':
66
+ convert_stylegan2ada_pth_weight(src_weight_path=args.source_model_path,
67
+ dst_weight_path=args.target_model_path,
68
+ test_num=args.test_num,
69
+ save_test_image=args.save_test_image,
70
+ verbose=args.verbose_log)
71
+ else:
72
+ raise NotImplementedError(f'Model type `{args.model_type}` is not '
73
+ f'supported!')
74
+
75
+
76
+ if __name__ == '__main__':
77
+ main()
networks/genforce/datasets/README.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Data Preparation
2
+
3
+ ## Data Format
4
+
5
+ Currently, our dataloader is able to load data from
6
+
7
+ - a directory that is full of images (support using [`turbojpeg`](https://pypi.org/project/PyTurboJPEG/) to speed up decoding images.)
8
+ - a `lmdb` file
9
+ - an image list
10
+ - a compressed file (i.e., `zip` package)
11
+
12
+ by modifying `data_format` in the configuration.
13
+
14
+ **NOTE:** For some computing clusters whose I/O speed may be slow, we recommend the `zip` format for two reasons. First, `zip` file is easy to create. Second, this can load a large file at one time instead of loading small files repeatedly.
15
+
16
+ ## Data Sampling
17
+
18
+ Considering that most generative models are trained in the unit of iterations instead of epochs, we change the default data loader to an *iter-based* one. Besides, the original distributed data sampler is also modified to make the shuffling correspond to iteration instead of epoch.
19
+
20
+ **NOTE:** In order to reduce the data re-loading cost between epochs, we manually extend the length of sampled indices to make it much more efficient.
21
+
22
+ ## Data Augmentation
23
+
24
+ To better align with the original implementation of PGGAN and StyleGAN (i.e., models that require progressive training), we support progressive resize in `transforms.py`, which downsamples images with the maximum resize factor of 2 at each time.
networks/genforce/datasets/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Collects datasets and data loaders."""
3
+
4
+ from .datasets import BaseDataset
5
+ from .dataloaders import IterDataLoader
6
+
7
+ __all__ = ['BaseDataset', 'IterDataLoader']
networks/genforce/datasets/dataloaders.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the class of data loader."""
3
+
4
+ import argparse
5
+
6
+ from torch.utils.data import DataLoader
7
+ from .distributed_sampler import DistributedSampler
8
+ from .datasets import BaseDataset
9
+
10
+
11
+ __all__ = ['IterDataLoader']
12
+
13
+
14
+ class IterDataLoader(object):
15
+ """Iteration-based data loader."""
16
+
17
+ def __init__(self,
18
+ dataset,
19
+ batch_size,
20
+ shuffle=True,
21
+ num_workers=1,
22
+ current_iter=0,
23
+ repeat=1):
24
+ """Initializes the data loader.
25
+
26
+ Args:
27
+ dataset: The dataset to load data from.
28
+ batch_size: The batch size on each GPU.
29
+ shuffle: Whether to shuffle the data. (default: True)
30
+ num_workers: Number of data workers for each GPU. (default: 1)
31
+ current_iter: The current number of iterations. (default: 0)
32
+ repeat: The repeating number of the whole dataloader. (default: 1)
33
+ """
34
+ self._dataset = dataset
35
+ self.batch_size = batch_size
36
+ self.shuffle = shuffle
37
+ self.num_workers = num_workers
38
+ self._dataloader = None
39
+ self.iter_loader = None
40
+ self._iter = current_iter
41
+ self.repeat = repeat
42
+ self.build_dataloader()
43
+
44
+ def build_dataloader(self):
45
+ """Builds data loader."""
46
+ dist_sampler = DistributedSampler(self._dataset,
47
+ shuffle=self.shuffle,
48
+ current_iter=self._iter,
49
+ repeat=self.repeat)
50
+
51
+ self._dataloader = DataLoader(self._dataset,
52
+ batch_size=self.batch_size,
53
+ shuffle=(dist_sampler is None),
54
+ num_workers=self.num_workers,
55
+ drop_last=self.shuffle,
56
+ pin_memory=True,
57
+ sampler=dist_sampler)
58
+ self.iter_loader = iter(self._dataloader)
59
+
60
+
61
+ def overwrite_param(self, batch_size=None, resolution=None):
62
+ """Overwrites some parameters for progressive training."""
63
+ if (not batch_size) and (not resolution):
64
+ return
65
+ if (batch_size == self.batch_size) and (
66
+ resolution == self.dataset.resolution):
67
+ return
68
+ if batch_size:
69
+ self.batch_size = batch_size
70
+ if resolution:
71
+ self._dataset.resolution = resolution
72
+ self.build_dataloader()
73
+
74
+ @property
75
+ def iter(self):
76
+ """Returns the current iteration."""
77
+ return self._iter
78
+
79
+ @property
80
+ def dataset(self):
81
+ """Returns the dataset."""
82
+ return self._dataset
83
+
84
+ @property
85
+ def dataloader(self):
86
+ """Returns the data loader."""
87
+ return self._dataloader
88
+
89
+ def __next__(self):
90
+ try:
91
+ data = next(self.iter_loader)
92
+ self._iter += 1
93
+ except StopIteration:
94
+ self._dataloader.sampler.__reset__(self._iter)
95
+ self.iter_loader = iter(self._dataloader)
96
+ data = next(self.iter_loader)
97
+ self._iter += 1
98
+ return data
99
+
100
+ def __len__(self):
101
+ return len(self._dataloader)
102
+
103
+
104
+ def dataloader_test(root_dir, test_num=10):
105
+ """Tests data loader."""
106
+ res = 2
107
+ bs = 2
108
+ dataset = BaseDataset(root_dir=root_dir, resolution=res)
109
+ dataloader = IterDataLoader(dataset=dataset,
110
+ batch_size=bs,
111
+ shuffle=False)
112
+ for _ in range(test_num):
113
+ data_batch = next(dataloader)
114
+ image = data_batch['image']
115
+ assert image.shape == (bs, 3, res, res)
116
+ res *= 2
117
+ bs += 1
118
+ dataloader.overwrite_param(batch_size=bs, resolution=res)
119
+
120
+
121
+ if __name__ == '__main__':
122
+ parser = argparse.ArgumentParser(description='Test Data Loader.')
123
+ parser.add_argument('root_dir', type=str,
124
+ help='Root directory of the dataset.')
125
+ parser.add_argument('--test_num', type=int, default=10,
126
+ help='Number of tests. (default: %(default)s)')
127
+ args = parser.parse_args()
128
+ dataloader_test(args.root_dir, args.test_num)
networks/genforce/datasets/datasets.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the class of dataset."""
3
+
4
+ import os
5
+ import pickle
6
+ import string
7
+ import zipfile
8
+ import numpy as np
9
+ import cv2
10
+ import lmdb
11
+
12
+ import torch
13
+ from torch.utils.data import Dataset
14
+
15
+ from .transforms import progressive_resize_image
16
+ from .transforms import crop_resize_image
17
+ from .transforms import resize_image
18
+ from .transforms import normalize_image
19
+
20
+ try:
21
+ import turbojpeg
22
+ BASE_DIR = os.path.dirname(os.path.relpath(__file__))
23
+ LIBRARY_NAME = 'libturbojpeg.so.0'
24
+ LIBRARY_PATH = os.path.join(BASE_DIR, LIBRARY_NAME)
25
+ jpeg = turbojpeg.TurboJPEG(LIBRARY_PATH)
26
+ except ImportError:
27
+ jpeg = None
28
+
29
+ __all__ = ['BaseDataset']
30
+
31
+ _FORMATS_ALLOWED = ['dir', 'lmdb', 'list', 'zip']
32
+
33
+
34
+ class ZipLoader(object):
35
+ """Defines a class to load zip file.
36
+
37
+ This is a static class, which is used to solve the problem that different
38
+ data workers can not share the same memory.
39
+ """
40
+ files = dict()
41
+
42
+ @staticmethod
43
+ def get_zipfile(file_path):
44
+ """Fetches a zip file."""
45
+ zip_files = ZipLoader.files
46
+ if file_path not in zip_files:
47
+ zip_files[file_path] = zipfile.ZipFile(file_path, 'r')
48
+ return zip_files[file_path]
49
+
50
+ @staticmethod
51
+ def get_image(file_path, image_path):
52
+ """Decodes an image from a particular zip file."""
53
+ zip_file = ZipLoader.get_zipfile(file_path)
54
+ image_str = zip_file.read(image_path)
55
+ image_np = np.frombuffer(image_str, np.uint8)
56
+ image = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
57
+ return image
58
+
59
+
60
+ class LmdbLoader(object):
61
+ """Defines a class to load lmdb file.
62
+
63
+ This is a static class, which is used to solve lmdb loading error
64
+ when num_workers > 0
65
+ """
66
+ files = dict()
67
+
68
+ @staticmethod
69
+ def get_lmdbfile(file_path):
70
+ """Fetches a lmdb file"""
71
+ lmdb_files = LmdbLoader.files
72
+ if 'env' not in lmdb_files:
73
+ env = lmdb.open(file_path,
74
+ max_readers=1,
75
+ readonly=True,
76
+ lock=False,
77
+ readahead=False,
78
+ meminit=False)
79
+ with env.begin(write=False) as txn:
80
+ num_samples = txn.stat()['entries']
81
+ cache_file = '_cache_' + ''.join(
82
+ c for c in file_path if c in string.ascii_letters)
83
+ if os.path.isfile(cache_file):
84
+ keys = pickle.load(open(cache_file, "rb"))
85
+ else:
86
+ with env.begin(write=False) as txn:
87
+ keys = [key for key, _ in txn.cursor()]
88
+ pickle.dump(keys, open(cache_file, "wb"))
89
+ lmdb_files['env'] = env
90
+ lmdb_files['num_samples'] = num_samples
91
+ lmdb_files['keys'] = keys
92
+ return lmdb_files
93
+
94
+ @staticmethod
95
+ def get_image(file_path, idx):
96
+ """Decodes an image from a particular lmdb file"""
97
+ lmdb_files = LmdbLoader.get_lmdbfile(file_path)
98
+ env = lmdb_files['env']
99
+ keys = lmdb_files['keys']
100
+ with env.begin(write=False) as txn:
101
+ imagebuf = txn.get(keys[idx])
102
+ image_np = np.frombuffer(imagebuf, np.uint8)
103
+ image = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
104
+ return image
105
+
106
+
107
+ class BaseDataset(Dataset):
108
+ """Defines the base dataset class.
109
+
110
+ This class supports loading data from a full-of-image folder, a lmdb
111
+ database, or an image list. Images will be pre-processed based on the given
112
+ `transform` function before fed into the data loader.
113
+
114
+ NOTE: The loaded data will be returned as a directory, where there must be
115
+ a key `image`.
116
+ """
117
+ def __init__(self,
118
+ root_dir,
119
+ resolution,
120
+ data_format='dir',
121
+ image_list_path=None,
122
+ mirror=0.0,
123
+ progressive_resize=True,
124
+ crop_resize_resolution=-1,
125
+ transform=normalize_image,
126
+ transform_kwargs=None,
127
+ **_unused_kwargs):
128
+ """Initializes the dataset.
129
+
130
+ Args:
131
+ root_dir: Root directory containing the dataset.
132
+ resolution: The resolution of the returned image.
133
+ data_format: Format the dataset is stored. Supports `dir`, `lmdb`,
134
+ and `list`. (default: `dir`)
135
+ image_list_path: Path to the image list. This field is required if
136
+ `data_format` is `list`. (default: None)
137
+ mirror: The probability to do mirror augmentation. (default: 0.0)
138
+ progressive_resize: Whether to resize images progressively.
139
+ (default: True)
140
+ crop_resize_resolution: The resolution of the output after crop
141
+ and resize. (default: -1)
142
+ transform: The transform function for pre-processing.
143
+ (default: `datasets.transforms.normalize_image()`)
144
+ transform_kwargs: The additional arguments for the `transform`
145
+ function. (default: None)
146
+
147
+ Raises:
148
+ ValueError: If the input `data_format` is not supported.
149
+ NotImplementedError: If the input `data_format` is not implemented.
150
+ """
151
+ if data_format.lower() not in _FORMATS_ALLOWED:
152
+ raise ValueError(f'Invalid data format `{data_format}`!\n'
153
+ f'Supported formats: {_FORMATS_ALLOWED}.')
154
+
155
+ self.root_dir = root_dir
156
+ self.resolution = resolution
157
+ self.data_format = data_format.lower()
158
+ self.image_list_path = image_list_path
159
+ self.mirror = np.clip(mirror, 0.0, 1.0)
160
+ self.progressive_resize = progressive_resize
161
+ self.crop_resize_resolution = crop_resize_resolution
162
+ self.transform = transform
163
+ self.transform_kwargs = transform_kwargs or dict()
164
+
165
+ if self.data_format == 'dir':
166
+ self.image_paths = sorted(os.listdir(self.root_dir))
167
+ self.num_samples = len(self.image_paths)
168
+ elif self.data_format == 'lmdb':
169
+ lmdb_file = LmdbLoader.get_lmdbfile(self.root_dir)
170
+ self.num_samples = lmdb_file['num_samples']
171
+ elif self.data_format == 'list':
172
+ self.metas = []
173
+ assert os.path.isfile(self.image_list_path)
174
+ with open(self.image_list_path) as f:
175
+ for line in f:
176
+ fields = line.rstrip().split(' ')
177
+ if len(fields) == 1:
178
+ self.metas.append((fields[0], None))
179
+ else:
180
+ assert len(fields) == 2
181
+ self.metas.append((fields[0], int(fields[1])))
182
+ self.num_samples = len(self.metas)
183
+ elif self.data_format == 'zip':
184
+ zip_file = ZipLoader.get_zipfile(self.root_dir)
185
+ image_paths = [f for f in zip_file.namelist()
186
+ if ('.jpg' in f or '.jpeg' in f or '.png' in f)]
187
+ self.image_paths = sorted(image_paths)
188
+ self.num_samples = len(self.image_paths)
189
+ else:
190
+ raise NotImplementedError(f'Not implemented data format '
191
+ f'`{self.data_format}`!')
192
+
193
+ def __len__(self):
194
+ return self.num_samples
195
+
196
+ def __getitem__(self, idx):
197
+ data = dict()
198
+
199
+ # Load data.
200
+ if self.data_format == 'dir':
201
+ image_path = self.image_paths[idx]
202
+ try:
203
+ in_file = open(os.path.join(self.root_dir, image_path), 'rb')
204
+ image = jpeg.decode(in_file.read())
205
+ except: # pylint: disable=bare-except
206
+ image = cv2.imread(os.path.join(self.root_dir, image_path))
207
+ elif self.data_format == 'lmdb':
208
+ image = LmdbLoader.get_image(self.root_dir, idx)
209
+ elif self.data_format == 'list':
210
+ image_path, label = self.metas[idx]
211
+ image = cv2.imread(os.path.join(self.root_dir, image_path))
212
+ label = None if label is None else torch.LongTensor(label)
213
+ # data.update({'label': label})
214
+ elif self.data_format == 'zip':
215
+ image_path = self.image_paths[idx]
216
+ image = ZipLoader.get_image(self.root_dir, image_path)
217
+ else:
218
+ raise NotImplementedError(f'Not implemented data format '
219
+ f'`{self.data_format}`!')
220
+
221
+ image = image[:, :, ::-1] # Converts BGR (cv2) to RGB.
222
+
223
+ # Transform image.
224
+ if self.crop_resize_resolution > 0:
225
+ image = crop_resize_image(image, self.crop_resize_resolution)
226
+ if self.progressive_resize:
227
+ image = progressive_resize_image(image, self.resolution)
228
+ image = image.transpose(2, 0, 1).astype(np.float32)
229
+ if np.random.uniform() < self.mirror:
230
+ image = image[:, :, ::-1] # CHW
231
+ image = torch.FloatTensor(image.copy())
232
+ if not self.progressive_resize:
233
+ image = resize_image(image, self.resolution)
234
+
235
+ if self.transform is not None:
236
+ image = self.transform(image, **self.transform_kwargs)
237
+ data.update({'image': image})
238
+
239
+ return data
networks/genforce/datasets/distributed_sampler.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the distributed data sampler.
3
+
4
+ This file is mostly borrowed from `torch/utils/data/distributed.py`.
5
+
6
+ However, sometimes, initialize the data loader and data sampler can be time
7
+ consuming (since it will load a large amount of data at one time). To avoid
8
+ re-initializing the data loader again and again, we modified the sampler to
9
+ support loading the data for only one time and then repeating the data loader.
10
+ Please use the class member `repeat` to control how many times you want the
11
+ data load to repeat. After `repeat` times, the data will be re-loaded.
12
+
13
+ NOTE: The number of repeat times should not be very large, especially when there
14
+ are too many samples in the dataset. We recommend to set `repeat = 500` for
15
+ datasets with ~50K samples.
16
+ """
17
+
18
+ # pylint: disable=line-too-long
19
+
20
+ import math
21
+ from typing import TypeVar, Optional, Iterator
22
+
23
+ import torch
24
+ from torch.utils.data import Sampler, Dataset
25
+ import torch.distributed as dist
26
+
27
+
28
+ T_co = TypeVar('T_co', covariant=True)
29
+
30
+
31
+ class DistributedSampler(Sampler):
32
+ r"""Sampler that restricts data loading to a subset of the dataset.
33
+
34
+ It is especially useful in conjunction with
35
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
36
+ process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a
37
+ :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
38
+ original dataset that is exclusive to it.
39
+
40
+ .. note::
41
+ Dataset is assumed to be of constant size.
42
+
43
+ Arguments:
44
+ dataset: Dataset used for sampling.
45
+ num_replicas (int, optional): Number of processes participating in
46
+ distributed training. By default, :attr:`rank` is retrieved from the
47
+ current distributed group.
48
+ rank (int, optional): Rank of the current process within :attr:`num_replicas`.
49
+ By default, :attr:`rank` is retrieved from the current distributed
50
+ group.
51
+ shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
52
+ indices.
53
+ seed (int, optional): random seed used to shuffle the sampler if
54
+ :attr:`shuffle=True`. This number should be identical across all
55
+ processes in the distributed group. Default: ``0``.
56
+ drop_last (bool, optional): if ``True``, then the sampler will drop the
57
+ tail of the data to make it evenly divisible across the number of
58
+ replicas. If ``False``, the sampler will add extra indices to make
59
+ the data evenly divisible across the replicas. Default: ``False``.
60
+ current_iter (int, optional): Number of current iteration. Default: ``0``.
61
+ repeat (int, optional): Repeating number of the whole dataloader. Default: ``1000``.
62
+
63
+ .. warning::
64
+ In distributed mode, calling the :meth:`set_epoch` method at
65
+ the beginning of each epoch **before** creating the :class:`DataLoader` iterator
66
+ is necessary to make shuffling work properly across multiple epochs. Otherwise,
67
+ the same ordering will be always used.
68
+
69
+ """
70
+
71
+ def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
72
+ rank: Optional[int] = None, shuffle: bool = True,
73
+ seed: int = 0, drop_last: bool = False, current_iter: int = 0,
74
+ repeat: int = 1000) -> None:
75
+ super().__init__(None)
76
+ if num_replicas is None:
77
+ if not dist.is_available():
78
+ raise RuntimeError("Requires distributed package to be available")
79
+ num_replicas = dist.get_world_size()
80
+ if rank is None:
81
+ if not dist.is_available():
82
+ raise RuntimeError("Requires distributed package to be available")
83
+ rank = dist.get_rank()
84
+ self.dataset = dataset
85
+ self.num_replicas = num_replicas
86
+ self.rank = rank
87
+ self.iter = current_iter
88
+ self.drop_last = drop_last
89
+
90
+ # NOTE: self.dataset_length is `repeat X len(self.dataset)`
91
+ self.repeat = repeat
92
+ self.dataset_length = len(self.dataset) * self.repeat
93
+
94
+ if self.drop_last and self.dataset_length % self.num_replicas != 0:
95
+ # Split to nearest available length that is evenly divisible.
96
+ # This is to ensure each rank receives the same amount of data when
97
+ # using this Sampler.
98
+ self.num_samples = math.ceil(
99
+ (self.dataset_length - self.num_replicas) / self.num_replicas
100
+ )
101
+ else:
102
+ self.num_samples = math.ceil(self.dataset_length / self.num_replicas)
103
+
104
+
105
+ self.total_size = self.num_samples * self.num_replicas
106
+ self.shuffle = shuffle
107
+ self.seed = seed
108
+ self.__generate_indices__()
109
+
110
+ def __generate_indices__(self) -> None:
111
+ g = torch.Generator()
112
+ indices_bank = []
113
+ for iter_ in range(self.iter, self.iter + self.repeat):
114
+ g.manual_seed(self.seed + iter_)
115
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
116
+ indices_bank.extend(indices)
117
+ self.indices = indices_bank
118
+
119
+ def __iter__(self) -> Iterator[T_co]:
120
+ if self.shuffle:
121
+ # deterministically shuffle based on iter and seed
122
+ indices = self.indices
123
+ else:
124
+ indices = list(range(self.dataset_length))
125
+
126
+ if not self.drop_last:
127
+ # add extra samples to make it evenly divisible
128
+ indices += indices[:(self.total_size - len(indices))]
129
+ else:
130
+ # remove tail of data to make it evenly divisible.
131
+ indices = indices[:self.total_size]
132
+
133
+ # subsample
134
+ indices = indices[self.rank:self.total_size:self.num_replicas]
135
+ return iter(indices)
136
+
137
+ def __len__(self) -> int:
138
+ return self.num_samples
139
+
140
+ def __reset__(self, iteration: int) -> None:
141
+ self.iter = iteration
142
+ self.__generate_indices__()
143
+
144
+ # pylint: enable=line-too-long
networks/genforce/datasets/libturbojpeg.so.0 ADDED
Binary file (396 kB). View file
 
networks/genforce/datasets/transforms.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Contains transform functions."""
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import PIL.Image
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ __all__ = [
13
+ 'crop_resize_image', 'progressive_resize_image', 'resize_image',
14
+ 'normalize_image', 'normalize_latent_code', 'ImageResizing',
15
+ 'ImageNormalization', 'LatentCodeNormalization',
16
+ ]
17
+
18
+
19
+ def crop_resize_image(image, size):
20
+ """Crops a square patch and then resizes it to the given size.
21
+
22
+ Args:
23
+ image: The input image to crop and resize.
24
+ size: An integer, indicating the target size.
25
+
26
+ Returns:
27
+ An image with target size.
28
+
29
+ Raises:
30
+ TypeError: If the input `image` is not with type `numpy.ndarray`.
31
+ ValueError: If the input `image` is not with shape [H, W, C].
32
+ """
33
+ if not isinstance(image, np.ndarray):
34
+ raise TypeError(f'Input image should be with type `numpy.ndarray`, '
35
+ f'but `{type(image)}` is received!')
36
+ if image.ndim != 3:
37
+ raise ValueError(f'Input image should be with shape [H, W, C], '
38
+ f'but `{image.shape}` is received!')
39
+
40
+ height, width, channel = image.shape
41
+ short_side = min(height, width)
42
+ image = image[(height - short_side) // 2:(height + short_side) // 2,
43
+ (width - short_side) // 2:(width + short_side) // 2]
44
+ pil_image = PIL.Image.fromarray(image)
45
+ pil_image = pil_image.resize((size, size), PIL.Image.ANTIALIAS)
46
+ image = np.asarray(pil_image)
47
+ assert image.shape == (size, size, channel)
48
+ return image
49
+
50
+
51
+ def progressive_resize_image(image, size):
52
+ """Resizes image to target size progressively.
53
+
54
+ Different from normal resize, this function will reduce the image size
55
+ progressively. In each step, the maximum reduce factor is 2.
56
+
57
+ NOTE: This function can only handle square images, and can only be used for
58
+ downsampling.
59
+
60
+ Args:
61
+ image: The input (square) image to resize.
62
+ size: An integer, indicating the target size.
63
+
64
+ Returns:
65
+ An image with target size.
66
+
67
+ Raises:
68
+ TypeError: If the input `image` is not with type `numpy.ndarray`.
69
+ ValueError: If the input `image` is not with shape [H, W, C].
70
+ """
71
+ if not isinstance(image, np.ndarray):
72
+ raise TypeError(f'Input image should be with type `numpy.ndarray`, '
73
+ f'but `{type(image)}` is received!')
74
+ if image.ndim != 3:
75
+ raise ValueError(f'Input image should be with shape [H, W, C], '
76
+ f'but `{image.shape}` is received!')
77
+
78
+ height, width, channel = image.shape
79
+ assert height == width
80
+ assert height >= size
81
+ num_iters = int(np.log2(height) - np.log2(size))
82
+ for _ in range(num_iters):
83
+ height = max(height // 2, size)
84
+ image = cv2.resize(image, (height, height),
85
+ interpolation=cv2.INTER_LINEAR)
86
+ assert image.shape == (size, size, channel)
87
+ return image
88
+
89
+
90
+ def resize_image(image, size):
91
+ """Resizes image to target size.
92
+
93
+ NOTE: We use adaptive average pooing for image resizing. Instead of bilinear
94
+ interpolation, average pooling is able to acquire information from more
95
+ pixels, such that the resized results can be with higher quality.
96
+
97
+ Args:
98
+ image: The input image tensor, with shape [C, H, W], to resize.
99
+ size: An integer or a tuple of integer, indicating the target size.
100
+
101
+ Returns:
102
+ An image tensor with target size.
103
+
104
+ Raises:
105
+ TypeError: If the input `image` is not with type `torch.Tensor`.
106
+ ValueError: If the input `image` is not with shape [C, H, W].
107
+ """
108
+ if not isinstance(image, torch.Tensor):
109
+ raise TypeError(f'Input image should be with type `torch.Tensor`, '
110
+ f'but `{type(image)}` is received!')
111
+ if image.ndim != 3:
112
+ raise ValueError(f'Input image should be with shape [C, H, W], '
113
+ f'but `{image.shape}` is received!')
114
+
115
+ image = F.adaptive_avg_pool2d(image.unsqueeze(0), size).squeeze(0)
116
+ return image
117
+
118
+
119
+ def normalize_image(image, mean=127.5, std=127.5):
120
+ """Normalizes image by subtracting mean and dividing std.
121
+
122
+ Args:
123
+ image: The input image tensor to normalize.
124
+ mean: The mean value to subtract from the input tensor. (default: 127.5)
125
+ std: The standard deviation to normalize the input tensor. (default:
126
+ 127.5)
127
+
128
+ Returns:
129
+ A normalized image tensor.
130
+
131
+ Raises:
132
+ TypeError: If the input `image` is not with type `torch.Tensor`.
133
+ """
134
+ if not isinstance(image, torch.Tensor):
135
+ raise TypeError(f'Input image should be with type `torch.Tensor`, '
136
+ f'but `{type(image)}` is received!')
137
+ out = (image - mean) / std
138
+ return out
139
+
140
+
141
+ def normalize_latent_code(latent_code, adjust_norm=True):
142
+ """Normalizes latent code.
143
+
144
+ NOTE: The latent code will always be normalized along the last axis.
145
+ Meanwhile, if `adjust_norm` is set as `True`, the norm of the result will be
146
+ adjusted to `sqrt(latent_code.shape[-1])` in order to avoid too small value.
147
+
148
+ Args:
149
+ latent_code: The input latent code tensor to normalize.
150
+ adjust_norm: Whether to adjust the norm of the output. (default: True)
151
+
152
+ Returns:
153
+ A normalized latent code tensor.
154
+
155
+ Raises:
156
+ TypeError: If the input `latent_code` is not with type `torch.Tensor`.
157
+ """
158
+ if not isinstance(latent_code, torch.Tensor):
159
+ raise TypeError(f'Input latent code should be with type '
160
+ f'`torch.Tensor`, but `{type(latent_code)}` is '
161
+ f'received!')
162
+ dim = latent_code.shape[-1]
163
+ norm = latent_code.pow(2).sum(-1, keepdim=True).pow(0.5)
164
+ out = latent_code / norm
165
+ if adjust_norm:
166
+ out = out * (dim ** 0.5)
167
+ return out
168
+
169
+
170
+ class ImageResizing(nn.Module):
171
+ """Implements the image resizing layer."""
172
+
173
+ def __init__(self, size):
174
+ super().__init__()
175
+ self.size = size
176
+
177
+ def forward(self, image):
178
+ return resize_image(image, self.size)
179
+
180
+
181
+ class ImageNormalization(nn.Module):
182
+ """Implements the image normalization layer."""
183
+
184
+ def __init__(self, mean=127.5, std=127.5):
185
+ super().__init__()
186
+ self.mean = mean
187
+ self.std = std
188
+
189
+ def forward(self, image):
190
+ return normalize_image(image, self.mean, self.std)
191
+
192
+
193
+ class LatentCodeNormalization(nn.Module):
194
+ """Implements the latent code normalization layer."""
195
+
196
+ def __init__(self, adjust_norm=True):
197
+ super().__init__()
198
+ self.adjust_norm = adjust_norm
199
+
200
+ def forward(self, latent_code):
201
+ return normalize_latent_code(latent_code, self.adjust_norm)
networks/genforce/metrics/README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation Metrics
2
+
3
+ Frechet Inception Distance (FID) is commonly used to evaluate generative model. It employs an [Inception Model](https://arxiv.org/abs/1512.00567) (pretrained on ImageNet) to extract features from both real and synthesized images.
4
+
5
+ ## Inception Model
6
+
7
+ For [PGGAN](https://github.com/tkarras/progressive_growing_of_gans), [StyleGAN](https://github.com/NVlabs/stylegan), etc, they use inception model from the [TensorFlow Models](https://github.com/tensorflow/models) repository, whose implementation is slightly different from that of `torchvision`. Hence, to make the evaluation metric comparable between different training frameworks (i.e., PyTorch and TensorFlow), we modify `torchvision/models/inception.py` as `inception.py`. The ported pre-trained weight is borrowed from [this repo](https://github.com/mseitzer/pytorch-fid).
8
+
9
+ **NOTE:** We also support using the model from `torchvision` to compute the FID. However, please be aware that the FID value from `torchvision` is usually ~1.5 smaller than that from the TensorFlow model.
10
+
11
+ Please use the following code to choose which model to use.
12
+
13
+ ```python
14
+ from metrics.inception import build_inception_model
15
+
16
+ inception_model_tf = build_inception_model(align_tf=True)
17
+ inception_model_pth = build_inception_model(align_tf=False)
18
+ ```
networks/genforce/metrics/__init__.py ADDED
File without changes
networks/genforce/metrics/fid.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the functions to compute Frechet Inception Distance (FID).
3
+
4
+ FID metric is introduced in paper
5
+
6
+ GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash
7
+ Equilibrium. Heusel et al. NeurIPS 2017.
8
+
9
+ See details at https://arxiv.org/pdf/1706.08500.pdf
10
+ """
11
+
12
+ import numpy as np
13
+ import scipy.linalg
14
+
15
+ __all__ = ['extract_feature', 'compute_fid']
16
+
17
+
18
+ def extract_feature(inception_model, images):
19
+ """Extracts feature from input images with given model.
20
+
21
+ NOTE: The input images are assumed to be with pixel range [-1, 1].
22
+
23
+ Args:
24
+ inception_model: The model used to extract features.
25
+ images: The input image tensor to extract features from.
26
+
27
+ Returns:
28
+ A `numpy.ndarray`, containing the extracted features.
29
+ """
30
+ features = inception_model(images, output_logits=False)
31
+ features = features.detach().cpu().numpy()
32
+ assert features.ndim == 2 and features.shape[1] == 2048
33
+ return features
34
+
35
+
36
+ def compute_fid(fake_features, real_features):
37
+ """Computes FID based on the features extracted from fake and real data.
38
+
39
+ Given the mean and covariance (m_f, C_f) of fake data and (m_r, C_r) of real
40
+ data, the FID metric can be computed by
41
+
42
+ d^2 = ||m_f - m_r||_2^2 + Tr(C_f + C_r - 2(C_f C_r)^0.5)
43
+
44
+ Args:
45
+ fake_features: The features extracted from fake data.
46
+ real_features: The features extracted from real data.
47
+
48
+ Returns:
49
+ A real number, suggesting the FID value.
50
+ """
51
+
52
+ m_f = np.mean(fake_features, axis=0)
53
+ C_f = np.cov(fake_features, rowvar=False)
54
+ m_r = np.mean(real_features, axis=0)
55
+ C_r = np.cov(real_features, rowvar=False)
56
+
57
+ fid = np.sum((m_f - m_r) ** 2) + np.trace(
58
+ C_f + C_r - 2 * scipy.linalg.sqrtm(np.dot(C_f, C_r)))
59
+ return np.real(fid)
networks/genforce/metrics/inception.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Contains the Inception V3 model.
3
+
4
+ This file is mostly borrowed from `torchvision/models/inception.py`.
5
+
6
+ Inception model is widely used to compute FID or IS metric for evaluating
7
+ generative models. However, the pre-trained models from torchvision is slightly
8
+ different from the TensorFlow version.
9
+
10
+ http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
11
+
12
+ In particular:
13
+
14
+ (1) The number of classes in TensorFlow model is 1008 instead of 1000.
15
+ (2) The avg_pool() layers in TensorFlow model does not include the padded zero.
16
+ (3) The last Inception E Block in TensorFlow model use max_pool() instead of
17
+ avg_pool().
18
+
19
+ Hence, to algin the evaluation results with those from TensorFlow
20
+ implementation, we modified the inception model to support both versions. Please
21
+ use `align_tf` argument to control the version.
22
+ """
23
+
24
+ # pylint: disable=line-too-long
25
+ # pylint: disable=missing-function-docstring
26
+ # pylint: disable=missing-class-docstring
27
+ # pylint: disable=super-with-arguments
28
+ # pylint: disable=consider-merging-isinstance
29
+ # pylint: disable=import-outside-toplevel
30
+ # pylint: disable=no-else-return
31
+
32
+ from collections import namedtuple
33
+ import warnings
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+ from torch.jit.annotations import Optional
38
+ from torch import Tensor
39
+ from torchvision.models.utils import load_state_dict_from_url
40
+
41
+
42
+ __all__ = ['build_inception_model', 'Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs']
43
+
44
+ model_urls = {
45
+ # Inception v3 ported from TensorFlow
46
+ 'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
47
+
48
+ # Inception v3 ported from http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
49
+ # This model is provided by https://github.com/mseitzer/pytorch-fid
50
+ 'tf_inception_v3': 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
51
+ }
52
+
53
+ InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
54
+ InceptionOutputs.__annotations__ = {'logits': torch.Tensor, 'aux_logits': Optional[torch.Tensor]}
55
+
56
+ # Script annotations failed with _GoogleNetOutputs = namedtuple ...
57
+ # _InceptionOutputs set here for backwards compat
58
+ _InceptionOutputs = InceptionOutputs
59
+
60
+
61
+ def build_inception_model(align_tf=True):
62
+ """Builds Inception V3 model.
63
+
64
+ This model is particular used for inference, such that `requires_grad` and
65
+ `mode` will both be set as `False`.
66
+
67
+ Args:
68
+ align_tf: Whether to align the implementation with TensorFlow version. (default: True)
69
+
70
+ Returns:
71
+ A `torch.nn.Module` with pre-trained weight.
72
+ """
73
+ if align_tf:
74
+ num_classes = 1008
75
+ model_url = model_urls['tf_inception_v3']
76
+ else:
77
+ num_classes = 1000
78
+ model_url = model_urls['inception_v3_google']
79
+ model = Inception3(num_classes=num_classes,
80
+ aux_logits=False,
81
+ transform_input=False,
82
+ align_tf=align_tf)
83
+ state_dict = load_state_dict_from_url(model_url)
84
+ model.load_state_dict(state_dict, strict=False)
85
+ model.eval()
86
+ for param in model.parameters():
87
+ param.requires_grad = False
88
+ return model
89
+
90
+
91
+ def inception_v3(pretrained=False, progress=True, **kwargs):
92
+ r"""Inception v3 model architecture from
93
+ `"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
94
+
95
+ .. note::
96
+ **Important**: In contrast to the other models the inception_v3 expects tensors with a size of
97
+ N x 3 x 299 x 299, so ensure your images are sized accordingly.
98
+
99
+ Args:
100
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
101
+ progress (bool): If True, displays a progress bar of the download to stderr
102
+ aux_logits (bool): If True, add an auxiliary branch that can improve training.
103
+ Default: *True*
104
+ transform_input (bool): If True, preprocesses the input according to the method with which it
105
+ was trained on ImageNet. Default: *False*
106
+ """
107
+ if pretrained:
108
+ if 'transform_input' not in kwargs:
109
+ kwargs['transform_input'] = True
110
+ if 'aux_logits' in kwargs:
111
+ original_aux_logits = kwargs['aux_logits']
112
+ kwargs['aux_logits'] = True
113
+ else:
114
+ original_aux_logits = True
115
+ model = Inception3(**kwargs)
116
+ state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
117
+ progress=progress)
118
+ model.load_state_dict(state_dict)
119
+ if not original_aux_logits:
120
+ model.aux_logits = False
121
+ del model.AuxLogits
122
+ return model
123
+
124
+ return Inception3(**kwargs)
125
+
126
+
127
+ class Inception3(nn.Module):
128
+
129
+ def __init__(self, num_classes=1000, aux_logits=True, transform_input=False,
130
+ inception_blocks=None, init_weights=True, align_tf=True):
131
+ super(Inception3, self).__init__()
132
+ if inception_blocks is None:
133
+ inception_blocks = [
134
+ BasicConv2d, InceptionA, InceptionB, InceptionC,
135
+ InceptionD, InceptionE, InceptionAux
136
+ ]
137
+ assert len(inception_blocks) == 7
138
+ conv_block = inception_blocks[0]
139
+ inception_a = inception_blocks[1]
140
+ inception_b = inception_blocks[2]
141
+ inception_c = inception_blocks[3]
142
+ inception_d = inception_blocks[4]
143
+ inception_e = inception_blocks[5]
144
+ inception_aux = inception_blocks[6]
145
+
146
+ self.aux_logits = aux_logits
147
+ self.transform_input = transform_input
148
+ self.align_tf = align_tf
149
+ self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
150
+ self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
151
+ self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
152
+ self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
153
+ self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
154
+ self.Mixed_5b = inception_a(192, pool_features=32, align_tf=self.align_tf)
155
+ self.Mixed_5c = inception_a(256, pool_features=64, align_tf=self.align_tf)
156
+ self.Mixed_5d = inception_a(288, pool_features=64, align_tf=self.align_tf)
157
+ self.Mixed_6a = inception_b(288)
158
+ self.Mixed_6b = inception_c(768, channels_7x7=128, align_tf=self.align_tf)
159
+ self.Mixed_6c = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
160
+ self.Mixed_6d = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
161
+ self.Mixed_6e = inception_c(768, channels_7x7=192, align_tf=self.align_tf)
162
+ if aux_logits:
163
+ self.AuxLogits = inception_aux(768, num_classes)
164
+ self.Mixed_7a = inception_d(768)
165
+ self.Mixed_7b = inception_e(1280, align_tf=self.align_tf)
166
+ self.Mixed_7c = inception_e(2048, use_max_pool=self.align_tf)
167
+ self.fc = nn.Linear(2048, num_classes)
168
+ if init_weights:
169
+ for m in self.modules():
170
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
171
+ import scipy.stats as stats
172
+ stddev = m.stddev if hasattr(m, 'stddev') else 0.1
173
+ X = stats.truncnorm(-2, 2, scale=stddev)
174
+ values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
175
+ values = values.view(m.weight.size())
176
+ with torch.no_grad():
177
+ m.weight.copy_(values)
178
+ elif isinstance(m, nn.BatchNorm2d):
179
+ nn.init.constant_(m.weight, 1)
180
+ nn.init.constant_(m.bias, 0)
181
+
182
+ def _transform_input(self, x):
183
+ if self.transform_input:
184
+ x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
185
+ x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
186
+ x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
187
+ x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
188
+ return x
189
+
190
+ def _forward(self, x, output_logits=False):
191
+ # Upsample if necessary
192
+ if x.shape[2] != 299 or x.shape[3] != 299:
193
+ x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
194
+
195
+ # N x 3 x 299 x 299
196
+ x = self.Conv2d_1a_3x3(x)
197
+ # N x 32 x 149 x 149
198
+ x = self.Conv2d_2a_3x3(x)
199
+ # N x 32 x 147 x 147
200
+ x = self.Conv2d_2b_3x3(x)
201
+ # N x 64 x 147 x 147
202
+ x = F.max_pool2d(x, kernel_size=3, stride=2)
203
+ # N x 64 x 73 x 73
204
+ x = self.Conv2d_3b_1x1(x)
205
+ # N x 80 x 73 x 73
206
+ x = self.Conv2d_4a_3x3(x)
207
+ # N x 192 x 71 x 71
208
+ x = F.max_pool2d(x, kernel_size=3, stride=2)
209
+ # N x 192 x 35 x 35
210
+ x = self.Mixed_5b(x)
211
+ # N x 256 x 35 x 35
212
+ x = self.Mixed_5c(x)
213
+ # N x 288 x 35 x 35
214
+ x = self.Mixed_5d(x)
215
+ # N x 288 x 35 x 35
216
+ x = self.Mixed_6a(x)
217
+ # N x 768 x 17 x 17
218
+ x = self.Mixed_6b(x)
219
+ # N x 768 x 17 x 17
220
+ x = self.Mixed_6c(x)
221
+ # N x 768 x 17 x 17
222
+ x = self.Mixed_6d(x)
223
+ # N x 768 x 17 x 17
224
+ x = self.Mixed_6e(x)
225
+ # N x 768 x 17 x 17
226
+ aux_defined = self.training and self.aux_logits
227
+ if aux_defined:
228
+ aux = self.AuxLogits(x)
229
+ else:
230
+ aux = None
231
+ # N x 768 x 17 x 17
232
+ x = self.Mixed_7a(x)
233
+ # N x 1280 x 8 x 8
234
+ x = self.Mixed_7b(x)
235
+ # N x 2048 x 8 x 8
236
+ x = self.Mixed_7c(x)
237
+ # N x 2048 x 8 x 8
238
+ # Adaptive average pooling
239
+ x = F.adaptive_avg_pool2d(x, (1, 1))
240
+ # N x 2048 x 1 x 1
241
+ x = F.dropout(x, training=self.training)
242
+ # N x 2048 x 1 x 1
243
+ x = torch.flatten(x, 1)
244
+ # N x 2048
245
+ if output_logits:
246
+ x = self.fc(x)
247
+ # N x 1000 (num_classes)
248
+ return x, aux
249
+
250
+ @torch.jit.unused
251
+ def eager_outputs(self, x, aux):
252
+ # type: (Tensor, Optional[Tensor]) -> InceptionOutputs
253
+ if self.training and self.aux_logits:
254
+ return InceptionOutputs(x, aux)
255
+ else:
256
+ return x
257
+
258
+ def forward(self, x, output_logits=False):
259
+ x = self._transform_input(x)
260
+ x, aux = self._forward(x, output_logits)
261
+ aux_defined = self.training and self.aux_logits
262
+ if torch.jit.is_scripting():
263
+ if not aux_defined:
264
+ warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
265
+ return InceptionOutputs(x, aux)
266
+ else:
267
+ return self.eager_outputs(x, aux)
268
+
269
+
270
+ class InceptionA(nn.Module):
271
+
272
+ def __init__(self, in_channels, pool_features, conv_block=None, align_tf=False):
273
+ super(InceptionA, self).__init__()
274
+ if conv_block is None:
275
+ conv_block = BasicConv2d
276
+ self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
277
+
278
+ self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
279
+ self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
280
+
281
+ self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
282
+ self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
283
+ self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
284
+
285
+ self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
286
+ self.pool_include_padding = not align_tf
287
+
288
+ def _forward(self, x):
289
+ branch1x1 = self.branch1x1(x)
290
+
291
+ branch5x5 = self.branch5x5_1(x)
292
+ branch5x5 = self.branch5x5_2(branch5x5)
293
+
294
+ branch3x3dbl = self.branch3x3dbl_1(x)
295
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
296
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
297
+
298
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
299
+ count_include_pad=self.pool_include_padding)
300
+ branch_pool = self.branch_pool(branch_pool)
301
+
302
+ outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
303
+ return outputs
304
+
305
+ def forward(self, x):
306
+ outputs = self._forward(x)
307
+ return torch.cat(outputs, 1)
308
+
309
+
310
+ class InceptionB(nn.Module):
311
+
312
+ def __init__(self, in_channels, conv_block=None):
313
+ super(InceptionB, self).__init__()
314
+ if conv_block is None:
315
+ conv_block = BasicConv2d
316
+ self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
317
+
318
+ self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
319
+ self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
320
+ self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
321
+
322
+ def _forward(self, x):
323
+ branch3x3 = self.branch3x3(x)
324
+
325
+ branch3x3dbl = self.branch3x3dbl_1(x)
326
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
327
+ branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
328
+
329
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
330
+
331
+ outputs = [branch3x3, branch3x3dbl, branch_pool]
332
+ return outputs
333
+
334
+ def forward(self, x):
335
+ outputs = self._forward(x)
336
+ return torch.cat(outputs, 1)
337
+
338
+
339
+ class InceptionC(nn.Module):
340
+
341
+ def __init__(self, in_channels, channels_7x7, conv_block=None, align_tf=False):
342
+ super(InceptionC, self).__init__()
343
+ if conv_block is None:
344
+ conv_block = BasicConv2d
345
+ self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
346
+
347
+ c7 = channels_7x7
348
+ self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
349
+ self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
350
+ self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
351
+
352
+ self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
353
+ self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
354
+ self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
355
+ self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
356
+ self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
357
+
358
+ self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
359
+ self.pool_include_padding = not align_tf
360
+
361
+ def _forward(self, x):
362
+ branch1x1 = self.branch1x1(x)
363
+
364
+ branch7x7 = self.branch7x7_1(x)
365
+ branch7x7 = self.branch7x7_2(branch7x7)
366
+ branch7x7 = self.branch7x7_3(branch7x7)
367
+
368
+ branch7x7dbl = self.branch7x7dbl_1(x)
369
+ branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
370
+ branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
371
+ branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
372
+ branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
373
+
374
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
375
+ count_include_pad=self.pool_include_padding)
376
+ branch_pool = self.branch_pool(branch_pool)
377
+
378
+ outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
379
+ return outputs
380
+
381
+ def forward(self, x):
382
+ outputs = self._forward(x)
383
+ return torch.cat(outputs, 1)
384
+
385
+
386
+ class InceptionD(nn.Module):
387
+
388
+ def __init__(self, in_channels, conv_block=None):
389
+ super(InceptionD, self).__init__()
390
+ if conv_block is None:
391
+ conv_block = BasicConv2d
392
+ self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
393
+ self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
394
+
395
+ self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
396
+ self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
397
+ self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
398
+ self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
399
+
400
+ def _forward(self, x):
401
+ branch3x3 = self.branch3x3_1(x)
402
+ branch3x3 = self.branch3x3_2(branch3x3)
403
+
404
+ branch7x7x3 = self.branch7x7x3_1(x)
405
+ branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
406
+ branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
407
+ branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
408
+
409
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
410
+ outputs = [branch3x3, branch7x7x3, branch_pool]
411
+ return outputs
412
+
413
+ def forward(self, x):
414
+ outputs = self._forward(x)
415
+ return torch.cat(outputs, 1)
416
+
417
+
418
+ class InceptionE(nn.Module):
419
+
420
+ def __init__(self, in_channels, conv_block=None, align_tf=False, use_max_pool=False):
421
+ super(InceptionE, self).__init__()
422
+ if conv_block is None:
423
+ conv_block = BasicConv2d
424
+ self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
425
+
426
+ self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
427
+ self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
428
+ self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
429
+
430
+ self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
431
+ self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
432
+ self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
433
+ self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
434
+
435
+ self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
436
+ self.pool_include_padding = not align_tf
437
+ self.use_max_pool = use_max_pool
438
+
439
+ def _forward(self, x):
440
+ branch1x1 = self.branch1x1(x)
441
+
442
+ branch3x3 = self.branch3x3_1(x)
443
+ branch3x3 = [
444
+ self.branch3x3_2a(branch3x3),
445
+ self.branch3x3_2b(branch3x3),
446
+ ]
447
+ branch3x3 = torch.cat(branch3x3, 1)
448
+
449
+ branch3x3dbl = self.branch3x3dbl_1(x)
450
+ branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
451
+ branch3x3dbl = [
452
+ self.branch3x3dbl_3a(branch3x3dbl),
453
+ self.branch3x3dbl_3b(branch3x3dbl),
454
+ ]
455
+ branch3x3dbl = torch.cat(branch3x3dbl, 1)
456
+
457
+ if self.use_max_pool:
458
+ branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
459
+ else:
460
+ branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
461
+ count_include_pad=self.pool_include_padding)
462
+ branch_pool = self.branch_pool(branch_pool)
463
+
464
+ outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
465
+ return outputs
466
+
467
+ def forward(self, x):
468
+ outputs = self._forward(x)
469
+ return torch.cat(outputs, 1)
470
+
471
+
472
+ class InceptionAux(nn.Module):
473
+
474
+ def __init__(self, in_channels, num_classes, conv_block=None):
475
+ super(InceptionAux, self).__init__()
476
+ if conv_block is None:
477
+ conv_block = BasicConv2d
478
+ self.conv0 = conv_block(in_channels, 128, kernel_size=1)
479
+ self.conv1 = conv_block(128, 768, kernel_size=5)
480
+ self.conv1.stddev = 0.01
481
+ self.fc = nn.Linear(768, num_classes)
482
+ self.fc.stddev = 0.001
483
+
484
+ def forward(self, x):
485
+ # N x 768 x 17 x 17
486
+ x = F.avg_pool2d(x, kernel_size=5, stride=3)
487
+ # N x 768 x 5 x 5
488
+ x = self.conv0(x)
489
+ # N x 128 x 5 x 5
490
+ x = self.conv1(x)
491
+ # N x 768 x 1 x 1
492
+ # Adaptive average pooling
493
+ x = F.adaptive_avg_pool2d(x, (1, 1))
494
+ # N x 768 x 1 x 1
495
+ x = torch.flatten(x, 1)
496
+ # N x 768
497
+ x = self.fc(x)
498
+ # N x 1000
499
+ return x
500
+
501
+
502
+ class BasicConv2d(nn.Module):
503
+
504
+ def __init__(self, in_channels, out_channels, **kwargs):
505
+ super(BasicConv2d, self).__init__()
506
+ self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
507
+ self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
508
+
509
+ def forward(self, x):
510
+ x = self.conv(x)
511
+ x = self.bn(x)
512
+ return F.relu(x, inplace=True)
513
+
514
+ # pylint: enable=line-too-long
515
+ # pylint: enable=missing-function-docstring
516
+ # pylint: enable=missing-class-docstring
517
+ # pylint: enable=super-with-arguments
518
+ # pylint: enable=consider-merging-isinstance
519
+ # pylint: enable=import-outside-toplevel
520
+ # pylint: enable=no-else-return
networks/genforce/models/__init__.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python3.7
2
+ """Collects all available models together."""
3
+
4
+ from .model_zoo import MODEL_ZOO
5
+ from .pggan_generator import PGGANGenerator
6
+ from .pggan_discriminator import PGGANDiscriminator
7
+ from .stylegan_generator import StyleGANGenerator
8
+ from .stylegan_discriminator import StyleGANDiscriminator
9
+ from .stylegan2_generator import StyleGAN2Generator
10
+ from .stylegan2_discriminator import StyleGAN2Discriminator
11
+ from .encoder import EncoderNet
12
+ from .perceptual_model import PerceptualModel
13
+
14
+ __all__ = [
15
+ 'MODEL_ZOO', 'PGGANGenerator', 'PGGANDiscriminator', 'StyleGANGenerator',
16
+ 'StyleGANDiscriminator', 'StyleGAN2Generator', 'StyleGAN2Discriminator',
17
+ 'EncoderNet', 'PerceptualModel', 'build_generator', 'build_discriminator',
18
+ 'build_encoder', 'build_perceptual', 'build_model'
19
+ ]
20
+
21
+ _GAN_TYPES_ALLOWED = ['pggan', 'stylegan', 'stylegan2']
22
+ _MODULES_ALLOWED = ['generator', 'discriminator', 'encoder', 'perceptual']
23
+
24
+
25
+ def build_generator(gan_type, resolution, **kwargs):
26
+ """Builds generator by GAN type.
27
+
28
+ Args:
29
+ gan_type: GAN type to which the generator belong.
30
+ resolution: Synthesis resolution.
31
+ **kwargs: Additional arguments to build the generator.
32
+
33
+ Raises:
34
+ ValueError: If the `gan_type` is not supported.
35
+ NotImplementedError: If the `gan_type` is not implemented.
36
+ """
37
+ if gan_type not in _GAN_TYPES_ALLOWED:
38
+ raise ValueError(f'Invalid GAN type: `{gan_type}`!\n'
39
+ f'Types allowed: {_GAN_TYPES_ALLOWED}.')
40
+
41
+ if gan_type == 'pggan':
42
+ return PGGANGenerator(resolution, **kwargs)
43
+ if gan_type == 'stylegan':
44
+ return StyleGANGenerator(resolution, **kwargs)
45
+ if gan_type == 'stylegan2':
46
+ return StyleGAN2Generator(resolution, **kwargs)
47
+ raise NotImplementedError(f'Unsupported GAN type `{gan_type}`!')
48
+
49
+
50
+ def build_discriminator(gan_type, resolution, **kwargs):
51
+ """Builds discriminator by GAN type.
52
+
53
+ Args:
54
+ gan_type: GAN type to which the discriminator belong.
55
+ resolution: Synthesis resolution.
56
+ **kwargs: Additional arguments to build the discriminator.
57
+
58
+ Raises:
59
+ ValueError: If the `gan_type` is not supported.
60
+ NotImplementedError: If the `gan_type` is not implemented.
61
+ """
62
+ if gan_type not in _GAN_TYPES_ALLOWED:
63
+ raise ValueError(f'Invalid GAN type: `{gan_type}`!\n'
64
+ f'Types allowed: {_GAN_TYPES_ALLOWED}.')
65
+
66
+ if gan_type == 'pggan':
67
+ return PGGANDiscriminator(resolution, **kwargs)
68
+ if gan_type == 'stylegan':
69
+ return StyleGANDiscriminator(resolution, **kwargs)
70
+ if gan_type == 'stylegan2':
71
+ return StyleGAN2Discriminator(resolution, **kwargs)
72
+ raise NotImplementedError(f'Unsupported GAN type `{gan_type}`!')
73
+
74
+
75
+ def build_encoder(gan_type, resolution, **kwargs):
76
+ """Builds encoder by GAN type.
77
+
78
+ Args:
79
+ gan_type: GAN type to which the encoder belong.
80
+ resolution: Input resolution for encoder.
81
+ **kwargs: Additional arguments to build the encoder.
82
+
83
+ Raises:
84
+ ValueError: If the `gan_type` is not supported.
85
+ NotImplementedError: If the `gan_type` is not implemented.
86
+ """
87
+ if gan_type not in _GAN_TYPES_ALLOWED:
88
+ raise ValueError(f'Invalid GAN type: `{gan_type}`!\n'
89
+ f'Types allowed: {_GAN_TYPES_ALLOWED}.')
90
+
91
+ if gan_type in ['stylegan', 'stylegan2']:
92
+ return EncoderNet(resolution, **kwargs)
93
+
94
+ raise NotImplementedError(f'Unsupported GAN type `{gan_type}` for encoder!')
95
+
96
+
97
+ def build_perceptual(**kwargs):
98
+ """Builds perceptual model.
99
+
100
+ Args:
101
+ **kwargs: Additional arguments to build the encoder.
102
+ """
103
+ return PerceptualModel(**kwargs)
104
+
105
+
106
+ def build_model(gan_type, module, resolution, **kwargs):
107
+ """Builds a GAN module (generator/discriminator/etc).
108
+
109
+ Args:
110
+ gan_type: GAN type to which the model belong.
111
+ module: GAN module to build, such as generator or discrimiantor.
112
+ resolution: Synthesis resolution.
113
+ **kwargs: Additional arguments to build the discriminator.
114
+
115
+ Raises:
116
+ ValueError: If the `module` is not supported.
117
+ NotImplementedError: If the `module` is not implemented.
118
+ """
119
+ if module not in _MODULES_ALLOWED:
120
+ raise ValueError(f'Invalid module: `{module}`!\n'
121
+ f'Modules allowed: {_MODULES_ALLOWED}.')
122
+
123
+ if module == 'generator':
124
+ return build_generator(gan_type, resolution, **kwargs)
125
+ if module == 'discriminator':
126
+ return build_discriminator(gan_type, resolution, **kwargs)
127
+ if module == 'encoder':
128
+ return build_encoder(gan_type, resolution, **kwargs)
129
+ if module == 'perceptual':
130
+ return build_perceptual(**kwargs)
131
+ raise NotImplementedError(f'Unsupported module `{module}`!')
networks/genforce/models/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (4.06 kB). View file
 
networks/genforce/models/__pycache__/encoder.cpython-38.pyc ADDED
Binary file (13.4 kB). View file
 
networks/genforce/models/__pycache__/model_zoo.cpython-38.pyc ADDED
Binary file (11.2 kB). View file
 
networks/genforce/models/__pycache__/perceptual_model.cpython-38.pyc ADDED
Binary file (5.58 kB). View file
 
networks/genforce/models/__pycache__/pggan_discriminator.cpython-38.pyc ADDED
Binary file (11.6 kB). View file
 
networks/genforce/models/__pycache__/pggan_generator.cpython-38.pyc ADDED
Binary file (9.34 kB). View file
 
networks/genforce/models/__pycache__/stylegan2_discriminator.cpython-38.pyc ADDED
Binary file (13.5 kB). View file
 
networks/genforce/models/__pycache__/stylegan2_generator.cpython-38.pyc ADDED
Binary file (28.6 kB). View file
 
networks/genforce/models/__pycache__/stylegan_discriminator.cpython-38.pyc ADDED
Binary file (15.2 kB). View file
 
networks/genforce/models/__pycache__/stylegan_generator.cpython-38.pyc ADDED
Binary file (26.5 kB). View file