Spaces:
Build error
Build error
Commit
•
2a76164
1
Parent(s):
917c2bc
Upload 194 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- models/stylegan2_ffhq1024.pth +3 -0
- networks/__pycache__/load_generator.cpython-38.pyc +0 -0
- networks/biggan/__init__.py +6 -0
- networks/biggan/__pycache__/__init__.cpython-38.pyc +0 -0
- networks/biggan/__pycache__/config.cpython-38.pyc +0 -0
- networks/biggan/__pycache__/file_utils.cpython-38.pyc +0 -0
- networks/biggan/__pycache__/model.cpython-38.pyc +0 -0
- networks/biggan/__pycache__/utils.cpython-38.pyc +0 -0
- networks/biggan/config.py +70 -0
- networks/biggan/convert.sh +21 -0
- networks/biggan/convert_tf_to_pytorch.py +312 -0
- networks/biggan/download_tf.sh +21 -0
- networks/biggan/file_utils.py +232 -0
- networks/biggan/model.py +352 -0
- networks/biggan/utils.py +216 -0
- networks/genforce/.gitignore +29 -0
- networks/genforce/LICENSE +18 -0
- networks/genforce/MODEL_ZOO.md +131 -0
- networks/genforce/README.md +169 -0
- networks/genforce/__init__.py +0 -0
- networks/genforce/__pycache__/__init__.cpython-38.pyc +0 -0
- networks/genforce/configs/stylegan_demo.py +61 -0
- networks/genforce/configs/stylegan_ffhq1024.py +63 -0
- networks/genforce/configs/stylegan_ffhq1024_val.py +29 -0
- networks/genforce/configs/stylegan_ffhq256.py +63 -0
- networks/genforce/configs/stylegan_ffhq256_encoder_y.py +73 -0
- networks/genforce/configs/stylegan_ffhq256_val.py +29 -0
- networks/genforce/convert_model.py +77 -0
- networks/genforce/datasets/README.md +24 -0
- networks/genforce/datasets/__init__.py +7 -0
- networks/genforce/datasets/dataloaders.py +128 -0
- networks/genforce/datasets/datasets.py +239 -0
- networks/genforce/datasets/distributed_sampler.py +144 -0
- networks/genforce/datasets/libturbojpeg.so.0 +0 -0
- networks/genforce/datasets/transforms.py +201 -0
- networks/genforce/metrics/README.md +18 -0
- networks/genforce/metrics/__init__.py +0 -0
- networks/genforce/metrics/fid.py +59 -0
- networks/genforce/metrics/inception.py +520 -0
- networks/genforce/models/__init__.py +131 -0
- networks/genforce/models/__pycache__/__init__.cpython-38.pyc +0 -0
- networks/genforce/models/__pycache__/encoder.cpython-38.pyc +0 -0
- networks/genforce/models/__pycache__/model_zoo.cpython-38.pyc +0 -0
- networks/genforce/models/__pycache__/perceptual_model.cpython-38.pyc +0 -0
- networks/genforce/models/__pycache__/pggan_discriminator.cpython-38.pyc +0 -0
- networks/genforce/models/__pycache__/pggan_generator.cpython-38.pyc +0 -0
- networks/genforce/models/__pycache__/stylegan2_discriminator.cpython-38.pyc +0 -0
- networks/genforce/models/__pycache__/stylegan2_generator.cpython-38.pyc +0 -0
- networks/genforce/models/__pycache__/stylegan_discriminator.cpython-38.pyc +0 -0
- networks/genforce/models/__pycache__/stylegan_generator.cpython-38.pyc +0 -0
models/stylegan2_ffhq1024.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d68c291b390869bee4f83f144d0480bd4c4cb7aab0fee0a9ee551073ea2d2163
|
3 |
+
size 381464183
|
networks/__pycache__/load_generator.cpython-38.pyc
ADDED
Binary file (1.43 kB). View file
|
|
networks/biggan/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .config import BigGANConfig
|
2 |
+
from .model import BigGAN
|
3 |
+
from .file_utils import PYTORCH_PRETRAINED_BIGGAN_CACHE, cached_path
|
4 |
+
from .utils import (truncated_noise_sample, save_as_images,
|
5 |
+
convert_to_images, display_in_terminal,
|
6 |
+
one_hot_from_int, one_hot_from_names)
|
networks/biggan/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (518 Bytes). View file
|
|
networks/biggan/__pycache__/config.cpython-38.pyc
ADDED
Binary file (2.54 kB). View file
|
|
networks/biggan/__pycache__/file_utils.cpython-38.pyc
ADDED
Binary file (6.22 kB). View file
|
|
networks/biggan/__pycache__/model.cpython-38.pyc
ADDED
Binary file (10.9 kB). View file
|
|
networks/biggan/__pycache__/utils.cpython-38.pyc
ADDED
Binary file (21.5 kB). View file
|
|
networks/biggan/config.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
"""
|
3 |
+
BigGAN config.
|
4 |
+
"""
|
5 |
+
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
6 |
+
|
7 |
+
import copy
|
8 |
+
import json
|
9 |
+
|
10 |
+
class BigGANConfig(object):
|
11 |
+
""" Configuration class to store the configuration of a `BigGAN`.
|
12 |
+
Defaults are for the 128x128 model.
|
13 |
+
layers tuple are (up-sample in the layer ?, input channels, output channels)
|
14 |
+
"""
|
15 |
+
def __init__(self,
|
16 |
+
output_dim=128,
|
17 |
+
z_dim=128,
|
18 |
+
class_embed_dim=128,
|
19 |
+
channel_width=128,
|
20 |
+
num_classes=1000,
|
21 |
+
layers=[(False, 16, 16),
|
22 |
+
(True, 16, 16),
|
23 |
+
(False, 16, 16),
|
24 |
+
(True, 16, 8),
|
25 |
+
(False, 8, 8),
|
26 |
+
(True, 8, 4),
|
27 |
+
(False, 4, 4),
|
28 |
+
(True, 4, 2),
|
29 |
+
(False, 2, 2),
|
30 |
+
(True, 2, 1)],
|
31 |
+
attention_layer_position=8,
|
32 |
+
eps=1e-4,
|
33 |
+
n_stats=51):
|
34 |
+
"""Constructs BigGANConfig. """
|
35 |
+
self.output_dim = output_dim
|
36 |
+
self.z_dim = z_dim
|
37 |
+
self.class_embed_dim = class_embed_dim
|
38 |
+
self.channel_width = channel_width
|
39 |
+
self.num_classes = num_classes
|
40 |
+
self.layers = layers
|
41 |
+
self.attention_layer_position = attention_layer_position
|
42 |
+
self.eps = eps
|
43 |
+
self.n_stats = n_stats
|
44 |
+
|
45 |
+
@classmethod
|
46 |
+
def from_dict(cls, json_object):
|
47 |
+
"""Constructs a `BigGANConfig` from a Python dictionary of parameters."""
|
48 |
+
config = BigGANConfig()
|
49 |
+
for key, value in json_object.items():
|
50 |
+
config.__dict__[key] = value
|
51 |
+
return config
|
52 |
+
|
53 |
+
@classmethod
|
54 |
+
def from_json_file(cls, json_file):
|
55 |
+
"""Constructs a `BigGANConfig` from a json file of parameters."""
|
56 |
+
with open(json_file, "r", encoding='utf-8') as reader:
|
57 |
+
text = reader.read()
|
58 |
+
return cls.from_dict(json.loads(text))
|
59 |
+
|
60 |
+
def __repr__(self):
|
61 |
+
return str(self.to_json_string())
|
62 |
+
|
63 |
+
def to_dict(self):
|
64 |
+
"""Serializes this instance to a Python dictionary."""
|
65 |
+
output = copy.deepcopy(self.__dict__)
|
66 |
+
return output
|
67 |
+
|
68 |
+
def to_json_string(self):
|
69 |
+
"""Serializes this instance to a JSON string."""
|
70 |
+
return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n"
|
networks/biggan/convert.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-present, Thomas Wolf, Huggingface Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
|
8 |
+
set -e
|
9 |
+
set -x
|
10 |
+
|
11 |
+
models="128 256 512"
|
12 |
+
|
13 |
+
mkdir -p models/model_128
|
14 |
+
mkdir -p models/model_256
|
15 |
+
mkdir -p models/model_512
|
16 |
+
|
17 |
+
# Convert TF Hub models.
|
18 |
+
for model in $models
|
19 |
+
do
|
20 |
+
pytorch_pretrained_biggan --model_type $model --tf_model_path models/model_$model --pt_save_path models/model_$model
|
21 |
+
done
|
networks/biggan/convert_tf_to_pytorch.py
ADDED
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
"""
|
3 |
+
Convert a TF Hub model for BigGAN in a PT one.
|
4 |
+
"""
|
5 |
+
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
6 |
+
|
7 |
+
from itertools import chain
|
8 |
+
|
9 |
+
import os
|
10 |
+
import argparse
|
11 |
+
import logging
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.nn.functional import normalize
|
17 |
+
|
18 |
+
from .model import BigGAN, WEIGHTS_NAME, CONFIG_NAME
|
19 |
+
from .config import BigGANConfig
|
20 |
+
|
21 |
+
logger = logging.getLogger(__name__)
|
22 |
+
|
23 |
+
|
24 |
+
def extract_batch_norm_stats(tf_model_path, batch_norm_stats_path=None):
|
25 |
+
try:
|
26 |
+
import numpy as np
|
27 |
+
import tensorflow as tf
|
28 |
+
import tensorflow_hub as hub
|
29 |
+
except ImportError:
|
30 |
+
raise ImportError("Loading a TensorFlow models in PyTorch, requires TensorFlow and TF Hub to be installed. "
|
31 |
+
"Please see https://www.tensorflow.org/install/ for installation instructions for TensorFlow. "
|
32 |
+
"And see https://github.com/tensorflow/hub for installing Hub. "
|
33 |
+
"Probably pip install tensorflow tensorflow-hub")
|
34 |
+
tf.reset_default_graph()
|
35 |
+
logger.info('Loading BigGAN module from: {}'.format(tf_model_path))
|
36 |
+
module = hub.Module(tf_model_path)
|
37 |
+
inputs = {k: tf.placeholder(v.dtype, v.get_shape().as_list(), k)
|
38 |
+
for k, v in module.get_input_info_dict().items()}
|
39 |
+
output = module(inputs)
|
40 |
+
|
41 |
+
initializer = tf.global_variables_initializer()
|
42 |
+
sess = tf.Session()
|
43 |
+
stacks = sum(((i*10 + 1, i*10 + 3, i*10 + 6, i*10 + 8) for i in range(50)), ())
|
44 |
+
numpy_stacks = []
|
45 |
+
for i in stacks:
|
46 |
+
logger.info("Retrieving module_apply_default/stack_{}".format(i))
|
47 |
+
try:
|
48 |
+
stack_var = tf.get_default_graph().get_tensor_by_name("module_apply_default/stack_%d:0" % i)
|
49 |
+
except KeyError:
|
50 |
+
break # We have all the stats
|
51 |
+
numpy_stacks.append(sess.run(stack_var))
|
52 |
+
|
53 |
+
if batch_norm_stats_path is not None:
|
54 |
+
torch.save(numpy_stacks, batch_norm_stats_path)
|
55 |
+
else:
|
56 |
+
return numpy_stacks
|
57 |
+
|
58 |
+
|
59 |
+
def build_tf_to_pytorch_map(model, config):
|
60 |
+
""" Build a map from TF variables to PyTorch modules. """
|
61 |
+
tf_to_pt_map = {}
|
62 |
+
|
63 |
+
# Embeddings and GenZ
|
64 |
+
tf_to_pt_map.update({'linear/w/ema_0.9999': model.embeddings.weight,
|
65 |
+
'Generator/GenZ/G_linear/b/ema_0.9999': model.generator.gen_z.bias,
|
66 |
+
'Generator/GenZ/G_linear/w/ema_0.9999': model.generator.gen_z.weight_orig,
|
67 |
+
'Generator/GenZ/G_linear/u0': model.generator.gen_z.weight_u})
|
68 |
+
|
69 |
+
# GBlock blocks
|
70 |
+
model_layer_idx = 0
|
71 |
+
for i, (up, in_channels, out_channels) in enumerate(config.layers):
|
72 |
+
if i == config.attention_layer_position:
|
73 |
+
model_layer_idx += 1
|
74 |
+
layer_str = "Generator/GBlock_%d/" % i if i > 0 else "Generator/GBlock/"
|
75 |
+
layer_pnt = model.generator.layers[model_layer_idx]
|
76 |
+
for i in range(4): # Batchnorms
|
77 |
+
batch_str = layer_str + ("BatchNorm_%d/" % i if i > 0 else "BatchNorm/")
|
78 |
+
batch_pnt = getattr(layer_pnt, 'bn_%d' % i)
|
79 |
+
for name in ('offset', 'scale'):
|
80 |
+
sub_module_str = batch_str + name + "/"
|
81 |
+
sub_module_pnt = getattr(batch_pnt, name)
|
82 |
+
tf_to_pt_map.update({sub_module_str + "w/ema_0.9999": sub_module_pnt.weight_orig,
|
83 |
+
sub_module_str + "u0": sub_module_pnt.weight_u})
|
84 |
+
for i in range(4): # Convolutions
|
85 |
+
conv_str = layer_str + "conv%d/" % i
|
86 |
+
conv_pnt = getattr(layer_pnt, 'conv_%d' % i)
|
87 |
+
tf_to_pt_map.update({conv_str + "b/ema_0.9999": conv_pnt.bias,
|
88 |
+
conv_str + "w/ema_0.9999": conv_pnt.weight_orig,
|
89 |
+
conv_str + "u0": conv_pnt.weight_u})
|
90 |
+
model_layer_idx += 1
|
91 |
+
|
92 |
+
# Attention block
|
93 |
+
layer_str = "Generator/attention/"
|
94 |
+
layer_pnt = model.generator.layers[config.attention_layer_position]
|
95 |
+
tf_to_pt_map.update({layer_str + "gamma/ema_0.9999": layer_pnt.gamma})
|
96 |
+
for pt_name, tf_name in zip(['snconv1x1_g', 'snconv1x1_o_conv', 'snconv1x1_phi', 'snconv1x1_theta'],
|
97 |
+
['g/', 'o_conv/', 'phi/', 'theta/']):
|
98 |
+
sub_module_str = layer_str + tf_name
|
99 |
+
sub_module_pnt = getattr(layer_pnt, pt_name)
|
100 |
+
tf_to_pt_map.update({sub_module_str + "w/ema_0.9999": sub_module_pnt.weight_orig,
|
101 |
+
sub_module_str + "u0": sub_module_pnt.weight_u})
|
102 |
+
|
103 |
+
# final batch norm and conv to rgb
|
104 |
+
layer_str = "Generator/BatchNorm/"
|
105 |
+
layer_pnt = model.generator.bn
|
106 |
+
tf_to_pt_map.update({layer_str + "offset/ema_0.9999": layer_pnt.bias,
|
107 |
+
layer_str + "scale/ema_0.9999": layer_pnt.weight})
|
108 |
+
layer_str = "Generator/conv_to_rgb/"
|
109 |
+
layer_pnt = model.generator.conv_to_rgb
|
110 |
+
tf_to_pt_map.update({layer_str + "b/ema_0.9999": layer_pnt.bias,
|
111 |
+
layer_str + "w/ema_0.9999": layer_pnt.weight_orig,
|
112 |
+
layer_str + "u0": layer_pnt.weight_u})
|
113 |
+
return tf_to_pt_map
|
114 |
+
|
115 |
+
|
116 |
+
def load_tf_weights_in_biggan(model, config, tf_model_path, batch_norm_stats_path=None):
|
117 |
+
""" Load tf checkpoints and standing statistics in a pytorch model
|
118 |
+
"""
|
119 |
+
try:
|
120 |
+
import numpy as np
|
121 |
+
import tensorflow as tf
|
122 |
+
except ImportError:
|
123 |
+
raise ImportError("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
124 |
+
"https://www.tensorflow.org/install/ for installation instructions.")
|
125 |
+
# Load weights from TF model
|
126 |
+
checkpoint_path = tf_model_path + "/variables/variables"
|
127 |
+
init_vars = tf.train.list_variables(checkpoint_path)
|
128 |
+
from pprint import pprint
|
129 |
+
pprint(init_vars)
|
130 |
+
|
131 |
+
# Extract batch norm statistics from model if needed
|
132 |
+
if batch_norm_stats_path:
|
133 |
+
stats = torch.load(batch_norm_stats_path)
|
134 |
+
else:
|
135 |
+
logger.info("Extracting batch norm stats")
|
136 |
+
stats = extract_batch_norm_stats(tf_model_path)
|
137 |
+
|
138 |
+
# Build TF to PyTorch weights loading map
|
139 |
+
tf_to_pt_map = build_tf_to_pytorch_map(model, config)
|
140 |
+
|
141 |
+
tf_weights = {}
|
142 |
+
for name in tf_to_pt_map.keys():
|
143 |
+
array = tf.train.load_variable(checkpoint_path, name)
|
144 |
+
tf_weights[name] = array
|
145 |
+
# logger.info("Loading TF weight {} with shape {}".format(name, array.shape))
|
146 |
+
|
147 |
+
# Load parameters
|
148 |
+
with torch.no_grad():
|
149 |
+
pt_params_pnt = set()
|
150 |
+
for name, pointer in tf_to_pt_map.items():
|
151 |
+
array = tf_weights[name]
|
152 |
+
if pointer.dim() == 1:
|
153 |
+
if pointer.dim() < array.ndim:
|
154 |
+
array = np.squeeze(array)
|
155 |
+
elif pointer.dim() == 2: # Weights
|
156 |
+
array = np.transpose(array)
|
157 |
+
elif pointer.dim() == 4: # Convolutions
|
158 |
+
array = np.transpose(array, (3, 2, 0, 1))
|
159 |
+
else:
|
160 |
+
raise "Wrong dimensions to adjust: " + str((pointer.shape, array.shape))
|
161 |
+
if pointer.shape != array.shape:
|
162 |
+
raise ValueError("Wrong dimensions: " + str((pointer.shape, array.shape)))
|
163 |
+
logger.info("Initialize PyTorch weight {} with shape {}".format(name, pointer.shape))
|
164 |
+
pointer.data = torch.from_numpy(array) if isinstance(array, np.ndarray) else torch.tensor(array)
|
165 |
+
tf_weights.pop(name, None)
|
166 |
+
pt_params_pnt.add(pointer.data_ptr())
|
167 |
+
|
168 |
+
# Prepare SpectralNorm buffers by running one step of Spectral Norm (no need to train the model):
|
169 |
+
for module in model.modules():
|
170 |
+
for n, buffer in module.named_buffers():
|
171 |
+
if n == 'weight_v':
|
172 |
+
weight_mat = module.weight_orig
|
173 |
+
weight_mat = weight_mat.reshape(weight_mat.size(0), -1)
|
174 |
+
u = module.weight_u
|
175 |
+
|
176 |
+
v = normalize(torch.mv(weight_mat.t(), u), dim=0, eps=config.eps)
|
177 |
+
buffer.data = v
|
178 |
+
pt_params_pnt.add(buffer.data_ptr())
|
179 |
+
|
180 |
+
u = normalize(torch.mv(weight_mat, v), dim=0, eps=config.eps)
|
181 |
+
module.weight_u.data = u
|
182 |
+
pt_params_pnt.add(module.weight_u.data_ptr())
|
183 |
+
|
184 |
+
# Load batch norm statistics
|
185 |
+
index = 0
|
186 |
+
for layer in model.generator.layers:
|
187 |
+
if not hasattr(layer, 'bn_0'):
|
188 |
+
continue
|
189 |
+
for i in range(4): # Batchnorms
|
190 |
+
bn_pointer = getattr(layer, 'bn_%d' % i)
|
191 |
+
pointer = bn_pointer.running_means
|
192 |
+
if pointer.shape != stats[index].shape:
|
193 |
+
raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape))
|
194 |
+
pointer.data = torch.from_numpy(stats[index])
|
195 |
+
pt_params_pnt.add(pointer.data_ptr())
|
196 |
+
|
197 |
+
pointer = bn_pointer.running_vars
|
198 |
+
if pointer.shape != stats[index+1].shape:
|
199 |
+
raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape))
|
200 |
+
pointer.data = torch.from_numpy(stats[index+1])
|
201 |
+
pt_params_pnt.add(pointer.data_ptr())
|
202 |
+
|
203 |
+
index += 2
|
204 |
+
|
205 |
+
bn_pointer = model.generator.bn
|
206 |
+
pointer = bn_pointer.running_means
|
207 |
+
if pointer.shape != stats[index].shape:
|
208 |
+
raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape))
|
209 |
+
pointer.data = torch.from_numpy(stats[index])
|
210 |
+
pt_params_pnt.add(pointer.data_ptr())
|
211 |
+
|
212 |
+
pointer = bn_pointer.running_vars
|
213 |
+
if pointer.shape != stats[index+1].shape:
|
214 |
+
raise "Wrong dimensions: " + str((pointer.shape, stats[index].shape))
|
215 |
+
pointer.data = torch.from_numpy(stats[index+1])
|
216 |
+
pt_params_pnt.add(pointer.data_ptr())
|
217 |
+
|
218 |
+
remaining_params = list(n for n, t in chain(model.named_parameters(), model.named_buffers()) \
|
219 |
+
if t.data_ptr() not in pt_params_pnt)
|
220 |
+
|
221 |
+
logger.info("TF Weights not copied to PyTorch model: {} -".format(', '.join(tf_weights.keys())))
|
222 |
+
logger.info("Remanining parameters/buffers from PyTorch model: {} -".format(', '.join(remaining_params)))
|
223 |
+
|
224 |
+
return model
|
225 |
+
|
226 |
+
|
227 |
+
BigGAN128 = BigGANConfig(output_dim=128, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000,
|
228 |
+
layers=[(False, 16, 16),
|
229 |
+
(True, 16, 16),
|
230 |
+
(False, 16, 16),
|
231 |
+
(True, 16, 8),
|
232 |
+
(False, 8, 8),
|
233 |
+
(True, 8, 4),
|
234 |
+
(False, 4, 4),
|
235 |
+
(True, 4, 2),
|
236 |
+
(False, 2, 2),
|
237 |
+
(True, 2, 1)],
|
238 |
+
attention_layer_position=8, eps=1e-4, n_stats=51)
|
239 |
+
|
240 |
+
BigGAN256 = BigGANConfig(output_dim=256, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000,
|
241 |
+
layers=[(False, 16, 16),
|
242 |
+
(True, 16, 16),
|
243 |
+
(False, 16, 16),
|
244 |
+
(True, 16, 8),
|
245 |
+
(False, 8, 8),
|
246 |
+
(True, 8, 8),
|
247 |
+
(False, 8, 8),
|
248 |
+
(True, 8, 4),
|
249 |
+
(False, 4, 4),
|
250 |
+
(True, 4, 2),
|
251 |
+
(False, 2, 2),
|
252 |
+
(True, 2, 1)],
|
253 |
+
attention_layer_position=8, eps=1e-4, n_stats=51)
|
254 |
+
|
255 |
+
BigGAN512 = BigGANConfig(output_dim=512, z_dim=128, class_embed_dim=128, channel_width=128, num_classes=1000,
|
256 |
+
layers=[(False, 16, 16),
|
257 |
+
(True, 16, 16),
|
258 |
+
(False, 16, 16),
|
259 |
+
(True, 16, 8),
|
260 |
+
(False, 8, 8),
|
261 |
+
(True, 8, 8),
|
262 |
+
(False, 8, 8),
|
263 |
+
(True, 8, 4),
|
264 |
+
(False, 4, 4),
|
265 |
+
(True, 4, 2),
|
266 |
+
(False, 2, 2),
|
267 |
+
(True, 2, 1),
|
268 |
+
(False, 1, 1),
|
269 |
+
(True, 1, 1)],
|
270 |
+
attention_layer_position=8, eps=1e-4, n_stats=51)
|
271 |
+
|
272 |
+
|
273 |
+
def main():
|
274 |
+
parser = argparse.ArgumentParser(description="Convert a BigGAN TF Hub model in a PyTorch model")
|
275 |
+
parser.add_argument("--model_type", type=str, default="", required=True,
|
276 |
+
help="BigGAN model type (128, 256, 512)")
|
277 |
+
parser.add_argument("--tf_model_path", type=str, default="", required=True,
|
278 |
+
help="Path of the downloaded TF Hub model")
|
279 |
+
parser.add_argument("--pt_save_path", type=str, default="",
|
280 |
+
help="Folder to save the PyTorch model (default: Folder of the TF Hub model)")
|
281 |
+
parser.add_argument("--batch_norm_stats_path", type=str, default="",
|
282 |
+
help="Path of previously extracted batch norm statistics")
|
283 |
+
args = parser.parse_args()
|
284 |
+
|
285 |
+
logging.basicConfig(level=logging.INFO)
|
286 |
+
|
287 |
+
if not args.pt_save_path:
|
288 |
+
args.pt_save_path = args.tf_model_path
|
289 |
+
|
290 |
+
if args.model_type == "128":
|
291 |
+
config = BigGAN128
|
292 |
+
elif args.model_type == "256":
|
293 |
+
config = BigGAN256
|
294 |
+
elif args.model_type == "512":
|
295 |
+
config = BigGAN512
|
296 |
+
else:
|
297 |
+
raise ValueError("model_type should be one of 128, 256 or 512")
|
298 |
+
|
299 |
+
model = BigGAN(config)
|
300 |
+
model = load_tf_weights_in_biggan(model, config, args.tf_model_path, args.batch_norm_stats_path)
|
301 |
+
|
302 |
+
model_save_path = os.path.join(args.pt_save_path, WEIGHTS_NAME)
|
303 |
+
config_save_path = os.path.join(args.pt_save_path, CONFIG_NAME)
|
304 |
+
|
305 |
+
logger.info("Save model dump to {}".format(model_save_path))
|
306 |
+
torch.save(model.state_dict(), model_save_path)
|
307 |
+
logger.info("Save configuration file to {}".format(config_save_path))
|
308 |
+
with open(config_save_path, "w", encoding="utf-8") as f:
|
309 |
+
f.write(config.to_json_string())
|
310 |
+
|
311 |
+
if __name__ == "__main__":
|
312 |
+
main()
|
networks/biggan/download_tf.sh
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2019-present, Thomas Wolf, Huggingface Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
|
8 |
+
set -e
|
9 |
+
set -x
|
10 |
+
|
11 |
+
models="128 256 512"
|
12 |
+
|
13 |
+
mkdir -p models/model_128
|
14 |
+
mkdir -p models/model_256
|
15 |
+
mkdir -p models/model_512
|
16 |
+
|
17 |
+
# Download TF Hub models.
|
18 |
+
for model in $models
|
19 |
+
do
|
20 |
+
curl -L "https://tfhub.dev/deepmind/biggan-deep-$model/1?tf-hub-format=compressed" | tar -zxvC models/model_$model
|
21 |
+
done
|
networks/biggan/file_utils.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Utilities for working with the local dataset cache.
|
3 |
+
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
4 |
+
Copyright by the AllenNLP authors.
|
5 |
+
"""
|
6 |
+
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
7 |
+
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
import os
|
11 |
+
import shutil
|
12 |
+
import tempfile
|
13 |
+
from functools import wraps
|
14 |
+
from hashlib import sha256
|
15 |
+
import sys
|
16 |
+
from io import open
|
17 |
+
|
18 |
+
import requests
|
19 |
+
from tqdm import tqdm
|
20 |
+
|
21 |
+
try:
|
22 |
+
from urllib.parse import urlparse
|
23 |
+
except ImportError:
|
24 |
+
from urlparse import urlparse
|
25 |
+
|
26 |
+
try:
|
27 |
+
from pathlib import Path
|
28 |
+
PYTORCH_PRETRAINED_BIGGAN_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
|
29 |
+
Path.home() / '.pytorch_pretrained_biggan'))
|
30 |
+
except (AttributeError, ImportError):
|
31 |
+
PYTORCH_PRETRAINED_BIGGAN_CACHE = os.getenv('PYTORCH_PRETRAINED_BIGGAN_CACHE',
|
32 |
+
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_biggan'))
|
33 |
+
|
34 |
+
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
35 |
+
|
36 |
+
|
37 |
+
def url_to_filename(url, etag=None):
|
38 |
+
"""
|
39 |
+
Convert `url` into a hashed filename in a repeatable way.
|
40 |
+
If `etag` is specified, append its hash to the url's, delimited
|
41 |
+
by a period.
|
42 |
+
"""
|
43 |
+
url_bytes = url.encode('utf-8')
|
44 |
+
url_hash = sha256(url_bytes)
|
45 |
+
filename = url_hash.hexdigest()
|
46 |
+
|
47 |
+
if etag:
|
48 |
+
etag_bytes = etag.encode('utf-8')
|
49 |
+
etag_hash = sha256(etag_bytes)
|
50 |
+
filename += '.' + etag_hash.hexdigest()
|
51 |
+
|
52 |
+
return filename
|
53 |
+
|
54 |
+
|
55 |
+
def filename_to_url(filename, cache_dir=None):
|
56 |
+
"""
|
57 |
+
Return the url and etag (which may be ``None``) stored for `filename`.
|
58 |
+
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
|
59 |
+
"""
|
60 |
+
if cache_dir is None:
|
61 |
+
cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
|
62 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
63 |
+
cache_dir = str(cache_dir)
|
64 |
+
|
65 |
+
cache_path = os.path.join(cache_dir, filename)
|
66 |
+
if not os.path.exists(cache_path):
|
67 |
+
raise EnvironmentError("file {} not found".format(cache_path))
|
68 |
+
|
69 |
+
meta_path = cache_path + '.json'
|
70 |
+
if not os.path.exists(meta_path):
|
71 |
+
raise EnvironmentError("file {} not found".format(meta_path))
|
72 |
+
|
73 |
+
with open(meta_path, encoding="utf-8") as meta_file:
|
74 |
+
metadata = json.load(meta_file)
|
75 |
+
url = metadata['url']
|
76 |
+
etag = metadata['etag']
|
77 |
+
|
78 |
+
return url, etag
|
79 |
+
|
80 |
+
|
81 |
+
def cached_path(url_or_filename, cache_dir=None):
|
82 |
+
"""
|
83 |
+
Given something that might be a URL (or might be a local path),
|
84 |
+
determine which. If it's a URL, download the file and cache it, and
|
85 |
+
return the path to the cached file. If it's already a local path,
|
86 |
+
make sure the file exists and then return the path.
|
87 |
+
"""
|
88 |
+
if cache_dir is None:
|
89 |
+
cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
|
90 |
+
if sys.version_info[0] == 3 and isinstance(url_or_filename, Path):
|
91 |
+
url_or_filename = str(url_or_filename)
|
92 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
93 |
+
cache_dir = str(cache_dir)
|
94 |
+
|
95 |
+
parsed = urlparse(url_or_filename)
|
96 |
+
|
97 |
+
if parsed.scheme in ('http', 'https', 's3'):
|
98 |
+
# URL, so get it from the cache (downloading if necessary)
|
99 |
+
return get_from_cache(url_or_filename, cache_dir)
|
100 |
+
elif os.path.exists(url_or_filename):
|
101 |
+
# File, and it exists.
|
102 |
+
return url_or_filename
|
103 |
+
elif parsed.scheme == '':
|
104 |
+
# File, but it doesn't exist.
|
105 |
+
raise EnvironmentError("file {} not found".format(url_or_filename))
|
106 |
+
else:
|
107 |
+
# Something unknown
|
108 |
+
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
109 |
+
|
110 |
+
|
111 |
+
def split_s3_path(url):
|
112 |
+
"""Split a full s3 path into the bucket name and path."""
|
113 |
+
parsed = urlparse(url)
|
114 |
+
if not parsed.netloc or not parsed.path:
|
115 |
+
raise ValueError("bad s3 path {}".format(url))
|
116 |
+
bucket_name = parsed.netloc
|
117 |
+
s3_path = parsed.path
|
118 |
+
# Remove '/' at beginning of path.
|
119 |
+
if s3_path.startswith("/"):
|
120 |
+
s3_path = s3_path[1:]
|
121 |
+
return bucket_name, s3_path
|
122 |
+
|
123 |
+
|
124 |
+
def s3_request(func):
|
125 |
+
"""
|
126 |
+
Wrapper function for s3 requests in order to create more helpful error
|
127 |
+
messages.
|
128 |
+
"""
|
129 |
+
|
130 |
+
@wraps(func)
|
131 |
+
def wrapper(url, *args, **kwargs):
|
132 |
+
try:
|
133 |
+
return func(url, *args, **kwargs)
|
134 |
+
except ClientError as exc:
|
135 |
+
if int(exc.response["Error"]["Code"]) == 404:
|
136 |
+
raise EnvironmentError("file {} not found".format(url))
|
137 |
+
else:
|
138 |
+
raise
|
139 |
+
|
140 |
+
return wrapper
|
141 |
+
|
142 |
+
|
143 |
+
def http_get(url, temp_file):
|
144 |
+
req = requests.get(url, stream=True)
|
145 |
+
content_length = req.headers.get('Content-Length')
|
146 |
+
total = int(content_length) if content_length is not None else None
|
147 |
+
progress = tqdm(unit="B", total=total)
|
148 |
+
for chunk in req.iter_content(chunk_size=1024):
|
149 |
+
if chunk: # filter out keep-alive new chunks
|
150 |
+
progress.update(len(chunk))
|
151 |
+
temp_file.write(chunk)
|
152 |
+
progress.close()
|
153 |
+
|
154 |
+
|
155 |
+
def get_from_cache(url, cache_dir=None):
|
156 |
+
"""
|
157 |
+
Given a URL, look for the corresponding dataset in the local cache.
|
158 |
+
If it's not there, download it. Then return the path to the cached file.
|
159 |
+
"""
|
160 |
+
if cache_dir is None:
|
161 |
+
cache_dir = PYTORCH_PRETRAINED_BIGGAN_CACHE
|
162 |
+
if sys.version_info[0] == 3 and isinstance(cache_dir, Path):
|
163 |
+
cache_dir = str(cache_dir)
|
164 |
+
|
165 |
+
if not os.path.exists(cache_dir):
|
166 |
+
os.makedirs(cache_dir)
|
167 |
+
|
168 |
+
# Get eTag to add to filename, if it exists.
|
169 |
+
if url.startswith("s3://"):
|
170 |
+
print('Not supported due to colab demo. Sorry!')
|
171 |
+
raise
|
172 |
+
else:
|
173 |
+
response = requests.head(url, allow_redirects=True)
|
174 |
+
if response.status_code != 200:
|
175 |
+
raise IOError("HEAD request failed for url {} with status code {}"
|
176 |
+
.format(url, response.status_code))
|
177 |
+
etag = response.headers.get("ETag")
|
178 |
+
|
179 |
+
filename = url_to_filename(url, etag)
|
180 |
+
|
181 |
+
# get cache path to put the file
|
182 |
+
cache_path = os.path.join(cache_dir, filename)
|
183 |
+
|
184 |
+
if not os.path.exists(cache_path):
|
185 |
+
# Download to temporary file, then copy to cache dir once finished.
|
186 |
+
# Otherwise you get corrupt cache entries if the download gets interrupted.
|
187 |
+
with tempfile.NamedTemporaryFile() as temp_file:
|
188 |
+
logger.info("%s not found in cache, downloading to %s", url, temp_file.name)
|
189 |
+
|
190 |
+
# GET file object
|
191 |
+
if url.startswith("s3://"):
|
192 |
+
print('Not supported due to colab demo. Sorry!')
|
193 |
+
raise
|
194 |
+
else:
|
195 |
+
http_get(url, temp_file)
|
196 |
+
|
197 |
+
# we are copying the file before closing it, so flush to avoid truncation
|
198 |
+
temp_file.flush()
|
199 |
+
# shutil.copyfileobj() starts at the current position, so go to the start
|
200 |
+
temp_file.seek(0)
|
201 |
+
|
202 |
+
logger.info("copying %s to cache at %s", temp_file.name, cache_path)
|
203 |
+
with open(cache_path, 'wb') as cache_file:
|
204 |
+
shutil.copyfileobj(temp_file, cache_file)
|
205 |
+
|
206 |
+
logger.info("creating metadata file for %s", cache_path)
|
207 |
+
meta = {'url': url, 'etag': etag}
|
208 |
+
meta_path = cache_path + '.json'
|
209 |
+
with open(meta_path, 'w', encoding="utf-8") as meta_file:
|
210 |
+
json.dump(meta, meta_file)
|
211 |
+
|
212 |
+
logger.info("removing temp file %s", temp_file.name)
|
213 |
+
|
214 |
+
return cache_path
|
215 |
+
|
216 |
+
|
217 |
+
def read_set_from_file(filename):
|
218 |
+
'''
|
219 |
+
Extract a de-duped collection (set) of text from a file.
|
220 |
+
Expected file format is one item per line.
|
221 |
+
'''
|
222 |
+
collection = set()
|
223 |
+
with open(filename, 'r', encoding='utf-8') as file_:
|
224 |
+
for line in file_:
|
225 |
+
collection.add(line.rstrip())
|
226 |
+
return collection
|
227 |
+
|
228 |
+
|
229 |
+
def get_file_extension(path, dot=True, lower=True):
|
230 |
+
ext = os.path.splitext(path)[1]
|
231 |
+
ext = ext if dot else ext[1:]
|
232 |
+
return ext.lower() if lower else ext
|
networks/biggan/model.py
ADDED
@@ -0,0 +1,352 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
""" BigGAN PyTorch model.
|
3 |
+
From "Large Scale GAN Training for High Fidelity Natural Image Synthesis"
|
4 |
+
By Andrew Brocky, Jeff Donahuey and Karen Simonyan.
|
5 |
+
https://openreview.net/forum?id=B1xsqj09Fm
|
6 |
+
|
7 |
+
PyTorch version implemented from the computational graph of the TF Hub module for BigGAN.
|
8 |
+
Some part of the code are adapted from https://github.com/brain-research/self-attention-gan
|
9 |
+
|
10 |
+
This version only comprises the generator (since the discriminator's weights are not released).
|
11 |
+
This version only comprises the "deep" version of BigGAN (see publication).
|
12 |
+
"""
|
13 |
+
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
14 |
+
|
15 |
+
import os
|
16 |
+
import logging
|
17 |
+
import math
|
18 |
+
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
|
24 |
+
from .config import BigGANConfig
|
25 |
+
from .file_utils import cached_path
|
26 |
+
|
27 |
+
logger = logging.getLogger(__name__)
|
28 |
+
|
29 |
+
PRETRAINED_MODEL_ARCHIVE_MAP = {
|
30 |
+
'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-pytorch_model.bin",
|
31 |
+
'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-pytorch_model.bin",
|
32 |
+
'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-pytorch_model.bin",
|
33 |
+
}
|
34 |
+
|
35 |
+
PRETRAINED_CONFIG_ARCHIVE_MAP = {
|
36 |
+
'biggan-deep-128': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-128-config.json",
|
37 |
+
'biggan-deep-256': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-256-config.json",
|
38 |
+
'biggan-deep-512': "https://s3.amazonaws.com/models.huggingface.co/biggan/biggan-deep-512-config.json",
|
39 |
+
}
|
40 |
+
|
41 |
+
WEIGHTS_NAME = 'pytorch_model.bin'
|
42 |
+
CONFIG_NAME = 'config.json'
|
43 |
+
|
44 |
+
|
45 |
+
def snconv2d(eps=1e-12, **kwargs):
|
46 |
+
return nn.utils.spectral_norm(nn.Conv2d(**kwargs), eps=eps)
|
47 |
+
|
48 |
+
def snlinear(eps=1e-12, **kwargs):
|
49 |
+
return nn.utils.spectral_norm(nn.Linear(**kwargs), eps=eps)
|
50 |
+
|
51 |
+
def sn_embedding(eps=1e-12, **kwargs):
|
52 |
+
return nn.utils.spectral_norm(nn.Embedding(**kwargs), eps=eps)
|
53 |
+
|
54 |
+
class SelfAttn(nn.Module):
|
55 |
+
""" Self attention Layer"""
|
56 |
+
def __init__(self, in_channels, eps=1e-12):
|
57 |
+
super(SelfAttn, self).__init__()
|
58 |
+
self.in_channels = in_channels
|
59 |
+
self.snconv1x1_theta = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
|
60 |
+
kernel_size=1, bias=False, eps=eps)
|
61 |
+
self.snconv1x1_phi = snconv2d(in_channels=in_channels, out_channels=in_channels//8,
|
62 |
+
kernel_size=1, bias=False, eps=eps)
|
63 |
+
self.snconv1x1_g = snconv2d(in_channels=in_channels, out_channels=in_channels//2,
|
64 |
+
kernel_size=1, bias=False, eps=eps)
|
65 |
+
self.snconv1x1_o_conv = snconv2d(in_channels=in_channels//2, out_channels=in_channels,
|
66 |
+
kernel_size=1, bias=False, eps=eps)
|
67 |
+
self.maxpool = nn.MaxPool2d(2, stride=2, padding=0)
|
68 |
+
self.softmax = nn.Softmax(dim=-1)
|
69 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
_, ch, h, w = x.size()
|
73 |
+
# Theta path
|
74 |
+
theta = self.snconv1x1_theta(x)
|
75 |
+
theta = theta.view(-1, ch//8, h*w)
|
76 |
+
# Phi path
|
77 |
+
phi = self.snconv1x1_phi(x)
|
78 |
+
phi = self.maxpool(phi)
|
79 |
+
phi = phi.view(-1, ch//8, h*w//4)
|
80 |
+
# Attn map
|
81 |
+
attn = torch.bmm(theta.permute(0, 2, 1), phi)
|
82 |
+
attn = self.softmax(attn)
|
83 |
+
# g path
|
84 |
+
g = self.snconv1x1_g(x)
|
85 |
+
g = self.maxpool(g)
|
86 |
+
g = g.view(-1, ch//2, h*w//4)
|
87 |
+
# Attn_g - o_conv
|
88 |
+
attn_g = torch.bmm(g, attn.permute(0, 2, 1))
|
89 |
+
attn_g = attn_g.view(-1, ch//2, h, w)
|
90 |
+
attn_g = self.snconv1x1_o_conv(attn_g)
|
91 |
+
# Out
|
92 |
+
out = x + self.gamma*attn_g
|
93 |
+
return out
|
94 |
+
|
95 |
+
|
96 |
+
class BigGANBatchNorm(nn.Module):
|
97 |
+
""" This is a batch norm module that can handle conditional input and can be provided with pre-computed
|
98 |
+
activation means and variances for various truncation parameters.
|
99 |
+
|
100 |
+
We cannot just rely on torch.batch_norm since it cannot handle
|
101 |
+
batched weights (pytorch 1.0.1). We computate batch_norm our-self without updating running means and variances.
|
102 |
+
If you want to train this model you should add running means and variance computation logic.
|
103 |
+
"""
|
104 |
+
def __init__(self, num_features, condition_vector_dim=None, n_stats=51, eps=1e-4, conditional=True):
|
105 |
+
super(BigGANBatchNorm, self).__init__()
|
106 |
+
self.num_features = num_features
|
107 |
+
self.eps = eps
|
108 |
+
self.conditional = conditional
|
109 |
+
|
110 |
+
# We use pre-computed statistics for n_stats values of truncation between 0 and 1
|
111 |
+
self.register_buffer('running_means', torch.zeros(n_stats, num_features))
|
112 |
+
self.register_buffer('running_vars', torch.ones(n_stats, num_features))
|
113 |
+
self.step_size = 1.0 / (n_stats - 1)
|
114 |
+
|
115 |
+
if conditional:
|
116 |
+
assert condition_vector_dim is not None
|
117 |
+
self.scale = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
|
118 |
+
self.offset = snlinear(in_features=condition_vector_dim, out_features=num_features, bias=False, eps=eps)
|
119 |
+
else:
|
120 |
+
self.weight = torch.nn.Parameter(torch.Tensor(num_features))
|
121 |
+
self.bias = torch.nn.Parameter(torch.Tensor(num_features))
|
122 |
+
|
123 |
+
def forward(self, x, truncation, condition_vector=None):
|
124 |
+
# Retreive pre-computed statistics associated to this truncation
|
125 |
+
coef, start_idx = math.modf(truncation / self.step_size)
|
126 |
+
start_idx = int(start_idx)
|
127 |
+
if coef != 0.0: # Interpolate
|
128 |
+
running_mean = self.running_means[start_idx] * coef + self.running_means[start_idx + 1] * (1 - coef)
|
129 |
+
running_var = self.running_vars[start_idx] * coef + self.running_vars[start_idx + 1] * (1 - coef)
|
130 |
+
else:
|
131 |
+
running_mean = self.running_means[start_idx]
|
132 |
+
running_var = self.running_vars[start_idx]
|
133 |
+
|
134 |
+
if self.conditional:
|
135 |
+
running_mean = running_mean.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
136 |
+
running_var = running_var.unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
137 |
+
|
138 |
+
weight = 1 + self.scale(condition_vector).unsqueeze(-1).unsqueeze(-1)
|
139 |
+
bias = self.offset(condition_vector).unsqueeze(-1).unsqueeze(-1)
|
140 |
+
|
141 |
+
out = (x - running_mean) / torch.sqrt(running_var + self.eps) * weight + bias
|
142 |
+
else:
|
143 |
+
out = F.batch_norm(x, running_mean, running_var, self.weight, self.bias,
|
144 |
+
training=False, momentum=0.0, eps=self.eps)
|
145 |
+
|
146 |
+
return out
|
147 |
+
|
148 |
+
|
149 |
+
class GenBlock(nn.Module):
|
150 |
+
def __init__(self, in_size, out_size, condition_vector_dim, reduction_factor=4, up_sample=False,
|
151 |
+
n_stats=51, eps=1e-12):
|
152 |
+
super(GenBlock, self).__init__()
|
153 |
+
self.up_sample = up_sample
|
154 |
+
self.drop_channels = (in_size != out_size)
|
155 |
+
middle_size = in_size // reduction_factor
|
156 |
+
|
157 |
+
self.bn_0 = BigGANBatchNorm(in_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
|
158 |
+
self.conv_0 = snconv2d(in_channels=in_size, out_channels=middle_size, kernel_size=1, eps=eps)
|
159 |
+
|
160 |
+
self.bn_1 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
|
161 |
+
self.conv_1 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
|
162 |
+
|
163 |
+
self.bn_2 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
|
164 |
+
self.conv_2 = snconv2d(in_channels=middle_size, out_channels=middle_size, kernel_size=3, padding=1, eps=eps)
|
165 |
+
|
166 |
+
self.bn_3 = BigGANBatchNorm(middle_size, condition_vector_dim, n_stats=n_stats, eps=eps, conditional=True)
|
167 |
+
self.conv_3 = snconv2d(in_channels=middle_size, out_channels=out_size, kernel_size=1, eps=eps)
|
168 |
+
|
169 |
+
self.relu = nn.ReLU()
|
170 |
+
|
171 |
+
def forward(self, x, cond_vector, truncation):
|
172 |
+
x0 = x
|
173 |
+
|
174 |
+
x = self.bn_0(x, truncation, cond_vector)
|
175 |
+
x = self.relu(x)
|
176 |
+
x = self.conv_0(x)
|
177 |
+
|
178 |
+
x = self.bn_1(x, truncation, cond_vector)
|
179 |
+
x = self.relu(x)
|
180 |
+
if self.up_sample:
|
181 |
+
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
182 |
+
x = self.conv_1(x)
|
183 |
+
|
184 |
+
x = self.bn_2(x, truncation, cond_vector)
|
185 |
+
x = self.relu(x)
|
186 |
+
x = self.conv_2(x)
|
187 |
+
|
188 |
+
x = self.bn_3(x, truncation, cond_vector)
|
189 |
+
x = self.relu(x)
|
190 |
+
x = self.conv_3(x)
|
191 |
+
|
192 |
+
if self.drop_channels:
|
193 |
+
new_channels = x0.shape[1] // 2
|
194 |
+
x0 = x0[:, :new_channels, ...]
|
195 |
+
if self.up_sample:
|
196 |
+
x0 = F.interpolate(x0, scale_factor=2, mode='nearest')
|
197 |
+
|
198 |
+
out = x + x0
|
199 |
+
return out
|
200 |
+
|
201 |
+
class Generator(nn.Module):
|
202 |
+
def __init__(self, config):
|
203 |
+
super(Generator, self).__init__()
|
204 |
+
self.config = config
|
205 |
+
ch = config.channel_width
|
206 |
+
condition_vector_dim = config.z_dim * 2
|
207 |
+
|
208 |
+
self.gen_z = snlinear(in_features=condition_vector_dim,
|
209 |
+
out_features=4 * 4 * 16 * ch, eps=config.eps)
|
210 |
+
|
211 |
+
layers = []
|
212 |
+
for i, layer in enumerate(config.layers):
|
213 |
+
if i == config.attention_layer_position:
|
214 |
+
layers.append(SelfAttn(ch*layer[1], eps=config.eps))
|
215 |
+
layers.append(GenBlock(ch*layer[1],
|
216 |
+
ch*layer[2],
|
217 |
+
condition_vector_dim,
|
218 |
+
up_sample=layer[0],
|
219 |
+
n_stats=config.n_stats,
|
220 |
+
eps=config.eps))
|
221 |
+
self.layers = nn.ModuleList(layers)
|
222 |
+
|
223 |
+
self.bn = BigGANBatchNorm(ch, n_stats=config.n_stats, eps=config.eps, conditional=False)
|
224 |
+
self.relu = nn.ReLU()
|
225 |
+
self.conv_to_rgb = snconv2d(in_channels=ch, out_channels=ch, kernel_size=3, padding=1, eps=config.eps)
|
226 |
+
self.tanh = nn.Tanh()
|
227 |
+
|
228 |
+
def forward(self, cond_vector, truncation, z=None, start=0, stop=None):
|
229 |
+
# We use this conversion step to be able to use TF weights:
|
230 |
+
# TF convention on shape is [batch, height, width, channels]
|
231 |
+
# PT convention on shape is [batch, channels, height, width]
|
232 |
+
if start == 0 and z is None:
|
233 |
+
z = self.gen_z(cond_vector)
|
234 |
+
z = z.view(-1, 4, 4, 16 * self.config.channel_width)
|
235 |
+
z = z.permute(0, 3, 1, 2).contiguous()
|
236 |
+
|
237 |
+
if stop is None: stop = len(self.layers)
|
238 |
+
|
239 |
+
# for i, layer in enumerate(self.layers):
|
240 |
+
for i in range(start, stop):
|
241 |
+
if isinstance(self.layers[i], GenBlock):
|
242 |
+
z = self.layers[i](z, cond_vector, truncation)
|
243 |
+
else:
|
244 |
+
z = self.layers[i](z)
|
245 |
+
|
246 |
+
if stop == len(self.layers):
|
247 |
+
z = self.bn(z, truncation)
|
248 |
+
z = self.relu(z)
|
249 |
+
z = self.conv_to_rgb(z)
|
250 |
+
z = z[:, :3, ...]
|
251 |
+
z = self.tanh(z)
|
252 |
+
|
253 |
+
# for i, layer in enumerate(self.layers):
|
254 |
+
# if isinstance(layer, GenBlock):
|
255 |
+
# z = layer(z, cond_vector, truncation)
|
256 |
+
# else:
|
257 |
+
# z = layer(z)
|
258 |
+
|
259 |
+
# z = self.bn(z, truncation)
|
260 |
+
# z = self.relu(z)
|
261 |
+
# z = self.conv_to_rgb(z)
|
262 |
+
# z = z[:, :3, ...]
|
263 |
+
# z = self.tanh(z)
|
264 |
+
return z
|
265 |
+
|
266 |
+
class BigGAN(nn.Module):
|
267 |
+
"""BigGAN Generator."""
|
268 |
+
|
269 |
+
@classmethod
|
270 |
+
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
271 |
+
if pretrained_model_name_or_path in PRETRAINED_MODEL_ARCHIVE_MAP:
|
272 |
+
model_file = PRETRAINED_MODEL_ARCHIVE_MAP[pretrained_model_name_or_path]
|
273 |
+
config_file = PRETRAINED_CONFIG_ARCHIVE_MAP[pretrained_model_name_or_path]
|
274 |
+
else:
|
275 |
+
model_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
276 |
+
config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
|
277 |
+
|
278 |
+
try:
|
279 |
+
resolved_model_file = cached_path(model_file, cache_dir=cache_dir)
|
280 |
+
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
281 |
+
except EnvironmentError:
|
282 |
+
logger.error("Wrong model name, should be a valid path to a folder containing "
|
283 |
+
"a {} file and a {} file or a model name in {}".format(
|
284 |
+
WEIGHTS_NAME, CONFIG_NAME, PRETRAINED_MODEL_ARCHIVE_MAP.keys()))
|
285 |
+
raise
|
286 |
+
|
287 |
+
logger.info("loading model {} from cache at {}".format(pretrained_model_name_or_path, resolved_model_file))
|
288 |
+
|
289 |
+
# Load config
|
290 |
+
config = BigGANConfig.from_json_file(resolved_config_file)
|
291 |
+
logger.info("Model config {}".format(config))
|
292 |
+
|
293 |
+
# Instantiate model.
|
294 |
+
model = cls(config, *inputs, **kwargs)
|
295 |
+
state_dict = torch.load(resolved_model_file, map_location='cpu' if not torch.cuda.is_available() else None)
|
296 |
+
model.load_state_dict(state_dict, strict=False)
|
297 |
+
return model
|
298 |
+
|
299 |
+
def __init__(self, config):
|
300 |
+
super(BigGAN, self).__init__()
|
301 |
+
self.config = config
|
302 |
+
self.embeddings = nn.Linear(config.num_classes, config.z_dim, bias=False)
|
303 |
+
self.generator = Generator(config)
|
304 |
+
# self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
305 |
+
# print(f'device: {self.device}')
|
306 |
+
# self.generator.to(self.device)
|
307 |
+
|
308 |
+
def forward(self, z, class_label, truncation, cond_vector=None, start=0, stop=None):
|
309 |
+
assert 0 < truncation <= 1
|
310 |
+
|
311 |
+
results = {}
|
312 |
+
if start == 0 and cond_vector is None:
|
313 |
+
embed = self.embeddings(class_label)
|
314 |
+
cond_vector = torch.cat((z, embed), dim=1)
|
315 |
+
results['cond_vector'] = cond_vector
|
316 |
+
|
317 |
+
results['z'] = self.generator(cond_vector, truncation, z=None if start == 0 else z, start=start, stop=stop)
|
318 |
+
return results
|
319 |
+
|
320 |
+
|
321 |
+
if __name__ == "__main__":
|
322 |
+
import PIL
|
323 |
+
from .utils import truncated_noise_sample, save_as_images, one_hot_from_names
|
324 |
+
from .convert_tf_to_pytorch import load_tf_weights_in_biggan
|
325 |
+
|
326 |
+
load_cache = False
|
327 |
+
cache_path = './saved_model.pt'
|
328 |
+
config = BigGANConfig()
|
329 |
+
model = BigGAN(config)
|
330 |
+
if not load_cache:
|
331 |
+
model = load_tf_weights_in_biggan(model, config, './models/model_128/', './models/model_128/batchnorms_stats.bin')
|
332 |
+
torch.save(model.state_dict(), cache_path)
|
333 |
+
else:
|
334 |
+
model.load_state_dict(torch.load(cache_path))
|
335 |
+
|
336 |
+
model.eval()
|
337 |
+
|
338 |
+
truncation = 0.4
|
339 |
+
noise = truncated_noise_sample(batch_size=2, truncation=truncation)
|
340 |
+
label = one_hot_from_names('diver', batch_size=2)
|
341 |
+
|
342 |
+
# Tests
|
343 |
+
# noise = np.zeros((1, 128))
|
344 |
+
# label = [983]
|
345 |
+
|
346 |
+
noise = torch.tensor(noise, dtype=torch.float)
|
347 |
+
label = torch.tensor(label, dtype=torch.float)
|
348 |
+
with torch.no_grad():
|
349 |
+
outputs = model(noise, label, truncation)
|
350 |
+
print(outputs.shape)
|
351 |
+
|
352 |
+
save_as_images(outputs)
|
networks/biggan/utils.py
ADDED
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding: utf-8
|
2 |
+
""" BigGAN utilities to prepare truncated noise samples and convert/save/display output images.
|
3 |
+
Also comprise ImageNet utilities to prepare one hot input vectors for ImageNet classes.
|
4 |
+
We use Wordnet so you can just input a name in a string and automatically get a corresponding
|
5 |
+
imagenet class if it exists (or a hypo/hypernym exists in imagenet).
|
6 |
+
"""
|
7 |
+
from __future__ import absolute_import, division, print_function, unicode_literals
|
8 |
+
|
9 |
+
import json
|
10 |
+
import logging
|
11 |
+
from io import BytesIO
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
from scipy.stats import truncnorm
|
15 |
+
|
16 |
+
logger = logging.getLogger(__name__)
|
17 |
+
|
18 |
+
NUM_CLASSES = 1000
|
19 |
+
|
20 |
+
|
21 |
+
def truncated_noise_sample(batch_size=1, dim_z=128, truncation=1., seed=None):
|
22 |
+
""" Create a truncated noise vector.
|
23 |
+
Params:
|
24 |
+
batch_size: batch size.
|
25 |
+
dim_z: dimension of z
|
26 |
+
truncation: truncation value to use
|
27 |
+
seed: seed for the random generator
|
28 |
+
Output:
|
29 |
+
array of shape (batch_size, dim_z)
|
30 |
+
"""
|
31 |
+
state = None if seed is None else np.random.RandomState(seed)
|
32 |
+
values = truncnorm.rvs(-2, 2, size=(batch_size, dim_z), random_state=state).astype(np.float32)
|
33 |
+
return truncation * values
|
34 |
+
|
35 |
+
|
36 |
+
def convert_to_images(obj):
|
37 |
+
""" Convert an output tensor from BigGAN in a list of images.
|
38 |
+
Params:
|
39 |
+
obj: tensor or numpy array of shape (batch_size, channels, height, width)
|
40 |
+
Output:
|
41 |
+
list of Pillow Images of size (height, width)
|
42 |
+
"""
|
43 |
+
try:
|
44 |
+
import PIL
|
45 |
+
except ImportError:
|
46 |
+
raise ImportError("Please install Pillow to use images: pip install Pillow")
|
47 |
+
|
48 |
+
if not isinstance(obj, np.ndarray):
|
49 |
+
obj = obj.detach().numpy()
|
50 |
+
|
51 |
+
obj = obj.transpose((0, 2, 3, 1))
|
52 |
+
obj = np.clip(((obj + 1) / 2.0) * 256, 0, 255)
|
53 |
+
|
54 |
+
img = []
|
55 |
+
for i, out in enumerate(obj):
|
56 |
+
out_array = np.asarray(np.uint8(out), dtype=np.uint8)
|
57 |
+
img.append(PIL.Image.fromarray(out_array))
|
58 |
+
return img
|
59 |
+
|
60 |
+
|
61 |
+
def save_as_images(obj, file_name='output'):
|
62 |
+
""" Convert and save an output tensor from BigGAN in a list of saved images.
|
63 |
+
Params:
|
64 |
+
obj: tensor or numpy array of shape (batch_size, channels, height, width)
|
65 |
+
file_name: path and beggingin of filename to save.
|
66 |
+
Images will be saved as `file_name_{image_number}.png`
|
67 |
+
"""
|
68 |
+
img = convert_to_images(obj)
|
69 |
+
|
70 |
+
for i, out in enumerate(img):
|
71 |
+
current_file_name = file_name + '_%d.png' % i
|
72 |
+
logger.info("Saving image to {}".format(current_file_name))
|
73 |
+
out.save(current_file_name, 'png')
|
74 |
+
|
75 |
+
|
76 |
+
def display_in_terminal(obj):
|
77 |
+
""" Convert and display an output tensor from BigGAN in the terminal.
|
78 |
+
This function use `libsixel` and will only work in a libsixel-compatible terminal.
|
79 |
+
Please refer to https://github.com/saitoha/libsixel for more details.
|
80 |
+
|
81 |
+
Params:
|
82 |
+
obj: tensor or numpy array of shape (batch_size, channels, height, width)
|
83 |
+
file_name: path and beggingin of filename to save.
|
84 |
+
Images will be saved as `file_name_{image_number}.png`
|
85 |
+
"""
|
86 |
+
try:
|
87 |
+
import PIL
|
88 |
+
from libsixel import (sixel_output_new, sixel_dither_new, sixel_dither_initialize,
|
89 |
+
sixel_dither_set_palette, sixel_dither_set_pixelformat,
|
90 |
+
sixel_dither_get, sixel_encode, sixel_dither_unref,
|
91 |
+
sixel_output_unref, SIXEL_PIXELFORMAT_RGBA8888,
|
92 |
+
SIXEL_PIXELFORMAT_RGB888, SIXEL_PIXELFORMAT_PAL8,
|
93 |
+
SIXEL_PIXELFORMAT_G8, SIXEL_PIXELFORMAT_G1)
|
94 |
+
except ImportError:
|
95 |
+
raise ImportError("Display in Terminal requires Pillow, libsixel "
|
96 |
+
"and a libsixel compatible terminal. "
|
97 |
+
"Please read info at https://github.com/saitoha/libsixel "
|
98 |
+
"and install with pip install Pillow libsixel-python")
|
99 |
+
|
100 |
+
s = BytesIO()
|
101 |
+
|
102 |
+
images = convert_to_images(obj)
|
103 |
+
widths, heights = zip(*(i.size for i in images))
|
104 |
+
|
105 |
+
output_width = sum(widths)
|
106 |
+
output_height = max(heights)
|
107 |
+
|
108 |
+
output_image = PIL.Image.new('RGB', (output_width, output_height))
|
109 |
+
|
110 |
+
x_offset = 0
|
111 |
+
for im in images:
|
112 |
+
output_image.paste(im, (x_offset,0))
|
113 |
+
x_offset += im.size[0]
|
114 |
+
|
115 |
+
try:
|
116 |
+
data = output_image.tobytes()
|
117 |
+
except NotImplementedError:
|
118 |
+
data = output_image.tostring()
|
119 |
+
output = sixel_output_new(lambda data, s: s.write(data), s)
|
120 |
+
|
121 |
+
try:
|
122 |
+
if output_image.mode == 'RGBA':
|
123 |
+
dither = sixel_dither_new(256)
|
124 |
+
sixel_dither_initialize(dither, data, output_width, output_height, SIXEL_PIXELFORMAT_RGBA8888)
|
125 |
+
elif output_image.mode == 'RGB':
|
126 |
+
dither = sixel_dither_new(256)
|
127 |
+
sixel_dither_initialize(dither, data, output_width, output_height, SIXEL_PIXELFORMAT_RGB888)
|
128 |
+
elif output_image.mode == 'P':
|
129 |
+
palette = output_image.getpalette()
|
130 |
+
dither = sixel_dither_new(256)
|
131 |
+
sixel_dither_set_palette(dither, palette)
|
132 |
+
sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_PAL8)
|
133 |
+
elif output_image.mode == 'L':
|
134 |
+
dither = sixel_dither_get(SIXEL_BUILTIN_G8)
|
135 |
+
sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_G8)
|
136 |
+
elif output_image.mode == '1':
|
137 |
+
dither = sixel_dither_get(SIXEL_BUILTIN_G1)
|
138 |
+
sixel_dither_set_pixelformat(dither, SIXEL_PIXELFORMAT_G1)
|
139 |
+
else:
|
140 |
+
raise RuntimeError('unexpected output_image mode')
|
141 |
+
try:
|
142 |
+
sixel_encode(data, output_width, output_height, 1, dither, output)
|
143 |
+
print(s.getvalue().decode('ascii'))
|
144 |
+
finally:
|
145 |
+
sixel_dither_unref(dither)
|
146 |
+
finally:
|
147 |
+
sixel_output_unref(output)
|
148 |
+
|
149 |
+
|
150 |
+
def one_hot_from_int(int_or_list, batch_size=1):
|
151 |
+
""" Create a one-hot vector from a class index or a list of class indices.
|
152 |
+
Params:
|
153 |
+
int_or_list: int, or list of int, of the imagenet classes (between 0 and 999)
|
154 |
+
batch_size: batch size.
|
155 |
+
If int_or_list is an int create a batch of identical classes.
|
156 |
+
If int_or_list is a list, we should have `len(int_or_list) == batch_size`
|
157 |
+
Output:
|
158 |
+
array of shape (batch_size, 1000)
|
159 |
+
"""
|
160 |
+
if isinstance(int_or_list, int):
|
161 |
+
int_or_list = [int_or_list]
|
162 |
+
|
163 |
+
if len(int_or_list) == 1 and batch_size > 1:
|
164 |
+
int_or_list = [int_or_list[0]] * batch_size
|
165 |
+
|
166 |
+
assert batch_size == len(int_or_list)
|
167 |
+
|
168 |
+
array = np.zeros((batch_size, NUM_CLASSES), dtype=np.float32)
|
169 |
+
for i, j in enumerate(int_or_list):
|
170 |
+
array[i, j] = 1.0
|
171 |
+
return array
|
172 |
+
|
173 |
+
|
174 |
+
def one_hot_from_names(class_name_or_list, batch_size=1):
|
175 |
+
""" Create a one-hot vector from the name of an imagenet class ('tennis ball', 'daisy', ...).
|
176 |
+
We use NLTK's wordnet search to try to find the relevant synset of ImageNet and take the first one.
|
177 |
+
If we can't find it direcly, we look at the hyponyms and hypernyms of the class name.
|
178 |
+
|
179 |
+
Params:
|
180 |
+
class_name_or_list: string containing the name of an imagenet object or a list of such strings (for a batch).
|
181 |
+
Output:
|
182 |
+
array of shape (batch_size, 1000)
|
183 |
+
"""
|
184 |
+
try:
|
185 |
+
from nltk.corpus import wordnet as wn
|
186 |
+
except ImportError:
|
187 |
+
raise ImportError("You need to install nltk to use this function")
|
188 |
+
|
189 |
+
if not isinstance(class_name_or_list, (list, tuple)):
|
190 |
+
class_name_or_list = [class_name_or_list]
|
191 |
+
else:
|
192 |
+
batch_size = max(batch_size, len(class_name_or_list))
|
193 |
+
|
194 |
+
classes = []
|
195 |
+
for class_name in class_name_or_list:
|
196 |
+
class_name = class_name.replace(" ", "_")
|
197 |
+
|
198 |
+
original_synsets = wn.synsets(class_name)
|
199 |
+
original_synsets = list(filter(lambda s: s.pos() == 'n', original_synsets)) # keep only names
|
200 |
+
if not original_synsets:
|
201 |
+
return None
|
202 |
+
|
203 |
+
possible_synsets = list(filter(lambda s: s.offset() in IMAGENET, original_synsets))
|
204 |
+
if possible_synsets:
|
205 |
+
classes.append(IMAGENET[possible_synsets[0].offset()])
|
206 |
+
else:
|
207 |
+
# try hypernyms and hyponyms
|
208 |
+
possible_synsets = sum([s.hypernyms() + s.hyponyms() for s in original_synsets], [])
|
209 |
+
possible_synsets = list(filter(lambda s: s.offset() in IMAGENET, possible_synsets))
|
210 |
+
if possible_synsets:
|
211 |
+
classes.append(IMAGENET[possible_synsets[0].offset()])
|
212 |
+
|
213 |
+
return one_hot_from_int(classes, batch_size=batch_size)
|
214 |
+
|
215 |
+
|
216 |
+
IMAGENET = {1440764: 0, 1443537: 1, 1484850: 2, 1491361: 3, 1494475: 4, 1496331: 5, 1498041: 6, 1514668: 7, 1514859: 8, 1518878: 9, 1530575: 10, 1531178: 11, 1532829: 12, 1534433: 13, 1537544: 14, 1558993: 15, 1560419: 16, 1580077: 17, 1582220: 18, 1592084: 19, 1601694: 20, 1608432: 21, 1614925: 22, 1616318: 23, 1622779: 24, 1629819: 25, 1630670: 26, 1631663: 27, 1632458: 28, 1632777: 29, 1641577: 30, 1644373: 31, 1644900: 32, 1664065: 33, 1665541: 34, 1667114: 35, 1667778: 36, 1669191: 37, 1675722: 38, 1677366: 39, 1682714: 40, 1685808: 41, 1687978: 42, 1688243: 43, 1689811: 44, 1692333: 45, 1693334: 46, 1694178: 47, 1695060: 48, 1697457: 49, 1698640: 50, 1704323: 51, 1728572: 52, 1728920: 53, 1729322: 54, 1729977: 55, 1734418: 56, 1735189: 57, 1737021: 58, 1739381: 59, 1740131: 60, 1742172: 61, 1744401: 62, 1748264: 63, 1749939: 64, 1751748: 65, 1753488: 66, 1755581: 67, 1756291: 68, 1768244: 69, 1770081: 70, 1770393: 71, 1773157: 72, 1773549: 73, 1773797: 74, 1774384: 75, 1774750: 76, 1775062: 77, 1776313: 78, 1784675: 79, 1795545: 80, 1796340: 81, 1797886: 82, 1798484: 83, 1806143: 84, 1806567: 85, 1807496: 86, 1817953: 87, 1818515: 88, 1819313: 89, 1820546: 90, 1824575: 91, 1828970: 92, 1829413: 93, 1833805: 94, 1843065: 95, 1843383: 96, 1847000: 97, 1855032: 98, 1855672: 99, 1860187: 100, 1871265: 101, 1872401: 102, 1873310: 103, 1877812: 104, 1882714: 105, 1883070: 106, 1910747: 107, 1914609: 108, 1917289: 109, 1924916: 110, 1930112: 111, 1943899: 112, 1944390: 113, 1945685: 114, 1950731: 115, 1955084: 116, 1968897: 117, 1978287: 118, 1978455: 119, 1980166: 120, 1981276: 121, 1983481: 122, 1984695: 123, 1985128: 124, 1986214: 125, 1990800: 126, 2002556: 127, 2002724: 128, 2006656: 129, 2007558: 130, 2009229: 131, 2009912: 132, 2011460: 133, 2012849: 134, 2013706: 135, 2017213: 136, 2018207: 137, 2018795: 138, 2025239: 139, 2027492: 140, 2028035: 141, 2033041: 142, 2037110: 143, 2051845: 144, 2056570: 145, 2058221: 146, 2066245: 147, 2071294: 148, 2074367: 149, 2077923: 150, 2085620: 151, 2085782: 152, 2085936: 153, 2086079: 154, 2086240: 155, 2086646: 156, 2086910: 157, 2087046: 158, 2087394: 159, 2088094: 160, 2088238: 161, 2088364: 162, 2088466: 163, 2088632: 164, 2089078: 165, 2089867: 166, 2089973: 167, 2090379: 168, 2090622: 169, 2090721: 170, 2091032: 171, 2091134: 172, 2091244: 173, 2091467: 174, 2091635: 175, 2091831: 176, 2092002: 177, 2092339: 178, 2093256: 179, 2093428: 180, 2093647: 181, 2093754: 182, 2093859: 183, 2093991: 184, 2094114: 185, 2094258: 186, 2094433: 187, 2095314: 188, 2095570: 189, 2095889: 190, 2096051: 191, 2096177: 192, 2096294: 193, 2096437: 194, 2096585: 195, 2097047: 196, 2097130: 197, 2097209: 198, 2097298: 199, 2097474: 200, 2097658: 201, 2098105: 202, 2098286: 203, 2098413: 204, 2099267: 205, 2099429: 206, 2099601: 207, 2099712: 208, 2099849: 209, 2100236: 210, 2100583: 211, 2100735: 212, 2100877: 213, 2101006: 214, 2101388: 215, 2101556: 216, 2102040: 217, 2102177: 218, 2102318: 219, 2102480: 220, 2102973: 221, 2104029: 222, 2104365: 223, 2105056: 224, 2105162: 225, 2105251: 226, 2105412: 227, 2105505: 228, 2105641: 229, 2105855: 230, 2106030: 231, 2106166: 232, 2106382: 233, 2106550: 234, 2106662: 235, 2107142: 236, 2107312: 237, 2107574: 238, 2107683: 239, 2107908: 240, 2108000: 241, 2108089: 242, 2108422: 243, 2108551: 244, 2108915: 245, 2109047: 246, 2109525: 247, 2109961: 248, 2110063: 249, 2110185: 250, 2110341: 251, 2110627: 252, 2110806: 253, 2110958: 254, 2111129: 255, 2111277: 256, 2111500: 257, 2111889: 258, 2112018: 259, 2112137: 260, 2112350: 261, 2112706: 262, 2113023: 263, 2113186: 264, 2113624: 265, 2113712: 266, 2113799: 267, 2113978: 268, 2114367: 269, 2114548: 270, 2114712: 271, 2114855: 272, 2115641: 273, 2115913: 274, 2116738: 275, 2117135: 276, 2119022: 277, 2119789: 278, 2120079: 279, 2120505: 280, 2123045: 281, 2123159: 282, 2123394: 283, 2123597: 284, 2124075: 285, 2125311: 286, 2127052: 287, 2128385: 288, 2128757: 289, 2128925: 290, 2129165: 291, 2129604: 292, 2130308: 293, 2132136: 294, 2133161: 295, 2134084: 296, 2134418: 297, 2137549: 298, 2138441: 299, 2165105: 300, 2165456: 301, 2167151: 302, 2168699: 303, 2169497: 304, 2172182: 305, 2174001: 306, 2177972: 307, 2190166: 308, 2206856: 309, 2219486: 310, 2226429: 311, 2229544: 312, 2231487: 313, 2233338: 314, 2236044: 315, 2256656: 316, 2259212: 317, 2264363: 318, 2268443: 319, 2268853: 320, 2276258: 321, 2277742: 322, 2279972: 323, 2280649: 324, 2281406: 325, 2281787: 326, 2317335: 327, 2319095: 328, 2321529: 329, 2325366: 330, 2326432: 331, 2328150: 332, 2342885: 333, 2346627: 334, 2356798: 335, 2361337: 336, 2363005: 337, 2364673: 338, 2389026: 339, 2391049: 340, 2395406: 341, 2396427: 342, 2397096: 343, 2398521: 344, 2403003: 345, 2408429: 346, 2410509: 347, 2412080: 348, 2415577: 349, 2417914: 350, 2422106: 351, 2422699: 352, 2423022: 353, 2437312: 354, 2437616: 355, 2441942: 356, 2442845: 357, 2443114: 358, 2443484: 359, 2444819: 360, 2445715: 361, 2447366: 362, 2454379: 363, 2457408: 364, 2480495: 365, 2480855: 366, 2481823: 367, 2483362: 368, 2483708: 369, 2484975: 370, 2486261: 371, 2486410: 372, 2487347: 373, 2488291: 374, 2488702: 375, 2489166: 376, 2490219: 377, 2492035: 378, 2492660: 379, 2493509: 380, 2493793: 381, 2494079: 382, 2497673: 383, 2500267: 384, 2504013: 385, 2504458: 386, 2509815: 387, 2510455: 388, 2514041: 389, 2526121: 390, 2536864: 391, 2606052: 392, 2607072: 393, 2640242: 394, 2641379: 395, 2643566: 396, 2655020: 397, 2666196: 398, 2667093: 399, 2669723: 400, 2672831: 401, 2676566: 402, 2687172: 403, 2690373: 404, 2692877: 405, 2699494: 406, 2701002: 407, 2704792: 408, 2708093: 409, 2727426: 410, 2730930: 411, 2747177: 412, 2749479: 413, 2769748: 414, 2776631: 415, 2777292: 416, 2782093: 417, 2783161: 418, 2786058: 419, 2787622: 420, 2788148: 421, 2790996: 422, 2791124: 423, 2791270: 424, 2793495: 425, 2794156: 426, 2795169: 427, 2797295: 428, 2799071: 429, 2802426: 430, 2804414: 431, 2804610: 432, 2807133: 433, 2808304: 434, 2808440: 435, 2814533: 436, 2814860: 437, 2815834: 438, 2817516: 439, 2823428: 440, 2823750: 441, 2825657: 442, 2834397: 443, 2835271: 444, 2837789: 445, 2840245: 446, 2841315: 447, 2843684: 448, 2859443: 449, 2860847: 450, 2865351: 451, 2869837: 452, 2870880: 453, 2871525: 454, 2877765: 455, 2879718: 456, 2883205: 457, 2892201: 458, 2892767: 459, 2894605: 460, 2895154: 461, 2906734: 462, 2909870: 463, 2910353: 464, 2916936: 465, 2917067: 466, 2927161: 467, 2930766: 468, 2939185: 469, 2948072: 470, 2950826: 471, 2951358: 472, 2951585: 473, 2963159: 474, 2965783: 475, 2966193: 476, 2966687: 477, 2971356: 478, 2974003: 479, 2977058: 480, 2978881: 481, 2979186: 482, 2980441: 483, 2981792: 484, 2988304: 485, 2992211: 486, 2992529: 487, 2999410: 488, 3000134: 489, 3000247: 490, 3000684: 491, 3014705: 492, 3016953: 493, 3017168: 494, 3018349: 495, 3026506: 496, 3028079: 497, 3032252: 498, 3041632: 499, 3042490: 500, 3045698: 501, 3047690: 502, 3062245: 503, 3063599: 504, 3063689: 505, 3065424: 506, 3075370: 507, 3085013: 508, 3089624: 509, 3095699: 510, 3100240: 511, 3109150: 512, 3110669: 513, 3124043: 514, 3124170: 515, 3125729: 516, 3126707: 517, 3127747: 518, 3127925: 519, 3131574: 520, 3133878: 521, 3134739: 522, 3141823: 523, 3146219: 524, 3160309: 525, 3179701: 526, 3180011: 527, 3187595: 528, 3188531: 529, 3196217: 530, 3197337: 531, 3201208: 532, 3207743: 533, 3207941: 534, 3208938: 535, 3216828: 536, 3218198: 537, 3220513: 538, 3223299: 539, 3240683: 540, 3249569: 541, 3250847: 542, 3255030: 543, 3259280: 544, 3271574: 545, 3272010: 546, 3272562: 547, 3290653: 548, 3291819: 549, 3297495: 550, 3314780: 551, 3325584: 552, 3337140: 553, 3344393: 554, 3345487: 555, 3347037: 556, 3355925: 557, 3372029: 558, 3376595: 559, 3379051: 560, 3384352: 561, 3388043: 562, 3388183: 563, 3388549: 564, 3393912: 565, 3394916: 566, 3400231: 567, 3404251: 568, 3417042: 569, 3424325: 570, 3425413: 571, 3443371: 572, 3444034: 573, 3445777: 574, 3445924: 575, 3447447: 576, 3447721: 577, 3450230: 578, 3452741: 579, 3457902: 580, 3459775: 581, 3461385: 582, 3467068: 583, 3476684: 584, 3476991: 585, 3478589: 586, 3481172: 587, 3482405: 588, 3483316: 589, 3485407: 590, 3485794: 591, 3492542: 592, 3494278: 593, 3495258: 594, 3496892: 595, 3498962: 596, 3527444: 597, 3529860: 598, 3530642: 599, 3532672: 600, 3534580: 601, 3535780: 602, 3538406: 603, 3544143: 604, 3584254: 605, 3584829: 606, 3590841: 607, 3594734: 608, 3594945: 609, 3595614: 610, 3598930: 611, 3599486: 612, 3602883: 613, 3617480: 614, 3623198: 615, 3627232: 616, 3630383: 617, 3633091: 618, 3637318: 619, 3642806: 620, 3649909: 621, 3657121: 622, 3658185: 623, 3661043: 624, 3662601: 625, 3666591: 626, 3670208: 627, 3673027: 628, 3676483: 629, 3680355: 630, 3690938: 631, 3691459: 632, 3692522: 633, 3697007: 634, 3706229: 635, 3709823: 636, 3710193: 637, 3710637: 638, 3710721: 639, 3717622: 640, 3720891: 641, 3721384: 642, 3724870: 643, 3729826: 644, 3733131: 645, 3733281: 646, 3733805: 647, 3742115: 648, 3743016: 649, 3759954: 650, 3761084: 651, 3763968: 652, 3764736: 653, 3769881: 654, 3770439: 655, 3770679: 656, 3773504: 657, 3775071: 658, 3775546: 659, 3776460: 660, 3777568: 661, 3777754: 662, 3781244: 663, 3782006: 664, 3785016: 665, 3786901: 666, 3787032: 667, 3788195: 668, 3788365: 669, 3791053: 670, 3792782: 671, 3792972: 672, 3793489: 673, 3794056: 674, 3796401: 675, 3803284: 676, 3804744: 677, 3814639: 678, 3814906: 679, 3825788: 680, 3832673: 681, 3837869: 682, 3838899: 683, 3840681: 684, 3841143: 685, 3843555: 686, 3854065: 687, 3857828: 688, 3866082: 689, 3868242: 690, 3868863: 691, 3871628: 692, 3873416: 693, 3874293: 694, 3874599: 695, 3876231: 696, 3877472: 697, 3877845: 698, 3884397: 699, 3887697: 700, 3888257: 701, 3888605: 702, 3891251: 703, 3891332: 704, 3895866: 705, 3899768: 706, 3902125: 707, 3903868: 708, 3908618: 709, 3908714: 710, 3916031: 711, 3920288: 712, 3924679: 713, 3929660: 714, 3929855: 715, 3930313: 716, 3930630: 717, 3933933: 718, 3935335: 719, 3937543: 720, 3938244: 721, 3942813: 722, 3944341: 723, 3947888: 724, 3950228: 725, 3954731: 726, 3956157: 727, 3958227: 728, 3961711: 729, 3967562: 730, 3970156: 731, 3976467: 732, 3976657: 733, 3977966: 734, 3980874: 735, 3982430: 736, 3983396: 737, 3991062: 738, 3992509: 739, 3995372: 740, 3998194: 741, 4004767: 742, 4005630: 743, 4008634: 744, 4009552: 745, 4019541: 746, 4023962: 747, 4026417: 748, 4033901: 749, 4033995: 750, 4037443: 751, 4039381: 752, 4040759: 753, 4041544: 754, 4044716: 755, 4049303: 756, 4065272: 757, 4067472: 758, 4069434: 759, 4070727: 760, 4074963: 761, 4081281: 762, 4086273: 763, 4090263: 764, 4099969: 765, 4111531: 766, 4116512: 767, 4118538: 768, 4118776: 769, 4120489: 770, 4125021: 771, 4127249: 772, 4131690: 773, 4133789: 774, 4136333: 775, 4141076: 776, 4141327: 777, 4141975: 778, 4146614: 779, 4147183: 780, 4149813: 781, 4152593: 782, 4153751: 783, 4154565: 784, 4162706: 785, 4179913: 786, 4192698: 787, 4200800: 788, 4201297: 789, 4204238: 790, 4204347: 791, 4208210: 792, 4209133: 793, 4209239: 794, 4228054: 795, 4229816: 796, 4235860: 797, 4238763: 798, 4239074: 799, 4243546: 800, 4251144: 801, 4252077: 802, 4252225: 803, 4254120: 804, 4254680: 805, 4254777: 806, 4258138: 807, 4259630: 808, 4263257: 809, 4264628: 810, 4265275: 811, 4266014: 812, 4270147: 813, 4273569: 814, 4275548: 815, 4277352: 816, 4285008: 817, 4286575: 818, 4296562: 819, 4310018: 820, 4311004: 821, 4311174: 822, 4317175: 823, 4325704: 824, 4326547: 825, 4328186: 826, 4330267: 827, 4332243: 828, 4335435: 829, 4336792: 830, 4344873: 831, 4346328: 832, 4347754: 833, 4350905: 834, 4355338: 835, 4355933: 836, 4356056: 837, 4357314: 838, 4366367: 839, 4367480: 840, 4370456: 841, 4371430: 842, 4371774: 843, 4372370: 844, 4376876: 845, 4380533: 846, 4389033: 847, 4392985: 848, 4398044: 849, 4399382: 850, 4404412: 851, 4409515: 852, 4417672: 853, 4418357: 854, 4423845: 855, 4428191: 856, 4429376: 857, 4435653: 858, 4442312: 859, 4443257: 860, 4447861: 861, 4456115: 862, 4458633: 863, 4461696: 864, 4462240: 865, 4465501: 866, 4467665: 867, 4476259: 868, 4479046: 869, 4482393: 870, 4483307: 871, 4485082: 872, 4486054: 873, 4487081: 874, 4487394: 875, 4493381: 876, 4501370: 877, 4505470: 878, 4507155: 879, 4509417: 880, 4515003: 881, 4517823: 882, 4522168: 883, 4523525: 884, 4525038: 885, 4525305: 886, 4532106: 887, 4532670: 888, 4536866: 889, 4540053: 890, 4542943: 891, 4548280: 892, 4548362: 893, 4550184: 894, 4552348: 895, 4553703: 896, 4554684: 897, 4557648: 898, 4560804: 899, 4562935: 900, 4579145: 901, 4579432: 902, 4584207: 903, 4589890: 904, 4590129: 905, 4591157: 906, 4591713: 907, 4592741: 908, 4596742: 909, 4597913: 910, 4599235: 911, 4604644: 912, 4606251: 913, 4612504: 914, 4613696: 915, 6359193: 916, 6596364: 917, 6785654: 918, 6794110: 919, 6874185: 920, 7248320: 921, 7565083: 922, 7579787: 923, 7583066: 924, 7584110: 925, 7590611: 926, 7613480: 927, 7614500: 928, 7615774: 929, 7684084: 930, 7693725: 931, 7695742: 932, 7697313: 933, 7697537: 934, 7711569: 935, 7714571: 936, 7714990: 937, 7715103: 938, 7716358: 939, 7716906: 940, 7717410: 941, 7717556: 942, 7718472: 943, 7718747: 944, 7720875: 945, 7730033: 946, 7734744: 947, 7742313: 948, 7745940: 949, 7747607: 950, 7749582: 951, 7753113: 952, 7753275: 953, 7753592: 954, 7754684: 955, 7760859: 956, 7768694: 957, 7802026: 958, 7831146: 959, 7836838: 960, 7860988: 961, 7871810: 962, 7873807: 963, 7875152: 964, 7880968: 965, 7892512: 966, 7920052: 967, 7930864: 968, 7932039: 969, 9193705: 970, 9229709: 971, 9246464: 972, 9256479: 973, 9288635: 974, 9332890: 975, 9399592: 976, 9421951: 977, 9428293: 978, 9468604: 979, 9472597: 980, 9835506: 981, 10148035: 982, 10565667: 983, 11879895: 984, 11939491: 985, 12057211: 986, 12144580: 987, 12267677: 988, 12620546: 989, 12768682: 990, 12985857: 991, 12998815: 992, 13037406: 993, 13040303: 994, 13044778: 995, 13052670: 996, 13054560: 997, 13133613: 998, 15075141: 999}
|
networks/genforce/.gitignore
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
*.py[cod]
|
3 |
+
|
4 |
+
/.vscode/
|
5 |
+
/.idea/
|
6 |
+
*.sw[pon]
|
7 |
+
|
8 |
+
/data/
|
9 |
+
/work_dirs/
|
10 |
+
*.jpg
|
11 |
+
*.png
|
12 |
+
*.jpeg
|
13 |
+
*.gif
|
14 |
+
*.avi
|
15 |
+
*.mp4
|
16 |
+
|
17 |
+
*.npy
|
18 |
+
*.txt
|
19 |
+
*.json
|
20 |
+
*.log
|
21 |
+
*.html
|
22 |
+
*.tar
|
23 |
+
*.zip
|
24 |
+
events.*
|
25 |
+
|
26 |
+
*.pth
|
27 |
+
*.pkl
|
28 |
+
*.h5
|
29 |
+
*.dat
|
networks/genforce/LICENSE
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2020 GenForce
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
4 |
+
this software and associated documentation files (the "Software"), to deal in
|
5 |
+
the Software without restriction, including without limitation the rights to
|
6 |
+
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
|
7 |
+
of the Software, and to permit persons to whom the Software is furnished to do
|
8 |
+
so, subject to the following conditions:
|
9 |
+
|
10 |
+
The above copyright notice and this permission notice shall be included in all
|
11 |
+
copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
15 |
+
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
16 |
+
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
17 |
+
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
18 |
+
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
networks/genforce/MODEL_ZOO.md
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Model Zoo
|
2 |
+
|
3 |
+
## Pre-trained Models
|
4 |
+
|
5 |
+
First of all, we thank the following repositories for their work on high-quality image synthesis
|
6 |
+
|
7 |
+
- [PGGAN](https://github.com/tkarras/progressive_growing_of_gans)
|
8 |
+
- [StyleGAN](https://github.com/NVlabs/stylegan)
|
9 |
+
- [StyleGAN2](https://github.com/NVlabs/stylegan2)
|
10 |
+
|
11 |
+
Please download the models you need and save them to `checkpoints/`.
|
12 |
+
|
13 |
+
| PGGAN Official | | | |
|
14 |
+
| :-- | :-- | :-- | :-- |
|
15 |
+
| *Face*
|
16 |
+
| [celebahq-1024x1024](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EW_3jQ6E7xlKvCSHYrbmkQQBAB8tgIv5W5evdT6-GuXiWw?e=gRifVa&download=1)
|
17 |
+
| *Indoor Scene*
|
18 |
+
| [bedroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EUZQWGz2GT5Bh_GJLalP63IBvCsXDTOxDFIC_ZBsmoEacA?e=VNXiDb&download=1) | [livingroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Efzh6qQv6QtCm0YN1lulH-YByqdE3AqlI-E6US_hXMuiig?e=ppdyB2&download=1) | [diningroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EcLb3_hGUkdClompZo27xk0BNmotgbFqdIeu-ZOGJsBMRg?e=xjYpN3&download=1) | [kitchen-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ESCyg6hpNn1LlHVX_un1wLsBZAORUNkW9MO2kU1X5kafAQ?e=09TbGC&download=1)
|
19 |
+
| *Outdoor Scene*
|
20 |
+
| [churchoutdoor-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EQ8cKujs2TVGjCL_j6bsnk8BqD9REF2ME2lBnpbTPsqIvA?e=zH55fT&download=1) | [tower-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EeyBJvgRVGJClKr1KKYDF_cBT1FDepRU1-GLqYNh8W9-fQ?e=nrpa5N&download=1) | [bridge-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EZ2QScfPy19PiDERLJQ3gPMBP4WmvZHwhNFLzfaP2YD8hQ?e=bef1U9&download=1)
|
21 |
+
| *Other Scene*
|
22 |
+
| [restaurant-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ERvJ4pz8jgtMrcuJXUfcOQEBDugZ099_TetCQs-9-ILCVg?e=qYsVdQ&download=1) | [classroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EUU9SCOPUxhMoUS4Ceo9kl0BQkVK7d69lA-JeOP-zOWvXw?e=YIB4no&download=1) | [conferenceroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EX8AF0_6NoJAl5vKFewHWnsBk0r4PK4WsqsMrJyj84TrqQ?e=oNQIZS&download=1)
|
23 |
+
| *Animal*
|
24 |
+
| [person-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EWu4SqR42YpCoqsVJOcM2cMBcdfXA0j5wZ2hno9X0R9ydQ?e=KuDRns&download=1) | [cat-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EQdveyUNOMtAue52n6BxoHoB6Yup5-PTvBDmyfUn7Un4Hw?e=7acGbT&download=1) | [dog-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ESaKyXA5fGlOvXJYDDFbT2kB9c0HlXh9n_wnyhiP05nhow?e=d4aKDV&download=1) | [bird-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Ef2p4Pd3AKVCmSm00YikCIABhylh2dLPaFjPfPVn3RiTXA?e=9bRitp&download=1)
|
25 |
+
| [horse-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EXwCPdv6XqJFtuvFFoswRScBmLJbhKzaC5D_iovl1GFOTw?e=WDdD77&download=1) | [sheep-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ER6J5EKjAUNFtm9VwLf-uUsBZ5dnqxeKsPxY9ijiPtMhcQ?e=OKtfva&download=1) | [cow-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ERZLxw7N7xJPm72FyePTbpcByzrr0pH-Fg7qyLt5tYGXwQ?e=ovIPCl&download=1)
|
26 |
+
| *Transportation*
|
27 |
+
| [car-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EfGc2we47aFDtAY1548pRvsByIju-uXRbkZEFpJotuPKZw?e=DQqVj8&download=1) | [bicycle-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Ed1dN_FgwmdBgeNWhaRUry8BgwT88-n2ppicSDPx-f7f_Q?e=bxTxnf&download=1) | [motorbike-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EV3yQdeJXIdPjZbMO0mp2-MBJbKuuBdypzBL4gnedO57Dw?e=tXdvtD&download=1) | [bus-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Ed7-OYLnq0RCqRlM8qK8wZ8B87dz_NUxIKBrvyFUwRCEbg?e=VP5bmX&download=1)
|
28 |
+
| [train-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EedE2cozKOVAkhvbdLd4SfwBknFW8vWZnKiqgeIBbAvCCA?e=BrLpTl&download=1) | [boat-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Eb39waqQFr9Bp4wO0rC5NHwB0Vz2NGCuqbRPucguBIkDrg?e=lddSyL&download=1) | [airplane-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Ee6FzIx3KjNDhxrS5mDvpCEB3iQ7TgErmKhbwbV-eF07iw?e=xflPXa&download=1)
|
29 |
+
| *Furniture*
|
30 |
+
| [bottle-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EWhoy2AFCTZGtEG1UoayWjcB9Kdc_wreJ8p4RlBB93nbNg?e=DMZceU&download=1) | [chair-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EbQRTfwdostBhXG30Uacn7ABsEUFa-tEW3oxiM5zDYQbRw?e=FkB7T0&download=1) | [pottedplant-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EWg7hnoGATBOuJvXWr4m7CQBJL9o7nqnD6nOMRhtH2SKXg?e=Zi3hjD&download=1) | [tvmonitor-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EVXwttoJVtBMuhHNDdK3cMwBdMiZARJV38PMTsL6whnFlA?e=RbG0ru&download=1)
|
31 |
+
| [diningtable-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EXVzBkbmTCVImMtuHLCTBeMBXZmv0RWyx5KXQQAe7-7D5w?e=6RYSnm&download=1) | [sofa-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EaADQYDXwY9NrzbiUFcRYRgBOu1GdJMG8YgNZZmbNjbn-Q?e=DqKrXG&download=1)
|
32 |
+
|
33 |
+
| StyleGAN Official | | | |
|
34 |
+
| :-- | :--: | :--: | :--: |
|
35 |
+
| Model (Dataset) | Training Samples | Training Duration (K Images) | FID
|
36 |
+
| [ffhq-1024x1024](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EdfMxgb0hU9BoXwiR3dqYDEBowCSEF1IcsW3n4kwfoZ9OQ?e=VwIV58&download=1) | 70,000 | 25,000 | 4.40 |
|
37 |
+
| [celebahq-1024x1024](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EcCdXHddE7FOvyfmqeOyc9ABqVuWh8PQYFnV6JM1CXvFig?e=1nUYZ5&download=1) | 30,000 | 25,000 | 5.06 |
|
38 |
+
| [bedroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Ea6RBPddjcRNoFMXm8AyEBcBUHdlRNtjtclNKFe89amjBw?e=Og8Vff&download=1) | 3,033,042 | 70,000 | 2.65 |
|
39 |
+
| [cat-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EVjX8u9HuehLip3z0hRfIHcB7QtoFkTB7NiRDb8nrKOl2w?e=lHcp1B&download=1) | 1,657,266 | 70,000 | 8.53 |
|
40 |
+
| [car-512x384](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EcRJNNzzUzJGjI2X53S9HjkBhXkKT5JRd6Q3IIhCY1AyRw?e=FvMRNj&download=1) | 5,520,756 | 46,000 | 3.27 |
|
41 |
+
|
42 |
+
| StyleGAN Ours | | | |
|
43 |
+
| :-- | :--: | :--: | :--: |
|
44 |
+
| Model (Dataset) | Training Samples | Training Duration (K Images) | FID
|
45 |
+
| *Face ("partial" means faces are not fully aligned to center)*
|
46 |
+
| [celeba_partial-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ET2etKNzMS9JmHj5j60fqMcBRJfQfYNvqUrujaIXxCvKDQ?e=QReLE6&download=1) | 103,706 | 50,000 | 7.03 |
|
47 |
+
| [ffhq-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ES-NAUCC2qdHg87BftvlBiQBVpbJ8-005Q4TNr5KrOxQEw?e=00AnWt&download=1) | 70,000 | 25,000 | 5.70 |
|
48 |
+
| [ffhq-512x512](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EZYrrwOiEgVOg-PfGv7QTegBzFQ9yq2v7o1WxNq5JJ9KNA?e=SZU8PI&download=1) | 70,000 | 25,000 | 5.15 |
|
49 |
+
| *LSUN Indoor Scene*
|
50 |
+
| [livingroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EfFCYLHjqbFDmjOvCCFJgDcBZ1QYgETfZJxp4ZTHjLxZBg?e=InVd0n&download=1) | 1,315,802 | 30,000 | 5.16 |
|
51 |
+
| [diningroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ERsUza_hSFRIm4iZCag7P0kBQ9EIdfQKByw4QYt_ay97lg?e=Cimh7S&download=1) | 657,571 | 25,000 | 4.13 |
|
52 |
+
| [kitchen-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ERcYvoingQNKix35lUs0vUkBQQkAZMp1rtDxjwNlOJAoaA?e=a1Tcwr&download=1) | 1,000,000 | 30,000 | 5.06 |
|
53 |
+
| *LSUN Indoor Scene Mixture*
|
54 |
+
| [apartment-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EfurPNSB2BRFtXdqGkmDD6YBwyKN8YK2v7nKwnJQdsbf6A?e=w3oYa4&download=1) | 4 * 200,000 | 60,000 | 4.18 |
|
55 |
+
| *LSUN Outdoor Scene*
|
56 |
+
| [church-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ETMgG1_d06tAlbUkJD1qA9IBaLZ9zJKPkG2kO-4jxhVV5w?e=Dbkb7o&download=1) | 126,227 | 30,000 | 4.82 |
|
57 |
+
| [tower-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Ebm9QMgqB2VDqyIE5rFhreEBgZ_RyKcRf8bQ333K453u3w?e=if8sDj&download=1) | 708,264 | 30,000 | 5.99 |
|
58 |
+
| [bridge-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Ed9QM6OP9sVHnazSp4cqPSEBb-ALfBPXRxP1hD7FsTYh8w?e=3vv06p&download=1) | 818,687 | 25,000 | 6.42 |
|
59 |
+
| *LSUN Other Scene*
|
60 |
+
| [restaurant-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/ESDhYr01WtlEvBNFrVpFezcB2l9lF1rBYuHFoeNpBr5B7A?e=uFWFNh&download=1) | 626,331 | 50,000 | 4.03 |
|
61 |
+
| [classroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EbWnI3oto9NPk-lxwZlWqPQB2atWpGiTWMIT59MzF9ij9Q?e=KvcNBg&download=1) | 168,103 | 50,000 | 10.10 |
|
62 |
+
| [conferenceroom-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Eb1gVi3pGa9PgJ4XYYu_6yABQZ0ZcGDak4FEHaTHaeYFzw?e=0BeE8t&download=1) | 229,069 | 50,000 | 6.20 |
|
63 |
+
|
64 |
+
| StyleGAN Third-Party | |
|
65 |
+
| :-- | :--: |
|
66 |
+
| Model (Dataset) | Source |
|
67 |
+
| [animeface-512x512](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EWDWflY6lBpGgX0CGQpd2Z4B5wTEVamTOA9JRYne7zdCvA?e=tOzgYA&download=1) | [link](https://www.gwern.net/Faces#portrait-results)
|
68 |
+
| [animeportrait-512x512](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EXBvhTBi-v5NsnQtrxhFEKsBin4xg-Dud9Jr62AEwFTIxg?e=bMGK7r&download=1) | [link](https://www.gwern.net/Faces#portrait-results)
|
69 |
+
| [artface-512x512](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/Eca0OiGqhyZMmoPbKahSBWQBWvcAH4q2CE3zdZJflp2jkQ?e=h4rWAm&download=1) | [link](https://github.com/ak9250/stylegan-art)
|
70 |
+
|
71 |
+
| StyleGAN2 Official | | | |
|
72 |
+
| :-- | :--: | :--: | :--: |
|
73 |
+
| Model (Dataset) | Training Samples | Training Duration (K Images) | FID
|
74 |
+
| [ffhq-1024x1024](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EX0DNWiBvl5FuOQTF4oMPBYBNSalcxTK0AbLwBn9Y3vfgg?e=Q0sZit&download=1) | 70,000 | 25,000 | 2.84 |
|
75 |
+
| [church-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EQzDtJUdQ4ROunMGn2sZouEBmNeFX4QWvxjermVE5cZvNA?e=tQ7r9r&download=1) | 126,227 | 48,000 | 3.86 |
|
76 |
+
| [cat-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EUKXeBwUUbZJr6kup7PW4ekBx2-vmTp8FjcGb10v8bgJxQ?e=nkerMF&download=1) | 1,657,266 | 88,000 | 6.93 |
|
77 |
+
| [horse-256x256](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EconoT6tb69OuAIqfXRtGlsBZz4vBx01UmmFO-JAS356Jg?e=bcSCC4&download=1) | 2,000,340 | 100,000 | 3.43 |
|
78 |
+
| [car-512x384](https://mycuhk-my.sharepoint.com/:u:/g/personal/1155082926_link_cuhk_edu_hk/EYSnUsxU8KJFuMHhZm-JLWoB0nHxdlbrLHNZ_Qkoe3b9LA?e=Ycjp5A&download=1) | 5,520,756 | 57,000 | 2.32 |
|
79 |
+
|
80 |
+
## Training Datasets
|
81 |
+
|
82 |
+
- [MNIST](http://yann.lecun.com/exdb/mnist/) (60,000 training samples and 10,000 test samples on 10 digital numbers)
|
83 |
+
- [SVHN](http://ufldl.stanford.edu/housenumbers/) (73,257 training samples, 26,032 testing samples, and 531,131 additional samples on 10 digital numbers)
|
84 |
+
- [CIFAR10](https://www.cs.toronto.edu/~kriz/cifar.html) (50,000 training samples and 10,000 test samples on 10 classes)
|
85 |
+
- [CIFAR100](https://www.cs.toronto.edu/~kriz/cifar.html) (50,000 training samples and 10,000 test samples on 100 classes)
|
86 |
+
- [ImageNet](http://www.image-net.org/) (1,281,167 training samples, 50,000 validation samples, and 100,100 testing samples on 1000 classes)
|
87 |
+
- [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) (202,599 samples from 10,177 identities, with 5 landmarks and 40 binary facial attributes)
|
88 |
+
- [CelebA-HQ](https://github.com/tkarras/progressive_growing_of_gans) (30,000 samples)
|
89 |
+
- [FF-HQ](https://github.com/NVlabs/ffhq-dataset) (70,000 samples)
|
90 |
+
- [LSUN](https://github.com/fyu/lsun) (see statistical information below)
|
91 |
+
- [Places](http://places2.csail.mit.edu/) (around 1.8M training samples covering 365 classes)
|
92 |
+
- [Cityscapes](https://www.cityscapes-dataset.com/) (2,975 training samples, 19998 extra training samples (one broken), 500 validation samples, and 1,525 test samples)
|
93 |
+
- [Streetscapes](http://streetscore.media.mit.edu/data.html)
|
94 |
+
|
95 |
+
Statistical information of [LSUN](https://github.com/fyu/lsun) dataset is summarized as follows:
|
96 |
+
|
97 |
+
| LSUN Datasets Stats | | |
|
98 |
+
| :-- | :--: | :--: |
|
99 |
+
| Name | Number of Samples | Size |
|
100 |
+
| *Scenes*
|
101 |
+
| bedroom (train) | 3,033,042 | 43G |
|
102 |
+
| bridge (train) | 818,687 | 15G |
|
103 |
+
| churchoutdoor (train) | 126,227 | 2G |
|
104 |
+
| classroom (train) | 168,103 | 3G |
|
105 |
+
| conferenceroom (train) | 229,069 | 4G |
|
106 |
+
| diningroom (train) | 657,571 | 11G |
|
107 |
+
| kitchen (train) | 2,212,277 | 33G |
|
108 |
+
| livingroom (train) | 1,315,802 | 21G |
|
109 |
+
| restaurant (train) | 626,331 | 13G |
|
110 |
+
| tower (train) | 708,264 | 11G |
|
111 |
+
| *Objects*
|
112 |
+
| airplane | 1,530,696 | 34G |
|
113 |
+
| bicycle | 3,347,211 | 129G |
|
114 |
+
| bird | 2,310,362 | 65G |
|
115 |
+
| boat | 2,651,165 | 86G |
|
116 |
+
| bottle | 3,202,760 | 64G |
|
117 |
+
| bus | 695,891 | 24G |
|
118 |
+
| car | 5,520,756 | 173G |
|
119 |
+
| cat | 1,657,266 | 42G |
|
120 |
+
| chair | 5,037,807 | 116G |
|
121 |
+
| cow | 377,379 | 15G |
|
122 |
+
| diningtable | 1,537,123 | 48G |
|
123 |
+
| dog | 5,054,817 | 145G |
|
124 |
+
| horse | 2,000,340 | 69G |
|
125 |
+
| motorbike | 1,194,101 | 42G |
|
126 |
+
| person | 18,890,816 | 477G |
|
127 |
+
| pottedplant | 1,104,859 | 43G |
|
128 |
+
| sheep | 418,983 | 18G |
|
129 |
+
| sofa | 2,365,870 | 56G |
|
130 |
+
| train | 1,148,020 | 43G |
|
131 |
+
| tvmonitor | 2,463,284 | 46G |
|
networks/genforce/README.md
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# GenForce Lib for Generative Modeling
|
2 |
+
|
3 |
+
An efficient PyTorch library for deep generative modeling. May the Generative Force (GenForce) be with You.
|
4 |
+
|
5 |
+
![image](./teaser.gif)
|
6 |
+
|
7 |
+
## Updates
|
8 |
+
|
9 |
+
- **Encoder Training:** We support training encoders on top of pre-trained GANs for GAN inversion.
|
10 |
+
- **Model Converters:** You can easily migrate your already started projects to this repository. Please check [here](./converters/README.md) for more details.
|
11 |
+
|
12 |
+
## Highlights
|
13 |
+
|
14 |
+
- **Distributed** training framework.
|
15 |
+
- **Fast** training speed.
|
16 |
+
- **Modular** design for prototyping new models.
|
17 |
+
- **Model zoo** containing a rich set of pretrained GAN models, with [Colab live demo](https://colab.research.google.com/github/genforce/genforce/blob/master/docs/synthesize_demo.ipynb) to play.
|
18 |
+
|
19 |
+
## Installation
|
20 |
+
|
21 |
+
1. Create a virtual environment via `conda`.
|
22 |
+
|
23 |
+
```shell
|
24 |
+
conda create -n genforce python=3.7
|
25 |
+
conda activate genforce
|
26 |
+
```
|
27 |
+
|
28 |
+
2. Install `cuda` and `cudnn`. (We use `CUDA 10.0` in case you would like to use `TensorFlow 1.15` for model conversion.)
|
29 |
+
|
30 |
+
```shell
|
31 |
+
conda install cudatoolkit=10.0 cudnn=7.6.5
|
32 |
+
```
|
33 |
+
|
34 |
+
3. Install `torch` and `torchvision`.
|
35 |
+
|
36 |
+
```shell
|
37 |
+
pip install torch==1.7 torchvision==0.8
|
38 |
+
```
|
39 |
+
|
40 |
+
4. Install requirements
|
41 |
+
|
42 |
+
```shell
|
43 |
+
pip install -r requirements.txt
|
44 |
+
```
|
45 |
+
|
46 |
+
## Quick Demo
|
47 |
+
|
48 |
+
We provide a quick training demo, `scripts/stylegan_training_demo.py`, which allows to train StyleGAN on a toy dataset (500 animeface images with 64 x 64 resolution). Try it via
|
49 |
+
|
50 |
+
```shell
|
51 |
+
./scripts/stylegan_training_demo.sh
|
52 |
+
```
|
53 |
+
|
54 |
+
We also provide an inference demo, `synthesize.py`, which allows to synthesize images with pre-trained models. Generated images can be found at `work_dirs/synthesis_results/`. Try it via
|
55 |
+
|
56 |
+
```shell
|
57 |
+
python synthesize.py stylegan_ffhq1024
|
58 |
+
```
|
59 |
+
|
60 |
+
You can also play the demo at [Colab](https://colab.research.google.com/github/genforce/genforce/blob/master/docs/synthesize_demo.ipynb).
|
61 |
+
|
62 |
+
## Play with GANs
|
63 |
+
|
64 |
+
### Test
|
65 |
+
|
66 |
+
Pre-trained models can be found at [model zoo](MODEL_ZOO.md).
|
67 |
+
|
68 |
+
- On local machine:
|
69 |
+
|
70 |
+
```shell
|
71 |
+
GPUS=8
|
72 |
+
CONFIG=configs/stylegan_ffhq256_val.py
|
73 |
+
WORK_DIR=work_dirs/stylegan_ffhq256_val
|
74 |
+
CHECKPOINT=checkpoints/stylegan_ffhq256.pth
|
75 |
+
./scripts/dist_test.sh ${GPUS} ${CONFIG} ${WORK_DIR} ${CHECKPOINT}
|
76 |
+
```
|
77 |
+
|
78 |
+
- Using `slurm`:
|
79 |
+
|
80 |
+
```shell
|
81 |
+
CONFIG=configs/stylegan_ffhq256_val.py
|
82 |
+
WORK_DIR=work_dirs/stylegan_ffhq256_val
|
83 |
+
CHECKPOINT=checkpoints/stylegan_ffhq256.pth
|
84 |
+
GPUS=8 ./scripts/slurm_test.sh ${PARTITION} ${JOB_NAME} \
|
85 |
+
${CONFIG} ${WORK_DIR} ${CHECKPOINT}
|
86 |
+
```
|
87 |
+
|
88 |
+
### Train
|
89 |
+
|
90 |
+
All log files in the training process, such as log message, checkpoints, synthesis snapshots, etc, will be saved to the work directory.
|
91 |
+
|
92 |
+
- On local machine:
|
93 |
+
|
94 |
+
```shell
|
95 |
+
GPUS=8
|
96 |
+
CONFIG=configs/stylegan_ffhq256.py
|
97 |
+
WORK_DIR=work_dirs/stylegan_ffhq256_train
|
98 |
+
./scripts/dist_train.sh ${GPUS} ${CONFIG} ${WORK_DIR} \
|
99 |
+
[--options additional_arguments]
|
100 |
+
```
|
101 |
+
|
102 |
+
- Using `slurm`:
|
103 |
+
|
104 |
+
```shell
|
105 |
+
CONFIG=configs/stylegan_ffhq256.py
|
106 |
+
WORK_DIR=work_dirs/stylegan_ffhq256_train
|
107 |
+
GPUS=8 ./scripts/slurm_train.sh ${PARTITION} ${JOB_NAME} \
|
108 |
+
${CONFIG} ${WORK_DIR} \
|
109 |
+
[--options additional_arguments]
|
110 |
+
```
|
111 |
+
|
112 |
+
## Play with Encoders for GAN Inversion
|
113 |
+
|
114 |
+
### Train
|
115 |
+
|
116 |
+
- On local machine:
|
117 |
+
|
118 |
+
```shell
|
119 |
+
GPUS=8
|
120 |
+
CONFIG=configs/stylegan_ffhq256_encoder_y.py
|
121 |
+
WORK_DIR=work_dirs/stylegan_ffhq256_encoder_y
|
122 |
+
./scripts/dist_train.sh ${GPUS} ${CONFIG} ${WORK_DIR} \
|
123 |
+
[--options additional_arguments]
|
124 |
+
```
|
125 |
+
|
126 |
+
|
127 |
+
- Using `slurm`:
|
128 |
+
|
129 |
+
```shell
|
130 |
+
CONFIG=configs/stylegan_ffhq256_encoder_y.py
|
131 |
+
WORK_DIR=work_dirs/stylegan_ffhq256_encoder_y
|
132 |
+
GPUS=8 ./scripts/slurm_train.sh ${PARTITION} ${JOB_NAME} \
|
133 |
+
${CONFIG} ${WORK_DIR} \
|
134 |
+
[--options additional_arguments]
|
135 |
+
```
|
136 |
+
## Contributors
|
137 |
+
|
138 |
+
| Member | Module |
|
139 |
+
| :-- | :-- |
|
140 |
+
|[Yujun Shen](http://shenyujun.github.io/) | models and running controllers
|
141 |
+
|[Yinghao Xu](https://justimyhxu.github.io/) | runner and loss functions
|
142 |
+
|[Ceyuan Yang](http://ceyuan.me/) | data loader
|
143 |
+
|[Jiapeng Zhu](https://zhujiapeng.github.io/) | evaluation metrics
|
144 |
+
|[Bolei Zhou](http://bzhou.ie.cuhk.edu.hk/) | cheerleader
|
145 |
+
|
146 |
+
**NOTE:** The above form only lists the person in charge for each module. We help each other a lot and develop as a **TEAM**.
|
147 |
+
|
148 |
+
*We welcome external contributors to join us for improving this library.*
|
149 |
+
|
150 |
+
## License
|
151 |
+
|
152 |
+
The project is under the [MIT License](./LICENSE).
|
153 |
+
|
154 |
+
## Acknowledgement
|
155 |
+
|
156 |
+
We thank [PGGAN](https://github.com/tkarras/progressive_growing_of_gans), [StyleGAN](https://github.com/NVlabs/stylegan), [StyleGAN2](https://github.com/NVlabs/stylegan2), [StyleGAN2-ADA](https://github.com/NVlabs/stylegan2-ada) for their work on high-quality image synthesis. We thank [IDInvert](https://github.com/genforce/idinvert) and [GHFeat](https://github.com/genforce/ghfeat) for their contribution to GAN inversion. We also thank [MMCV](https://github.com/open-mmlab/mmcv) for the inspiration on the design of controllers.
|
157 |
+
|
158 |
+
## BibTex
|
159 |
+
|
160 |
+
We open source this library to the community to facilitate the research of generative modeling. If you do like our work and use the codebase or models for your research, please cite our work as follows.
|
161 |
+
|
162 |
+
```bibtex
|
163 |
+
@misc{genforce2020,
|
164 |
+
title = {GenForce},
|
165 |
+
author = {Shen, Yujun and Xu, Yinghao and Yang, Ceyuan and Zhu, Jiapeng and Zhou, Bolei},
|
166 |
+
howpublished = {\url{https://github.com/genforce/genforce}},
|
167 |
+
year = {2020}
|
168 |
+
}
|
169 |
+
```
|
networks/genforce/__init__.py
ADDED
File without changes
|
networks/genforce/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (162 Bytes). View file
|
|
networks/genforce/configs/stylegan_demo.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Configuration for StyleGAN training demo.
|
3 |
+
|
4 |
+
All settings are particularly used for one replica (GPU), such as `batch_size`
|
5 |
+
and `num_workers`.
|
6 |
+
"""
|
7 |
+
|
8 |
+
runner_type = 'StyleGANRunner'
|
9 |
+
gan_type = 'stylegan'
|
10 |
+
resolution = 64
|
11 |
+
batch_size = 4
|
12 |
+
val_batch_size = 32
|
13 |
+
total_img = 100_000
|
14 |
+
|
15 |
+
# Training dataset is repeated at the beginning to avoid loading dataset
|
16 |
+
# repeatedly at the end of each epoch. This can save some I/O time.
|
17 |
+
data = dict(
|
18 |
+
num_workers=4,
|
19 |
+
repeat=500,
|
20 |
+
train=dict(root_dir='data/demo.zip', data_format='zip',
|
21 |
+
resolution=resolution, mirror=0.5),
|
22 |
+
val=dict(root_dir='data/demo.zip', data_format='zip',
|
23 |
+
resolution=resolution),
|
24 |
+
)
|
25 |
+
|
26 |
+
controllers = dict(
|
27 |
+
RunningLogger=dict(every_n_iters=10),
|
28 |
+
ProgressScheduler=dict(
|
29 |
+
every_n_iters=1, init_res=8, minibatch_repeats=4,
|
30 |
+
lod_training_img=5_000, lod_transition_img=5_000,
|
31 |
+
batch_size_schedule=dict(res4=64, res8=32, res16=16, res32=8),
|
32 |
+
),
|
33 |
+
Snapshoter=dict(every_n_iters=500, first_iter=True, num=200),
|
34 |
+
FIDEvaluator=dict(every_n_iters=5000, first_iter=True, num=50000),
|
35 |
+
Checkpointer=dict(every_n_iters=5000, first_iter=True),
|
36 |
+
)
|
37 |
+
|
38 |
+
modules = dict(
|
39 |
+
discriminator=dict(
|
40 |
+
model=dict(gan_type=gan_type, resolution=resolution),
|
41 |
+
lr=dict(lr_type='FIXED'),
|
42 |
+
opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)),
|
43 |
+
kwargs_train=dict(),
|
44 |
+
kwargs_val=dict(),
|
45 |
+
),
|
46 |
+
generator=dict(
|
47 |
+
model=dict(gan_type=gan_type, resolution=resolution),
|
48 |
+
lr=dict(lr_type='FIXED'),
|
49 |
+
opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)),
|
50 |
+
kwargs_train=dict(w_moving_decay=0.995, style_mixing_prob=0.9,
|
51 |
+
trunc_psi=1.0, trunc_layers=0, randomize_noise=True),
|
52 |
+
kwargs_val=dict(trunc_psi=1.0, trunc_layers=0, randomize_noise=False),
|
53 |
+
g_smooth_img=10000,
|
54 |
+
)
|
55 |
+
)
|
56 |
+
|
57 |
+
loss = dict(
|
58 |
+
type='LogisticGANLoss',
|
59 |
+
d_loss_kwargs=dict(r1_gamma=10.0),
|
60 |
+
g_loss_kwargs=dict(),
|
61 |
+
)
|
networks/genforce/configs/stylegan_ffhq1024.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Configuration for training StyleGAN on FF-HQ (1024) dataset.
|
3 |
+
|
4 |
+
All settings are particularly used for one replica (GPU), such as `batch_size`
|
5 |
+
and `num_workers`.
|
6 |
+
"""
|
7 |
+
|
8 |
+
runner_type = 'StyleGANRunner'
|
9 |
+
gan_type = 'stylegan'
|
10 |
+
resolution = 1024
|
11 |
+
batch_size = 4
|
12 |
+
val_batch_size = 16
|
13 |
+
total_img = 25000_000
|
14 |
+
|
15 |
+
# Training dataset is repeated at the beginning to avoid loading dataset
|
16 |
+
# repeatedly at the end of each epoch. This can save some I/O time.
|
17 |
+
data = dict(
|
18 |
+
num_workers=4,
|
19 |
+
repeat=500,
|
20 |
+
# train=dict(root_dir='data/ffhq', resolution=resolution, mirror=0.5),
|
21 |
+
# val=dict(root_dir='data/ffhq', resolution=resolution),
|
22 |
+
train=dict(root_dir='data/ffhq.zip', data_format='zip',
|
23 |
+
resolution=resolution, mirror=0.5),
|
24 |
+
val=dict(root_dir='data/ffhq.zip', data_format='zip',
|
25 |
+
resolution=resolution),
|
26 |
+
)
|
27 |
+
|
28 |
+
controllers = dict(
|
29 |
+
RunningLogger=dict(every_n_iters=10),
|
30 |
+
ProgressScheduler=dict(
|
31 |
+
every_n_iters=1, init_res=8, minibatch_repeats=4,
|
32 |
+
lod_training_img=600_000, lod_transition_img=600_000,
|
33 |
+
batch_size_schedule=dict(res4=64, res8=32, res16=16, res32=8),
|
34 |
+
),
|
35 |
+
Snapshoter=dict(every_n_iters=500, first_iter=True, num=200),
|
36 |
+
FIDEvaluator=dict(every_n_iters=5000, first_iter=True, num=50000),
|
37 |
+
Checkpointer=dict(every_n_iters=5000, first_iter=True),
|
38 |
+
)
|
39 |
+
|
40 |
+
modules = dict(
|
41 |
+
discriminator=dict(
|
42 |
+
model=dict(gan_type=gan_type, resolution=resolution),
|
43 |
+
lr=dict(lr_type='FIXED'),
|
44 |
+
opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)),
|
45 |
+
kwargs_train=dict(),
|
46 |
+
kwargs_val=dict(),
|
47 |
+
),
|
48 |
+
generator=dict(
|
49 |
+
model=dict(gan_type=gan_type, resolution=resolution),
|
50 |
+
lr=dict(lr_type='FIXED'),
|
51 |
+
opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)),
|
52 |
+
kwargs_train=dict(w_moving_decay=0.995, style_mixing_prob=0.9,
|
53 |
+
trunc_psi=1.0, trunc_layers=0, randomize_noise=True),
|
54 |
+
kwargs_val=dict(trunc_psi=1.0, trunc_layers=0, randomize_noise=False),
|
55 |
+
g_smooth_img=10_000,
|
56 |
+
)
|
57 |
+
)
|
58 |
+
|
59 |
+
loss = dict(
|
60 |
+
type='LogisticGANLoss',
|
61 |
+
d_loss_kwargs=dict(r1_gamma=10.0),
|
62 |
+
g_loss_kwargs=dict(),
|
63 |
+
)
|
networks/genforce/configs/stylegan_ffhq1024_val.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Configuration for testing StyleGAN on FF-HQ (1024) dataset.
|
3 |
+
|
4 |
+
All settings are particularly used for one replica (GPU), such as `batch_size`
|
5 |
+
and `num_workers`.
|
6 |
+
"""
|
7 |
+
|
8 |
+
runner_type = 'StyleGANRunner'
|
9 |
+
gan_type = 'stylegan'
|
10 |
+
resolution = 1024
|
11 |
+
batch_size = 16
|
12 |
+
|
13 |
+
data = dict(
|
14 |
+
num_workers=4,
|
15 |
+
# val=dict(root_dir='data/ffhq', resolution=resolution),
|
16 |
+
val=dict(root_dir='data/ffhq.zip', data_format='zip',
|
17 |
+
resolution=resolution),
|
18 |
+
)
|
19 |
+
|
20 |
+
modules = dict(
|
21 |
+
discriminator=dict(
|
22 |
+
model=dict(gan_type=gan_type, resolution=resolution),
|
23 |
+
kwargs_val=dict(),
|
24 |
+
),
|
25 |
+
generator=dict(
|
26 |
+
model=dict(gan_type=gan_type, resolution=resolution),
|
27 |
+
kwargs_val=dict(trunc_psi=0.7, trunc_layers=8, randomize_noise=False),
|
28 |
+
)
|
29 |
+
)
|
networks/genforce/configs/stylegan_ffhq256.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Configuration for training StyleGAN on FF-HQ (256) dataset.
|
3 |
+
|
4 |
+
All settings are particularly used for one replica (GPU), such as `batch_size`
|
5 |
+
and `num_workers`.
|
6 |
+
"""
|
7 |
+
|
8 |
+
runner_type = 'StyleGANRunner'
|
9 |
+
gan_type = 'stylegan'
|
10 |
+
resolution = 256
|
11 |
+
batch_size = 4
|
12 |
+
val_batch_size = 64
|
13 |
+
total_img = 25000_000
|
14 |
+
|
15 |
+
# Training dataset is repeated at the beginning to avoid loading dataset
|
16 |
+
# repeatedly at the end of each epoch. This can save some I/O time.
|
17 |
+
data = dict(
|
18 |
+
num_workers=4,
|
19 |
+
repeat=500,
|
20 |
+
# train=dict(root_dir='data/ffhq', resolution=resolution, mirror=0.5),
|
21 |
+
# val=dict(root_dir='data/ffhq', resolution=resolution),
|
22 |
+
train=dict(root_dir='data/ffhq.zip', data_format='zip',
|
23 |
+
resolution=resolution, mirror=0.5),
|
24 |
+
val=dict(root_dir='data/ffhq.zip', data_format='zip',
|
25 |
+
resolution=resolution),
|
26 |
+
)
|
27 |
+
|
28 |
+
controllers = dict(
|
29 |
+
RunningLogger=dict(every_n_iters=10),
|
30 |
+
ProgressScheduler=dict(
|
31 |
+
every_n_iters=1, init_res=8, minibatch_repeats=4,
|
32 |
+
lod_training_img=600_000, lod_transition_img=600_000,
|
33 |
+
batch_size_schedule=dict(res4=64, res8=32, res16=16, res32=8),
|
34 |
+
),
|
35 |
+
Snapshoter=dict(every_n_iters=500, first_iter=True, num=200),
|
36 |
+
FIDEvaluator=dict(every_n_iters=5000, first_iter=True, num=50000),
|
37 |
+
Checkpointer=dict(every_n_iters=5000, first_iter=True),
|
38 |
+
)
|
39 |
+
|
40 |
+
modules = dict(
|
41 |
+
discriminator=dict(
|
42 |
+
model=dict(gan_type=gan_type, resolution=resolution),
|
43 |
+
lr=dict(lr_type='FIXED'),
|
44 |
+
opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)),
|
45 |
+
kwargs_train=dict(),
|
46 |
+
kwargs_val=dict(),
|
47 |
+
),
|
48 |
+
generator=dict(
|
49 |
+
model=dict(gan_type=gan_type, resolution=resolution),
|
50 |
+
lr=dict(lr_type='FIXED'),
|
51 |
+
opt=dict(opt_type='Adam', base_lr=1e-3, betas=(0.0, 0.99)),
|
52 |
+
kwargs_train=dict(w_moving_decay=0.995, style_mixing_prob=0.9,
|
53 |
+
trunc_psi=1.0, trunc_layers=0, randomize_noise=True),
|
54 |
+
kwargs_val=dict(trunc_psi=1.0, trunc_layers=0, randomize_noise=False),
|
55 |
+
g_smooth_img=10_000,
|
56 |
+
)
|
57 |
+
)
|
58 |
+
|
59 |
+
loss = dict(
|
60 |
+
type='LogisticGANLoss',
|
61 |
+
d_loss_kwargs=dict(r1_gamma=10.0),
|
62 |
+
g_loss_kwargs=dict(),
|
63 |
+
)
|
networks/genforce/configs/stylegan_ffhq256_encoder_y.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Configuration for training StyleGAN Encoder on FF-HQ (256) dataset.
|
3 |
+
|
4 |
+
All settings are particularly used for one replica (GPU), such as `batch_size`
|
5 |
+
and `num_workers`.
|
6 |
+
"""
|
7 |
+
|
8 |
+
gan_model_path = 'checkpoints/stylegan_ffhq256.pth'
|
9 |
+
perceptual_model_path = 'checkpoints/vgg16.pth'
|
10 |
+
|
11 |
+
runner_type = 'EncoderRunner'
|
12 |
+
gan_type = 'stylegan'
|
13 |
+
resolution = 256
|
14 |
+
batch_size = 12
|
15 |
+
val_batch_size = 25
|
16 |
+
total_img = 14000_000
|
17 |
+
space_of_latent = 'y'
|
18 |
+
|
19 |
+
# Training dataset is repeated at the beginning to avoid loading dataset
|
20 |
+
# repeatedly at the end of each epoch. This can save some I/O time.
|
21 |
+
data = dict(
|
22 |
+
num_workers=4,
|
23 |
+
repeat=500,
|
24 |
+
# train=dict(root_dir='data/ffhq', resolution=resolution, mirror=0.5),
|
25 |
+
# val=dict(root_dir='data/ffhq', resolution=resolution),
|
26 |
+
train=dict(root_dir='data/', data_format='list',
|
27 |
+
image_list_path='data/ffhq/ffhq_train_list.txt',
|
28 |
+
resolution=resolution, mirror=0.5),
|
29 |
+
val=dict(root_dir='data/', data_format='list',
|
30 |
+
image_list_path='./data/ffhq/ffhq_val_list.txt',
|
31 |
+
resolution=resolution),
|
32 |
+
)
|
33 |
+
|
34 |
+
controllers = dict(
|
35 |
+
RunningLogger=dict(every_n_iters=50),
|
36 |
+
Snapshoter=dict(every_n_iters=10000, first_iter=True, num=200),
|
37 |
+
Checkpointer=dict(every_n_iters=10000, first_iter=False),
|
38 |
+
)
|
39 |
+
|
40 |
+
modules = dict(
|
41 |
+
discriminator=dict(
|
42 |
+
model=dict(gan_type=gan_type, resolution=resolution),
|
43 |
+
lr=dict(lr_type='ExpSTEP', decay_factor=0.8, decay_step=36458 // 2),
|
44 |
+
opt=dict(opt_type='Adam', base_lr=1e-4, betas=(0.9, 0.99)),
|
45 |
+
kwargs_train=dict(),
|
46 |
+
kwargs_val=dict(),
|
47 |
+
),
|
48 |
+
generator=dict(
|
49 |
+
model=dict(gan_type=gan_type, resolution=resolution, repeat_w=True),
|
50 |
+
kwargs_val=dict(randomize_noise=False),
|
51 |
+
),
|
52 |
+
encoder=dict(
|
53 |
+
model=dict(gan_type=gan_type, resolution=resolution, network_depth=18,
|
54 |
+
latent_dim = [1024] * 8 + [512, 512, 256, 256, 128, 128],
|
55 |
+
num_latents_per_head=[4, 4, 6],
|
56 |
+
use_fpn=True,
|
57 |
+
fpn_channels=512,
|
58 |
+
use_sam=True,
|
59 |
+
sam_channels=512),
|
60 |
+
lr=dict(lr_type='ExpSTEP', decay_factor=0.8, decay_step=36458 // 2),
|
61 |
+
opt=dict(opt_type='Adam', base_lr=1e-4, betas=(0.9, 0.99)),
|
62 |
+
kwargs_train=dict(),
|
63 |
+
kwargs_val=dict(),
|
64 |
+
),
|
65 |
+
)
|
66 |
+
|
67 |
+
loss = dict(
|
68 |
+
type='EncoderLoss',
|
69 |
+
d_loss_kwargs=dict(r1_gamma=10.0),
|
70 |
+
e_loss_kwargs=dict(adv_lw=0.08, perceptual_lw=5e-5),
|
71 |
+
perceptual_kwargs=dict(output_layer_idx=23,
|
72 |
+
pretrained_weight_path=perceptual_model_path),
|
73 |
+
)
|
networks/genforce/configs/stylegan_ffhq256_val.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Configuration for testing StyleGAN on FF-HQ (256) dataset.
|
3 |
+
|
4 |
+
All settings are particularly used for one replica (GPU), such as `batch_size`
|
5 |
+
and `num_workers`.
|
6 |
+
"""
|
7 |
+
|
8 |
+
runner_type = 'StyleGANRunner'
|
9 |
+
gan_type = 'stylegan'
|
10 |
+
resolution = 256
|
11 |
+
batch_size = 64
|
12 |
+
|
13 |
+
data = dict(
|
14 |
+
num_workers=4,
|
15 |
+
# val=dict(root_dir='data/ffhq', resolution=resolution),
|
16 |
+
val=dict(root_dir='data/ffhq.zip', data_format='zip',
|
17 |
+
resolution=resolution),
|
18 |
+
)
|
19 |
+
|
20 |
+
modules = dict(
|
21 |
+
discriminator=dict(
|
22 |
+
model=dict(gan_type=gan_type, resolution=resolution),
|
23 |
+
kwargs_val=dict(),
|
24 |
+
),
|
25 |
+
generator=dict(
|
26 |
+
model=dict(gan_type=gan_type, resolution=resolution),
|
27 |
+
kwargs_val=dict(trunc_psi=0.7, trunc_layers=8, randomize_noise=False),
|
28 |
+
)
|
29 |
+
)
|
networks/genforce/convert_model.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Script to convert officially released models to match this repository."""
|
2 |
+
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
from converters import convert_pggan_weight
|
6 |
+
from converters import convert_stylegan_weight
|
7 |
+
from converters import convert_stylegan2_weight
|
8 |
+
from converters import convert_stylegan2ada_tf_weight
|
9 |
+
from converters import convert_stylegan2ada_pth_weight
|
10 |
+
|
11 |
+
|
12 |
+
def parse_args():
|
13 |
+
"""Parses arguments."""
|
14 |
+
parser = argparse.ArgumentParser(description='Convert pre-trained models.')
|
15 |
+
parser.add_argument('model_type', type=str,
|
16 |
+
choices=['pggan', 'stylegan', 'stylegan2',
|
17 |
+
'stylegan2ada_tf', 'stylegan2ada_pth'],
|
18 |
+
help='Type of the model to convert')
|
19 |
+
parser.add_argument('--source_model_path', type=str, required=True,
|
20 |
+
help='Path to load the model for conversion.')
|
21 |
+
parser.add_argument('--target_model_path', type=str, default=None,
|
22 |
+
help='Path to save the converted model. If not '
|
23 |
+
'specified, the model will be saved to the same '
|
24 |
+
'directory of the source model.')
|
25 |
+
parser.add_argument('--test_num', type=int, default=10,
|
26 |
+
help='Number of test samples used to check the '
|
27 |
+
'precision of the converted model. (default: 10)')
|
28 |
+
parser.add_argument('--save_test_image', action='store_true',
|
29 |
+
help='Whether to save the test image. (default: False)')
|
30 |
+
parser.add_argument('--verbose_log', action='store_true',
|
31 |
+
help='Whether to print verbose log. (default: False)')
|
32 |
+
return parser.parse_args()
|
33 |
+
|
34 |
+
|
35 |
+
def main():
|
36 |
+
"""Main function."""
|
37 |
+
args = parse_args()
|
38 |
+
if args.target_model_path is None:
|
39 |
+
args.target_model_path = args.source_model_path.replace('.pkl', '.pth')
|
40 |
+
|
41 |
+
if args.model_type == 'pggan':
|
42 |
+
convert_pggan_weight(tf_weight_path=args.source_model_path,
|
43 |
+
pth_weight_path=args.target_model_path,
|
44 |
+
test_num=args.test_num,
|
45 |
+
save_test_image=args.save_test_image,
|
46 |
+
verbose=args.verbose_log)
|
47 |
+
elif args.model_type == 'stylegan':
|
48 |
+
convert_stylegan_weight(tf_weight_path=args.source_model_path,
|
49 |
+
pth_weight_path=args.target_model_path,
|
50 |
+
test_num=args.test_num,
|
51 |
+
save_test_image=args.save_test_image,
|
52 |
+
verbose=args.verbose_log)
|
53 |
+
elif args.model_type == 'stylegan2':
|
54 |
+
convert_stylegan2_weight(tf_weight_path=args.source_model_path,
|
55 |
+
pth_weight_path=args.target_model_path,
|
56 |
+
test_num=args.test_num,
|
57 |
+
save_test_image=args.save_test_image,
|
58 |
+
verbose=args.verbose_log)
|
59 |
+
elif args.model_type == 'stylegan2ada_tf':
|
60 |
+
convert_stylegan2ada_tf_weight(tf_weight_path=args.source_model_path,
|
61 |
+
pth_weight_path=args.target_model_path,
|
62 |
+
test_num=args.test_num,
|
63 |
+
save_test_image=args.save_test_image,
|
64 |
+
verbose=args.verbose_log)
|
65 |
+
elif args.model_type == 'stylegan2ada_pth':
|
66 |
+
convert_stylegan2ada_pth_weight(src_weight_path=args.source_model_path,
|
67 |
+
dst_weight_path=args.target_model_path,
|
68 |
+
test_num=args.test_num,
|
69 |
+
save_test_image=args.save_test_image,
|
70 |
+
verbose=args.verbose_log)
|
71 |
+
else:
|
72 |
+
raise NotImplementedError(f'Model type `{args.model_type}` is not '
|
73 |
+
f'supported!')
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__':
|
77 |
+
main()
|
networks/genforce/datasets/README.md
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Data Preparation
|
2 |
+
|
3 |
+
## Data Format
|
4 |
+
|
5 |
+
Currently, our dataloader is able to load data from
|
6 |
+
|
7 |
+
- a directory that is full of images (support using [`turbojpeg`](https://pypi.org/project/PyTurboJPEG/) to speed up decoding images.)
|
8 |
+
- a `lmdb` file
|
9 |
+
- an image list
|
10 |
+
- a compressed file (i.e., `zip` package)
|
11 |
+
|
12 |
+
by modifying `data_format` in the configuration.
|
13 |
+
|
14 |
+
**NOTE:** For some computing clusters whose I/O speed may be slow, we recommend the `zip` format for two reasons. First, `zip` file is easy to create. Second, this can load a large file at one time instead of loading small files repeatedly.
|
15 |
+
|
16 |
+
## Data Sampling
|
17 |
+
|
18 |
+
Considering that most generative models are trained in the unit of iterations instead of epochs, we change the default data loader to an *iter-based* one. Besides, the original distributed data sampler is also modified to make the shuffling correspond to iteration instead of epoch.
|
19 |
+
|
20 |
+
**NOTE:** In order to reduce the data re-loading cost between epochs, we manually extend the length of sampled indices to make it much more efficient.
|
21 |
+
|
22 |
+
## Data Augmentation
|
23 |
+
|
24 |
+
To better align with the original implementation of PGGAN and StyleGAN (i.e., models that require progressive training), we support progressive resize in `transforms.py`, which downsamples images with the maximum resize factor of 2 at each time.
|
networks/genforce/datasets/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Collects datasets and data loaders."""
|
3 |
+
|
4 |
+
from .datasets import BaseDataset
|
5 |
+
from .dataloaders import IterDataLoader
|
6 |
+
|
7 |
+
__all__ = ['BaseDataset', 'IterDataLoader']
|
networks/genforce/datasets/dataloaders.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the class of data loader."""
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
from .distributed_sampler import DistributedSampler
|
8 |
+
from .datasets import BaseDataset
|
9 |
+
|
10 |
+
|
11 |
+
__all__ = ['IterDataLoader']
|
12 |
+
|
13 |
+
|
14 |
+
class IterDataLoader(object):
|
15 |
+
"""Iteration-based data loader."""
|
16 |
+
|
17 |
+
def __init__(self,
|
18 |
+
dataset,
|
19 |
+
batch_size,
|
20 |
+
shuffle=True,
|
21 |
+
num_workers=1,
|
22 |
+
current_iter=0,
|
23 |
+
repeat=1):
|
24 |
+
"""Initializes the data loader.
|
25 |
+
|
26 |
+
Args:
|
27 |
+
dataset: The dataset to load data from.
|
28 |
+
batch_size: The batch size on each GPU.
|
29 |
+
shuffle: Whether to shuffle the data. (default: True)
|
30 |
+
num_workers: Number of data workers for each GPU. (default: 1)
|
31 |
+
current_iter: The current number of iterations. (default: 0)
|
32 |
+
repeat: The repeating number of the whole dataloader. (default: 1)
|
33 |
+
"""
|
34 |
+
self._dataset = dataset
|
35 |
+
self.batch_size = batch_size
|
36 |
+
self.shuffle = shuffle
|
37 |
+
self.num_workers = num_workers
|
38 |
+
self._dataloader = None
|
39 |
+
self.iter_loader = None
|
40 |
+
self._iter = current_iter
|
41 |
+
self.repeat = repeat
|
42 |
+
self.build_dataloader()
|
43 |
+
|
44 |
+
def build_dataloader(self):
|
45 |
+
"""Builds data loader."""
|
46 |
+
dist_sampler = DistributedSampler(self._dataset,
|
47 |
+
shuffle=self.shuffle,
|
48 |
+
current_iter=self._iter,
|
49 |
+
repeat=self.repeat)
|
50 |
+
|
51 |
+
self._dataloader = DataLoader(self._dataset,
|
52 |
+
batch_size=self.batch_size,
|
53 |
+
shuffle=(dist_sampler is None),
|
54 |
+
num_workers=self.num_workers,
|
55 |
+
drop_last=self.shuffle,
|
56 |
+
pin_memory=True,
|
57 |
+
sampler=dist_sampler)
|
58 |
+
self.iter_loader = iter(self._dataloader)
|
59 |
+
|
60 |
+
|
61 |
+
def overwrite_param(self, batch_size=None, resolution=None):
|
62 |
+
"""Overwrites some parameters for progressive training."""
|
63 |
+
if (not batch_size) and (not resolution):
|
64 |
+
return
|
65 |
+
if (batch_size == self.batch_size) and (
|
66 |
+
resolution == self.dataset.resolution):
|
67 |
+
return
|
68 |
+
if batch_size:
|
69 |
+
self.batch_size = batch_size
|
70 |
+
if resolution:
|
71 |
+
self._dataset.resolution = resolution
|
72 |
+
self.build_dataloader()
|
73 |
+
|
74 |
+
@property
|
75 |
+
def iter(self):
|
76 |
+
"""Returns the current iteration."""
|
77 |
+
return self._iter
|
78 |
+
|
79 |
+
@property
|
80 |
+
def dataset(self):
|
81 |
+
"""Returns the dataset."""
|
82 |
+
return self._dataset
|
83 |
+
|
84 |
+
@property
|
85 |
+
def dataloader(self):
|
86 |
+
"""Returns the data loader."""
|
87 |
+
return self._dataloader
|
88 |
+
|
89 |
+
def __next__(self):
|
90 |
+
try:
|
91 |
+
data = next(self.iter_loader)
|
92 |
+
self._iter += 1
|
93 |
+
except StopIteration:
|
94 |
+
self._dataloader.sampler.__reset__(self._iter)
|
95 |
+
self.iter_loader = iter(self._dataloader)
|
96 |
+
data = next(self.iter_loader)
|
97 |
+
self._iter += 1
|
98 |
+
return data
|
99 |
+
|
100 |
+
def __len__(self):
|
101 |
+
return len(self._dataloader)
|
102 |
+
|
103 |
+
|
104 |
+
def dataloader_test(root_dir, test_num=10):
|
105 |
+
"""Tests data loader."""
|
106 |
+
res = 2
|
107 |
+
bs = 2
|
108 |
+
dataset = BaseDataset(root_dir=root_dir, resolution=res)
|
109 |
+
dataloader = IterDataLoader(dataset=dataset,
|
110 |
+
batch_size=bs,
|
111 |
+
shuffle=False)
|
112 |
+
for _ in range(test_num):
|
113 |
+
data_batch = next(dataloader)
|
114 |
+
image = data_batch['image']
|
115 |
+
assert image.shape == (bs, 3, res, res)
|
116 |
+
res *= 2
|
117 |
+
bs += 1
|
118 |
+
dataloader.overwrite_param(batch_size=bs, resolution=res)
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == '__main__':
|
122 |
+
parser = argparse.ArgumentParser(description='Test Data Loader.')
|
123 |
+
parser.add_argument('root_dir', type=str,
|
124 |
+
help='Root directory of the dataset.')
|
125 |
+
parser.add_argument('--test_num', type=int, default=10,
|
126 |
+
help='Number of tests. (default: %(default)s)')
|
127 |
+
args = parser.parse_args()
|
128 |
+
dataloader_test(args.root_dir, args.test_num)
|
networks/genforce/datasets/datasets.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the class of dataset."""
|
3 |
+
|
4 |
+
import os
|
5 |
+
import pickle
|
6 |
+
import string
|
7 |
+
import zipfile
|
8 |
+
import numpy as np
|
9 |
+
import cv2
|
10 |
+
import lmdb
|
11 |
+
|
12 |
+
import torch
|
13 |
+
from torch.utils.data import Dataset
|
14 |
+
|
15 |
+
from .transforms import progressive_resize_image
|
16 |
+
from .transforms import crop_resize_image
|
17 |
+
from .transforms import resize_image
|
18 |
+
from .transforms import normalize_image
|
19 |
+
|
20 |
+
try:
|
21 |
+
import turbojpeg
|
22 |
+
BASE_DIR = os.path.dirname(os.path.relpath(__file__))
|
23 |
+
LIBRARY_NAME = 'libturbojpeg.so.0'
|
24 |
+
LIBRARY_PATH = os.path.join(BASE_DIR, LIBRARY_NAME)
|
25 |
+
jpeg = turbojpeg.TurboJPEG(LIBRARY_PATH)
|
26 |
+
except ImportError:
|
27 |
+
jpeg = None
|
28 |
+
|
29 |
+
__all__ = ['BaseDataset']
|
30 |
+
|
31 |
+
_FORMATS_ALLOWED = ['dir', 'lmdb', 'list', 'zip']
|
32 |
+
|
33 |
+
|
34 |
+
class ZipLoader(object):
|
35 |
+
"""Defines a class to load zip file.
|
36 |
+
|
37 |
+
This is a static class, which is used to solve the problem that different
|
38 |
+
data workers can not share the same memory.
|
39 |
+
"""
|
40 |
+
files = dict()
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def get_zipfile(file_path):
|
44 |
+
"""Fetches a zip file."""
|
45 |
+
zip_files = ZipLoader.files
|
46 |
+
if file_path not in zip_files:
|
47 |
+
zip_files[file_path] = zipfile.ZipFile(file_path, 'r')
|
48 |
+
return zip_files[file_path]
|
49 |
+
|
50 |
+
@staticmethod
|
51 |
+
def get_image(file_path, image_path):
|
52 |
+
"""Decodes an image from a particular zip file."""
|
53 |
+
zip_file = ZipLoader.get_zipfile(file_path)
|
54 |
+
image_str = zip_file.read(image_path)
|
55 |
+
image_np = np.frombuffer(image_str, np.uint8)
|
56 |
+
image = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
|
57 |
+
return image
|
58 |
+
|
59 |
+
|
60 |
+
class LmdbLoader(object):
|
61 |
+
"""Defines a class to load lmdb file.
|
62 |
+
|
63 |
+
This is a static class, which is used to solve lmdb loading error
|
64 |
+
when num_workers > 0
|
65 |
+
"""
|
66 |
+
files = dict()
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def get_lmdbfile(file_path):
|
70 |
+
"""Fetches a lmdb file"""
|
71 |
+
lmdb_files = LmdbLoader.files
|
72 |
+
if 'env' not in lmdb_files:
|
73 |
+
env = lmdb.open(file_path,
|
74 |
+
max_readers=1,
|
75 |
+
readonly=True,
|
76 |
+
lock=False,
|
77 |
+
readahead=False,
|
78 |
+
meminit=False)
|
79 |
+
with env.begin(write=False) as txn:
|
80 |
+
num_samples = txn.stat()['entries']
|
81 |
+
cache_file = '_cache_' + ''.join(
|
82 |
+
c for c in file_path if c in string.ascii_letters)
|
83 |
+
if os.path.isfile(cache_file):
|
84 |
+
keys = pickle.load(open(cache_file, "rb"))
|
85 |
+
else:
|
86 |
+
with env.begin(write=False) as txn:
|
87 |
+
keys = [key for key, _ in txn.cursor()]
|
88 |
+
pickle.dump(keys, open(cache_file, "wb"))
|
89 |
+
lmdb_files['env'] = env
|
90 |
+
lmdb_files['num_samples'] = num_samples
|
91 |
+
lmdb_files['keys'] = keys
|
92 |
+
return lmdb_files
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
def get_image(file_path, idx):
|
96 |
+
"""Decodes an image from a particular lmdb file"""
|
97 |
+
lmdb_files = LmdbLoader.get_lmdbfile(file_path)
|
98 |
+
env = lmdb_files['env']
|
99 |
+
keys = lmdb_files['keys']
|
100 |
+
with env.begin(write=False) as txn:
|
101 |
+
imagebuf = txn.get(keys[idx])
|
102 |
+
image_np = np.frombuffer(imagebuf, np.uint8)
|
103 |
+
image = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
|
104 |
+
return image
|
105 |
+
|
106 |
+
|
107 |
+
class BaseDataset(Dataset):
|
108 |
+
"""Defines the base dataset class.
|
109 |
+
|
110 |
+
This class supports loading data from a full-of-image folder, a lmdb
|
111 |
+
database, or an image list. Images will be pre-processed based on the given
|
112 |
+
`transform` function before fed into the data loader.
|
113 |
+
|
114 |
+
NOTE: The loaded data will be returned as a directory, where there must be
|
115 |
+
a key `image`.
|
116 |
+
"""
|
117 |
+
def __init__(self,
|
118 |
+
root_dir,
|
119 |
+
resolution,
|
120 |
+
data_format='dir',
|
121 |
+
image_list_path=None,
|
122 |
+
mirror=0.0,
|
123 |
+
progressive_resize=True,
|
124 |
+
crop_resize_resolution=-1,
|
125 |
+
transform=normalize_image,
|
126 |
+
transform_kwargs=None,
|
127 |
+
**_unused_kwargs):
|
128 |
+
"""Initializes the dataset.
|
129 |
+
|
130 |
+
Args:
|
131 |
+
root_dir: Root directory containing the dataset.
|
132 |
+
resolution: The resolution of the returned image.
|
133 |
+
data_format: Format the dataset is stored. Supports `dir`, `lmdb`,
|
134 |
+
and `list`. (default: `dir`)
|
135 |
+
image_list_path: Path to the image list. This field is required if
|
136 |
+
`data_format` is `list`. (default: None)
|
137 |
+
mirror: The probability to do mirror augmentation. (default: 0.0)
|
138 |
+
progressive_resize: Whether to resize images progressively.
|
139 |
+
(default: True)
|
140 |
+
crop_resize_resolution: The resolution of the output after crop
|
141 |
+
and resize. (default: -1)
|
142 |
+
transform: The transform function for pre-processing.
|
143 |
+
(default: `datasets.transforms.normalize_image()`)
|
144 |
+
transform_kwargs: The additional arguments for the `transform`
|
145 |
+
function. (default: None)
|
146 |
+
|
147 |
+
Raises:
|
148 |
+
ValueError: If the input `data_format` is not supported.
|
149 |
+
NotImplementedError: If the input `data_format` is not implemented.
|
150 |
+
"""
|
151 |
+
if data_format.lower() not in _FORMATS_ALLOWED:
|
152 |
+
raise ValueError(f'Invalid data format `{data_format}`!\n'
|
153 |
+
f'Supported formats: {_FORMATS_ALLOWED}.')
|
154 |
+
|
155 |
+
self.root_dir = root_dir
|
156 |
+
self.resolution = resolution
|
157 |
+
self.data_format = data_format.lower()
|
158 |
+
self.image_list_path = image_list_path
|
159 |
+
self.mirror = np.clip(mirror, 0.0, 1.0)
|
160 |
+
self.progressive_resize = progressive_resize
|
161 |
+
self.crop_resize_resolution = crop_resize_resolution
|
162 |
+
self.transform = transform
|
163 |
+
self.transform_kwargs = transform_kwargs or dict()
|
164 |
+
|
165 |
+
if self.data_format == 'dir':
|
166 |
+
self.image_paths = sorted(os.listdir(self.root_dir))
|
167 |
+
self.num_samples = len(self.image_paths)
|
168 |
+
elif self.data_format == 'lmdb':
|
169 |
+
lmdb_file = LmdbLoader.get_lmdbfile(self.root_dir)
|
170 |
+
self.num_samples = lmdb_file['num_samples']
|
171 |
+
elif self.data_format == 'list':
|
172 |
+
self.metas = []
|
173 |
+
assert os.path.isfile(self.image_list_path)
|
174 |
+
with open(self.image_list_path) as f:
|
175 |
+
for line in f:
|
176 |
+
fields = line.rstrip().split(' ')
|
177 |
+
if len(fields) == 1:
|
178 |
+
self.metas.append((fields[0], None))
|
179 |
+
else:
|
180 |
+
assert len(fields) == 2
|
181 |
+
self.metas.append((fields[0], int(fields[1])))
|
182 |
+
self.num_samples = len(self.metas)
|
183 |
+
elif self.data_format == 'zip':
|
184 |
+
zip_file = ZipLoader.get_zipfile(self.root_dir)
|
185 |
+
image_paths = [f for f in zip_file.namelist()
|
186 |
+
if ('.jpg' in f or '.jpeg' in f or '.png' in f)]
|
187 |
+
self.image_paths = sorted(image_paths)
|
188 |
+
self.num_samples = len(self.image_paths)
|
189 |
+
else:
|
190 |
+
raise NotImplementedError(f'Not implemented data format '
|
191 |
+
f'`{self.data_format}`!')
|
192 |
+
|
193 |
+
def __len__(self):
|
194 |
+
return self.num_samples
|
195 |
+
|
196 |
+
def __getitem__(self, idx):
|
197 |
+
data = dict()
|
198 |
+
|
199 |
+
# Load data.
|
200 |
+
if self.data_format == 'dir':
|
201 |
+
image_path = self.image_paths[idx]
|
202 |
+
try:
|
203 |
+
in_file = open(os.path.join(self.root_dir, image_path), 'rb')
|
204 |
+
image = jpeg.decode(in_file.read())
|
205 |
+
except: # pylint: disable=bare-except
|
206 |
+
image = cv2.imread(os.path.join(self.root_dir, image_path))
|
207 |
+
elif self.data_format == 'lmdb':
|
208 |
+
image = LmdbLoader.get_image(self.root_dir, idx)
|
209 |
+
elif self.data_format == 'list':
|
210 |
+
image_path, label = self.metas[idx]
|
211 |
+
image = cv2.imread(os.path.join(self.root_dir, image_path))
|
212 |
+
label = None if label is None else torch.LongTensor(label)
|
213 |
+
# data.update({'label': label})
|
214 |
+
elif self.data_format == 'zip':
|
215 |
+
image_path = self.image_paths[idx]
|
216 |
+
image = ZipLoader.get_image(self.root_dir, image_path)
|
217 |
+
else:
|
218 |
+
raise NotImplementedError(f'Not implemented data format '
|
219 |
+
f'`{self.data_format}`!')
|
220 |
+
|
221 |
+
image = image[:, :, ::-1] # Converts BGR (cv2) to RGB.
|
222 |
+
|
223 |
+
# Transform image.
|
224 |
+
if self.crop_resize_resolution > 0:
|
225 |
+
image = crop_resize_image(image, self.crop_resize_resolution)
|
226 |
+
if self.progressive_resize:
|
227 |
+
image = progressive_resize_image(image, self.resolution)
|
228 |
+
image = image.transpose(2, 0, 1).astype(np.float32)
|
229 |
+
if np.random.uniform() < self.mirror:
|
230 |
+
image = image[:, :, ::-1] # CHW
|
231 |
+
image = torch.FloatTensor(image.copy())
|
232 |
+
if not self.progressive_resize:
|
233 |
+
image = resize_image(image, self.resolution)
|
234 |
+
|
235 |
+
if self.transform is not None:
|
236 |
+
image = self.transform(image, **self.transform_kwargs)
|
237 |
+
data.update({'image': image})
|
238 |
+
|
239 |
+
return data
|
networks/genforce/datasets/distributed_sampler.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the distributed data sampler.
|
3 |
+
|
4 |
+
This file is mostly borrowed from `torch/utils/data/distributed.py`.
|
5 |
+
|
6 |
+
However, sometimes, initialize the data loader and data sampler can be time
|
7 |
+
consuming (since it will load a large amount of data at one time). To avoid
|
8 |
+
re-initializing the data loader again and again, we modified the sampler to
|
9 |
+
support loading the data for only one time and then repeating the data loader.
|
10 |
+
Please use the class member `repeat` to control how many times you want the
|
11 |
+
data load to repeat. After `repeat` times, the data will be re-loaded.
|
12 |
+
|
13 |
+
NOTE: The number of repeat times should not be very large, especially when there
|
14 |
+
are too many samples in the dataset. We recommend to set `repeat = 500` for
|
15 |
+
datasets with ~50K samples.
|
16 |
+
"""
|
17 |
+
|
18 |
+
# pylint: disable=line-too-long
|
19 |
+
|
20 |
+
import math
|
21 |
+
from typing import TypeVar, Optional, Iterator
|
22 |
+
|
23 |
+
import torch
|
24 |
+
from torch.utils.data import Sampler, Dataset
|
25 |
+
import torch.distributed as dist
|
26 |
+
|
27 |
+
|
28 |
+
T_co = TypeVar('T_co', covariant=True)
|
29 |
+
|
30 |
+
|
31 |
+
class DistributedSampler(Sampler):
|
32 |
+
r"""Sampler that restricts data loading to a subset of the dataset.
|
33 |
+
|
34 |
+
It is especially useful in conjunction with
|
35 |
+
:class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
|
36 |
+
process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a
|
37 |
+
:class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
|
38 |
+
original dataset that is exclusive to it.
|
39 |
+
|
40 |
+
.. note::
|
41 |
+
Dataset is assumed to be of constant size.
|
42 |
+
|
43 |
+
Arguments:
|
44 |
+
dataset: Dataset used for sampling.
|
45 |
+
num_replicas (int, optional): Number of processes participating in
|
46 |
+
distributed training. By default, :attr:`rank` is retrieved from the
|
47 |
+
current distributed group.
|
48 |
+
rank (int, optional): Rank of the current process within :attr:`num_replicas`.
|
49 |
+
By default, :attr:`rank` is retrieved from the current distributed
|
50 |
+
group.
|
51 |
+
shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
|
52 |
+
indices.
|
53 |
+
seed (int, optional): random seed used to shuffle the sampler if
|
54 |
+
:attr:`shuffle=True`. This number should be identical across all
|
55 |
+
processes in the distributed group. Default: ``0``.
|
56 |
+
drop_last (bool, optional): if ``True``, then the sampler will drop the
|
57 |
+
tail of the data to make it evenly divisible across the number of
|
58 |
+
replicas. If ``False``, the sampler will add extra indices to make
|
59 |
+
the data evenly divisible across the replicas. Default: ``False``.
|
60 |
+
current_iter (int, optional): Number of current iteration. Default: ``0``.
|
61 |
+
repeat (int, optional): Repeating number of the whole dataloader. Default: ``1000``.
|
62 |
+
|
63 |
+
.. warning::
|
64 |
+
In distributed mode, calling the :meth:`set_epoch` method at
|
65 |
+
the beginning of each epoch **before** creating the :class:`DataLoader` iterator
|
66 |
+
is necessary to make shuffling work properly across multiple epochs. Otherwise,
|
67 |
+
the same ordering will be always used.
|
68 |
+
|
69 |
+
"""
|
70 |
+
|
71 |
+
def __init__(self, dataset: Dataset, num_replicas: Optional[int] = None,
|
72 |
+
rank: Optional[int] = None, shuffle: bool = True,
|
73 |
+
seed: int = 0, drop_last: bool = False, current_iter: int = 0,
|
74 |
+
repeat: int = 1000) -> None:
|
75 |
+
super().__init__(None)
|
76 |
+
if num_replicas is None:
|
77 |
+
if not dist.is_available():
|
78 |
+
raise RuntimeError("Requires distributed package to be available")
|
79 |
+
num_replicas = dist.get_world_size()
|
80 |
+
if rank is None:
|
81 |
+
if not dist.is_available():
|
82 |
+
raise RuntimeError("Requires distributed package to be available")
|
83 |
+
rank = dist.get_rank()
|
84 |
+
self.dataset = dataset
|
85 |
+
self.num_replicas = num_replicas
|
86 |
+
self.rank = rank
|
87 |
+
self.iter = current_iter
|
88 |
+
self.drop_last = drop_last
|
89 |
+
|
90 |
+
# NOTE: self.dataset_length is `repeat X len(self.dataset)`
|
91 |
+
self.repeat = repeat
|
92 |
+
self.dataset_length = len(self.dataset) * self.repeat
|
93 |
+
|
94 |
+
if self.drop_last and self.dataset_length % self.num_replicas != 0:
|
95 |
+
# Split to nearest available length that is evenly divisible.
|
96 |
+
# This is to ensure each rank receives the same amount of data when
|
97 |
+
# using this Sampler.
|
98 |
+
self.num_samples = math.ceil(
|
99 |
+
(self.dataset_length - self.num_replicas) / self.num_replicas
|
100 |
+
)
|
101 |
+
else:
|
102 |
+
self.num_samples = math.ceil(self.dataset_length / self.num_replicas)
|
103 |
+
|
104 |
+
|
105 |
+
self.total_size = self.num_samples * self.num_replicas
|
106 |
+
self.shuffle = shuffle
|
107 |
+
self.seed = seed
|
108 |
+
self.__generate_indices__()
|
109 |
+
|
110 |
+
def __generate_indices__(self) -> None:
|
111 |
+
g = torch.Generator()
|
112 |
+
indices_bank = []
|
113 |
+
for iter_ in range(self.iter, self.iter + self.repeat):
|
114 |
+
g.manual_seed(self.seed + iter_)
|
115 |
+
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
116 |
+
indices_bank.extend(indices)
|
117 |
+
self.indices = indices_bank
|
118 |
+
|
119 |
+
def __iter__(self) -> Iterator[T_co]:
|
120 |
+
if self.shuffle:
|
121 |
+
# deterministically shuffle based on iter and seed
|
122 |
+
indices = self.indices
|
123 |
+
else:
|
124 |
+
indices = list(range(self.dataset_length))
|
125 |
+
|
126 |
+
if not self.drop_last:
|
127 |
+
# add extra samples to make it evenly divisible
|
128 |
+
indices += indices[:(self.total_size - len(indices))]
|
129 |
+
else:
|
130 |
+
# remove tail of data to make it evenly divisible.
|
131 |
+
indices = indices[:self.total_size]
|
132 |
+
|
133 |
+
# subsample
|
134 |
+
indices = indices[self.rank:self.total_size:self.num_replicas]
|
135 |
+
return iter(indices)
|
136 |
+
|
137 |
+
def __len__(self) -> int:
|
138 |
+
return self.num_samples
|
139 |
+
|
140 |
+
def __reset__(self, iteration: int) -> None:
|
141 |
+
self.iter = iteration
|
142 |
+
self.__generate_indices__()
|
143 |
+
|
144 |
+
# pylint: enable=line-too-long
|
networks/genforce/datasets/libturbojpeg.so.0
ADDED
Binary file (396 kB). View file
|
|
networks/genforce/datasets/transforms.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Contains transform functions."""
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import PIL.Image
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
__all__ = [
|
13 |
+
'crop_resize_image', 'progressive_resize_image', 'resize_image',
|
14 |
+
'normalize_image', 'normalize_latent_code', 'ImageResizing',
|
15 |
+
'ImageNormalization', 'LatentCodeNormalization',
|
16 |
+
]
|
17 |
+
|
18 |
+
|
19 |
+
def crop_resize_image(image, size):
|
20 |
+
"""Crops a square patch and then resizes it to the given size.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
image: The input image to crop and resize.
|
24 |
+
size: An integer, indicating the target size.
|
25 |
+
|
26 |
+
Returns:
|
27 |
+
An image with target size.
|
28 |
+
|
29 |
+
Raises:
|
30 |
+
TypeError: If the input `image` is not with type `numpy.ndarray`.
|
31 |
+
ValueError: If the input `image` is not with shape [H, W, C].
|
32 |
+
"""
|
33 |
+
if not isinstance(image, np.ndarray):
|
34 |
+
raise TypeError(f'Input image should be with type `numpy.ndarray`, '
|
35 |
+
f'but `{type(image)}` is received!')
|
36 |
+
if image.ndim != 3:
|
37 |
+
raise ValueError(f'Input image should be with shape [H, W, C], '
|
38 |
+
f'but `{image.shape}` is received!')
|
39 |
+
|
40 |
+
height, width, channel = image.shape
|
41 |
+
short_side = min(height, width)
|
42 |
+
image = image[(height - short_side) // 2:(height + short_side) // 2,
|
43 |
+
(width - short_side) // 2:(width + short_side) // 2]
|
44 |
+
pil_image = PIL.Image.fromarray(image)
|
45 |
+
pil_image = pil_image.resize((size, size), PIL.Image.ANTIALIAS)
|
46 |
+
image = np.asarray(pil_image)
|
47 |
+
assert image.shape == (size, size, channel)
|
48 |
+
return image
|
49 |
+
|
50 |
+
|
51 |
+
def progressive_resize_image(image, size):
|
52 |
+
"""Resizes image to target size progressively.
|
53 |
+
|
54 |
+
Different from normal resize, this function will reduce the image size
|
55 |
+
progressively. In each step, the maximum reduce factor is 2.
|
56 |
+
|
57 |
+
NOTE: This function can only handle square images, and can only be used for
|
58 |
+
downsampling.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
image: The input (square) image to resize.
|
62 |
+
size: An integer, indicating the target size.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
An image with target size.
|
66 |
+
|
67 |
+
Raises:
|
68 |
+
TypeError: If the input `image` is not with type `numpy.ndarray`.
|
69 |
+
ValueError: If the input `image` is not with shape [H, W, C].
|
70 |
+
"""
|
71 |
+
if not isinstance(image, np.ndarray):
|
72 |
+
raise TypeError(f'Input image should be with type `numpy.ndarray`, '
|
73 |
+
f'but `{type(image)}` is received!')
|
74 |
+
if image.ndim != 3:
|
75 |
+
raise ValueError(f'Input image should be with shape [H, W, C], '
|
76 |
+
f'but `{image.shape}` is received!')
|
77 |
+
|
78 |
+
height, width, channel = image.shape
|
79 |
+
assert height == width
|
80 |
+
assert height >= size
|
81 |
+
num_iters = int(np.log2(height) - np.log2(size))
|
82 |
+
for _ in range(num_iters):
|
83 |
+
height = max(height // 2, size)
|
84 |
+
image = cv2.resize(image, (height, height),
|
85 |
+
interpolation=cv2.INTER_LINEAR)
|
86 |
+
assert image.shape == (size, size, channel)
|
87 |
+
return image
|
88 |
+
|
89 |
+
|
90 |
+
def resize_image(image, size):
|
91 |
+
"""Resizes image to target size.
|
92 |
+
|
93 |
+
NOTE: We use adaptive average pooing for image resizing. Instead of bilinear
|
94 |
+
interpolation, average pooling is able to acquire information from more
|
95 |
+
pixels, such that the resized results can be with higher quality.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
image: The input image tensor, with shape [C, H, W], to resize.
|
99 |
+
size: An integer or a tuple of integer, indicating the target size.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
An image tensor with target size.
|
103 |
+
|
104 |
+
Raises:
|
105 |
+
TypeError: If the input `image` is not with type `torch.Tensor`.
|
106 |
+
ValueError: If the input `image` is not with shape [C, H, W].
|
107 |
+
"""
|
108 |
+
if not isinstance(image, torch.Tensor):
|
109 |
+
raise TypeError(f'Input image should be with type `torch.Tensor`, '
|
110 |
+
f'but `{type(image)}` is received!')
|
111 |
+
if image.ndim != 3:
|
112 |
+
raise ValueError(f'Input image should be with shape [C, H, W], '
|
113 |
+
f'but `{image.shape}` is received!')
|
114 |
+
|
115 |
+
image = F.adaptive_avg_pool2d(image.unsqueeze(0), size).squeeze(0)
|
116 |
+
return image
|
117 |
+
|
118 |
+
|
119 |
+
def normalize_image(image, mean=127.5, std=127.5):
|
120 |
+
"""Normalizes image by subtracting mean and dividing std.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
image: The input image tensor to normalize.
|
124 |
+
mean: The mean value to subtract from the input tensor. (default: 127.5)
|
125 |
+
std: The standard deviation to normalize the input tensor. (default:
|
126 |
+
127.5)
|
127 |
+
|
128 |
+
Returns:
|
129 |
+
A normalized image tensor.
|
130 |
+
|
131 |
+
Raises:
|
132 |
+
TypeError: If the input `image` is not with type `torch.Tensor`.
|
133 |
+
"""
|
134 |
+
if not isinstance(image, torch.Tensor):
|
135 |
+
raise TypeError(f'Input image should be with type `torch.Tensor`, '
|
136 |
+
f'but `{type(image)}` is received!')
|
137 |
+
out = (image - mean) / std
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
def normalize_latent_code(latent_code, adjust_norm=True):
|
142 |
+
"""Normalizes latent code.
|
143 |
+
|
144 |
+
NOTE: The latent code will always be normalized along the last axis.
|
145 |
+
Meanwhile, if `adjust_norm` is set as `True`, the norm of the result will be
|
146 |
+
adjusted to `sqrt(latent_code.shape[-1])` in order to avoid too small value.
|
147 |
+
|
148 |
+
Args:
|
149 |
+
latent_code: The input latent code tensor to normalize.
|
150 |
+
adjust_norm: Whether to adjust the norm of the output. (default: True)
|
151 |
+
|
152 |
+
Returns:
|
153 |
+
A normalized latent code tensor.
|
154 |
+
|
155 |
+
Raises:
|
156 |
+
TypeError: If the input `latent_code` is not with type `torch.Tensor`.
|
157 |
+
"""
|
158 |
+
if not isinstance(latent_code, torch.Tensor):
|
159 |
+
raise TypeError(f'Input latent code should be with type '
|
160 |
+
f'`torch.Tensor`, but `{type(latent_code)}` is '
|
161 |
+
f'received!')
|
162 |
+
dim = latent_code.shape[-1]
|
163 |
+
norm = latent_code.pow(2).sum(-1, keepdim=True).pow(0.5)
|
164 |
+
out = latent_code / norm
|
165 |
+
if adjust_norm:
|
166 |
+
out = out * (dim ** 0.5)
|
167 |
+
return out
|
168 |
+
|
169 |
+
|
170 |
+
class ImageResizing(nn.Module):
|
171 |
+
"""Implements the image resizing layer."""
|
172 |
+
|
173 |
+
def __init__(self, size):
|
174 |
+
super().__init__()
|
175 |
+
self.size = size
|
176 |
+
|
177 |
+
def forward(self, image):
|
178 |
+
return resize_image(image, self.size)
|
179 |
+
|
180 |
+
|
181 |
+
class ImageNormalization(nn.Module):
|
182 |
+
"""Implements the image normalization layer."""
|
183 |
+
|
184 |
+
def __init__(self, mean=127.5, std=127.5):
|
185 |
+
super().__init__()
|
186 |
+
self.mean = mean
|
187 |
+
self.std = std
|
188 |
+
|
189 |
+
def forward(self, image):
|
190 |
+
return normalize_image(image, self.mean, self.std)
|
191 |
+
|
192 |
+
|
193 |
+
class LatentCodeNormalization(nn.Module):
|
194 |
+
"""Implements the latent code normalization layer."""
|
195 |
+
|
196 |
+
def __init__(self, adjust_norm=True):
|
197 |
+
super().__init__()
|
198 |
+
self.adjust_norm = adjust_norm
|
199 |
+
|
200 |
+
def forward(self, latent_code):
|
201 |
+
return normalize_latent_code(latent_code, self.adjust_norm)
|
networks/genforce/metrics/README.md
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Evaluation Metrics
|
2 |
+
|
3 |
+
Frechet Inception Distance (FID) is commonly used to evaluate generative model. It employs an [Inception Model](https://arxiv.org/abs/1512.00567) (pretrained on ImageNet) to extract features from both real and synthesized images.
|
4 |
+
|
5 |
+
## Inception Model
|
6 |
+
|
7 |
+
For [PGGAN](https://github.com/tkarras/progressive_growing_of_gans), [StyleGAN](https://github.com/NVlabs/stylegan), etc, they use inception model from the [TensorFlow Models](https://github.com/tensorflow/models) repository, whose implementation is slightly different from that of `torchvision`. Hence, to make the evaluation metric comparable between different training frameworks (i.e., PyTorch and TensorFlow), we modify `torchvision/models/inception.py` as `inception.py`. The ported pre-trained weight is borrowed from [this repo](https://github.com/mseitzer/pytorch-fid).
|
8 |
+
|
9 |
+
**NOTE:** We also support using the model from `torchvision` to compute the FID. However, please be aware that the FID value from `torchvision` is usually ~1.5 smaller than that from the TensorFlow model.
|
10 |
+
|
11 |
+
Please use the following code to choose which model to use.
|
12 |
+
|
13 |
+
```python
|
14 |
+
from metrics.inception import build_inception_model
|
15 |
+
|
16 |
+
inception_model_tf = build_inception_model(align_tf=True)
|
17 |
+
inception_model_pth = build_inception_model(align_tf=False)
|
18 |
+
```
|
networks/genforce/metrics/__init__.py
ADDED
File without changes
|
networks/genforce/metrics/fid.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the functions to compute Frechet Inception Distance (FID).
|
3 |
+
|
4 |
+
FID metric is introduced in paper
|
5 |
+
|
6 |
+
GANs Trained by a Two Time-Scale Update Rule Converge to a Local Nash
|
7 |
+
Equilibrium. Heusel et al. NeurIPS 2017.
|
8 |
+
|
9 |
+
See details at https://arxiv.org/pdf/1706.08500.pdf
|
10 |
+
"""
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import scipy.linalg
|
14 |
+
|
15 |
+
__all__ = ['extract_feature', 'compute_fid']
|
16 |
+
|
17 |
+
|
18 |
+
def extract_feature(inception_model, images):
|
19 |
+
"""Extracts feature from input images with given model.
|
20 |
+
|
21 |
+
NOTE: The input images are assumed to be with pixel range [-1, 1].
|
22 |
+
|
23 |
+
Args:
|
24 |
+
inception_model: The model used to extract features.
|
25 |
+
images: The input image tensor to extract features from.
|
26 |
+
|
27 |
+
Returns:
|
28 |
+
A `numpy.ndarray`, containing the extracted features.
|
29 |
+
"""
|
30 |
+
features = inception_model(images, output_logits=False)
|
31 |
+
features = features.detach().cpu().numpy()
|
32 |
+
assert features.ndim == 2 and features.shape[1] == 2048
|
33 |
+
return features
|
34 |
+
|
35 |
+
|
36 |
+
def compute_fid(fake_features, real_features):
|
37 |
+
"""Computes FID based on the features extracted from fake and real data.
|
38 |
+
|
39 |
+
Given the mean and covariance (m_f, C_f) of fake data and (m_r, C_r) of real
|
40 |
+
data, the FID metric can be computed by
|
41 |
+
|
42 |
+
d^2 = ||m_f - m_r||_2^2 + Tr(C_f + C_r - 2(C_f C_r)^0.5)
|
43 |
+
|
44 |
+
Args:
|
45 |
+
fake_features: The features extracted from fake data.
|
46 |
+
real_features: The features extracted from real data.
|
47 |
+
|
48 |
+
Returns:
|
49 |
+
A real number, suggesting the FID value.
|
50 |
+
"""
|
51 |
+
|
52 |
+
m_f = np.mean(fake_features, axis=0)
|
53 |
+
C_f = np.cov(fake_features, rowvar=False)
|
54 |
+
m_r = np.mean(real_features, axis=0)
|
55 |
+
C_r = np.cov(real_features, rowvar=False)
|
56 |
+
|
57 |
+
fid = np.sum((m_f - m_r) ** 2) + np.trace(
|
58 |
+
C_f + C_r - 2 * scipy.linalg.sqrtm(np.dot(C_f, C_r)))
|
59 |
+
return np.real(fid)
|
networks/genforce/metrics/inception.py
ADDED
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the Inception V3 model.
|
3 |
+
|
4 |
+
This file is mostly borrowed from `torchvision/models/inception.py`.
|
5 |
+
|
6 |
+
Inception model is widely used to compute FID or IS metric for evaluating
|
7 |
+
generative models. However, the pre-trained models from torchvision is slightly
|
8 |
+
different from the TensorFlow version.
|
9 |
+
|
10 |
+
http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
11 |
+
|
12 |
+
In particular:
|
13 |
+
|
14 |
+
(1) The number of classes in TensorFlow model is 1008 instead of 1000.
|
15 |
+
(2) The avg_pool() layers in TensorFlow model does not include the padded zero.
|
16 |
+
(3) The last Inception E Block in TensorFlow model use max_pool() instead of
|
17 |
+
avg_pool().
|
18 |
+
|
19 |
+
Hence, to algin the evaluation results with those from TensorFlow
|
20 |
+
implementation, we modified the inception model to support both versions. Please
|
21 |
+
use `align_tf` argument to control the version.
|
22 |
+
"""
|
23 |
+
|
24 |
+
# pylint: disable=line-too-long
|
25 |
+
# pylint: disable=missing-function-docstring
|
26 |
+
# pylint: disable=missing-class-docstring
|
27 |
+
# pylint: disable=super-with-arguments
|
28 |
+
# pylint: disable=consider-merging-isinstance
|
29 |
+
# pylint: disable=import-outside-toplevel
|
30 |
+
# pylint: disable=no-else-return
|
31 |
+
|
32 |
+
from collections import namedtuple
|
33 |
+
import warnings
|
34 |
+
import torch
|
35 |
+
import torch.nn as nn
|
36 |
+
import torch.nn.functional as F
|
37 |
+
from torch.jit.annotations import Optional
|
38 |
+
from torch import Tensor
|
39 |
+
from torchvision.models.utils import load_state_dict_from_url
|
40 |
+
|
41 |
+
|
42 |
+
__all__ = ['build_inception_model', 'Inception3', 'inception_v3', 'InceptionOutputs', '_InceptionOutputs']
|
43 |
+
|
44 |
+
model_urls = {
|
45 |
+
# Inception v3 ported from TensorFlow
|
46 |
+
'inception_v3_google': 'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
|
47 |
+
|
48 |
+
# Inception v3 ported from http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
49 |
+
# This model is provided by https://github.com/mseitzer/pytorch-fid
|
50 |
+
'tf_inception_v3': 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth'
|
51 |
+
}
|
52 |
+
|
53 |
+
InceptionOutputs = namedtuple('InceptionOutputs', ['logits', 'aux_logits'])
|
54 |
+
InceptionOutputs.__annotations__ = {'logits': torch.Tensor, 'aux_logits': Optional[torch.Tensor]}
|
55 |
+
|
56 |
+
# Script annotations failed with _GoogleNetOutputs = namedtuple ...
|
57 |
+
# _InceptionOutputs set here for backwards compat
|
58 |
+
_InceptionOutputs = InceptionOutputs
|
59 |
+
|
60 |
+
|
61 |
+
def build_inception_model(align_tf=True):
|
62 |
+
"""Builds Inception V3 model.
|
63 |
+
|
64 |
+
This model is particular used for inference, such that `requires_grad` and
|
65 |
+
`mode` will both be set as `False`.
|
66 |
+
|
67 |
+
Args:
|
68 |
+
align_tf: Whether to align the implementation with TensorFlow version. (default: True)
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
A `torch.nn.Module` with pre-trained weight.
|
72 |
+
"""
|
73 |
+
if align_tf:
|
74 |
+
num_classes = 1008
|
75 |
+
model_url = model_urls['tf_inception_v3']
|
76 |
+
else:
|
77 |
+
num_classes = 1000
|
78 |
+
model_url = model_urls['inception_v3_google']
|
79 |
+
model = Inception3(num_classes=num_classes,
|
80 |
+
aux_logits=False,
|
81 |
+
transform_input=False,
|
82 |
+
align_tf=align_tf)
|
83 |
+
state_dict = load_state_dict_from_url(model_url)
|
84 |
+
model.load_state_dict(state_dict, strict=False)
|
85 |
+
model.eval()
|
86 |
+
for param in model.parameters():
|
87 |
+
param.requires_grad = False
|
88 |
+
return model
|
89 |
+
|
90 |
+
|
91 |
+
def inception_v3(pretrained=False, progress=True, **kwargs):
|
92 |
+
r"""Inception v3 model architecture from
|
93 |
+
`"Rethinking the Inception Architecture for Computer Vision" <http://arxiv.org/abs/1512.00567>`_.
|
94 |
+
|
95 |
+
.. note::
|
96 |
+
**Important**: In contrast to the other models the inception_v3 expects tensors with a size of
|
97 |
+
N x 3 x 299 x 299, so ensure your images are sized accordingly.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
101 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
102 |
+
aux_logits (bool): If True, add an auxiliary branch that can improve training.
|
103 |
+
Default: *True*
|
104 |
+
transform_input (bool): If True, preprocesses the input according to the method with which it
|
105 |
+
was trained on ImageNet. Default: *False*
|
106 |
+
"""
|
107 |
+
if pretrained:
|
108 |
+
if 'transform_input' not in kwargs:
|
109 |
+
kwargs['transform_input'] = True
|
110 |
+
if 'aux_logits' in kwargs:
|
111 |
+
original_aux_logits = kwargs['aux_logits']
|
112 |
+
kwargs['aux_logits'] = True
|
113 |
+
else:
|
114 |
+
original_aux_logits = True
|
115 |
+
model = Inception3(**kwargs)
|
116 |
+
state_dict = load_state_dict_from_url(model_urls['inception_v3_google'],
|
117 |
+
progress=progress)
|
118 |
+
model.load_state_dict(state_dict)
|
119 |
+
if not original_aux_logits:
|
120 |
+
model.aux_logits = False
|
121 |
+
del model.AuxLogits
|
122 |
+
return model
|
123 |
+
|
124 |
+
return Inception3(**kwargs)
|
125 |
+
|
126 |
+
|
127 |
+
class Inception3(nn.Module):
|
128 |
+
|
129 |
+
def __init__(self, num_classes=1000, aux_logits=True, transform_input=False,
|
130 |
+
inception_blocks=None, init_weights=True, align_tf=True):
|
131 |
+
super(Inception3, self).__init__()
|
132 |
+
if inception_blocks is None:
|
133 |
+
inception_blocks = [
|
134 |
+
BasicConv2d, InceptionA, InceptionB, InceptionC,
|
135 |
+
InceptionD, InceptionE, InceptionAux
|
136 |
+
]
|
137 |
+
assert len(inception_blocks) == 7
|
138 |
+
conv_block = inception_blocks[0]
|
139 |
+
inception_a = inception_blocks[1]
|
140 |
+
inception_b = inception_blocks[2]
|
141 |
+
inception_c = inception_blocks[3]
|
142 |
+
inception_d = inception_blocks[4]
|
143 |
+
inception_e = inception_blocks[5]
|
144 |
+
inception_aux = inception_blocks[6]
|
145 |
+
|
146 |
+
self.aux_logits = aux_logits
|
147 |
+
self.transform_input = transform_input
|
148 |
+
self.align_tf = align_tf
|
149 |
+
self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
|
150 |
+
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
|
151 |
+
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
|
152 |
+
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
|
153 |
+
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
|
154 |
+
self.Mixed_5b = inception_a(192, pool_features=32, align_tf=self.align_tf)
|
155 |
+
self.Mixed_5c = inception_a(256, pool_features=64, align_tf=self.align_tf)
|
156 |
+
self.Mixed_5d = inception_a(288, pool_features=64, align_tf=self.align_tf)
|
157 |
+
self.Mixed_6a = inception_b(288)
|
158 |
+
self.Mixed_6b = inception_c(768, channels_7x7=128, align_tf=self.align_tf)
|
159 |
+
self.Mixed_6c = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
|
160 |
+
self.Mixed_6d = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
|
161 |
+
self.Mixed_6e = inception_c(768, channels_7x7=192, align_tf=self.align_tf)
|
162 |
+
if aux_logits:
|
163 |
+
self.AuxLogits = inception_aux(768, num_classes)
|
164 |
+
self.Mixed_7a = inception_d(768)
|
165 |
+
self.Mixed_7b = inception_e(1280, align_tf=self.align_tf)
|
166 |
+
self.Mixed_7c = inception_e(2048, use_max_pool=self.align_tf)
|
167 |
+
self.fc = nn.Linear(2048, num_classes)
|
168 |
+
if init_weights:
|
169 |
+
for m in self.modules():
|
170 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
171 |
+
import scipy.stats as stats
|
172 |
+
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
|
173 |
+
X = stats.truncnorm(-2, 2, scale=stddev)
|
174 |
+
values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
|
175 |
+
values = values.view(m.weight.size())
|
176 |
+
with torch.no_grad():
|
177 |
+
m.weight.copy_(values)
|
178 |
+
elif isinstance(m, nn.BatchNorm2d):
|
179 |
+
nn.init.constant_(m.weight, 1)
|
180 |
+
nn.init.constant_(m.bias, 0)
|
181 |
+
|
182 |
+
def _transform_input(self, x):
|
183 |
+
if self.transform_input:
|
184 |
+
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
|
185 |
+
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
|
186 |
+
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
|
187 |
+
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
|
188 |
+
return x
|
189 |
+
|
190 |
+
def _forward(self, x, output_logits=False):
|
191 |
+
# Upsample if necessary
|
192 |
+
if x.shape[2] != 299 or x.shape[3] != 299:
|
193 |
+
x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=False)
|
194 |
+
|
195 |
+
# N x 3 x 299 x 299
|
196 |
+
x = self.Conv2d_1a_3x3(x)
|
197 |
+
# N x 32 x 149 x 149
|
198 |
+
x = self.Conv2d_2a_3x3(x)
|
199 |
+
# N x 32 x 147 x 147
|
200 |
+
x = self.Conv2d_2b_3x3(x)
|
201 |
+
# N x 64 x 147 x 147
|
202 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
203 |
+
# N x 64 x 73 x 73
|
204 |
+
x = self.Conv2d_3b_1x1(x)
|
205 |
+
# N x 80 x 73 x 73
|
206 |
+
x = self.Conv2d_4a_3x3(x)
|
207 |
+
# N x 192 x 71 x 71
|
208 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
209 |
+
# N x 192 x 35 x 35
|
210 |
+
x = self.Mixed_5b(x)
|
211 |
+
# N x 256 x 35 x 35
|
212 |
+
x = self.Mixed_5c(x)
|
213 |
+
# N x 288 x 35 x 35
|
214 |
+
x = self.Mixed_5d(x)
|
215 |
+
# N x 288 x 35 x 35
|
216 |
+
x = self.Mixed_6a(x)
|
217 |
+
# N x 768 x 17 x 17
|
218 |
+
x = self.Mixed_6b(x)
|
219 |
+
# N x 768 x 17 x 17
|
220 |
+
x = self.Mixed_6c(x)
|
221 |
+
# N x 768 x 17 x 17
|
222 |
+
x = self.Mixed_6d(x)
|
223 |
+
# N x 768 x 17 x 17
|
224 |
+
x = self.Mixed_6e(x)
|
225 |
+
# N x 768 x 17 x 17
|
226 |
+
aux_defined = self.training and self.aux_logits
|
227 |
+
if aux_defined:
|
228 |
+
aux = self.AuxLogits(x)
|
229 |
+
else:
|
230 |
+
aux = None
|
231 |
+
# N x 768 x 17 x 17
|
232 |
+
x = self.Mixed_7a(x)
|
233 |
+
# N x 1280 x 8 x 8
|
234 |
+
x = self.Mixed_7b(x)
|
235 |
+
# N x 2048 x 8 x 8
|
236 |
+
x = self.Mixed_7c(x)
|
237 |
+
# N x 2048 x 8 x 8
|
238 |
+
# Adaptive average pooling
|
239 |
+
x = F.adaptive_avg_pool2d(x, (1, 1))
|
240 |
+
# N x 2048 x 1 x 1
|
241 |
+
x = F.dropout(x, training=self.training)
|
242 |
+
# N x 2048 x 1 x 1
|
243 |
+
x = torch.flatten(x, 1)
|
244 |
+
# N x 2048
|
245 |
+
if output_logits:
|
246 |
+
x = self.fc(x)
|
247 |
+
# N x 1000 (num_classes)
|
248 |
+
return x, aux
|
249 |
+
|
250 |
+
@torch.jit.unused
|
251 |
+
def eager_outputs(self, x, aux):
|
252 |
+
# type: (Tensor, Optional[Tensor]) -> InceptionOutputs
|
253 |
+
if self.training and self.aux_logits:
|
254 |
+
return InceptionOutputs(x, aux)
|
255 |
+
else:
|
256 |
+
return x
|
257 |
+
|
258 |
+
def forward(self, x, output_logits=False):
|
259 |
+
x = self._transform_input(x)
|
260 |
+
x, aux = self._forward(x, output_logits)
|
261 |
+
aux_defined = self.training and self.aux_logits
|
262 |
+
if torch.jit.is_scripting():
|
263 |
+
if not aux_defined:
|
264 |
+
warnings.warn("Scripted Inception3 always returns Inception3 Tuple")
|
265 |
+
return InceptionOutputs(x, aux)
|
266 |
+
else:
|
267 |
+
return self.eager_outputs(x, aux)
|
268 |
+
|
269 |
+
|
270 |
+
class InceptionA(nn.Module):
|
271 |
+
|
272 |
+
def __init__(self, in_channels, pool_features, conv_block=None, align_tf=False):
|
273 |
+
super(InceptionA, self).__init__()
|
274 |
+
if conv_block is None:
|
275 |
+
conv_block = BasicConv2d
|
276 |
+
self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
|
277 |
+
|
278 |
+
self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
|
279 |
+
self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
|
280 |
+
|
281 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
282 |
+
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
|
283 |
+
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
|
284 |
+
|
285 |
+
self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
|
286 |
+
self.pool_include_padding = not align_tf
|
287 |
+
|
288 |
+
def _forward(self, x):
|
289 |
+
branch1x1 = self.branch1x1(x)
|
290 |
+
|
291 |
+
branch5x5 = self.branch5x5_1(x)
|
292 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
293 |
+
|
294 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
295 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
296 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
297 |
+
|
298 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
299 |
+
count_include_pad=self.pool_include_padding)
|
300 |
+
branch_pool = self.branch_pool(branch_pool)
|
301 |
+
|
302 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
303 |
+
return outputs
|
304 |
+
|
305 |
+
def forward(self, x):
|
306 |
+
outputs = self._forward(x)
|
307 |
+
return torch.cat(outputs, 1)
|
308 |
+
|
309 |
+
|
310 |
+
class InceptionB(nn.Module):
|
311 |
+
|
312 |
+
def __init__(self, in_channels, conv_block=None):
|
313 |
+
super(InceptionB, self).__init__()
|
314 |
+
if conv_block is None:
|
315 |
+
conv_block = BasicConv2d
|
316 |
+
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
|
317 |
+
|
318 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
319 |
+
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
|
320 |
+
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
|
321 |
+
|
322 |
+
def _forward(self, x):
|
323 |
+
branch3x3 = self.branch3x3(x)
|
324 |
+
|
325 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
326 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
327 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
328 |
+
|
329 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
|
330 |
+
|
331 |
+
outputs = [branch3x3, branch3x3dbl, branch_pool]
|
332 |
+
return outputs
|
333 |
+
|
334 |
+
def forward(self, x):
|
335 |
+
outputs = self._forward(x)
|
336 |
+
return torch.cat(outputs, 1)
|
337 |
+
|
338 |
+
|
339 |
+
class InceptionC(nn.Module):
|
340 |
+
|
341 |
+
def __init__(self, in_channels, channels_7x7, conv_block=None, align_tf=False):
|
342 |
+
super(InceptionC, self).__init__()
|
343 |
+
if conv_block is None:
|
344 |
+
conv_block = BasicConv2d
|
345 |
+
self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
|
346 |
+
|
347 |
+
c7 = channels_7x7
|
348 |
+
self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
|
349 |
+
self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
|
350 |
+
self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
|
351 |
+
|
352 |
+
self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
|
353 |
+
self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
|
354 |
+
self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
|
355 |
+
self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
|
356 |
+
self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
|
357 |
+
|
358 |
+
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
|
359 |
+
self.pool_include_padding = not align_tf
|
360 |
+
|
361 |
+
def _forward(self, x):
|
362 |
+
branch1x1 = self.branch1x1(x)
|
363 |
+
|
364 |
+
branch7x7 = self.branch7x7_1(x)
|
365 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
366 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
367 |
+
|
368 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
369 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
370 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
371 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
372 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
373 |
+
|
374 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
375 |
+
count_include_pad=self.pool_include_padding)
|
376 |
+
branch_pool = self.branch_pool(branch_pool)
|
377 |
+
|
378 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
379 |
+
return outputs
|
380 |
+
|
381 |
+
def forward(self, x):
|
382 |
+
outputs = self._forward(x)
|
383 |
+
return torch.cat(outputs, 1)
|
384 |
+
|
385 |
+
|
386 |
+
class InceptionD(nn.Module):
|
387 |
+
|
388 |
+
def __init__(self, in_channels, conv_block=None):
|
389 |
+
super(InceptionD, self).__init__()
|
390 |
+
if conv_block is None:
|
391 |
+
conv_block = BasicConv2d
|
392 |
+
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
393 |
+
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
|
394 |
+
|
395 |
+
self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
396 |
+
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
|
397 |
+
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
|
398 |
+
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
|
399 |
+
|
400 |
+
def _forward(self, x):
|
401 |
+
branch3x3 = self.branch3x3_1(x)
|
402 |
+
branch3x3 = self.branch3x3_2(branch3x3)
|
403 |
+
|
404 |
+
branch7x7x3 = self.branch7x7x3_1(x)
|
405 |
+
branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
|
406 |
+
branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
|
407 |
+
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
|
408 |
+
|
409 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
|
410 |
+
outputs = [branch3x3, branch7x7x3, branch_pool]
|
411 |
+
return outputs
|
412 |
+
|
413 |
+
def forward(self, x):
|
414 |
+
outputs = self._forward(x)
|
415 |
+
return torch.cat(outputs, 1)
|
416 |
+
|
417 |
+
|
418 |
+
class InceptionE(nn.Module):
|
419 |
+
|
420 |
+
def __init__(self, in_channels, conv_block=None, align_tf=False, use_max_pool=False):
|
421 |
+
super(InceptionE, self).__init__()
|
422 |
+
if conv_block is None:
|
423 |
+
conv_block = BasicConv2d
|
424 |
+
self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
|
425 |
+
|
426 |
+
self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
|
427 |
+
self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
428 |
+
self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
429 |
+
|
430 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
|
431 |
+
self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
|
432 |
+
self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
433 |
+
self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
434 |
+
|
435 |
+
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
|
436 |
+
self.pool_include_padding = not align_tf
|
437 |
+
self.use_max_pool = use_max_pool
|
438 |
+
|
439 |
+
def _forward(self, x):
|
440 |
+
branch1x1 = self.branch1x1(x)
|
441 |
+
|
442 |
+
branch3x3 = self.branch3x3_1(x)
|
443 |
+
branch3x3 = [
|
444 |
+
self.branch3x3_2a(branch3x3),
|
445 |
+
self.branch3x3_2b(branch3x3),
|
446 |
+
]
|
447 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
448 |
+
|
449 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
450 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
451 |
+
branch3x3dbl = [
|
452 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
453 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
454 |
+
]
|
455 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
456 |
+
|
457 |
+
if self.use_max_pool:
|
458 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
459 |
+
else:
|
460 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
461 |
+
count_include_pad=self.pool_include_padding)
|
462 |
+
branch_pool = self.branch_pool(branch_pool)
|
463 |
+
|
464 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
465 |
+
return outputs
|
466 |
+
|
467 |
+
def forward(self, x):
|
468 |
+
outputs = self._forward(x)
|
469 |
+
return torch.cat(outputs, 1)
|
470 |
+
|
471 |
+
|
472 |
+
class InceptionAux(nn.Module):
|
473 |
+
|
474 |
+
def __init__(self, in_channels, num_classes, conv_block=None):
|
475 |
+
super(InceptionAux, self).__init__()
|
476 |
+
if conv_block is None:
|
477 |
+
conv_block = BasicConv2d
|
478 |
+
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
|
479 |
+
self.conv1 = conv_block(128, 768, kernel_size=5)
|
480 |
+
self.conv1.stddev = 0.01
|
481 |
+
self.fc = nn.Linear(768, num_classes)
|
482 |
+
self.fc.stddev = 0.001
|
483 |
+
|
484 |
+
def forward(self, x):
|
485 |
+
# N x 768 x 17 x 17
|
486 |
+
x = F.avg_pool2d(x, kernel_size=5, stride=3)
|
487 |
+
# N x 768 x 5 x 5
|
488 |
+
x = self.conv0(x)
|
489 |
+
# N x 128 x 5 x 5
|
490 |
+
x = self.conv1(x)
|
491 |
+
# N x 768 x 1 x 1
|
492 |
+
# Adaptive average pooling
|
493 |
+
x = F.adaptive_avg_pool2d(x, (1, 1))
|
494 |
+
# N x 768 x 1 x 1
|
495 |
+
x = torch.flatten(x, 1)
|
496 |
+
# N x 768
|
497 |
+
x = self.fc(x)
|
498 |
+
# N x 1000
|
499 |
+
return x
|
500 |
+
|
501 |
+
|
502 |
+
class BasicConv2d(nn.Module):
|
503 |
+
|
504 |
+
def __init__(self, in_channels, out_channels, **kwargs):
|
505 |
+
super(BasicConv2d, self).__init__()
|
506 |
+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
507 |
+
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
508 |
+
|
509 |
+
def forward(self, x):
|
510 |
+
x = self.conv(x)
|
511 |
+
x = self.bn(x)
|
512 |
+
return F.relu(x, inplace=True)
|
513 |
+
|
514 |
+
# pylint: enable=line-too-long
|
515 |
+
# pylint: enable=missing-function-docstring
|
516 |
+
# pylint: enable=missing-class-docstring
|
517 |
+
# pylint: enable=super-with-arguments
|
518 |
+
# pylint: enable=consider-merging-isinstance
|
519 |
+
# pylint: enable=import-outside-toplevel
|
520 |
+
# pylint: enable=no-else-return
|
networks/genforce/models/__init__.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Collects all available models together."""
|
3 |
+
|
4 |
+
from .model_zoo import MODEL_ZOO
|
5 |
+
from .pggan_generator import PGGANGenerator
|
6 |
+
from .pggan_discriminator import PGGANDiscriminator
|
7 |
+
from .stylegan_generator import StyleGANGenerator
|
8 |
+
from .stylegan_discriminator import StyleGANDiscriminator
|
9 |
+
from .stylegan2_generator import StyleGAN2Generator
|
10 |
+
from .stylegan2_discriminator import StyleGAN2Discriminator
|
11 |
+
from .encoder import EncoderNet
|
12 |
+
from .perceptual_model import PerceptualModel
|
13 |
+
|
14 |
+
__all__ = [
|
15 |
+
'MODEL_ZOO', 'PGGANGenerator', 'PGGANDiscriminator', 'StyleGANGenerator',
|
16 |
+
'StyleGANDiscriminator', 'StyleGAN2Generator', 'StyleGAN2Discriminator',
|
17 |
+
'EncoderNet', 'PerceptualModel', 'build_generator', 'build_discriminator',
|
18 |
+
'build_encoder', 'build_perceptual', 'build_model'
|
19 |
+
]
|
20 |
+
|
21 |
+
_GAN_TYPES_ALLOWED = ['pggan', 'stylegan', 'stylegan2']
|
22 |
+
_MODULES_ALLOWED = ['generator', 'discriminator', 'encoder', 'perceptual']
|
23 |
+
|
24 |
+
|
25 |
+
def build_generator(gan_type, resolution, **kwargs):
|
26 |
+
"""Builds generator by GAN type.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
gan_type: GAN type to which the generator belong.
|
30 |
+
resolution: Synthesis resolution.
|
31 |
+
**kwargs: Additional arguments to build the generator.
|
32 |
+
|
33 |
+
Raises:
|
34 |
+
ValueError: If the `gan_type` is not supported.
|
35 |
+
NotImplementedError: If the `gan_type` is not implemented.
|
36 |
+
"""
|
37 |
+
if gan_type not in _GAN_TYPES_ALLOWED:
|
38 |
+
raise ValueError(f'Invalid GAN type: `{gan_type}`!\n'
|
39 |
+
f'Types allowed: {_GAN_TYPES_ALLOWED}.')
|
40 |
+
|
41 |
+
if gan_type == 'pggan':
|
42 |
+
return PGGANGenerator(resolution, **kwargs)
|
43 |
+
if gan_type == 'stylegan':
|
44 |
+
return StyleGANGenerator(resolution, **kwargs)
|
45 |
+
if gan_type == 'stylegan2':
|
46 |
+
return StyleGAN2Generator(resolution, **kwargs)
|
47 |
+
raise NotImplementedError(f'Unsupported GAN type `{gan_type}`!')
|
48 |
+
|
49 |
+
|
50 |
+
def build_discriminator(gan_type, resolution, **kwargs):
|
51 |
+
"""Builds discriminator by GAN type.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
gan_type: GAN type to which the discriminator belong.
|
55 |
+
resolution: Synthesis resolution.
|
56 |
+
**kwargs: Additional arguments to build the discriminator.
|
57 |
+
|
58 |
+
Raises:
|
59 |
+
ValueError: If the `gan_type` is not supported.
|
60 |
+
NotImplementedError: If the `gan_type` is not implemented.
|
61 |
+
"""
|
62 |
+
if gan_type not in _GAN_TYPES_ALLOWED:
|
63 |
+
raise ValueError(f'Invalid GAN type: `{gan_type}`!\n'
|
64 |
+
f'Types allowed: {_GAN_TYPES_ALLOWED}.')
|
65 |
+
|
66 |
+
if gan_type == 'pggan':
|
67 |
+
return PGGANDiscriminator(resolution, **kwargs)
|
68 |
+
if gan_type == 'stylegan':
|
69 |
+
return StyleGANDiscriminator(resolution, **kwargs)
|
70 |
+
if gan_type == 'stylegan2':
|
71 |
+
return StyleGAN2Discriminator(resolution, **kwargs)
|
72 |
+
raise NotImplementedError(f'Unsupported GAN type `{gan_type}`!')
|
73 |
+
|
74 |
+
|
75 |
+
def build_encoder(gan_type, resolution, **kwargs):
|
76 |
+
"""Builds encoder by GAN type.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
gan_type: GAN type to which the encoder belong.
|
80 |
+
resolution: Input resolution for encoder.
|
81 |
+
**kwargs: Additional arguments to build the encoder.
|
82 |
+
|
83 |
+
Raises:
|
84 |
+
ValueError: If the `gan_type` is not supported.
|
85 |
+
NotImplementedError: If the `gan_type` is not implemented.
|
86 |
+
"""
|
87 |
+
if gan_type not in _GAN_TYPES_ALLOWED:
|
88 |
+
raise ValueError(f'Invalid GAN type: `{gan_type}`!\n'
|
89 |
+
f'Types allowed: {_GAN_TYPES_ALLOWED}.')
|
90 |
+
|
91 |
+
if gan_type in ['stylegan', 'stylegan2']:
|
92 |
+
return EncoderNet(resolution, **kwargs)
|
93 |
+
|
94 |
+
raise NotImplementedError(f'Unsupported GAN type `{gan_type}` for encoder!')
|
95 |
+
|
96 |
+
|
97 |
+
def build_perceptual(**kwargs):
|
98 |
+
"""Builds perceptual model.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
**kwargs: Additional arguments to build the encoder.
|
102 |
+
"""
|
103 |
+
return PerceptualModel(**kwargs)
|
104 |
+
|
105 |
+
|
106 |
+
def build_model(gan_type, module, resolution, **kwargs):
|
107 |
+
"""Builds a GAN module (generator/discriminator/etc).
|
108 |
+
|
109 |
+
Args:
|
110 |
+
gan_type: GAN type to which the model belong.
|
111 |
+
module: GAN module to build, such as generator or discrimiantor.
|
112 |
+
resolution: Synthesis resolution.
|
113 |
+
**kwargs: Additional arguments to build the discriminator.
|
114 |
+
|
115 |
+
Raises:
|
116 |
+
ValueError: If the `module` is not supported.
|
117 |
+
NotImplementedError: If the `module` is not implemented.
|
118 |
+
"""
|
119 |
+
if module not in _MODULES_ALLOWED:
|
120 |
+
raise ValueError(f'Invalid module: `{module}`!\n'
|
121 |
+
f'Modules allowed: {_MODULES_ALLOWED}.')
|
122 |
+
|
123 |
+
if module == 'generator':
|
124 |
+
return build_generator(gan_type, resolution, **kwargs)
|
125 |
+
if module == 'discriminator':
|
126 |
+
return build_discriminator(gan_type, resolution, **kwargs)
|
127 |
+
if module == 'encoder':
|
128 |
+
return build_encoder(gan_type, resolution, **kwargs)
|
129 |
+
if module == 'perceptual':
|
130 |
+
return build_perceptual(**kwargs)
|
131 |
+
raise NotImplementedError(f'Unsupported module `{module}`!')
|
networks/genforce/models/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (4.06 kB). View file
|
|
networks/genforce/models/__pycache__/encoder.cpython-38.pyc
ADDED
Binary file (13.4 kB). View file
|
|
networks/genforce/models/__pycache__/model_zoo.cpython-38.pyc
ADDED
Binary file (11.2 kB). View file
|
|
networks/genforce/models/__pycache__/perceptual_model.cpython-38.pyc
ADDED
Binary file (5.58 kB). View file
|
|
networks/genforce/models/__pycache__/pggan_discriminator.cpython-38.pyc
ADDED
Binary file (11.6 kB). View file
|
|
networks/genforce/models/__pycache__/pggan_generator.cpython-38.pyc
ADDED
Binary file (9.34 kB). View file
|
|
networks/genforce/models/__pycache__/stylegan2_discriminator.cpython-38.pyc
ADDED
Binary file (13.5 kB). View file
|
|
networks/genforce/models/__pycache__/stylegan2_generator.cpython-38.pyc
ADDED
Binary file (28.6 kB). View file
|
|
networks/genforce/models/__pycache__/stylegan_discriminator.cpython-38.pyc
ADDED
Binary file (15.2 kB). View file
|
|
networks/genforce/models/__pycache__/stylegan_generator.cpython-38.pyc
ADDED
Binary file (26.5 kB). View file
|
|