akhaliq HF staff commited on
Commit
81170fd
1 Parent(s): 0edc624
checkpoint.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import flax
2
+ import dill as pickle
3
+ import os
4
+ import builtins
5
+ from jax._src.lib import xla_client
6
+ import tensorflow as tf
7
+
8
+
9
+ # Hack: this is the module reported by this object.
10
+ # https://github.com/google/jax/issues/8505
11
+ builtins.bfloat16 = xla_client.bfloat16
12
+
13
+
14
+ def pickle_dump(obj, filename):
15
+ """ Wrapper to dump an object to a file."""
16
+ with tf.io.gfile.GFile(filename, "wb") as f:
17
+ f.write(pickle.dumps(obj))
18
+
19
+
20
+ def pickle_load(filename):
21
+ """ Wrapper to load an object from a file."""
22
+ with tf.io.gfile.GFile(filename, 'rb') as f:
23
+ pickled = pickle.loads(f.read())
24
+ return pickled
25
+
26
+
27
+ def save_checkpoint(ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=None, keep=2):
28
+ """
29
+ Saves checkpoint.
30
+
31
+ Args:
32
+ ckpt_dir (str): Path to the directory, where checkpoints are saved.
33
+ state_G (train_state.TrainState): Generator state.
34
+ state_D (train_state.TrainState): Discriminator state.
35
+ params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator.
36
+ pl_mean (array): Moving average of the path length (generator regularization).
37
+ config (argparse.Namespace): Configuration.
38
+ step (int): Current step.
39
+ epoch (int): Current epoch.
40
+ fid_score (float): FID score corresponding to the checkpoint.
41
+ keep (int): Number of checkpoints to keep.
42
+ """
43
+ state_dict = {'state_G': flax.jax_utils.unreplicate(state_G),
44
+ 'state_D': flax.jax_utils.unreplicate(state_D),
45
+ 'params_ema_G': params_ema_G,
46
+ 'pl_mean': flax.jax_utils.unreplicate(pl_mean),
47
+ 'config': config,
48
+ 'fid_score': fid_score,
49
+ 'step': step,
50
+ 'epoch': epoch}
51
+
52
+ pickle_dump(state_dict, os.path.join(ckpt_dir, f'ckpt_{step}.pickle'))
53
+ ckpts = tf.io.gfile.glob(os.path.join(ckpt_dir, '*.pickle'))
54
+ if len(ckpts) > keep:
55
+ modified_times = {}
56
+ for ckpt in ckpts:
57
+ stats = tf.io.gfile.stat(ckpt)
58
+ modified_times[ckpt] = stats.mtime_nsec
59
+ oldest_ckpt = sorted(modified_times, key=modified_times.get)[0]
60
+ tf.io.gfile.remove(oldest_ckpt)
61
+
62
+
63
+ def load_checkpoint(filename):
64
+ """
65
+ Loads checkpoints.
66
+
67
+ Args:
68
+ filename (str): Path to the checkpoint file.
69
+
70
+ Returns:
71
+ (dict): Checkpoint.
72
+ """
73
+ state_dict = pickle_load(filename)
74
+ return state_dict
75
+
76
+
77
+ def get_latest_checkpoint(ckpt_dir):
78
+ """
79
+ Returns the path of the latest checkpoint.
80
+
81
+ Args:
82
+ ckpt_dir (str): Path to the directory, where checkpoints are saved.
83
+
84
+ Returns:
85
+ (str): Path to latest checkpoint (if it exists).
86
+ """
87
+ ckpts = tf.io.gfile.glob(os.path.join(ckpt_dir, '*.pickle'))
88
+ if len(ckpts) == 0:
89
+ return None
90
+
91
+ modified_times = {}
92
+ for ckpt in ckpts:
93
+ stats = tf.io.gfile.stat(ckpt)
94
+ modified_times[ckpt] = stats.mtime_nsec
95
+ latest_ckpt = sorted(modified_times, key=modified_times.get)[-1]
96
+ return latest_ckpt
data_pipeline.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow_datasets as tfds
3
+ import jax
4
+ import flax
5
+ import numpy as np
6
+ from PIL import Image
7
+ import os
8
+ from typing import Sequence
9
+ from tqdm import tqdm
10
+ import json
11
+ from tqdm import tqdm
12
+ import logging
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ def prefetch(dataset, n_prefetch):
18
+ # Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
19
+ ds_iter = iter(dataset)
20
+ ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
21
+ ds_iter)
22
+ if n_prefetch:
23
+ ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
24
+ return ds_iter
25
+
26
+
27
+ def get_data(data_dir, img_size, img_channels, num_classes, num_local_devices, batch_size, shuffle_buffer=1000):
28
+ """
29
+
30
+ Args:
31
+ data_dir (str): Root directory of the dataset.
32
+ img_size (int): Image size for training.
33
+ img_channels (int): Number of image channels.
34
+ num_classes (int): Number of classes, 0 for no classes.
35
+ num_local_devices (int): Number of devices.
36
+ batch_size (int): Batch size (per device).
37
+ shuffle_buffer (int): Buffer used for shuffling the dataset.
38
+
39
+ Returns:
40
+ (tf.data.Dataset): Dataset.
41
+ """
42
+
43
+ def pre_process(serialized_example):
44
+ feature = {'height': tf.io.FixedLenFeature([], tf.int64),
45
+ 'width': tf.io.FixedLenFeature([], tf.int64),
46
+ 'channels': tf.io.FixedLenFeature([], tf.int64),
47
+ 'image': tf.io.FixedLenFeature([], tf.string),
48
+ 'label': tf.io.FixedLenFeature([], tf.int64)}
49
+ example = tf.io.parse_single_example(serialized_example, feature)
50
+
51
+ height = tf.cast(example['height'], dtype=tf.int64)
52
+ width = tf.cast(example['width'], dtype=tf.int64)
53
+ channels = tf.cast(example['channels'], dtype=tf.int64)
54
+
55
+ image = tf.io.decode_raw(example['image'], out_type=tf.uint8)
56
+ image = tf.reshape(image, shape=[height, width, channels])
57
+
58
+ image = tf.cast(image, dtype='float32')
59
+ image = tf.image.resize(image, size=[img_size, img_size], method='bicubic', antialias=True)
60
+ image = tf.image.random_flip_left_right(image)
61
+
62
+ image = (image - 127.5) / 127.5
63
+
64
+ label = tf.one_hot(example['label'], num_classes)
65
+ return {'image': image, 'label': label}
66
+
67
+ def shard(data):
68
+ # Reshape images from [num_devices * batch_size, H, W, C] to [num_devices, batch_size, H, W, C]
69
+ # because the first dimension will be mapped across devices using jax.pmap
70
+ data['image'] = tf.reshape(data['image'], [num_local_devices, -1, img_size, img_size, img_channels])
71
+ data['label'] = tf.reshape(data['label'], [num_local_devices, -1, num_classes])
72
+ return data
73
+
74
+ logger.info('Loading TFRecord...')
75
+ with tf.io.gfile.GFile(os.path.join(data_dir, 'dataset_info.json'), 'r') as fin:
76
+ dataset_info = json.load(fin)
77
+
78
+ ds = tf.data.TFRecordDataset(filenames=os.path.join(data_dir, 'dataset.tfrecords'))
79
+ ds = ds.shard(jax.process_count(), jax.process_index())
80
+ ds = ds.shuffle(min(dataset_info['num_examples'], shuffle_buffer))
81
+ ds = ds.map(pre_process, tf.data.AUTOTUNE)
82
+ ds = ds.batch(batch_size * num_local_devices, drop_remainder=True) # uses per-worker batch size
83
+ ds = ds.map(shard, tf.data.AUTOTUNE)
84
+ ds = ds.prefetch(1) # prefetches the next batch
85
+ return ds, dataset_info
dataset_utils/crop_image_borders.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import os
4
+ from tqdm import tqdm
5
+ import argparse
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ """
11
+ Crops the black borders around images.
12
+ """
13
+
14
+
15
+ def crop_border(x, constant=0.0):
16
+ top = 0
17
+ while True:
18
+ if np.sum(x[top] != constant) != 0.0:
19
+ break
20
+ top += 1
21
+ bottom = x.shape[0] - 1
22
+ while True:
23
+ if np.sum(x[bottom] != constant) != 0.0:
24
+ bottom += 1
25
+ break
26
+ bottom -= 1
27
+ left = 0
28
+ while True:
29
+ if np.sum(x[:, left] != constant) != 0.0:
30
+ break
31
+ left += 1
32
+ right = x.shape[1] - 1
33
+ while True:
34
+ if np.sum(x[:, right] != constant) != 0.0:
35
+ right += 1
36
+ break
37
+ right -= 1
38
+ return x[top:bottom, left:right]
39
+
40
+
41
+ def crop_images(path, constant_value):
42
+ logger.info('Crop image borders...')
43
+ for f in tqdm(os.listdir(path)):
44
+ img = Image.open(os.path.join(path, f))
45
+ img = crop_border(np.array(img), constant=constant_value)
46
+ img = Image.fromarray(img)
47
+ img.save(os.path.join(path, f))
48
+
49
+
50
+ if __name__ == '__main__':
51
+ parser = argparse.ArgumentParser()
52
+ parser.add_argument('--image_dir', type=str, help='Path to the image directory.')
53
+ parser.add_argument('--constant_value', type=float, default=0.0, help='Value of the border that should be cropped.')
54
+
55
+ args = parser.parse_args()
56
+
57
+ crop_images(args.image_dir, args.constant_value)
dataset_utils/images_to_tfrecords.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import numpy as np
3
+ from PIL import Image
4
+ from typing import Sequence
5
+ from tqdm import tqdm
6
+ import argparse
7
+ import json
8
+ import os
9
+ import logging
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def images_to_tfrecords(image_dir, data_dir, has_labels):
15
+ """
16
+ Converts a folder of images to a TFRecord file.
17
+
18
+ The image directory should have one of the following structures:
19
+
20
+ If has_labels = False, image_dir should look like this:
21
+
22
+ path/to/image_dir/
23
+ 0.jpg
24
+ 1.jpg
25
+ 2.jpg
26
+ 4.jpg
27
+ ...
28
+
29
+
30
+ If has_labels = True, image_dir should look like this:
31
+
32
+ path/to/image_dir/
33
+ label0/
34
+ 0.jpg
35
+ 1.jpg
36
+ ...
37
+ label1/
38
+ a.jpg
39
+ b.jpg
40
+ c.jpg
41
+ ...
42
+ ...
43
+
44
+
45
+ The labels will be label0 -> 0, label1 -> 1.
46
+
47
+ Args:
48
+ image_dir (str): Path to images.
49
+ data_dir (str): Path where the TFrecords dataset is stored.
50
+ has_labels (bool): If True, 'image_dir' contains label directories.
51
+
52
+ Returns:
53
+ (dict): Dataset info.
54
+ """
55
+
56
+ def _bytes_feature(value):
57
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
58
+
59
+ def _int64_feature(value):
60
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
61
+
62
+ os.makedirs(data_dir, exist_ok=True)
63
+ writer = tf.io.TFRecordWriter(os.path.join(data_dir, 'dataset.tfrecords'))
64
+
65
+ num_examples = 0
66
+ num_classes = 0
67
+
68
+ if has_labels:
69
+ for label_dir in os.listdir(image_dir):
70
+ if not os.path.isdir(os.path.join(image_dir, label_dir)):
71
+ logger.warning('The image directory should contain one directory for each label.')
72
+ logger.warning('These label directories should contain the image files.')
73
+ if os.path.exists(os.path.join(data_dir, 'dataset.tfrecords')):
74
+ os.remove(os.path.join(data_dir, 'dataset.tfrecords'))
75
+ return
76
+
77
+ for img_file in tqdm(os.listdir(os.path.join(image_dir, label_dir))):
78
+ file_format = img_file[img_file.rfind('.') + 1:]
79
+ if file_format not in ['png', 'jpg', 'jpeg']:
80
+ continue
81
+
82
+ #img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size)
83
+ img = Image.open(os.path.join(image_dir, label_dir, img_file))
84
+ img = np.array(img, dtype=np.uint8)
85
+
86
+ height = img.shape[0]
87
+ width = img.shape[1]
88
+ channels = img.shape[2]
89
+
90
+ img_encoded = img.tobytes()
91
+
92
+ example = tf.train.Example(features=tf.train.Features(feature={
93
+ 'height': _int64_feature(height),
94
+ 'width': _int64_feature(width),
95
+ 'channels': _int64_feature(channels),
96
+ 'image': _bytes_feature(img_encoded),
97
+ 'label': _int64_feature(num_classes)}))
98
+
99
+ writer.write(example.SerializeToString())
100
+ num_examples += 1
101
+
102
+ num_classes += 1
103
+ else:
104
+ for img_file in tqdm(os.listdir(os.path.join(image_dir))):
105
+ file_format = img_file[img_file.rfind('.') + 1:]
106
+ if file_format not in ['png', 'jpg', 'jpeg']:
107
+ continue
108
+
109
+ #img = Image.open(os.path.join(image_dir, label_dir, img_file)).resize(img_size)
110
+ img = Image.open(os.path.join(image_dir, img_file))
111
+ img = np.array(img, dtype=np.uint8)
112
+
113
+ height = img.shape[0]
114
+ width = img.shape[1]
115
+ channels = img.shape[2]
116
+
117
+ img_encoded = img.tobytes()
118
+
119
+ example = tf.train.Example(features=tf.train.Features(feature={
120
+ 'height': _int64_feature(height),
121
+ 'width': _int64_feature(width),
122
+ 'channels': _int64_feature(channels),
123
+ 'image': _bytes_feature(img_encoded),
124
+ 'label': _int64_feature(num_classes)})) # dummy label
125
+
126
+ writer.write(example.SerializeToString())
127
+ num_examples += 1
128
+
129
+ writer.close()
130
+
131
+ dataset_info = {'num_examples': num_examples, 'num_classes': num_classes}
132
+ with open(os.path.join(data_dir, 'dataset_info.json'), 'w') as fout:
133
+ json.dump(dataset_info, fout)
134
+
135
+
136
+ if __name__ == '__main__':
137
+ parser = argparse.ArgumentParser()
138
+ parser.add_argument('--image_dir', type=str, help='Path to the image directory.')
139
+ parser.add_argument('--data_dir', type=str, help='Path where the TFRecords dataset is stored.')
140
+ parser.add_argument('--has_labels', action='store_true', help='If True, image_dir contains label directories.')
141
+
142
+ args = parser.parse_args()
143
+
144
+ images_to_tfrecords(args.image_dir, args.data_dir, args.has_labels)
145
+
fid/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .core import FID
fid/core.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import flax
4
+ import flax.linen as nn
5
+ import numpy as np
6
+ import os
7
+ import functools
8
+ import argparse
9
+ import scipy
10
+ from tqdm import tqdm
11
+ import logging
12
+
13
+ from . import inception
14
+ from . import utils
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ class FID:
19
+
20
+ def __init__(self, generator, dataset, config, use_cache=True, truncation_psi=1.0):
21
+ """
22
+ Evaluates the FID score for a given generator and a given dataset.
23
+ Implementation mostly taken from https://github.com/matthias-wright/jax-fid
24
+
25
+ Reference: https://arxiv.org/abs/1706.08500
26
+
27
+ Args:
28
+ generator (nn.Module): Generator network.
29
+ dataset (tf.data.Dataset): Dataset containing the real images.
30
+ config (argparse.Namespace): Configuration.
31
+ use_cache (bool): If True, only compute the activation stats once for the real images and store them.
32
+ truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled.
33
+ """
34
+ self.num_images = config.num_fid_images
35
+ self.batch_size = config.batch_size
36
+ self.c_dim = config.c_dim
37
+ self.z_dim = config.z_dim
38
+ self.dataset = dataset
39
+ self.num_devices = jax.device_count()
40
+ self.num_local_devices = jax.local_device_count()
41
+ self.use_cache = use_cache
42
+
43
+ if self.use_cache:
44
+ self.cache = {}
45
+
46
+ rng = jax.random.PRNGKey(0)
47
+ inception_net = inception.InceptionV3(pretrained=True)
48
+ self.inception_params = inception_net.init(rng, jnp.ones((1, config.resolution, config.resolution, 3)))
49
+ self.inception_params = flax.jax_utils.replicate(self.inception_params)
50
+ #self.inception = jax.jit(functools.partial(model.apply, train=False))
51
+ self.inception_apply = jax.pmap(functools.partial(inception_net.apply, train=False), axis_name='batch')
52
+
53
+ self.generator_apply = jax.pmap(functools.partial(generator.apply, truncation_psi=truncation_psi, train=False, noise_mode='const'), axis_name='batch')
54
+
55
+ def compute_fid(self, generator_params, seed_offset=0):
56
+ generator_params = flax.jax_utils.replicate(generator_params)
57
+ mu_real, sigma_real = self.compute_stats_for_dataset()
58
+ mu_fake, sigma_fake = self.compute_stats_for_generator(generator_params, seed_offset)
59
+ fid_score = self.compute_frechet_distance(mu_real, mu_fake, sigma_real, sigma_fake, eps=1e-6)
60
+ return fid_score
61
+
62
+ def compute_frechet_distance(self, mu1, mu2, sigma1, sigma2, eps=1e-6):
63
+ # Taken from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
64
+ mu1 = np.atleast_1d(mu1)
65
+ mu2 = np.atleast_1d(mu2)
66
+ sigma1 = np.atleast_1d(sigma1)
67
+ sigma2 = np.atleast_1d(sigma2)
68
+
69
+ assert mu1.shape == mu2.shape
70
+ assert sigma1.shape == sigma2.shape
71
+
72
+ diff = mu1 - mu2
73
+
74
+ covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False)
75
+ if not np.isfinite(covmean).all():
76
+ msg = ('fid calculation produces singular product; '
77
+ 'adding %s to diagonal of cov estimates') % eps
78
+ logger.info(msg)
79
+ offset = np.eye(sigma1.shape[0]) * eps
80
+ covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
81
+
82
+ # Numerical error might give slight imaginary component
83
+ if np.iscomplexobj(covmean):
84
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
85
+ m = np.max(np.abs(covmean.imag))
86
+ raise ValueError('Imaginary component {}'.format(m))
87
+ covmean = covmean.real
88
+
89
+ tr_covmean = np.trace(covmean)
90
+ return (diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean)
91
+
92
+ def compute_stats_for_dataset(self):
93
+ if self.use_cache and 'mu' in self.cache and 'sigma' in self.cache:
94
+ logger.info('Use cached statistics for dataset...')
95
+ return self.cache['mu'], self.cache['sigma']
96
+
97
+ print()
98
+ logger.info('Compute statistics for dataset...')
99
+ image_count = 0
100
+
101
+ activations = []
102
+ for batch in utils.prefetch(self.dataset, n_prefetch=2):
103
+ act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(batch['image']))
104
+ act = jnp.reshape(act, (self.num_local_devices * self.batch_size, -1))
105
+ activations.append(act)
106
+
107
+ image_count += self.num_local_devices * self.batch_size
108
+ if image_count >= self.num_images:
109
+ break
110
+
111
+ activations = jnp.concatenate(activations, axis=0)
112
+ activations = activations[:self.num_images]
113
+ mu = np.mean(activations, axis=0)
114
+ sigma = np.cov(activations, rowvar=False)
115
+ self.cache['mu'] = mu
116
+ self.cache['sigma'] = sigma
117
+ return mu, sigma
118
+
119
+ def compute_stats_for_generator(self, generator_params, seed_offset):
120
+ print()
121
+ logger.info('Compute statistics for generator...')
122
+ num_batches = int(np.ceil(self.num_images / (self.batch_size * self.num_local_devices)))
123
+
124
+ activations = []
125
+
126
+ for i in range(num_batches):
127
+ rng = jax.random.PRNGKey(seed_offset + i)
128
+ z_latent = jax.random.normal(rng, shape=(self.num_local_devices, self.batch_size, self.z_dim))
129
+
130
+ labels = None
131
+ if self.c_dim > 0:
132
+ labels = jax.random.randint(rng, shape=(self.num_local_devices * self.batch_size,), minval=0, maxval=self.c_dim)
133
+ labels = jax.nn.one_hot(labels, num_classes=self.c_dim)
134
+ labels = jnp.reshape(labels, (self.num_local_devices, self.batch_size, self.c_dim))
135
+
136
+ image = self.generator_apply(generator_params, jax.lax.stop_gradient(z_latent), labels)
137
+ image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image))
138
+
139
+ image = 2 * image - 1
140
+ act = self.inception_apply(self.inception_params, jax.lax.stop_gradient(image))
141
+ act = jnp.reshape(act, (self.num_local_devices * self.batch_size, -1))
142
+ activations.append(act)
143
+
144
+ activations = jnp.concatenate(activations, axis=0)
145
+ activations = activations[:self.num_images]
146
+ mu = np.mean(activations, axis=0)
147
+ sigma = np.cov(activations, rowvar=False)
148
+ return mu, sigma
149
+
150
+
fid/inception.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ from jax import lax
3
+ from jax.nn import initializers
4
+ import jax.numpy as jnp
5
+ import flax
6
+ from flax.linen.module import merge_param
7
+ import flax.linen as nn
8
+ from typing import Callable, Iterable, Optional, Tuple, Union, Any
9
+ import functools
10
+ import pickle
11
+ from . import utils
12
+
13
+ PRNGKey = Any
14
+ Array = Any
15
+ Shape = Tuple[int]
16
+ Dtype = Any
17
+
18
+
19
+ class InceptionV3(nn.Module):
20
+ """
21
+ InceptionV3 network.
22
+ Reference: https://arxiv.org/abs/1512.00567
23
+ Ported mostly from: https://github.com/pytorch/vision/blob/master/torchvision/models/inception.py
24
+
25
+ Attributes:
26
+ include_head (bool): If True, include classifier head.
27
+ num_classes (int): Number of classes.
28
+ pretrained (bool): If True, use pretrained weights.
29
+ transform_input (bool): If True, preprocesses the input according to the method with which it
30
+ was trained on ImageNet.
31
+ aux_logits (bool): If True, add an auxiliary branch that can improve training.
32
+ dtype (str): Data type.
33
+ """
34
+ include_head: bool=False
35
+ num_classes: int=1000
36
+ pretrained: bool=False
37
+ transform_input: bool=False
38
+ aux_logits: bool=False
39
+ ckpt_path: str='https://www.dropbox.com/s/0zo4pd6cfwgzem7/inception_v3_weights_fid.pickle?dl=1'
40
+ dtype: str='float32'
41
+
42
+ def setup(self):
43
+ if self.pretrained:
44
+ ckpt_file = utils.download(self.ckpt_path)
45
+ self.params_dict = pickle.load(open(ckpt_file, 'rb'))
46
+ self.num_classes_ = 1000
47
+ else:
48
+ self.params_dict = None
49
+ self.num_classes_ = self.num_classes
50
+
51
+ @nn.compact
52
+ def __call__(self, x, train=True, rng=jax.random.PRNGKey(0)):
53
+ """
54
+ Args:
55
+ x (tensor): Input image, shape [B, H, W, C].
56
+ train (bool): If True, training mode.
57
+ rng (jax.random.PRNGKey): Random seed.
58
+ """
59
+ x = self._transform_input(x)
60
+ x = BasicConv2d(out_channels=32,
61
+ kernel_size=(3, 3),
62
+ strides=(2, 2),
63
+ params_dict=utils.get(self.params_dict, 'Conv2d_1a_3x3'),
64
+ dtype=self.dtype)(x, train)
65
+ x = BasicConv2d(out_channels=32,
66
+ kernel_size=(3, 3),
67
+ params_dict=utils.get(self.params_dict, 'Conv2d_2a_3x3'),
68
+ dtype=self.dtype)(x, train)
69
+ x = BasicConv2d(out_channels=64,
70
+ kernel_size=(3, 3),
71
+ padding=((1, 1), (1, 1)),
72
+ params_dict=utils.get(self.params_dict, 'Conv2d_2b_3x3'),
73
+ dtype=self.dtype)(x, train)
74
+ x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
75
+ x = BasicConv2d(out_channels=80,
76
+ kernel_size=(1, 1),
77
+ params_dict=utils.get(self.params_dict, 'Conv2d_3b_1x1'),
78
+ dtype=self.dtype)(x, train)
79
+ x = BasicConv2d(out_channels=192,
80
+ kernel_size=(3, 3),
81
+ params_dict=utils.get(self.params_dict, 'Conv2d_4a_3x3'),
82
+ dtype=self.dtype)(x, train)
83
+ x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
84
+ x = InceptionA(pool_features=32,
85
+ params_dict=utils.get(self.params_dict, 'Mixed_5b'),
86
+ dtype=self.dtype)(x, train)
87
+ x = InceptionA(pool_features=64,
88
+ params_dict=utils.get(self.params_dict, 'Mixed_5c'),
89
+ dtype=self.dtype)(x, train)
90
+ x = InceptionA(pool_features=64,
91
+ params_dict=utils.get(self.params_dict, 'Mixed_5d'),
92
+ dtype=self.dtype)(x, train)
93
+ x = InceptionB(params_dict=utils.get(self.params_dict, 'Mixed_6a'),
94
+ dtype=self.dtype)(x, train)
95
+ x = InceptionC(channels_7x7=128,
96
+ params_dict=utils.get(self.params_dict, 'Mixed_6b'),
97
+ dtype=self.dtype)(x, train)
98
+ x = InceptionC(channels_7x7=160,
99
+ params_dict=utils.get(self.params_dict, 'Mixed_6c'),
100
+ dtype=self.dtype)(x, train)
101
+ x = InceptionC(channels_7x7=160,
102
+ params_dict=utils.get(self.params_dict, 'Mixed_6d'),
103
+ dtype=self.dtype)(x, train)
104
+ x = InceptionC(channels_7x7=192,
105
+ params_dict=utils.get(self.params_dict, 'Mixed_6e'),
106
+ dtype=self.dtype)(x, train)
107
+ aux = None
108
+ if self.aux_logits and train:
109
+ aux = InceptionAux(num_classes=self.num_classes_,
110
+ params_dict=utils.get(self.params_dict, 'AuxLogits'),
111
+ dtype=self.dtype)(x, train)
112
+ x = InceptionD(params_dict=utils.get(self.params_dict, 'Mixed_7a'),
113
+ dtype=self.dtype)(x, train)
114
+ x = InceptionE(avg_pool, params_dict=utils.get(self.params_dict, 'Mixed_7b'),
115
+ dtype=self.dtype)(x, train)
116
+ # Following the implementation by @mseitzer, we use max pooling instead
117
+ # of average pooling here.
118
+ # See: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/inception.py#L320
119
+ x = InceptionE(nn.max_pool, params_dict=utils.get(self.params_dict, 'Mixed_7c'),
120
+ dtype=self.dtype)(x, train)
121
+ x = jnp.mean(x, axis=(1, 2), keepdims=True)
122
+ if not self.include_head:
123
+ return x
124
+ x = nn.Dropout(rate=0.5)(x, deterministic=not train, rng=rng)
125
+ x = jnp.reshape(x, newshape=(x.shape[0], -1))
126
+ x = Dense(features=self.num_classes_,
127
+ params_dict=utils.get(self.params_dict, 'fc'),
128
+ dtype=self.dtype)(x)
129
+ if self.aux_logits:
130
+ return x, aux
131
+ return x
132
+
133
+ def _transform_input(self, x):
134
+ if self.transform_input:
135
+ x_ch0 = jnp.expand_dims(x[..., 0], axis=-1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
136
+ x_ch1 = jnp.expand_dims(x[..., 1], axis=-1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
137
+ x_ch2 = jnp.expand_dims(x[..., 2], axis=-1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
138
+ x = jnp.concatenate((x_ch0, x_ch1, x_ch2), axis=-1)
139
+ return x
140
+
141
+
142
+ class Dense(nn.Module):
143
+ features: int
144
+ kernel_init: functools.partial=nn.initializers.lecun_normal()
145
+ bias_init: functools.partial=nn.initializers.zeros
146
+ params_dict: dict=None
147
+ dtype: str='float32'
148
+
149
+ @nn.compact
150
+ def __call__(self, x):
151
+ x = nn.Dense(features=self.features,
152
+ kernel_init=self.kernel_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['kernel']),
153
+ bias_init=self.bias_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['bias']))(x)
154
+ return x
155
+
156
+
157
+ class BasicConv2d(nn.Module):
158
+ out_channels: int
159
+ kernel_size: Union[int, Iterable[int]]=(3, 3)
160
+ strides: Optional[Iterable[int]]=(1, 1)
161
+ padding: Union[str, Iterable[Tuple[int, int]]]='valid'
162
+ use_bias: bool=False
163
+ kernel_init: functools.partial=nn.initializers.lecun_normal()
164
+ bias_init: functools.partial=nn.initializers.zeros
165
+ params_dict: dict=None
166
+ dtype: str='float32'
167
+
168
+ @nn.compact
169
+ def __call__(self, x, train=True):
170
+ x = nn.Conv(features=self.out_channels,
171
+ kernel_size=self.kernel_size,
172
+ strides=self.strides,
173
+ padding=self.padding,
174
+ use_bias=self.use_bias,
175
+ kernel_init=self.kernel_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['conv']['kernel']),
176
+ bias_init=self.bias_init if self.params_dict is None else lambda *_ : jnp.array(self.params_dict['conv']['bias']),
177
+ dtype=self.dtype)(x)
178
+ if self.params_dict is None:
179
+ x = BatchNorm(epsilon=0.001,
180
+ momentum=0.1,
181
+ use_running_average=not train,
182
+ dtype=self.dtype)(x)
183
+ else:
184
+ x = BatchNorm(epsilon=0.001,
185
+ momentum=0.1,
186
+ bias_init=lambda *_ : jnp.array(self.params_dict['bn']['bias']),
187
+ scale_init=lambda *_ : jnp.array(self.params_dict['bn']['scale']),
188
+ mean_init=lambda *_ : jnp.array(self.params_dict['bn']['mean']),
189
+ var_init=lambda *_ : jnp.array(self.params_dict['bn']['var']),
190
+ use_running_average=not train,
191
+ dtype=self.dtype)(x)
192
+ x = jax.nn.relu(x)
193
+ return x
194
+
195
+
196
+ class InceptionA(nn.Module):
197
+ pool_features: int
198
+ params_dict: dict=None
199
+ dtype: str='float32'
200
+
201
+ @nn.compact
202
+ def __call__(self, x, train=True):
203
+ branch1x1 = BasicConv2d(out_channels=64,
204
+ kernel_size=(1, 1),
205
+ params_dict=utils.get(self.params_dict, 'branch1x1'),
206
+ dtype=self.dtype)(x, train)
207
+ branch5x5 = BasicConv2d(out_channels=48,
208
+ kernel_size=(1, 1),
209
+ params_dict=utils.get(self.params_dict, 'branch5x5_1'),
210
+ dtype=self.dtype)(x, train)
211
+ branch5x5 = BasicConv2d(out_channels=64,
212
+ kernel_size=(5, 5),
213
+ padding=((2, 2), (2, 2)),
214
+ params_dict=utils.get(self.params_dict, 'branch5x5_2'),
215
+ dtype=self.dtype)(branch5x5, train)
216
+
217
+ branch3x3dbl = BasicConv2d(out_channels=64,
218
+ kernel_size=(1, 1),
219
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
220
+ dtype=self.dtype)(x, train)
221
+ branch3x3dbl = BasicConv2d(out_channels=96,
222
+ kernel_size=(3, 3),
223
+ padding=((1, 1), (1, 1)),
224
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
225
+ dtype=self.dtype)(branch3x3dbl, train)
226
+ branch3x3dbl = BasicConv2d(out_channels=96,
227
+ kernel_size=(3, 3),
228
+ padding=((1, 1), (1, 1)),
229
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_3'),
230
+ dtype=self.dtype)(branch3x3dbl, train)
231
+
232
+ branch_pool = avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
233
+ branch_pool = BasicConv2d(out_channels=self.pool_features,
234
+ kernel_size=(1, 1),
235
+ params_dict=utils.get(self.params_dict, 'branch_pool'),
236
+ dtype=self.dtype)(branch_pool, train)
237
+
238
+ output = jnp.concatenate((branch1x1, branch5x5, branch3x3dbl, branch_pool), axis=-1)
239
+ return output
240
+
241
+
242
+ class InceptionB(nn.Module):
243
+ params_dict: dict=None
244
+ dtype: str='float32'
245
+
246
+ @nn.compact
247
+ def __call__(self, x, train=True):
248
+ branch3x3 = BasicConv2d(out_channels=384,
249
+ kernel_size=(3, 3),
250
+ strides=(2, 2),
251
+ params_dict=utils.get(self.params_dict, 'branch3x3'),
252
+ dtype=self.dtype)(x, train)
253
+
254
+ branch3x3dbl = BasicConv2d(out_channels=64,
255
+ kernel_size=(1, 1),
256
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
257
+ dtype=self.dtype)(x, train)
258
+ branch3x3dbl = BasicConv2d(out_channels=96,
259
+ kernel_size=(3, 3),
260
+ padding=((1, 1), (1, 1)),
261
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
262
+ dtype=self.dtype)(branch3x3dbl, train)
263
+ branch3x3dbl = BasicConv2d(out_channels=96,
264
+ kernel_size=(3, 3),
265
+ strides=(2, 2),
266
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_3'),
267
+ dtype=self.dtype)(branch3x3dbl, train)
268
+
269
+ branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
270
+
271
+ output = jnp.concatenate((branch3x3, branch3x3dbl, branch_pool), axis=-1)
272
+ return output
273
+
274
+
275
+ class InceptionC(nn.Module):
276
+ channels_7x7: int
277
+ params_dict: dict=None
278
+ dtype: str='float32'
279
+
280
+ @nn.compact
281
+ def __call__(self, x, train=True):
282
+ branch1x1 = BasicConv2d(out_channels=192,
283
+ kernel_size=(1, 1),
284
+ params_dict=utils.get(self.params_dict, 'branch1x1'),
285
+ dtype=self.dtype)(x, train)
286
+
287
+ branch7x7 = BasicConv2d(out_channels=self.channels_7x7,
288
+ kernel_size=(1, 1),
289
+ params_dict=utils.get(self.params_dict, 'branch7x7_1'),
290
+ dtype=self.dtype)(x, train)
291
+ branch7x7 = BasicConv2d(out_channels=self.channels_7x7,
292
+ kernel_size=(1, 7),
293
+ padding=((0, 0), (3, 3)),
294
+ params_dict=utils.get(self.params_dict, 'branch7x7_2'),
295
+ dtype=self.dtype)(branch7x7, train)
296
+ branch7x7 = BasicConv2d(out_channels=192,
297
+ kernel_size=(7, 1),
298
+ padding=((3, 3), (0, 0)),
299
+ params_dict=utils.get(self.params_dict, 'branch7x7_3'),
300
+ dtype=self.dtype)(branch7x7, train)
301
+
302
+ branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
303
+ kernel_size=(1, 1),
304
+ params_dict=utils.get(self.params_dict, 'branch7x7dbl_1'),
305
+ dtype=self.dtype)(x, train)
306
+ branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
307
+ kernel_size=(7, 1),
308
+ padding=((3, 3), (0, 0)),
309
+ params_dict=utils.get(self.params_dict, 'branch7x7dbl_2'),
310
+ dtype=self.dtype)(branch7x7dbl, train)
311
+ branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
312
+ kernel_size=(1, 7),
313
+ padding=((0, 0), (3, 3)),
314
+ params_dict=utils.get(self.params_dict, 'branch7x7dbl_3'),
315
+ dtype=self.dtype)(branch7x7dbl, train)
316
+ branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
317
+ kernel_size=(7, 1),
318
+ padding=((3, 3), (0, 0)),
319
+ params_dict=utils.get(self.params_dict, 'branch7x7dbl_4'),
320
+ dtype=self.dtype)(branch7x7dbl, train)
321
+ branch7x7dbl = BasicConv2d(out_channels=self.channels_7x7,
322
+ kernel_size=(1, 7),
323
+ padding=((0, 0), (3, 3)),
324
+ params_dict=utils.get(self.params_dict, 'branch7x7dbl_5'),
325
+ dtype=self.dtype)(branch7x7dbl, train)
326
+
327
+ branch_pool = avg_pool(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
328
+ branch_pool = BasicConv2d(out_channels=192,
329
+ kernel_size=(1, 1),
330
+ params_dict=utils.get(self.params_dict, 'branch_pool'),
331
+ dtype=self.dtype)(branch_pool, train)
332
+
333
+ output = jnp.concatenate((branch1x1, branch7x7, branch7x7dbl, branch_pool), axis=-1)
334
+ return output
335
+
336
+
337
+ class InceptionD(nn.Module):
338
+ params_dict: dict=None
339
+ dtype: str='float32'
340
+
341
+ @nn.compact
342
+ def __call__(self, x, train=True):
343
+ branch3x3 = BasicConv2d(out_channels=192,
344
+ kernel_size=(1, 1),
345
+ params_dict=utils.get(self.params_dict, 'branch3x3_1'),
346
+ dtype=self.dtype)(x, train)
347
+ branch3x3 = BasicConv2d(out_channels=320,
348
+ kernel_size=(3, 3),
349
+ strides=(2, 2),
350
+ params_dict=utils.get(self.params_dict, 'branch3x3_2'),
351
+ dtype=self.dtype)(branch3x3, train)
352
+
353
+ branch7x7x3 = BasicConv2d(out_channels=192,
354
+ kernel_size=(1, 1),
355
+ params_dict=utils.get(self.params_dict, 'branch7x7x3_1'),
356
+ dtype=self.dtype)(x, train)
357
+ branch7x7x3 = BasicConv2d(out_channels=192,
358
+ kernel_size=(1, 7),
359
+ padding=((0, 0), (3, 3)),
360
+ params_dict=utils.get(self.params_dict, 'branch7x7x3_2'),
361
+ dtype=self.dtype)(branch7x7x3, train)
362
+ branch7x7x3 = BasicConv2d(out_channels=192,
363
+ kernel_size=(7, 1),
364
+ padding=((3, 3), (0, 0)),
365
+ params_dict=utils.get(self.params_dict, 'branch7x7x3_3'),
366
+ dtype=self.dtype)(branch7x7x3, train)
367
+ branch7x7x3 = BasicConv2d(out_channels=192,
368
+ kernel_size=(3, 3),
369
+ strides=(2, 2),
370
+ params_dict=utils.get(self.params_dict, 'branch7x7x3_4'),
371
+ dtype=self.dtype)(branch7x7x3, train)
372
+
373
+ branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))
374
+
375
+ output = jnp.concatenate((branch3x3, branch7x7x3, branch_pool), axis=-1)
376
+ return output
377
+
378
+
379
+ class InceptionE(nn.Module):
380
+ pooling: Callable
381
+ params_dict: dict=None
382
+ dtype: str='float32'
383
+
384
+ @nn.compact
385
+ def __call__(self, x, train=True):
386
+ branch1x1 = BasicConv2d(out_channels=320,
387
+ kernel_size=(1, 1),
388
+ params_dict=utils.get(self.params_dict, 'branch1x1'),
389
+ dtype=self.dtype)(x, train)
390
+
391
+ branch3x3 = BasicConv2d(out_channels=384,
392
+ kernel_size=(1, 1),
393
+ params_dict=utils.get(self.params_dict, 'branch3x3_1'),
394
+ dtype=self.dtype)(x, train)
395
+ branch3x3_a = BasicConv2d(out_channels=384,
396
+ kernel_size=(1, 3),
397
+ padding=((0, 0), (1, 1)),
398
+ params_dict=utils.get(self.params_dict, 'branch3x3_2a'),
399
+ dtype=self.dtype)(branch3x3, train)
400
+ branch3x3_b = BasicConv2d(out_channels=384,
401
+ kernel_size=(3, 1),
402
+ padding=((1, 1), (0, 0)),
403
+ params_dict=utils.get(self.params_dict, 'branch3x3_2b'),
404
+ dtype=self.dtype)(branch3x3, train)
405
+ branch3x3 = jnp.concatenate((branch3x3_a, branch3x3_b), axis=-1)
406
+
407
+ branch3x3dbl = BasicConv2d(out_channels=448,
408
+ kernel_size=(1, 1),
409
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_1'),
410
+ dtype=self.dtype)(x, train)
411
+ branch3x3dbl = BasicConv2d(out_channels=384,
412
+ kernel_size=(3, 3),
413
+ padding=((1, 1), (1, 1)),
414
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_2'),
415
+ dtype=self.dtype)(branch3x3dbl, train)
416
+ branch3x3dbl_a = BasicConv2d(out_channels=384,
417
+ kernel_size=(1, 3),
418
+ padding=((0, 0), (1, 1)),
419
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_3a'),
420
+ dtype=self.dtype)(branch3x3dbl, train)
421
+ branch3x3dbl_b = BasicConv2d(out_channels=384,
422
+ kernel_size=(3, 1),
423
+ padding=((1, 1), (0, 0)),
424
+ params_dict=utils.get(self.params_dict, 'branch3x3dbl_3b'),
425
+ dtype=self.dtype)(branch3x3dbl, train)
426
+ branch3x3dbl = jnp.concatenate((branch3x3dbl_a, branch3x3dbl_b), axis=-1)
427
+
428
+ branch_pool = self.pooling(x, window_shape=(3, 3), strides=(1, 1), padding=((1, 1), (1, 1)))
429
+ branch_pool = BasicConv2d(out_channels=192,
430
+ kernel_size=(1, 1),
431
+ params_dict=utils.get(self.params_dict, 'branch_pool'),
432
+ dtype=self.dtype)(branch_pool, train)
433
+
434
+ output = jnp.concatenate((branch1x1, branch3x3, branch3x3dbl, branch_pool), axis=-1)
435
+ return output
436
+
437
+
438
+ class InceptionAux(nn.Module):
439
+ num_classes: int
440
+ kernel_init: functools.partial=nn.initializers.lecun_normal()
441
+ bias_init: functools.partial=nn.initializers.zeros
442
+ params_dict: dict=None
443
+ dtype: str='float32'
444
+
445
+ @nn.compact
446
+ def __call__(self, x, train=True):
447
+ x = avg_pool(x, window_shape=(5, 5), strides=(3, 3))
448
+ x = BasicConv2d(out_channels=128,
449
+ kernel_size=(1, 1),
450
+ params_dict=utils.get(self.params_dict, 'conv0'),
451
+ dtype=self.dtype)(x, train)
452
+ x = BasicConv2d(out_channels=768,
453
+ kernel_size=(5, 5),
454
+ params_dict=utils.get(self.params_dict, 'conv1'),
455
+ dtype=self.dtype)(x, train)
456
+ x = jnp.mean(x, axis=(1, 2))
457
+ x = jnp.reshape(x, newshape=(x.shape[0], -1))
458
+ x = Dense(features=self.num_classes,
459
+ params_dict=utils.get(self.params_dict, 'fc'),
460
+ dtype=self.dtype)(x)
461
+ return x
462
+
463
+ def _absolute_dims(rank, dims):
464
+ return tuple([rank + dim if dim < 0 else dim for dim in dims])
465
+
466
+
467
+ class BatchNorm(nn.Module):
468
+ """BatchNorm Module.
469
+ Taken from: https://github.com/google/flax/blob/master/flax/linen/normalization.py
470
+ Attributes:
471
+ use_running_average: if True, the statistics stored in batch_stats
472
+ will be used instead of computing the batch statistics on the input.
473
+ axis: the feature or non-batch axis of the input.
474
+ momentum: decay rate for the exponential moving average of the batch statistics.
475
+ epsilon: a small float added to variance to avoid dividing by zero.
476
+ dtype: the dtype of the computation (default: float32).
477
+ use_bias: if True, bias (beta) is added.
478
+ use_scale: if True, multiply by scale (gamma).
479
+ When the next layer is linear (also e.g. nn.relu), this can be disabled
480
+ since the scaling will be done by the next layer.
481
+ bias_init: initializer for bias, by default, zero.
482
+ scale_init: initializer for scale, by default, one.
483
+ axis_name: the axis name used to combine batch statistics from multiple
484
+ devices. See `jax.pmap` for a description of axis names (default: None).
485
+ axis_index_groups: groups of axis indices within that named axis
486
+ representing subsets of devices to reduce over (default: None). For
487
+ example, `[[0, 1], [2, 3]]` would independently batch-normalize over
488
+ the examples on the first two and last two devices. See `jax.lax.psum`
489
+ for more details.
490
+ """
491
+ use_running_average: Optional[bool] = None
492
+ axis: int = -1
493
+ momentum: float = 0.99
494
+ epsilon: float = 1e-5
495
+ dtype: Dtype = jnp.float32
496
+ use_bias: bool = True
497
+ use_scale: bool = True
498
+ bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros
499
+ scale_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones
500
+ mean_init: Callable[[Shape], Array] = lambda s: jnp.zeros(s, jnp.float32)
501
+ var_init: Callable[[Shape], Array] = lambda s: jnp.ones(s, jnp.float32)
502
+ axis_name: Optional[str] = None
503
+ axis_index_groups: Any = None
504
+
505
+ @nn.compact
506
+ def __call__(self, x, use_running_average: Optional[bool] = None):
507
+ """Normalizes the input using batch statistics.
508
+
509
+ NOTE:
510
+ During initialization (when parameters are mutable) the running average
511
+ of the batch statistics will not be updated. Therefore, the inputs
512
+ fed during initialization don't need to match that of the actual input
513
+ distribution and the reduction axis (set with `axis_name`) does not have
514
+ to exist.
515
+ Args:
516
+ x: the input to be normalized.
517
+ use_running_average: if true, the statistics stored in batch_stats
518
+ will be used instead of computing the batch statistics on the input.
519
+ Returns:
520
+ Normalized inputs (the same shape as inputs).
521
+ """
522
+ use_running_average = merge_param(
523
+ 'use_running_average', self.use_running_average, use_running_average)
524
+ x = jnp.asarray(x, jnp.float32)
525
+ axis = self.axis if isinstance(self.axis, tuple) else (self.axis,)
526
+ axis = _absolute_dims(x.ndim, axis)
527
+ feature_shape = tuple(d if i in axis else 1 for i, d in enumerate(x.shape))
528
+ reduced_feature_shape = tuple(d for i, d in enumerate(x.shape) if i in axis)
529
+ reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)
530
+
531
+ # see NOTE above on initialization behavior
532
+ initializing = self.is_mutable_collection('params')
533
+
534
+ ra_mean = self.variable('batch_stats', 'mean',
535
+ self.mean_init,
536
+ reduced_feature_shape)
537
+ ra_var = self.variable('batch_stats', 'var',
538
+ self.var_init,
539
+ reduced_feature_shape)
540
+
541
+ if use_running_average:
542
+ mean, var = ra_mean.value, ra_var.value
543
+ else:
544
+ mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
545
+ mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False)
546
+ if self.axis_name is not None and not initializing:
547
+ concatenated_mean = jnp.concatenate([mean, mean2])
548
+ mean, mean2 = jnp.split(
549
+ lax.pmean(
550
+ concatenated_mean,
551
+ axis_name=self.axis_name,
552
+ axis_index_groups=self.axis_index_groups), 2)
553
+ var = mean2 - lax.square(mean)
554
+
555
+ if not initializing:
556
+ ra_mean.value = self.momentum * ra_mean.value + (1 - self.momentum) * mean
557
+ ra_var.value = self.momentum * ra_var.value + (1 - self.momentum) * var
558
+
559
+ y = x - mean.reshape(feature_shape)
560
+ mul = lax.rsqrt(var + self.epsilon)
561
+ if self.use_scale:
562
+ scale = self.param('scale',
563
+ self.scale_init,
564
+ reduced_feature_shape).reshape(feature_shape)
565
+ mul = mul * scale
566
+ y = y * mul
567
+ if self.use_bias:
568
+ bias = self.param('bias',
569
+ self.bias_init,
570
+ reduced_feature_shape).reshape(feature_shape)
571
+ y = y + bias
572
+ return jnp.asarray(y, self.dtype)
573
+
574
+
575
+ def pool(inputs, init, reduce_fn, window_shape, strides, padding):
576
+ """
577
+ Taken from: https://github.com/google/flax/blob/main/flax/linen/pooling.py
578
+
579
+ Helper function to define pooling functions.
580
+ Pooling functions are implemented using the ReduceWindow XLA op.
581
+ NOTE: Be aware that pooling is not generally differentiable.
582
+ That means providing a reduce_fn that is differentiable does not imply
583
+ that pool is differentiable.
584
+ Args:
585
+ inputs: input data with dimensions (batch, window dims..., features).
586
+ init: the initial value for the reduction
587
+ reduce_fn: a reduce function of the form `(T, T) -> T`.
588
+ window_shape: a shape tuple defining the window to reduce over.
589
+ strides: a sequence of `n` integers, representing the inter-window
590
+ strides.
591
+ padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
592
+ of `n` `(low, high)` integer pairs that give the padding to apply before
593
+ and after each spatial dimension.
594
+ Returns:
595
+ The output of the reduction for each window slice.
596
+ """
597
+ strides = strides or (1,) * len(window_shape)
598
+ assert len(window_shape) == len(strides), (
599
+ f"len({window_shape}) == len({strides})")
600
+ strides = (1,) + strides + (1,)
601
+ dims = (1,) + window_shape + (1,)
602
+
603
+ is_single_input = False
604
+ if inputs.ndim == len(dims) - 1:
605
+ # add singleton batch dimension because lax.reduce_window always
606
+ # needs a batch dimension.
607
+ inputs = inputs[None]
608
+ is_single_input = True
609
+
610
+ assert inputs.ndim == len(dims), f"len({inputs.shape}) != len({dims})"
611
+ if not isinstance(padding, str):
612
+ padding = tuple(map(tuple, padding))
613
+ assert(len(padding) == len(window_shape)), (
614
+ f"padding {padding} must specify pads for same number of dims as "
615
+ f"window_shape {window_shape}")
616
+ assert(all([len(x) == 2 for x in padding])), (
617
+ f"each entry in padding {padding} must be length 2")
618
+ padding = ((0,0),) + padding + ((0,0),)
619
+ y = jax.lax.reduce_window(inputs, init, reduce_fn, dims, strides, padding)
620
+ if is_single_input:
621
+ y = jnp.squeeze(y, axis=0)
622
+ return y
623
+
624
+
625
+ def avg_pool(inputs, window_shape, strides=None, padding='VALID'):
626
+ """
627
+ Pools the input by taking the average over a window.
628
+
629
+ In comparison to flax.linen.avg_pool, this pooling operation does not
630
+ consider the padded zero's for the average computation.
631
+
632
+ Args:
633
+ inputs: input data with dimensions (batch, window dims..., features).
634
+ window_shape: a shape tuple defining the window to reduce over.
635
+ strides: a sequence of `n` integers, representing the inter-window
636
+ strides (default: `(1, ..., 1)`).
637
+ padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
638
+ of `n` `(low, high)` integer pairs that give the padding to apply before
639
+ and after each spatial dimension (default: `'VALID'`).
640
+ Returns:
641
+ The average for each window slice.
642
+ """
643
+ assert inputs.ndim == 4
644
+ assert len(window_shape) == 2
645
+
646
+ y = pool(inputs, 0., jax.lax.add, window_shape, strides, padding)
647
+ ones = jnp.ones(shape=(1, inputs.shape[1], inputs.shape[2], 1)).astype(inputs.dtype)
648
+ counts = jax.lax.conv_general_dilated(ones,
649
+ jnp.expand_dims(jnp.ones(window_shape).astype(inputs.dtype), axis=(-2, -1)),
650
+ window_strides=(1, 1),
651
+ padding=((1, 1), (1, 1)),
652
+ dimension_numbers=nn.linear._conv_dimension_numbers(ones.shape),
653
+ feature_group_count=1)
654
+ y = y / counts
655
+ return y
fid/utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import flax
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ import requests
6
+ import os
7
+ import tempfile
8
+ import logging
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+
13
+ def download(url, ckpt_dir=None):
14
+ name = url[url.rfind('/') + 1 : url.rfind('?')]
15
+ if ckpt_dir is None:
16
+ ckpt_dir = tempfile.gettempdir()
17
+ ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels')
18
+ ckpt_file = os.path.join(ckpt_dir, name)
19
+ if not os.path.exists(ckpt_file):
20
+ logger.info(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
21
+ if not os.path.exists(ckpt_dir):
22
+ os.makedirs(ckpt_dir)
23
+
24
+ response = requests.get(url, stream=True)
25
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
26
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
27
+
28
+ # first create temp file, in case the download fails
29
+ ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
30
+ with open(ckpt_file_temp, 'wb') as file:
31
+ for data in response.iter_content(chunk_size=1024):
32
+ progress_bar.update(len(data))
33
+ file.write(data)
34
+ progress_bar.close()
35
+
36
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
37
+ logger.error('An error occured while downloading, please try again.')
38
+ if os.path.exists(ckpt_file_temp):
39
+ os.remove(ckpt_file_temp)
40
+ else:
41
+ # if download was successful, rename the temp file
42
+ os.rename(ckpt_file_temp, ckpt_file)
43
+ return ckpt_file
44
+
45
+
46
+ def get(dictionary, key):
47
+ if dictionary is None or key not in dictionary:
48
+ return None
49
+ return dictionary[key]
50
+
51
+
52
+ def prefetch(dataset, n_prefetch):
53
+ # Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
54
+ ds_iter = iter(dataset)
55
+ ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
56
+ ds_iter)
57
+ if n_prefetch:
58
+ ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
59
+ return ds_iter
generate_images.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import logging
4
+ import os
5
+
6
+ import jax
7
+ import jax.numpy as jnp
8
+ import numpy as np
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+
12
+ import checkpoint
13
+ from stylegan2.generator import Generator
14
+
15
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)-5.5s] [%(name)-12.12s]: %(message)s', force=True)
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def generate_images(args):
20
+ logger.info(f"Loading checking '{args.checkpoint}'...")
21
+ ckpt = checkpoint.load_checkpoint(args.checkpoint)
22
+ config = ckpt['config']
23
+ params_ema_G = ckpt['params_ema_G']
24
+
25
+ generator_ema = Generator(
26
+ resolution=config.resolution,
27
+ num_channels=config.img_channels,
28
+ z_dim=config.z_dim,
29
+ c_dim=config.c_dim,
30
+ w_dim=config.w_dim,
31
+ num_ws=int(np.log2(config.resolution)) * 2 - 3,
32
+ num_mapping_layers=8,
33
+ fmap_base=config.fmap_base,
34
+ dtype=jnp.float32
35
+ )
36
+
37
+ generator_apply = jax.jit(
38
+ functools.partial(generator_ema.apply, truncation_psi=args.truncation_psi, train=False, noise_mode='const')
39
+ )
40
+
41
+ logger.info(f"Generating {len(args.seeds)} images with truncation {args.truncation_psi}...")
42
+ for seed in tqdm(args.seeds):
43
+ rng = jax.random.PRNGKey(seed)
44
+ z_latent = jax.random.normal(rng, shape=(1, config.z_dim))
45
+ image = generator_apply(params_ema_G, jax.lax.stop_gradient(z_latent), None)
46
+ image = (image - jnp.min(image)) / (jnp.max(image) - jnp.min(image))
47
+
48
+ Image.fromarray(np.uint8(np.clip(image[0] * 255, 0, 255))).save(os.path.join(args.out_path, f'{seed}.png'))
49
+ logger.info(f"Images saved in '{args.out_path}/'")
50
+
51
+
52
+ if __name__ == '__main__':
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument('--checkpoint', type=str, help='Path to the checkpoint.', required=True)
55
+ parser.add_argument('--out_path', type=str, default='generated_images', help='Path where the generated images are stored.')
56
+ parser.add_argument('--truncation_psi', type=float, default=0.5, help='Controls truncation (trading off variation for quality). If 1, truncation is disabled.')
57
+ parser.add_argument('--seeds', type=int, nargs='*', default=[0], help='List of random seeds.')
58
+ args = parser.parse_args()
59
+ os.makedirs(args.out_path, exist_ok=True)
60
+
61
+ generate_images(args)
main.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import jax
4
+ import wandb
5
+ import training
6
+ import logging
7
+ import json
8
+
9
+
10
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(levelname)-5.5s] [%(name)-12.12s]: %(message)s', force=True)
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def main():
15
+ parser = argparse.ArgumentParser()
16
+ # Paths
17
+ parser.add_argument('--data_dir', type=str, required=True, help='Directory of the dataset.')
18
+ parser.add_argument('--save_dir', type=str, default='gs://ig-standard-usc1/sg2-flax/checkpoints/', help='Directory where checkpoints will be written to. A subfolder with run_id will be created.')
19
+ parser.add_argument('--load_from_pkl', type=str, help='If provided, start training from an existing checkpoint pickle file.')
20
+ parser.add_argument('--resume_run_id', type=str, help='If provided, resume existing training run. If --wandb is enabled W&B will also resume.')
21
+ parser.add_argument('--project', type=str, default='sg2-flax', help='Name of this project.')
22
+ # Training
23
+ parser.add_argument('--num_epochs', type=int, default=10000, help='Number of epochs.')
24
+ parser.add_argument('--learning_rate', type=float, default=0.002, help='Learning rate.')
25
+ parser.add_argument('--batch_size', type=int, default=8, help='Batch size.')
26
+ parser.add_argument('--num_prefetch', type=int, default=2, help='Number of prefetched examples for the data pipeline.')
27
+ parser.add_argument('--resolution', type=int, default=128, help='Image resolution. Must be a multiple of 2.')
28
+ parser.add_argument('--img_channels', type=int, default=3, help='Number of image channels.')
29
+ parser.add_argument('--mixed_precision', action='store_true', help='Use mixed precision training.')
30
+ parser.add_argument('--random_seed', type=int, default=0, help='Random seed.')
31
+ parser.add_argument('--bf16', action='store_true', help='Use bf16 dtype (This is still WIP).')
32
+ # Generator
33
+ parser.add_argument('--fmap_base', type=int, default=16384, help='Overall multiplier for the number of feature maps.')
34
+ # Discriminator
35
+ parser.add_argument('--mbstd_group_size', type=int, help='Group size for the minibatch standard deviation layer, None = entire minibatch.')
36
+ # Exponentially Moving Average of Generator Weights
37
+ parser.add_argument('--ema_kimg', type=float, default=20.0, help='Controls the ema of the generator weights (larger value -> larger beta).')
38
+ # Losses
39
+ parser.add_argument('--pl_decay', type=float, default=0.01, help='Exponentially decay for mean of path length (Path length regul).')
40
+ parser.add_argument('--pl_weight', type=float, default=2, help='Weight for path length regularization.')
41
+ # Regularization
42
+ parser.add_argument('--mixing_prob', type=float, default=0.9, help='Probability for style mixing.')
43
+ parser.add_argument('--G_reg_interval', type=int, default=4, help='How often to perform regularization for G.')
44
+ parser.add_argument('--D_reg_interval', type=int, default=16, help='How often to perform regularization for D.')
45
+ parser.add_argument('--r1_gamma', type=float, default=10.0, help='Weight for R1 regularization.')
46
+ # Model
47
+ parser.add_argument('--z_dim', type=int, default=512, help='Input latent (Z) dimensionality.')
48
+ parser.add_argument('--c_dim', type=int, default=0, help='Conditioning label (C) dimensionality, 0 = no label.')
49
+ parser.add_argument('--w_dim', type=int, default=512, help='Conditioning label (W) dimensionality.')
50
+ # Logging
51
+ parser.add_argument('--log_every', type=int, default=100, help='Log every log_every steps.')
52
+ parser.add_argument('--save_every', type=int, default=2000, help='Save every save_every steps. Will be ignored if FID evaluation is enabled.')
53
+ parser.add_argument('--generate_samples_every', type=int, default=10000, help='Generate samples every generate_samples_every steps.')
54
+ parser.add_argument('--debug', action='store_true', help='Show debug log.')
55
+ # FID
56
+ parser.add_argument('--eval_fid_every', type=int, default=1000, help='Compute FID score every eval_fid_every steps.')
57
+ parser.add_argument('--num_fid_images', type=int, default=10000, help='Number of images to use for FID computation.')
58
+ parser.add_argument('--disable_fid', action='store_true', help='Disable FID evaluation.')
59
+ # W&B
60
+ parser.add_argument('--wandb', action='store_true', help='Log to Weights&Biases.')
61
+ parser.add_argument('--name', type=str, default=None, help='Name of this experiment in Weights&Biases.')
62
+ parser.add_argument('--entity', type=str, default='nyxai', help='Entity for this experiment in Weights&Biases.')
63
+ parser.add_argument('--group', type=str, default=None, help='Group name of this experiment for Weights&Biases.')
64
+
65
+ args = parser.parse_args()
66
+
67
+ # debug mode
68
+ if args.debug:
69
+ logging.getLogger().setLevel(logging.DEBUG)
70
+
71
+ # some validation
72
+ if args.resume_run_id is not None:
73
+ assert args.load_from_pkl is None, 'When resuming a run one cannot also specify --load_from_pkl'
74
+
75
+ # set unique Run ID
76
+ if args.resume_run_id:
77
+ resume = 'must' # throw error if cannot find id to be resumed
78
+ args.run_id = args.resume_run_id
79
+ else:
80
+ resume = None # default
81
+ args.run_id = wandb.util.generate_id()
82
+ args.ckpt_dir = os.path.join(args.save_dir, args.run_id)
83
+
84
+ if jax.process_index() == 0:
85
+ if not args.ckpt_dir.startswith('gs://') and not os.path.exists(args.ckpt_dir):
86
+ os.makedirs(args.ckpt_dir)
87
+ if args.wandb:
88
+ wandb.init(id=args.run_id,
89
+ project=args.project,
90
+ group=args.group,
91
+ config=args,
92
+ name=args.name,
93
+ entity=args.entity,
94
+ resume=resume)
95
+ logger.info('Starting new run with config:')
96
+ print(json.dumps(vars(args), indent=4))
97
+
98
+ training.train_and_evaluate(args)
99
+
100
+
101
+ if __name__ == '__main__':
102
+ main()
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ flaxmodels==0.1.1
2
+ flax==0.4.1
3
+ jax==0.3.14
4
+ tensorflow==2.4.1
5
+ optax==0.0.9
6
+ numpy
7
+ tensorflow-datasets
8
+ argparse
9
+ wandb
10
+ tqdm
11
+ dill
12
+ h5py
13
+ dataclasses
14
+ tqdm
stylegan2/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .generator import SynthesisNetwork
2
+ from .generator import MappingNetwork
3
+ from .generator import Generator
4
+ from .discriminator import Discriminator
5
+
stylegan2/discriminator.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import jax
3
+ from jax import random
4
+ import jax.numpy as jnp
5
+ import flax.linen as nn
6
+ from typing import Any, Tuple, List, Callable
7
+ import h5py
8
+ from . import ops
9
+ from stylegan2 import utils
10
+
11
+
12
+ URLS = {'afhqcat': 'https://www.dropbox.com/s/qygbjkefyqyu9k9/stylegan2_discriminator_afhqcat.h5?dl=1',
13
+ 'afhqdog': 'https://www.dropbox.com/s/kmoxbp33qswz64p/stylegan2_discriminator_afhqdog.h5?dl=1',
14
+ 'afhqwild': 'https://www.dropbox.com/s/jz1hpsyt3isj6e7/stylegan2_discriminator_afhqwild.h5?dl=1',
15
+ 'brecahad': 'https://www.dropbox.com/s/h0cb89hruo6pmyj/stylegan2_discriminator_brecahad.h5?dl=1',
16
+ 'car': 'https://www.dropbox.com/s/2ghjrmxih7cic76/stylegan2_discriminator_car.h5?dl=1',
17
+ 'cat': 'https://www.dropbox.com/s/zfhjsvlsny5qixd/stylegan2_discriminator_cat.h5?dl=1',
18
+ 'church': 'https://www.dropbox.com/s/jlno7zeivkjtk8g/stylegan2_discriminator_church.h5?dl=1',
19
+ 'cifar10': 'https://www.dropbox.com/s/eldpubfkl4c6rur/stylegan2_discriminator_cifar10.h5?dl=1',
20
+ 'ffhq': 'https://www.dropbox.com/s/m42qy9951b7lq1s/stylegan2_discriminator_ffhq.h5?dl=1',
21
+ 'horse': 'https://www.dropbox.com/s/19f5pxrcdh2g8cw/stylegan2_discriminator_horse.h5?dl=1',
22
+ 'metfaces': 'https://www.dropbox.com/s/xnokaunql12glkd/stylegan2_discriminator_metfaces.h5?dl=1'}
23
+
24
+ RESOLUTION = {'metfaces': 1024,
25
+ 'ffhq': 1024,
26
+ 'church': 256,
27
+ 'cat': 256,
28
+ 'horse': 256,
29
+ 'car': 512,
30
+ 'brecahad': 512,
31
+ 'afhqwild': 512,
32
+ 'afhqdog': 512,
33
+ 'afhqcat': 512,
34
+ 'cifar10': 32}
35
+
36
+ C_DIM = {'metfaces': 0,
37
+ 'ffhq': 0,
38
+ 'church': 0,
39
+ 'cat': 0,
40
+ 'horse': 0,
41
+ 'car': 0,
42
+ 'brecahad': 0,
43
+ 'afhqwild': 0,
44
+ 'afhqdog': 0,
45
+ 'afhqcat': 0,
46
+ 'cifar10': 10}
47
+
48
+ ARCHITECTURE = {'metfaces': 'resnet',
49
+ 'ffhq': 'resnet',
50
+ 'church': 'resnet',
51
+ 'cat': 'resnet',
52
+ 'horse': 'resnet',
53
+ 'car': 'resnet',
54
+ 'brecahad': 'resnet',
55
+ 'afhqwild': 'resnet',
56
+ 'afhqdog': 'resnet',
57
+ 'afhqcat': 'resnet',
58
+ 'cifar10': 'orig'}
59
+
60
+ MBSTD_GROUP_SIZE = {'metfaces': None,
61
+ 'ffhq': None,
62
+ 'church': None,
63
+ 'cat': None,
64
+ 'horse': None,
65
+ 'car': None,
66
+ 'brecahad': None,
67
+ 'afhqwild': None,
68
+ 'afhqdog': None,
69
+ 'afhqcat': None,
70
+ 'cifar10': 32}
71
+
72
+
73
+ class FromRGBLayer(nn.Module):
74
+ """
75
+ From RGB Layer.
76
+
77
+ Attributes:
78
+ fmaps (int): Number of output channels of the convolution.
79
+ kernel (int): Kernel size of the convolution.
80
+ lr_multiplier (float): Learning rate multiplier.
81
+ activation (str): Activation function: 'relu', 'lrelu', etc.
82
+ param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
83
+ clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
84
+ dtype (str): Data dtype.
85
+ rng (jax.random.PRNGKey): PRNG for initialization.
86
+ """
87
+ fmaps: int
88
+ kernel: int=1
89
+ lr_multiplier: float=1
90
+ activation: str='leaky_relu'
91
+ param_dict: h5py.Group=None
92
+ clip_conv: float=None
93
+ dtype: str='float32'
94
+ rng: Any=random.PRNGKey(0)
95
+
96
+ @nn.compact
97
+ def __call__(self, x, y):
98
+ """
99
+ Run From RGB Layer.
100
+
101
+ Args:
102
+ x (tensor): Input image of shape [N, H, W, num_channels].
103
+ y (tensor): Input tensor of shape [N, H, W, out_channels].
104
+
105
+ Returns:
106
+ (tensor): Output tensor of shape [N, H, W, out_channels].
107
+ """
108
+ w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps]
109
+ w, b = ops.get_weight(w_shape, self.lr_multiplier, True, self.param_dict, 'fromrgb', self.rng)
110
+
111
+ w = self.param(name='weight', init_fn=lambda *_ : w)
112
+ b = self.param(name='bias', init_fn=lambda *_ : b)
113
+ w = ops.equalize_lr_weight(w, self.lr_multiplier)
114
+ b = ops.equalize_lr_bias(b, self.lr_multiplier)
115
+
116
+ x = x.astype(self.dtype)
117
+ x = ops.conv2d(x, w.astype(x.dtype))
118
+ x += b.astype(x.dtype)
119
+ x = ops.apply_activation(x, activation=self.activation)
120
+ if self.clip_conv is not None:
121
+ x = jnp.clip(x, -self.clip_conv, self.clip_conv)
122
+ if y is not None:
123
+ x += y
124
+ return x
125
+
126
+
127
+ class DiscriminatorLayer(nn.Module):
128
+ """
129
+ Discriminator Layer.
130
+
131
+ Attributes:
132
+ fmaps (int): Number of output channels of the convolution.
133
+ kernel (int): Kernel size of the convolution.
134
+ use_bias (bool): If True, use bias.
135
+ down (bool): If True, downsample the spatial resolution.
136
+ resample_kernel (Tuple): Kernel that is used for FIR filter.
137
+ activation (str): Activation function: 'relu', 'lrelu', etc.
138
+ layer_name (str): Layer name.
139
+ param_dict (h5py.Group): Parameter dict with pretrained parameters.
140
+ lr_multiplier (float): Learning rate multiplier.
141
+ clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
142
+ dtype (str): Data dtype.
143
+ rng (jax.random.PRNGKey): PRNG for initialization.
144
+ """
145
+ fmaps: int
146
+ kernel: int=3
147
+ use_bias: bool=True
148
+ down: bool=False
149
+ resample_kernel: Tuple=None
150
+ activation: str='leaky_relu'
151
+ layer_name: str=None
152
+ param_dict: h5py.Group=None
153
+ lr_multiplier: float=1
154
+ clip_conv: float=None
155
+ dtype: str='float32'
156
+ rng: Any=random.PRNGKey(0)
157
+
158
+ @nn.compact
159
+ def __call__(self, x):
160
+ """
161
+ Run Discriminator Layer.
162
+
163
+ Args:
164
+ x (tensor): Input tensor of shape [N, H, W, C].
165
+
166
+ Returns:
167
+ (tensor): Output tensor of shape [N, H, W, fmaps].
168
+ """
169
+ w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps]
170
+ if self.use_bias:
171
+ w, b = ops.get_weight(w_shape, self.lr_multiplier, self.use_bias, self.param_dict, self.layer_name, self.rng)
172
+ else:
173
+ w = ops.get_weight(w_shape, self.lr_multiplier, self.use_bias, self.param_dict, self.layer_name, self.rng)
174
+
175
+ w = self.param(name='weight', init_fn=lambda *_ : w)
176
+ w = ops.equalize_lr_weight(w, self.lr_multiplier)
177
+ if self.use_bias:
178
+ b = self.param(name='bias', init_fn=lambda *_ : b)
179
+ b = ops.equalize_lr_bias(b, self.lr_multiplier)
180
+
181
+ x = x.astype(self.dtype)
182
+ x = ops.conv2d(x, w, down=self.down, resample_kernel=self.resample_kernel)
183
+ if self.use_bias: x += b.astype(x.dtype)
184
+ x = ops.apply_activation(x, activation=self.activation)
185
+ if self.clip_conv is not None:
186
+ x = jnp.clip(x, -self.clip_conv, self.clip_conv)
187
+ return x
188
+
189
+
190
+ class DiscriminatorBlock(nn.Module):
191
+ """
192
+ Discriminator Block.
193
+
194
+ Attributes:
195
+ fmaps (int): Number of output channels of the convolution.
196
+ kernel (int): Kernel size of the convolution.
197
+ resample_kernel (Tuple): Kernel that is used for FIR filter.
198
+ activation (str): Activation function: 'relu', 'lrelu', etc.
199
+ param_dict (h5py.Group): Parameter dict with pretrained parameters.
200
+ lr_multiplier (float): Learning rate multiplier.
201
+ architecture (str): Architecture: 'orig', 'resnet'.
202
+ nf (Callable): Callable that returns the number of feature maps for a given layer.
203
+ clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
204
+ dtype (str): Data dtype.
205
+ rng (jax.random.PRNGKey): Random seed for initialization.
206
+ """
207
+ res: int
208
+ kernel: int=3
209
+ resample_kernel: Tuple=(1, 3, 3, 1)
210
+ activation: str='leaky_relu'
211
+ param_dict: Any=None
212
+ lr_multiplier: float=1
213
+ architecture: str='resnet'
214
+ nf: Callable=None
215
+ clip_conv: float=None
216
+ dtype: str='float32'
217
+ rng: Any=random.PRNGKey(0)
218
+
219
+ @nn.compact
220
+ def __call__(self, x):
221
+ """
222
+ Run Discriminator Block.
223
+
224
+ Args:
225
+ x (tensor): Input tensor of shape [N, H, W, C].
226
+
227
+ Returns:
228
+ (tensor): Output tensor of shape [N, H, W, fmaps].
229
+ """
230
+ init_rng = self.rng
231
+ x = x.astype(self.dtype)
232
+ residual = x
233
+ for i in range(2):
234
+ init_rng, init_key = random.split(init_rng)
235
+ x = DiscriminatorLayer(fmaps=self.nf(self.res - (i + 1)),
236
+ kernel=self.kernel,
237
+ down=i == 1,
238
+ resample_kernel=self.resample_kernel if i == 1 else None,
239
+ activation=self.activation,
240
+ layer_name=f'conv{i}',
241
+ param_dict=self.param_dict,
242
+ lr_multiplier=self.lr_multiplier,
243
+ clip_conv=self.clip_conv,
244
+ dtype=self.dtype,
245
+ rng=init_key)(x)
246
+
247
+
248
+ if self.architecture == 'resnet':
249
+ init_rng, init_key = random.split(init_rng)
250
+ residual = DiscriminatorLayer(fmaps=self.nf(self.res - 2),
251
+ kernel=1,
252
+ use_bias=False,
253
+ down=True,
254
+ resample_kernel=self.resample_kernel,
255
+ activation='linear',
256
+ layer_name='skip',
257
+ param_dict=self.param_dict,
258
+ lr_multiplier=self.lr_multiplier,
259
+ dtype=self.dtype,
260
+ rng=init_key)(residual)
261
+
262
+ x = (x + residual) * np.sqrt(0.5, dtype=x.dtype)
263
+ return x
264
+
265
+
266
+ class Discriminator(nn.Module):
267
+ """
268
+ Discriminator.
269
+
270
+ Attributes:
271
+ resolution (int): Input resolution. Overridden based on dataset.
272
+ num_channels (int): Number of input color channels. Overridden based on dataset.
273
+ c_dim (int): Dimensionality of the labels (c), 0 if no labels. Overrttten based on dataset.
274
+ fmap_base (int): Overall multiplier for the number of feature maps.
275
+ fmap_decay (int): Log2 feature map reduction when doubling the resolution.
276
+ fmap_min (int): Minimum number of feature maps in any layer.
277
+ fmap_max (int): Maximum number of feature maps in any layer.
278
+ mapping_layers (int): Number of additional mapping layers for the conditioning labels.
279
+ mapping_fmaps (int): Number of activations in the mapping layers, None = default.
280
+ mapping_lr_multiplier (float): Learning rate multiplier for the mapping layers.
281
+ architecture (str): Architecture: 'orig', 'resnet'.
282
+ activation (int): Activation function: 'relu', 'leaky_relu', etc.
283
+ mbstd_group_size (int): Group size for the minibatch standard deviation layer, None = entire minibatch.
284
+ mbstd_num_features (int): Number of features for the minibatch standard deviation layer, 0 = disable.
285
+ resample_kernel (Tuple): Low-pass filter to apply when resampling activations, None = box filter.
286
+ num_fp16_res (int): Use float16 for the 'num_fp16_res' highest resolutions.
287
+ clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
288
+ pretrained (str): Use pretrained model, None for random initialization.
289
+ ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.
290
+ dtype (str): Data type.
291
+ rng (jax.random.PRNGKey): PRNG for initialization.
292
+ """
293
+ # Input dimensions.
294
+ resolution: int=1024
295
+ num_channels: int=3
296
+ c_dim: int=0
297
+
298
+ # Capacity.
299
+ fmap_base: int=16384
300
+ fmap_decay: int=1
301
+ fmap_min: int=1
302
+ fmap_max: int=512
303
+
304
+ # Internal details.
305
+ mapping_layers: int=0
306
+ mapping_fmaps: int=None
307
+ mapping_lr_multiplier: float=0.1
308
+ architecture: str='resnet'
309
+ activation: str='leaky_relu'
310
+ mbstd_group_size: int=None
311
+ mbstd_num_features: int=1
312
+ resample_kernel: Tuple=(1, 3, 3, 1)
313
+ num_fp16_res: int=0
314
+ clip_conv: float=None
315
+
316
+ # Pretraining
317
+ pretrained: str=None
318
+ ckpt_dir: str=None
319
+
320
+ dtype: str='float32'
321
+ rng: Any=random.PRNGKey(0)
322
+
323
+ def setup(self):
324
+ self.resolution_ = self.resolution
325
+ self.c_dim_ = self.c_dim
326
+ self.architecture_ = self.architecture
327
+ self.mbstd_group_size_ = self.mbstd_group_size
328
+ self.param_dict = None
329
+ if self.pretrained is not None:
330
+ assert self.pretrained in URLS.keys(), f'Pretrained model not available: {self.pretrained}'
331
+ ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])
332
+ self.param_dict = h5py.File(ckpt_file, 'r')['discriminator']
333
+ self.resolution_ = RESOLUTION[self.pretrained]
334
+ self.architecture_ = ARCHITECTURE[self.pretrained]
335
+ self.mbstd_group_size_ = MBSTD_GROUP_SIZE[self.pretrained]
336
+ self.c_dim_ = C_DIM[self.pretrained]
337
+
338
+ assert self.architecture in ['orig', 'resnet']
339
+
340
+ @nn.compact
341
+ def __call__(self, x, c=None):
342
+ """
343
+ Run Discriminator.
344
+
345
+ Args:
346
+ x (tensor): Input image of shape [N, H, W, num_channels].
347
+ c (tensor): Input labels, shape [N, c_dim].
348
+
349
+ Returns:
350
+ (tensor): Output tensor of shape [N, 1].
351
+ """
352
+ resolution_log2 = int(np.log2(self.resolution_))
353
+ assert self.resolution_ == 2**resolution_log2 and self.resolution_ >= 4
354
+ def nf(stage): return np.clip(int(self.fmap_base / (2.0 ** (stage * self.fmap_decay))), self.fmap_min, self.fmap_max)
355
+ if self.mapping_fmaps is None:
356
+ mapping_fmaps = nf(0)
357
+ else:
358
+ mapping_fmaps = self.mapping_fmaps
359
+
360
+ init_rng = self.rng
361
+ # Label embedding and mapping.
362
+ if self.c_dim_ > 0:
363
+ c = ops.LinearLayer(in_features=self.c_dim_,
364
+ out_features=mapping_fmaps,
365
+ lr_multiplier=self.mapping_lr_multiplier,
366
+ param_dict=self.param_dict,
367
+ layer_name='label_embedding',
368
+ dtype=self.dtype,
369
+ rng=init_rng)(c)
370
+
371
+ c = ops.normalize_2nd_moment(c)
372
+ for i in range(self.mapping_layers):
373
+ init_rng, init_key = random.split(init_rng)
374
+ c = ops.LinearLayer(in_features=self.c_dim_,
375
+ out_features=mapping_fmaps,
376
+ lr_multiplier=self.mapping_lr_multiplier,
377
+ param_dict=self.param_dict,
378
+ layer_name=f'fc{i}',
379
+ dtype=self.dtype,
380
+ rng=init_key)(c)
381
+
382
+ # Layers for >=8x8 resolutions.
383
+ y = None
384
+ for res in range(resolution_log2, 2, -1):
385
+ res_str = f'block_{2**res}x{2**res}'
386
+ if res == resolution_log2:
387
+ init_rng, init_key = random.split(init_rng)
388
+ x = FromRGBLayer(fmaps=nf(res - 1),
389
+ kernel=1,
390
+ activation=self.activation,
391
+ param_dict=self.param_dict[res_str] if self.param_dict is not None else None,
392
+ clip_conv=self.clip_conv,
393
+ dtype=self.dtype if res >= resolution_log2 + 1 - self.num_fp16_res else 'float32',
394
+ rng=init_key)(x, y)
395
+
396
+ init_rng, init_key = random.split(init_rng)
397
+ x = DiscriminatorBlock(res=res,
398
+ kernel=3,
399
+ resample_kernel=self.resample_kernel,
400
+ activation=self.activation,
401
+ param_dict=self.param_dict[res_str] if self.param_dict is not None else None,
402
+ architecture=self.architecture_,
403
+ nf=nf,
404
+ clip_conv=self.clip_conv,
405
+ dtype=self.dtype if res >= resolution_log2 + 1 - self.num_fp16_res else 'float32',
406
+ rng=init_key)(x)
407
+
408
+ # Layers for 4x4 resolution.
409
+ dtype = jnp.float32
410
+ x = x.astype(dtype)
411
+ if self.mbstd_num_features > 0:
412
+ x = ops.minibatch_stddev_layer(x, self.mbstd_group_size_, self.mbstd_num_features)
413
+ init_rng, init_key = random.split(init_rng)
414
+ x = DiscriminatorLayer(fmaps=nf(1),
415
+ kernel=3,
416
+ use_bias=True,
417
+ activation=self.activation,
418
+ layer_name='conv0',
419
+ param_dict=self.param_dict['block_4x4'] if self.param_dict is not None else None,
420
+ clip_conv=self.clip_conv,
421
+ dtype=dtype,
422
+ rng=init_rng)(x)
423
+
424
+ # Switch to NCHW so that the pretrained weights still work after reshaping
425
+ x = jnp.transpose(x, axes=(0, 3, 1, 2))
426
+ x = jnp.reshape(x, newshape=(-1, x.shape[1] * x.shape[2] * x.shape[3]))
427
+
428
+ init_rng, init_key = random.split(init_rng)
429
+ x = ops.LinearLayer(in_features=x.shape[1],
430
+ out_features=nf(0),
431
+ activation=self.activation,
432
+ param_dict=self.param_dict['block_4x4'] if self.param_dict is not None else None,
433
+ layer_name='fc0',
434
+ dtype=dtype,
435
+ rng=init_key)(x)
436
+
437
+ # Output layer.
438
+ init_rng, init_key = random.split(init_rng)
439
+ x = ops.LinearLayer(in_features=x.shape[1],
440
+ out_features=1 if self.c_dim_ == 0 else mapping_fmaps,
441
+ param_dict=self.param_dict,
442
+ layer_name='output',
443
+ dtype=dtype,
444
+ rng=init_key)(x)
445
+
446
+ if self.c_dim_ > 0:
447
+ x = jnp.sum(x * c, axis=1, keepdims=True) / jnp.sqrt(mapping_fmaps)
448
+ return x
449
+
450
+
451
+
stylegan2/generator.py ADDED
@@ -0,0 +1,713 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import jax
3
+ from jax import random
4
+ import jax.numpy as jnp
5
+ import flax.linen as nn
6
+ from typing import Any, Tuple, List
7
+ import h5py
8
+ from . import ops
9
+ from stylegan2 import utils
10
+
11
+
12
+ URLS = {'afhqcat': 'https://www.dropbox.com/s/lv1r0bwvg5ta51f/stylegan2_generator_afhqcat.h5?dl=1',
13
+ 'afhqdog': 'https://www.dropbox.com/s/px6ply9hv0vdwen/stylegan2_generator_afhqdog.h5?dl=1',
14
+ 'afhqwild': 'https://www.dropbox.com/s/p1slbtmzhcnw9q8/stylegan2_generator_afhqwild.h5?dl=1',
15
+ 'brecahad': 'https://www.dropbox.com/s/28uykhj0ku6hwg2/stylegan2_generator_brecahad.h5?dl=1',
16
+ 'car': 'https://www.dropbox.com/s/67o834b6xfg9x1q/stylegan2_generator_car.h5?dl=1',
17
+ 'cat': 'https://www.dropbox.com/s/cu9egc4e74e1nig/stylegan2_generator_cat.h5?dl=1',
18
+ 'church': 'https://www.dropbox.com/s/kwvokfwbrhsn58m/stylegan2_generator_church.h5?dl=1',
19
+ 'cifar10': 'https://www.dropbox.com/s/h1kmymjzfwwkftk/stylegan2_generator_cifar10.h5?dl=1',
20
+ 'ffhq': 'https://www.dropbox.com/s/e8de1peq7p8gq9d/stylegan2_generator_ffhq.h5?dl=1',
21
+ 'horse': 'https://www.dropbox.com/s/3e5bimv2d41bc13/stylegan2_generator_horse.h5?dl=1',
22
+ 'metfaces': 'https://www.dropbox.com/s/75klr5k6mgm7qdy/stylegan2_generator_metfaces.h5?dl=1'}
23
+
24
+ RESOLUTION = {'metfaces': 1024,
25
+ 'ffhq': 1024,
26
+ 'church': 256,
27
+ 'cat': 256,
28
+ 'horse': 256,
29
+ 'car': 512,
30
+ 'brecahad': 512,
31
+ 'afhqwild': 512,
32
+ 'afhqdog': 512,
33
+ 'afhqcat': 512,
34
+ 'cifar10': 32}
35
+
36
+ C_DIM = {'metfaces': 0,
37
+ 'ffhq': 0,
38
+ 'church': 0,
39
+ 'cat': 0,
40
+ 'horse': 0,
41
+ 'car': 0,
42
+ 'brecahad': 0,
43
+ 'afhqwild': 0,
44
+ 'afhqdog': 0,
45
+ 'afhqcat': 0,
46
+ 'cifar10': 10}
47
+
48
+ NUM_MAPPING_LAYERS = {'metfaces': 8,
49
+ 'ffhq': 8,
50
+ 'church': 8,
51
+ 'cat': 8,
52
+ 'horse': 8,
53
+ 'car': 8,
54
+ 'brecahad': 8,
55
+ 'afhqwild': 8,
56
+ 'afhqdog': 8,
57
+ 'afhqcat': 8,
58
+ 'cifar10': 2}
59
+
60
+
61
+ class MappingNetwork(nn.Module):
62
+ """
63
+ Mapping Network.
64
+
65
+ Attributes:
66
+ z_dim (int): Input latent (Z) dimensionality.
67
+ c_dim (int): Conditioning label (C) dimensionality, 0 = no label.
68
+ w_dim (int): Intermediate latent (W) dimensionality.
69
+ embed_features (int): Label embedding dimensionality, None = same as w_dim.
70
+ layer_features (int): Number of intermediate features in the mapping layers, None = same as w_dim.
71
+ num_ws (int): Number of intermediate latents to output, None = do not broadcast.
72
+ num_layers (int): Number of mapping layers.
73
+ pretrained (str): Which pretrained model to use, None for random initialization.
74
+ param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
75
+ ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.
76
+ activation (str): Activation function: 'relu', 'lrelu', etc.
77
+ lr_multiplier (float): Learning rate multiplier for the mapping layers.
78
+ w_avg_beta (float): Decay for tracking the moving average of W during training, None = do not track.
79
+ dtype (str): Data type.
80
+ rng (jax.random.PRNGKey): PRNG for initialization.
81
+ """
82
+ # Dimensionality
83
+ z_dim: int=512
84
+ c_dim: int=0
85
+ w_dim: int=512
86
+ embed_features: int=None
87
+ layer_features: int=512
88
+
89
+ # Layers
90
+ num_ws: int=18
91
+ num_layers: int=8
92
+
93
+ # Pretrained
94
+ pretrained: str=None
95
+ param_dict: h5py.Group=None
96
+ ckpt_dir: str=None
97
+
98
+ # Internal details
99
+ activation: str='leaky_relu'
100
+ lr_multiplier: float=0.01
101
+ w_avg_beta: float=0.995
102
+ dtype: str='float32'
103
+ rng: Any=random.PRNGKey(0)
104
+
105
+ def setup(self):
106
+ self.embed_features_ = self.embed_features
107
+ self.c_dim_ = self.c_dim
108
+ self.layer_features_ = self.layer_features
109
+ self.num_layers_ = self.num_layers
110
+ self.param_dict_ = self.param_dict
111
+
112
+ if self.pretrained is not None and self.param_dict is None:
113
+ assert self.pretrained in URLS.keys(), f'Pretrained model not available: {self.pretrained}'
114
+ ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])
115
+ self.param_dict_ = h5py.File(ckpt_file, 'r')['mapping_network']
116
+ self.c_dim_ = C_DIM[self.pretrained]
117
+ self.num_layers_ = NUM_MAPPING_LAYERS[self.pretrained]
118
+
119
+ if self.embed_features_ is None:
120
+ self.embed_features_ = self.w_dim
121
+ if self.c_dim_ == 0:
122
+ self.embed_features_ = 0
123
+ if self.layer_features_ is None:
124
+ self.layer_features_ = self.w_dim
125
+
126
+ if self.param_dict_ is not None and 'w_avg' in self.param_dict_:
127
+ self.w_avg = self.variable('moving_stats', 'w_avg', lambda *_ : jnp.array(self.param_dict_['w_avg']), [self.w_dim])
128
+ else:
129
+ self.w_avg = self.variable('moving_stats', 'w_avg', jnp.zeros, [self.w_dim])
130
+
131
+ @nn.compact
132
+ def __call__(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False, train=True):
133
+ """
134
+ Run Mapping Network.
135
+
136
+ Args:
137
+ z (tensor): Input noise, shape [N, z_dim].
138
+ c (tensor): Input labels, shape [N, c_dim].
139
+ truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled.
140
+ truncation_cutoff (int): Controls truncation. None = disable.
141
+ skip_w_avg_update (bool): If True, updates the exponential moving average of W.
142
+ train (bool): Training mode.
143
+
144
+ Returns:
145
+ (tensor): Intermediate latent W.
146
+ """
147
+ init_rng = self.rng
148
+ # Embed, normalize, and concat inputs.
149
+ x = None
150
+ if self.z_dim > 0:
151
+ x = ops.normalize_2nd_moment(z.astype(jnp.float32))
152
+ if self.c_dim_ > 0:
153
+ # Conditioning label
154
+ y = ops.LinearLayer(in_features=self.c_dim_,
155
+ out_features=self.embed_features_,
156
+ use_bias=True,
157
+ lr_multiplier=self.lr_multiplier,
158
+ activation='linear',
159
+ param_dict=self.param_dict_,
160
+ layer_name='label_embedding',
161
+ dtype=self.dtype,
162
+ rng=init_rng)(c.astype(jnp.float32))
163
+
164
+ y = ops.normalize_2nd_moment(y)
165
+ x = jnp.concatenate((x, y), axis=1) if x is not None else y
166
+
167
+ # Main layers.
168
+ for i in range(self.num_layers_):
169
+ init_rng, init_key = random.split(init_rng)
170
+ x = ops.LinearLayer(in_features=x.shape[1],
171
+ out_features=self.layer_features_,
172
+ use_bias=True,
173
+ lr_multiplier=self.lr_multiplier,
174
+ activation=self.activation,
175
+ param_dict=self.param_dict_,
176
+ layer_name=f'fc{i}',
177
+ dtype=self.dtype,
178
+ rng=init_key)(x)
179
+
180
+ # Update moving average of W.
181
+ if self.w_avg_beta is not None and train and not skip_w_avg_update:
182
+ self.w_avg.value = self.w_avg_beta * self.w_avg.value + (1 - self.w_avg_beta) * jnp.mean(x, axis=0)
183
+
184
+ # Broadcast.
185
+ if self.num_ws is not None:
186
+ x = jnp.repeat(jnp.expand_dims(x, axis=-2), repeats=self.num_ws, axis=-2)
187
+
188
+ # Apply truncation.
189
+ if truncation_psi != 1:
190
+ assert self.w_avg_beta is not None
191
+ if self.num_ws is None or truncation_cutoff is None:
192
+ x = truncation_psi * x + (1 - truncation_psi) * self.w_avg.value
193
+ else:
194
+ x[:, :truncation_cutoff] = truncation_psi * x[:, :truncation_cutoff] + (1 - truncation_psi) * self.w_avg.value
195
+
196
+ return x
197
+
198
+
199
+ class SynthesisLayer(nn.Module):
200
+ """
201
+ Synthesis Layer.
202
+
203
+ Attributes:
204
+ fmaps (int): Number of output channels of the modulated convolution.
205
+ kernel (int): Kernel size of the modulated convolution.
206
+ layer_idx (int): Layer index. Used to access the latent code for a specific layer.
207
+ res (int): Resolution (log2) of the current layer.
208
+ lr_multiplier (float): Learning rate multiplier.
209
+ up (bool): If True, upsample the spatial resolution.
210
+ activation (str): Activation function: 'relu', 'lrelu', etc.
211
+ use_noise (bool): If True, add spatial-specific noise.
212
+ resample_kernel (Tuple): Kernel that is used for FIR filter.
213
+ fused_modconv (bool): If True, Perform modulation, convolution, and demodulation as a single fused operation.
214
+ param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
215
+ clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
216
+ dtype (str): Data dtype.
217
+ rng (jax.random.PRNGKey): PRNG for initialization.
218
+ """
219
+ fmaps: int
220
+ kernel: int
221
+ layer_idx: int
222
+ res: int
223
+ lr_multiplier: float=1
224
+ up: bool=False
225
+ activation: str='leaky_relu'
226
+ use_noise: bool=True
227
+ resample_kernel: Tuple=(1, 3, 3, 1)
228
+ fused_modconv: bool=False
229
+ param_dict: h5py.Group=None
230
+ clip_conv: float=None
231
+ dtype: str='float32'
232
+ rng: Any=random.PRNGKey(0)
233
+
234
+ def setup(self):
235
+ if self.param_dict is not None:
236
+ noise_const = jnp.array(self.param_dict['noise_const'], dtype=self.dtype)
237
+ else:
238
+ noise_const = random.normal(self.rng, shape=(1, 2 ** self.res, 2 ** self.res, 1), dtype=self.dtype)
239
+ self.noise_const = self.variable('noise_consts', 'noise_const', lambda *_: noise_const)
240
+
241
+ @nn.compact
242
+ def __call__(self, x, dlatents, noise_mode='random', rng=random.PRNGKey(0)):
243
+ """
244
+ Run Synthesis Layer.
245
+
246
+ Args:
247
+ x (tensor): Input tensor of the shape [N, H, W, C].
248
+ dlatents (tensor): Intermediate latents (W) of shape [N, num_ws, w_dim].
249
+ noise_mode (str): Noise type.
250
+ - 'const': Constant noise.
251
+ - 'random': Random noise.
252
+ - 'none': No noise.
253
+ rng (jax.random.PRNGKey): PRNG for spatialwise noise.
254
+
255
+ Returns:
256
+ (tensor): Output tensor of shape [N, H', W', fmaps].
257
+ """
258
+ assert noise_mode in ['const', 'random', 'none']
259
+
260
+ linear_rng, conv_rng = random.split(self.rng)
261
+ # Affine transformation to obtain style variable.
262
+ s = ops.LinearLayer(in_features=dlatents[:, self.layer_idx].shape[1],
263
+ out_features=x.shape[3],
264
+ use_bias=True,
265
+ bias_init=1,
266
+ lr_multiplier=self.lr_multiplier,
267
+ param_dict=self.param_dict,
268
+ layer_name='affine',
269
+ dtype=self.dtype,
270
+ rng=linear_rng)(dlatents[:, self.layer_idx])
271
+
272
+ # Noise variables.
273
+ if self.param_dict is None:
274
+ noise_strength = jnp.zeros(())
275
+ else:
276
+ noise_strength = jnp.array(self.param_dict['noise_strength'])
277
+ noise_strength = self.param(name='noise_strength', init_fn=lambda *_ : noise_strength)
278
+
279
+ # Weight and bias for convolution operation.
280
+ w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps]
281
+ w, b = ops.get_weight(w_shape, self.lr_multiplier, True, self.param_dict, 'conv', conv_rng)
282
+ w = self.param(name='weight', init_fn=lambda *_ : w)
283
+ b = self.param(name='bias', init_fn=lambda *_ : b)
284
+ w = ops.equalize_lr_weight(w, self.lr_multiplier)
285
+ b = ops.equalize_lr_bias(b, self.lr_multiplier)
286
+
287
+ x = ops.modulated_conv2d_layer(x=x,
288
+ w=w,
289
+ s=s,
290
+ fmaps=self.fmaps,
291
+ kernel=self.kernel,
292
+ up=self.up,
293
+ resample_kernel=self.resample_kernel,
294
+ fused_modconv=self.fused_modconv)
295
+
296
+ if self.use_noise and noise_mode != 'none':
297
+ if noise_mode == 'const':
298
+ noise = self.noise_const.value
299
+ elif noise_mode == 'random':
300
+ noise = random.normal(rng, shape=(x.shape[0], x.shape[1], x.shape[2], 1), dtype=self.dtype)
301
+ x += noise * noise_strength.astype(self.dtype)
302
+ x += b.astype(x.dtype)
303
+ x = ops.apply_activation(x, activation=self.activation)
304
+ if self.clip_conv is not None:
305
+ x = jnp.clip(x, -self.clip_conv, self.clip_conv)
306
+ return x
307
+
308
+
309
+ class ToRGBLayer(nn.Module):
310
+ """
311
+ To RGB Layer.
312
+
313
+ Attributes:
314
+ fmaps (int): Number of output channels of the modulated convolution.
315
+ layer_idx (int): Layer index. Used to access the latent code for a specific layer.
316
+ kernel (int): Kernel size of the modulated convolution.
317
+ lr_multiplier (float): Learning rate multiplier.
318
+ fused_modconv (bool): If True, Perform modulation, convolution, and demodulation as a single fused operation.
319
+ param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
320
+ clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
321
+ dtype (str): Data dtype.
322
+ rng (jax.random.PRNGKey): PRNG for initialization.
323
+ """
324
+ fmaps: int
325
+ layer_idx: int
326
+ kernel: int=1
327
+ lr_multiplier: float=1
328
+ fused_modconv: bool=False
329
+ param_dict: h5py.Group=None
330
+ clip_conv: float=None
331
+ dtype: str='float32'
332
+ rng: Any=random.PRNGKey(0)
333
+
334
+ @nn.compact
335
+ def __call__(self, x, y, dlatents):
336
+ """
337
+ Run To RGB Layer.
338
+
339
+ Args:
340
+ x (tensor): Input tensor of shape [N, H, W, C].
341
+ y (tensor): Image of shape [N, H', W', fmaps].
342
+ dlatents (tensor): Intermediate latents (W) of shape [N, num_ws, w_dim].
343
+
344
+ Returns:
345
+ (tensor): Output tensor of shape [N, H', W', fmaps].
346
+ """
347
+ # Affine transformation to obtain style variable.
348
+ s = ops.LinearLayer(in_features=dlatents[:, self.layer_idx].shape[1],
349
+ out_features=x.shape[3],
350
+ use_bias=True,
351
+ bias_init=1,
352
+ lr_multiplier=self.lr_multiplier,
353
+ param_dict=self.param_dict,
354
+ layer_name='affine',
355
+ dtype=self.dtype,
356
+ rng=self.rng)(dlatents[:, self.layer_idx])
357
+
358
+ # Weight and bias for convolution operation.
359
+ w_shape = [self.kernel, self.kernel, x.shape[3], self.fmaps]
360
+ w, b = ops.get_weight(w_shape, self.lr_multiplier, True, self.param_dict, 'conv', self.rng)
361
+ w = self.param(name='weight', init_fn=lambda *_ : w)
362
+ b = self.param(name='bias', init_fn=lambda *_ : b)
363
+ w = ops.equalize_lr_weight(w, self.lr_multiplier)
364
+ b = ops.equalize_lr_bias(b, self.lr_multiplier)
365
+
366
+ x = ops.modulated_conv2d_layer(x, w, s, fmaps=self.fmaps, kernel=self.kernel, demodulate=False, fused_modconv=self.fused_modconv)
367
+ x += b.astype(x.dtype)
368
+ x = ops.apply_activation(x, activation='linear')
369
+ if self.clip_conv is not None:
370
+ x = jnp.clip(x, -self.clip_conv, self.clip_conv)
371
+ if y is not None:
372
+ x += y.astype(jnp.float32)
373
+ return x
374
+
375
+
376
+ class SynthesisBlock(nn.Module):
377
+ """
378
+ Synthesis Block.
379
+
380
+ Attributes:
381
+ fmaps (int): Number of output channels of the modulated convolution.
382
+ res (int): Resolution (log2) of the current block.
383
+ num_layers (int): Number of layers in the current block.
384
+ num_channels (int): Number of output color channels.
385
+ lr_multiplier (float): Learning rate multiplier.
386
+ activation (str): Activation function: 'relu', 'lrelu', etc.
387
+ use_noise (bool): If True, add spatial-specific noise.
388
+ resample_kernel (Tuple): Kernel that is used for FIR filter.
389
+ fused_modconv (bool): If True, Perform modulation, convolution, and demodulation as a single fused operation.
390
+ param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
391
+ clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
392
+ dtype (str): Data dtype.
393
+ rng (jax.random.PRNGKey): PRNG for initialization.
394
+ """
395
+ fmaps: int
396
+ res: int
397
+ num_layers: int=2
398
+ num_channels: int=3
399
+ lr_multiplier: float=1
400
+ activation: str='leaky_relu'
401
+ use_noise: bool=True
402
+ resample_kernel: Tuple=(1, 3, 3, 1)
403
+ fused_modconv: bool=False
404
+ param_dict: h5py.Group=None
405
+ clip_conv: float=None
406
+ dtype: str='float32'
407
+ rng: Any=random.PRNGKey(0)
408
+
409
+ @nn.compact
410
+ def __call__(self, x, y, dlatents_in, noise_mode='random', rng=random.PRNGKey(0)):
411
+ """
412
+ Run Synthesis Block.
413
+
414
+ Args:
415
+ x (tensor): Input tensor of shape [N, H, W, C].
416
+ y (tensor): Image of shape [N, H', W', fmaps].
417
+ dlatents (tensor): Intermediate latents (W) of shape [N, num_ws, w_dim].
418
+ noise_mode (str): Noise type.
419
+ - 'const': Constant noise.
420
+ - 'random': Random noise.
421
+ - 'none': No noise.
422
+ rng (jax.random.PRNGKey): PRNG for spatialwise noise.
423
+
424
+ Returns:
425
+ (tensor): Output tensor of shape [N, H', W', fmaps].
426
+ """
427
+ x = x.astype(self.dtype)
428
+ init_rng = self.rng
429
+ for i in range(self.num_layers):
430
+ init_rng, init_key = random.split(init_rng)
431
+ x = SynthesisLayer(fmaps=self.fmaps,
432
+ kernel=3,
433
+ layer_idx=self.res * 2 - (5 - i) if self.res > 2 else 0,
434
+ res=self.res,
435
+ lr_multiplier=self.lr_multiplier,
436
+ up=i == 0 and self.res != 2,
437
+ activation=self.activation,
438
+ use_noise=self.use_noise,
439
+ resample_kernel=self.resample_kernel,
440
+ fused_modconv=self.fused_modconv,
441
+ param_dict=self.param_dict[f'layer{i}'] if self.param_dict is not None else None,
442
+ dtype=self.dtype,
443
+ rng=init_key)(x, dlatents_in, noise_mode, rng)
444
+
445
+ if self.num_layers == 2:
446
+ k = ops.setup_filter(self.resample_kernel)
447
+ y = ops.upsample2d(y, f=k, up=2)
448
+
449
+ init_rng, init_key = random.split(init_rng)
450
+ y = ToRGBLayer(fmaps=self.num_channels,
451
+ layer_idx=self.res * 2 - 3,
452
+ lr_multiplier=self.lr_multiplier,
453
+ param_dict=self.param_dict['torgb'] if self.param_dict is not None else None,
454
+ dtype=self.dtype,
455
+ rng=init_key)(x, y, dlatents_in)
456
+ return x, y
457
+
458
+
459
+ class SynthesisNetwork(nn.Module):
460
+ """
461
+ Synthesis Network.
462
+
463
+ Attributes:
464
+ resolution (int): Output resolution.
465
+ num_channels (int): Number of output color channels.
466
+ w_dim (int): Input latent (Z) dimensionality.
467
+ fmap_base (int): Overall multiplier for the number of feature maps.
468
+ fmap_decay (int): Log2 feature map reduction when doubling the resolution.
469
+ fmap_min (int): Minimum number of feature maps in any layer.
470
+ fmap_max (int): Maximum number of feature maps in any layer.
471
+ fmap_const (int): Number of feature maps in the constant input layer. None = default.
472
+ pretrained (str): Which pretrained model to use, None for random initialization.
473
+ param_dict (h5py.Group): Parameter dict with pretrained parameters. If not None, 'pretrained' will be ignored.
474
+ ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.
475
+ activation (str): Activation function: 'relu', 'lrelu', etc.
476
+ use_noise (bool): If True, add spatial-specific noise.
477
+ resample_kernel (Tuple): Kernel that is used for FIR filter.
478
+ fused_modconv (bool): If True, Perform modulation, convolution, and demodulation as a single fused operation.
479
+ num_fp16_res (int): Use float16 for the 'num_fp16_res' highest resolutions.
480
+ clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
481
+ dtype (str): Data type.
482
+ rng (jax.random.PRNGKey): PRNG for initialization.
483
+ """
484
+ # Dimensionality
485
+ resolution: int=1024
486
+ num_channels: int=3
487
+ w_dim: int=512
488
+
489
+ # Capacity
490
+ fmap_base: int=16384
491
+ fmap_decay: int=1
492
+ fmap_min: int=1
493
+ fmap_max: int=512
494
+ fmap_const: int=None
495
+
496
+ # Pretraining
497
+ pretrained: str=None
498
+ param_dict: h5py.Group=None
499
+ ckpt_dir: str=None
500
+
501
+ # Internal details
502
+ activation: str='leaky_relu'
503
+ use_noise: bool=True
504
+ resample_kernel: Tuple=(1, 3, 3, 1)
505
+ fused_modconv: bool=False
506
+ num_fp16_res: int=0
507
+ clip_conv: float=None
508
+ dtype: str='float32'
509
+ rng: Any=random.PRNGKey(0)
510
+
511
+ def setup(self):
512
+ self.resolution_ = self.resolution
513
+ self.param_dict_ = self.param_dict
514
+ if self.pretrained is not None and self.param_dict is None:
515
+ assert self.pretrained in URLS.keys(), f'Pretrained model not available: {self.pretrained}'
516
+ ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])
517
+ self.param_dict_ = h5py.File(ckpt_file, 'r')['synthesis_network']
518
+ self.resolution_ = RESOLUTION[self.pretrained]
519
+
520
+ @nn.compact
521
+ def __call__(self, dlatents_in, noise_mode='random', rng=random.PRNGKey(0)):
522
+ """
523
+ Run Synthesis Network.
524
+
525
+ Args:
526
+ dlatents_in (tensor): Intermediate latents (W) of shape [N, num_ws, w_dim].
527
+ noise_mode (str): Noise type.
528
+ - 'const': Constant noise.
529
+ - 'random': Random noise.
530
+ - 'none': No noise.
531
+ rng (jax.random.PRNGKey): PRNG for spatialwise noise.
532
+
533
+ Returns:
534
+ (tensor): Image of shape [N, H, W, num_channels].
535
+ """
536
+ resolution_log2 = int(np.log2(self.resolution_))
537
+ assert self.resolution_ == 2 ** resolution_log2 and self.resolution_ >= 4
538
+
539
+ def nf(stage): return np.clip(int(self.fmap_base / (2.0 ** (stage * self.fmap_decay))), self.fmap_min, self.fmap_max)
540
+ num_layers = resolution_log2 * 2 - 2
541
+
542
+ fmaps = self.fmap_const if self.fmap_const is not None else nf(1)
543
+
544
+ if self.param_dict_ is None:
545
+ const = random.normal(self.rng, (1, 4, 4, fmaps), dtype=self.dtype)
546
+ else:
547
+ const = jnp.array(self.param_dict_['const'], dtype=self.dtype)
548
+ x = self.param(name='const', init_fn=lambda *_ : const)
549
+ x = jnp.repeat(x, repeats=dlatents_in.shape[0], axis=0)
550
+
551
+ y = None
552
+
553
+ dlatents_in = dlatents_in.astype(jnp.float32)
554
+
555
+ init_rng = self.rng
556
+ for res in range(2, resolution_log2 + 1):
557
+ init_rng, init_key = random.split(init_rng)
558
+ x, y = SynthesisBlock(fmaps=nf(res - 1),
559
+ res=res,
560
+ num_layers=1 if res == 2 else 2,
561
+ num_channels=self.num_channels,
562
+ activation=self.activation,
563
+ use_noise=self.use_noise,
564
+ resample_kernel=self.resample_kernel,
565
+ fused_modconv=self.fused_modconv,
566
+ param_dict=self.param_dict_[f'block_{2 ** res}x{2 ** res}'] if self.param_dict_ is not None else None,
567
+ clip_conv=self.clip_conv,
568
+ dtype=self.dtype if res > resolution_log2 - self.num_fp16_res else 'float32',
569
+ rng=init_key)(x, y, dlatents_in, noise_mode, rng)
570
+
571
+ return y
572
+
573
+
574
+ class Generator(nn.Module):
575
+ """
576
+ Generator.
577
+
578
+ Attributes:
579
+ resolution (int): Output resolution.
580
+ num_channels (int): Number of output color channels.
581
+ z_dim (int): Input latent (Z) dimensionality.
582
+ c_dim (int): Conditioning label (C) dimensionality, 0 = no label.
583
+ w_dim (int): Intermediate latent (W) dimensionality.
584
+ mapping_layer_features (int): Number of intermediate features in the mapping layers, None = same as w_dim.
585
+ mapping_embed_features (int): Label embedding dimensionality, None = same as w_dim.
586
+ num_ws (int): Number of intermediate latents to output, None = do not broadcast.
587
+ num_mapping_layers (int): Number of mapping layers.
588
+ fmap_base (int): Overall multiplier for the number of feature maps.
589
+ fmap_decay (int): Log2 feature map reduction when doubling the resolution.
590
+ fmap_min (int): Minimum number of feature maps in any layer.
591
+ fmap_max (int): Maximum number of feature maps in any layer.
592
+ fmap_const (int): Number of feature maps in the constant input layer. None = default.
593
+ pretrained (str): Which pretrained model to use, None for random initialization.
594
+ ckpt_dir (str): Directory to which the pretrained weights are downloaded. If None, a temp directory will be used.
595
+ use_noise (bool): If True, add spatial-specific noise.
596
+ activation (str): Activation function: 'relu', 'lrelu', etc.
597
+ w_avg_beta (float): Decay for tracking the moving average of W during training, None = do not track.
598
+ mapping_lr_multiplier (float): Learning rate multiplier for the mapping network.
599
+ resample_kernel (Tuple): Kernel that is used for FIR filter.
600
+ fused_modconv (bool): If True, Perform modulation, convolution, and demodulation as a single fused operation.
601
+ num_fp16_res (int): Use float16 for the 'num_fp16_res' highest resolutions.
602
+ clip_conv (float): Clip the output of convolution layers to [-clip_conv, +clip_conv], None = disable clipping.
603
+ dtype (str): Data type.
604
+ rng (jax.random.PRNGKey): PRNG for initialization.
605
+ """
606
+ # Dimensionality
607
+ resolution: int=1024
608
+ num_channels: int=3
609
+ z_dim: int=512
610
+ c_dim: int=0
611
+ w_dim: int=512
612
+ mapping_layer_features: int=512
613
+ mapping_embed_features: int=None
614
+
615
+ # Layers
616
+ num_ws: int=18
617
+ num_mapping_layers: int=8
618
+
619
+ # Capacity
620
+ fmap_base: int=16384
621
+ fmap_decay: int=1
622
+ fmap_min: int=1
623
+ fmap_max: int=512
624
+ fmap_const: int=None
625
+
626
+ # Pretraining
627
+ pretrained: str=None
628
+ ckpt_dir: str=None
629
+
630
+ # Internal details
631
+ use_noise: bool=True
632
+ activation: str='leaky_relu'
633
+ w_avg_beta: float=0.995
634
+ mapping_lr_multiplier: float=0.01
635
+ resample_kernel: Tuple=(1, 3, 3, 1)
636
+ fused_modconv: bool=False
637
+ num_fp16_res: int=0
638
+ clip_conv: float=None
639
+ dtype: str='float32'
640
+ rng: Any=random.PRNGKey(0)
641
+
642
+ def setup(self):
643
+ self.resolution_ = self.resolution
644
+ self.c_dim_ = self.c_dim
645
+ self.num_mapping_layers_ = self.num_mapping_layers
646
+ if self.pretrained is not None:
647
+ assert self.pretrained in URLS.keys(), f'Pretrained model not available: {self.pretrained}'
648
+ ckpt_file = utils.download(self.ckpt_dir, URLS[self.pretrained])
649
+ self.param_dict = h5py.File(ckpt_file, 'r')
650
+ self.resolution_ = RESOLUTION[self.pretrained]
651
+ self.c_dim_ = C_DIM[self.pretrained]
652
+ self.num_mapping_layers_ = NUM_MAPPING_LAYERS[self.pretrained]
653
+ else:
654
+ self.param_dict = None
655
+ self.init_rng_mapping, self.init_rng_synthesis = random.split(self.rng)
656
+
657
+ @nn.compact
658
+ def __call__(self, z, c=None, truncation_psi=1, truncation_cutoff=None, skip_w_avg_update=False, train=True, noise_mode='random', rng=random.PRNGKey(0)):
659
+ """
660
+ Run Generator.
661
+
662
+ Args:
663
+ z (tensor): Input noise, shape [N, z_dim].
664
+ c (tensor): Input labels, shape [N, c_dim].
665
+ truncation_psi (float): Controls truncation (trading off variation for quality). If 1, truncation is disabled.
666
+ truncation_cutoff (int): Controls truncation. None = disable.
667
+ skip_w_avg_update (bool): If True, updates the exponential moving average of W.
668
+ train (bool): Training mode.
669
+ noise_mode (str): Noise type.
670
+ - 'const': Constant noise.
671
+ - 'random': Random noise.
672
+ - 'none': No noise.
673
+ rng (jax.random.PRNGKey): PRNG for spatialwise noise.
674
+
675
+ Returns:
676
+ (tensor): Image of shape [N, H, W, num_channels].
677
+ """
678
+ dlatents_in = MappingNetwork(z_dim=self.z_dim,
679
+ c_dim=self.c_dim_,
680
+ w_dim=self.w_dim,
681
+ num_ws=self.num_ws,
682
+ num_layers=self.num_mapping_layers_,
683
+ embed_features=self.mapping_embed_features,
684
+ layer_features=self.mapping_layer_features,
685
+ activation=self.activation,
686
+ lr_multiplier=self.mapping_lr_multiplier,
687
+ w_avg_beta=self.w_avg_beta,
688
+ param_dict=self.param_dict['mapping_network'] if self.param_dict is not None else None,
689
+ dtype=self.dtype,
690
+ rng=self.init_rng_mapping,
691
+ name='mapping_network')(z, c, truncation_psi, truncation_cutoff, skip_w_avg_update, train)
692
+
693
+ x = SynthesisNetwork(resolution=self.resolution_,
694
+ num_channels=self.num_channels,
695
+ w_dim=self.w_dim,
696
+ fmap_base=self.fmap_base,
697
+ fmap_decay=self.fmap_decay,
698
+ fmap_min=self.fmap_min,
699
+ fmap_max=self.fmap_max,
700
+ fmap_const=self.fmap_const,
701
+ param_dict=self.param_dict['synthesis_network'] if self.param_dict is not None else None,
702
+ activation=self.activation,
703
+ use_noise=self.use_noise,
704
+ resample_kernel=self.resample_kernel,
705
+ fused_modconv=self.fused_modconv,
706
+ num_fp16_res=self.num_fp16_res,
707
+ clip_conv=self.clip_conv,
708
+ dtype=self.dtype,
709
+ rng=self.init_rng_synthesis,
710
+ name='synthesis_network')(dlatents_in, noise_mode, rng)
711
+
712
+ return x
713
+
stylegan2/ops.py ADDED
@@ -0,0 +1,674 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from jax import random
4
+ import flax.linen as nn
5
+ from jax import jit
6
+ import numpy as np
7
+ from functools import partial
8
+ from typing import Any
9
+ import h5py
10
+
11
+
12
+ #------------------------------------------------------
13
+ # Other
14
+ #------------------------------------------------------
15
+ def minibatch_stddev_layer(x, group_size=None, num_new_features=1):
16
+ if group_size is None:
17
+ group_size = x.shape[0]
18
+ else:
19
+ # Minibatch must be divisible by (or smaller than) group_size.
20
+ group_size = min(group_size, x.shape[0])
21
+
22
+ G = group_size
23
+ F = num_new_features
24
+ _, H, W, C = x.shape
25
+ c = C // F
26
+
27
+ # [NHWC] Cast to FP32.
28
+ y = x.astype(jnp.float32)
29
+ # [GnHWFc] Split minibatch N into n groups of size G, and channels C into F groups of size c.
30
+ y = jnp.reshape(y, newshape=(G, -1, H, W, F, c))
31
+ # [GnHWFc] Subtract mean over group.
32
+ y -= jnp.mean(y, axis=0)
33
+ # [nHWFc] Calc variance over group.
34
+ y = jnp.mean(jnp.square(y), axis=0)
35
+ # [nHWFc] Calc stddev over group.
36
+ y = jnp.sqrt(y + 1e-8)
37
+ # [nF] Take average over channels and pixels.
38
+ y = jnp.mean(y, axis=(1, 2, 4))
39
+ # [nF] Cast back to original data type.
40
+ y = y.astype(x.dtype)
41
+ # [n11F] Add missing dimensions.
42
+ y = jnp.reshape(y, newshape=(-1, 1, 1, F))
43
+ # [NHWC] Replicate over group and pixels.
44
+ y = jnp.tile(y, (G, H, W, 1))
45
+ return jnp.concatenate((x, y), axis=3)
46
+
47
+
48
+ #------------------------------------------------------
49
+ # Activation
50
+ #------------------------------------------------------
51
+ def apply_activation(x, activation='linear', alpha=0.2, gain=np.sqrt(2)):
52
+ gain = jnp.array(gain, dtype=x.dtype)
53
+ if activation == 'relu':
54
+ return jax.nn.relu(x) * gain
55
+ if activation == 'leaky_relu':
56
+ return jax.nn.leaky_relu(x, negative_slope=alpha) * gain
57
+ return x
58
+
59
+
60
+ #------------------------------------------------------
61
+ # Weights
62
+ #------------------------------------------------------
63
+ def get_weight(shape, lr_multiplier=1, bias=True, param_dict=None, layer_name='', key=None):
64
+ if param_dict is None:
65
+ w = random.normal(key, shape=shape, dtype=jnp.float32) / lr_multiplier
66
+ if bias: b = jnp.zeros(shape=(shape[-1],), dtype=jnp.float32)
67
+ else:
68
+ w = jnp.array(param_dict[layer_name]['weight']).astype(jnp.float32)
69
+ if bias: b = jnp.array(param_dict[layer_name]['bias']).astype(jnp.float32)
70
+
71
+ if bias: return w, b
72
+ return w
73
+
74
+
75
+ def equalize_lr_weight(w, lr_multiplier=1):
76
+ """
77
+ Equalized learning rate, see: https://arxiv.org/pdf/1710.10196.pdf.
78
+
79
+ Args:
80
+ w (tensor): Weight parameter. Shape [kernel, kernel, fmaps_in, fmaps_out]
81
+ for convolutions and shape [in, out] for MLPs.
82
+ lr_multiplier (float): Learning rate multiplier.
83
+
84
+ Returns:
85
+ (tensor): Scaled weight parameter.
86
+ """
87
+ in_features = np.prod(w.shape[:-1])
88
+ gain = lr_multiplier / np.sqrt(in_features)
89
+ w *= gain
90
+ return w
91
+
92
+
93
+ def equalize_lr_bias(b, lr_multiplier=1):
94
+ """
95
+ Equalized learning rate, see: https://arxiv.org/pdf/1710.10196.pdf.
96
+
97
+ Args:
98
+ b (tensor): Bias parameter.
99
+ lr_multiplier (float): Learning rate multiplier.
100
+
101
+ Returns:
102
+ (tensor): Scaled bias parameter.
103
+ """
104
+ gain = lr_multiplier
105
+ b *= gain
106
+ return b
107
+
108
+
109
+ #------------------------------------------------------
110
+ # Normalization
111
+ #------------------------------------------------------
112
+ def normalize_2nd_moment(x, eps=1e-8):
113
+ return x * jax.lax.rsqrt(jnp.mean(jnp.square(x), axis=1, keepdims=True) + eps)
114
+
115
+
116
+ #------------------------------------------------------
117
+ # Upsampling
118
+ #------------------------------------------------------
119
+ def setup_filter(f, normalize=True, flip_filter=False, gain=1, separable=None):
120
+ """
121
+ Convenience function to setup 2D FIR filter for `upfirdn2d()`.
122
+
123
+ Args:
124
+ f (tensor): Tensor or python list of the shape.
125
+ normalize (bool): Normalize the filter so that it retains the magnitude.
126
+ for constant input signal (DC)? (default: True).
127
+ flip_filter (bool): Flip the filter? (default: False).
128
+ gain (int): Overall scaling factor for signal magnitude (default: 1).
129
+ separable: Return a separable filter? (default: select automatically).
130
+
131
+ Returns:
132
+ (tensor): Output filter of shape [filter_height, filter_width] or [filter_taps]
133
+ """
134
+ # Validate.
135
+ if f is None:
136
+ f = 1
137
+ f = jnp.array(f, dtype=jnp.float32)
138
+ assert f.ndim in [0, 1, 2]
139
+ assert f.size > 0
140
+ if f.ndim == 0:
141
+ f = f[jnp.newaxis]
142
+
143
+ # Separable?
144
+ if separable is None:
145
+ separable = (f.ndim == 1 and f.size >= 8)
146
+ if f.ndim == 1 and not separable:
147
+ f = jnp.outer(f, f)
148
+ assert f.ndim == (1 if separable else 2)
149
+
150
+ # Apply normalize, flip, gain, and device.
151
+ if normalize:
152
+ f /= jnp.sum(f)
153
+ if flip_filter:
154
+ for i in range(f.ndim):
155
+ f = jnp.flip(f, axis=i)
156
+ f = f * (gain ** (f.ndim / 2))
157
+ return f
158
+
159
+
160
+ def upfirdn2d(x, f, padding=(2, 1, 2, 1), up=1, down=1, strides=(1, 1), flip_filter=False, gain=1):
161
+
162
+ if f is None:
163
+ f = jnp.ones((1, 1), dtype=jnp.float32)
164
+
165
+ B, H, W, C = x.shape
166
+ padx0, padx1, pady0, pady1 = padding
167
+
168
+ # upsample by inserting zeros
169
+ x = jnp.reshape(x, newshape=(B, H, 1, W, 1, C))
170
+ x = jnp.pad(x, pad_width=((0, 0), (0, 0), (0, up - 1), (0, 0), (0, up - 1), (0, 0)))
171
+ x = jnp.reshape(x, newshape=(B, H * up, W * up, C))
172
+
173
+ # padding
174
+ x = jnp.pad(x, pad_width=((0, 0), (max(pady0, 0), max(pady1, 0)), (max(padx0, 0), max(padx1, 0)), (0, 0)))
175
+ x = x[:, max(-pady0, 0) : x.shape[1] - max(-pady1, 0), max(-padx0, 0) : x.shape[2] - max(-padx1, 0)]
176
+
177
+ # setup filter
178
+ f = f * (gain ** (f.ndim / 2))
179
+ if not flip_filter:
180
+ for i in range(f.ndim):
181
+ f = jnp.flip(f, axis=i)
182
+
183
+ # convole filter
184
+ f = jnp.repeat(jnp.expand_dims(f, axis=(-2, -1)), repeats=C, axis=-1)
185
+ if f.ndim == 4:
186
+ x = jax.lax.conv_general_dilated(x,
187
+ f.astype(x.dtype),
188
+ window_strides=strides or (1,) * (x.ndim - 2),
189
+ padding='valid',
190
+ dimension_numbers=nn.linear._conv_dimension_numbers(x.shape),
191
+ feature_group_count=C)
192
+ else:
193
+ x = jax.lax.conv_general_dilated(x,
194
+ jnp.expand_dims(f, axis=0).astype(x.dtype),
195
+ window_strides=strides or (1,) * (x.ndim - 2),
196
+ padding='valid',
197
+ dimension_numbers=nn.linear._conv_dimension_numbers(x.shape),
198
+ feature_group_count=C)
199
+ x = jax.lax.conv_general_dilated(x,
200
+ jnp.expand_dims(f, axis=1).astype(x.dtype),
201
+ window_strides=strides or (1,) * (x.ndim - 2),
202
+ padding='valid',
203
+ dimension_numbers=nn.linear._conv_dimension_numbers(x.shape),
204
+ feature_group_count=C)
205
+ x = x[:, ::down, ::down]
206
+ return x
207
+
208
+
209
+ def upsample2d(x, f, up=2, padding=0, flip_filter=False, gain=1):
210
+ if f.ndim == 1:
211
+ fh, fw = f.shape[0], f.shape[0]
212
+ elif f.ndim == 2:
213
+ fh, fw = f.shape[0], f.shape[1]
214
+ else:
215
+ raise ValueError('Invalid filter shape:', f.shape)
216
+ padx0 = padding + (fw + up - 1) // 2
217
+ padx1 = padding + (fw - up) // 2
218
+ pady0 = padding + (fh + up - 1) // 2
219
+ pady1 = padding + (fh - up) // 2
220
+ return upfirdn2d(x, f=f, up=up, padding=(padx0, padx1, pady0, pady1), flip_filter=flip_filter, gain=gain * up * up)
221
+
222
+
223
+ #------------------------------------------------------
224
+ # Linear
225
+ #------------------------------------------------------
226
+ class LinearLayer(nn.Module):
227
+ """
228
+ Linear Layer.
229
+
230
+ Attributes:
231
+ in_features (int): Input dimension.
232
+ out_features (int): Output dimension.
233
+ use_bias (bool): If True, use bias.
234
+ bias_init (int): Bias init.
235
+ lr_multiplier (float): Learning rate multiplier.
236
+ activation (str): Activation function: 'relu', 'lrelu', etc.
237
+ param_dict (h5py.Group): Parameter dict with pretrained parameters.
238
+ layer_name (str): Layer name.
239
+ dtype (str): Data type.
240
+ rng (jax.random.PRNGKey): Random seed for initialization.
241
+ """
242
+ in_features: int
243
+ out_features: int
244
+ use_bias: bool=True
245
+ bias_init: int=0
246
+ lr_multiplier: float=1
247
+ activation: str='linear'
248
+ param_dict: h5py.Group=None
249
+ layer_name: str=None
250
+ dtype: str='float32'
251
+ rng: Any=random.PRNGKey(0)
252
+
253
+ @nn.compact
254
+ def __call__(self, x):
255
+ """
256
+ Run Linear Layer.
257
+
258
+ Args:
259
+ x (tensor): Input tensor of shape [N, in_features].
260
+
261
+ Returns:
262
+ (tensor): Output tensor of shape [N, out_features].
263
+ """
264
+ w_shape = [self.in_features, self.out_features]
265
+ params = get_weight(w_shape, self.lr_multiplier, self.use_bias, self.param_dict, self.layer_name, self.rng)
266
+
267
+ if self.use_bias:
268
+ w, b = params
269
+ else:
270
+ w = params
271
+
272
+ w = self.param(name='weight', init_fn=lambda *_ : w)
273
+ w = equalize_lr_weight(w, self.lr_multiplier)
274
+ x = jnp.matmul(x, w.astype(x.dtype))
275
+
276
+ if self.use_bias:
277
+ b = self.param(name='bias', init_fn=lambda *_ : b)
278
+ b = equalize_lr_bias(b, self.lr_multiplier)
279
+ x += b.astype(x.dtype)
280
+ x += self.bias_init
281
+
282
+ x = apply_activation(x, activation=self.activation)
283
+ return x
284
+
285
+
286
+ #------------------------------------------------------
287
+ # Convolution
288
+ #------------------------------------------------------
289
+ def conv_downsample_2d(x, w, k=None, factor=2, gain=1, padding=0):
290
+ """
291
+ Fused downsample convolution.
292
+
293
+ Padding is performed only once at the beginning, not between the operations.
294
+ The fused op is considerably more efficient than performing the same calculation
295
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
296
+
297
+ Args:
298
+ x (tensor): Input tensor of the shape [N, H, W, C].
299
+ w (tensor): Weight tensor of the shape [filterH, filterW, inChannels, outChannels].
300
+ Grouped convolution can be performed by inChannels = x.shape[0] // numGroups.
301
+ k (tensor): FIR filter of the shape [firH, firW] or [firN].
302
+ The default is `[1] * factor`, which corresponds to average pooling.
303
+ factor (int): Downsampling factor (default: 2).
304
+ gain (float): Scaling factor for signal magnitude (default: 1.0).
305
+ padding (int): Number of pixels to pad or crop the output on each side (default: 0).
306
+
307
+ Returns:
308
+ (tensor): Output of the shape [N, H // factor, W // factor, C].
309
+ """
310
+ assert isinstance(factor, int) and factor >= 1
311
+ assert isinstance(padding, int)
312
+
313
+ # Check weight shape.
314
+ ch, cw, _inC, _outC = w.shape
315
+ assert cw == ch
316
+
317
+ # Setup filter kernel.
318
+ k = setup_filter(k, gain=gain)
319
+ assert k.shape[0] == k.shape[1]
320
+
321
+ # Execute.
322
+ pad0 = (k.shape[0] - factor + cw) // 2 + padding * factor
323
+ pad1 = (k.shape[0] - factor + cw - 1) // 2 + padding * factor
324
+ x = upfirdn2d(x=x, f=k, padding=(pad0, pad0, pad1, pad1))
325
+
326
+ x = jax.lax.conv_general_dilated(x,
327
+ w,
328
+ window_strides=(factor, factor),
329
+ padding='VALID',
330
+ dimension_numbers=nn.linear._conv_dimension_numbers(x.shape))
331
+ return x
332
+
333
+
334
+ def upsample_conv_2d(x, w, k=None, factor=2, gain=1, padding=0):
335
+ """
336
+ Fused upsample convolution.
337
+
338
+ Padding is performed only once at the beginning, not between the operations.
339
+ The fused op is considerably more efficient than performing the same calculation
340
+ using standard TensorFlow ops. It supports gradients of arbitrary order.
341
+
342
+ Args:
343
+ x (tensor): Input tensor of the shape [N, H, W, C].
344
+ w (tensor): Weight tensor of the shape [filterH, filterW, inChannels, outChannels].
345
+ Grouped convolution can be performed by inChannels = x.shape[0] // numGroups.
346
+ k (tensor): FIR filter of the shape [firH, firW] or [firN].
347
+ The default is [1] * factor, which corresponds to nearest-neighbor upsampling.
348
+ factor (int): Integer upsampling factor (default: 2).
349
+ gain (float): Scaling factor for signal magnitude (default: 1.0).
350
+ padding (int): Number of pixels to pad or crop the output on each side (default: 0).
351
+
352
+ Returns:
353
+ (tensor): Output of the shape [N, H * factor, W * factor, C].
354
+ """
355
+ assert isinstance(factor, int) and factor >= 1
356
+ assert isinstance(padding, int)
357
+
358
+ # Check weight shape.
359
+ ch, cw, _inC, _outC = w.shape
360
+ inC = w.shape[2]
361
+ outC = w.shape[3]
362
+ assert cw == ch
363
+
364
+ # Fast path for 1x1 convolution.
365
+ if cw == 1 and ch == 1:
366
+ x = jax.lax.conv_general_dilated(x,
367
+ w,
368
+ window_strides=(1, 1),
369
+ padding='VALID',
370
+ dimension_numbers=nn.linear._conv_dimension_numbers(x.shape))
371
+ k = setup_filter(k, gain=gain * (factor ** 2))
372
+ pad0 = (k.shape[0] + factor - cw) // 2 + padding
373
+ pad1 = (k.shape[0] - factor) // 2 + padding
374
+ x = upfirdn2d(x, f=k, up=factor, padding=(pad0, pad1, pad0, pad1))
375
+ return x
376
+
377
+ # Setup filter kernel.
378
+ k = setup_filter(k, gain=gain * (factor ** 2))
379
+ assert k.shape[0] == k.shape[1]
380
+
381
+ # Determine data dimensions.
382
+ stride = (factor, factor)
383
+ output_shape = ((x.shape[1] - 1) * factor + ch, (x.shape[2] - 1) * factor + cw)
384
+ num_groups = x.shape[3] // inC
385
+
386
+ # Transpose weights.
387
+ w = jnp.reshape(w, (ch, cw, inC, num_groups, -1))
388
+ w = jnp.transpose(w[::-1, ::-1], (0, 1, 4, 3, 2))
389
+ w = jnp.reshape(w, (ch, cw, -1, num_groups * inC))
390
+
391
+ # Execute.
392
+ x = gradient_based_conv_transpose(lhs=x,
393
+ rhs=w,
394
+ strides=stride,
395
+ padding='VALID',
396
+ output_padding=(0, 0, 0, 0),
397
+ output_shape=output_shape,
398
+ )
399
+
400
+ pad0 = (k.shape[0] + factor - cw) // 2 + padding
401
+ pad1 = (k.shape[0] - factor - cw + 3) // 2 + padding
402
+ x = upfirdn2d(x=x, f=k, padding=(pad0, pad1, pad0, pad1))
403
+ return x
404
+
405
+
406
+ def conv2d(x, w, up=False, down=False, resample_kernel=None, padding=0):
407
+ assert not (up and down)
408
+ kernel = w.shape[0]
409
+ assert w.shape[1] == kernel
410
+ assert kernel >= 1 and kernel % 2 == 1
411
+
412
+ num_groups = x.shape[3] // w.shape[2]
413
+
414
+ w = w.astype(x.dtype)
415
+ if up:
416
+ x = upsample_conv_2d(x, w, k=resample_kernel, padding=padding)
417
+ elif down:
418
+ x = conv_downsample_2d(x, w, k=resample_kernel, padding=padding)
419
+ else:
420
+ padding_mode = {0: 'SAME', -(kernel // 2): 'VALID'}[padding]
421
+ x = jax.lax.conv_general_dilated(x,
422
+ w,
423
+ window_strides=(1, 1),
424
+ padding=padding_mode,
425
+ dimension_numbers=nn.linear._conv_dimension_numbers(x.shape),
426
+ feature_group_count=num_groups)
427
+ return x
428
+
429
+
430
+ def modulated_conv2d_layer(x, w, s, fmaps, kernel, up=False, down=False, demodulate=True, resample_kernel=None, fused_modconv=False):
431
+ assert not (up and down)
432
+ assert kernel >= 1 and kernel % 2 == 1
433
+
434
+ # Get weight.
435
+ wshape = (kernel, kernel, x.shape[3], fmaps)
436
+ if x.dtype.name == 'float16' and not fused_modconv and demodulate:
437
+ w *= jnp.sqrt(1 / np.prod(wshape[:-1])) / jnp.max(jnp.abs(w), axis=(0, 1, 2)) # Pre-normalize to avoid float16 overflow.
438
+ ww = w[jnp.newaxis] # [BkkIO] Introduce minibatch dimension.
439
+
440
+ # Modulate.
441
+ if x.dtype.name == 'float16' and not fused_modconv and demodulate:
442
+ s *= 1 / jnp.max(jnp.abs(s)) # Pre-normalize to avoid float16 overflow.
443
+ ww *= s[:, jnp.newaxis, jnp.newaxis, :, jnp.newaxis].astype(w.dtype) # [BkkIO] Scale input feature maps.
444
+
445
+ # Demodulate.
446
+ if demodulate:
447
+ d = jax.lax.rsqrt(jnp.sum(jnp.square(ww), axis=(1, 2, 3)) + 1e-8) # [BO] Scaling factor.
448
+ ww *= d[:, jnp.newaxis, jnp.newaxis, jnp.newaxis, :] # [BkkIO] Scale output feature maps.
449
+
450
+ # Reshape/scale input.
451
+ if fused_modconv:
452
+ x = jnp.transpose(x, axes=(0, 3, 1, 2))
453
+ x = jnp.reshape(x, (1, -1, x.shape[2], x.shape[3])) # Fused => reshape minibatch to convolution groups.
454
+ x = jnp.transpose(x, axes=(0, 2, 3, 1))
455
+ w = jnp.reshape(jnp.transpose(ww, (1, 2, 3, 0, 4)), (ww.shape[1], ww.shape[2], ww.shape[3], -1))
456
+ else:
457
+ x *= s[:, jnp.newaxis, jnp.newaxis].astype(x.dtype) # [BIhw] Not fused => scale input activations.
458
+
459
+ # 2D convolution.
460
+ x = conv2d(x, w.astype(x.dtype), up=up, down=down, resample_kernel=resample_kernel)
461
+
462
+ # Reshape/scale output.
463
+ if fused_modconv:
464
+ x = jnp.transpose(x, axes=(0, 3, 1, 2))
465
+ x = jnp.reshape(x, (-1, fmaps, x.shape[2], x.shape[3])) # Fused => reshape convolution groups back to minibatch.
466
+ x = jnp.transpose(x, axes=(0, 2, 3, 1))
467
+ elif demodulate:
468
+ x *= d[:, jnp.newaxis, jnp.newaxis].astype(x.dtype) # [BOhw] Not fused => scale output activations.
469
+
470
+ return x
471
+
472
+
473
+ def _deconv_output_length(input_length, filter_size, padding, output_padding=None, stride=0, dilation=1):
474
+ """
475
+ Taken from: https://github.com/google/jax/pull/5772/commits
476
+
477
+ Determines the output length of a transposed convolution given the input length.
478
+ Function modified from Keras.
479
+ Arguments:
480
+ input_length: Integer.
481
+ filter_size: Integer.
482
+ padding: one of `"SAME"`, `"VALID"`, or a 2-integer tuple.
483
+ output_padding: Integer, amount of padding along the output dimension. Can
484
+ be set to `None` in which case the output length is inferred.
485
+ stride: Integer.
486
+ dilation: Integer.
487
+ Returns:
488
+ The output length (integer).
489
+ """
490
+ if input_length is None:
491
+ return None
492
+
493
+ # Get the dilated kernel size
494
+ filter_size = filter_size + (filter_size - 1) * (dilation - 1)
495
+
496
+ # Infer length if output padding is None, else compute the exact length
497
+ if output_padding is None:
498
+ if padding == 'VALID':
499
+ length = input_length * stride + max(filter_size - stride, 0)
500
+ elif padding == 'SAME':
501
+ length = input_length * stride
502
+ else:
503
+ length = ((input_length - 1) * stride + filter_size - padding[0] - padding[1])
504
+
505
+ else:
506
+ if padding == 'SAME':
507
+ pad = filter_size // 2
508
+ total_pad = pad * 2
509
+ elif padding == 'VALID':
510
+ total_pad = 0
511
+ else:
512
+ total_pad = padding[0] + padding[1]
513
+
514
+ length = ((input_length - 1) * stride + filter_size - total_pad + output_padding)
515
+ return length
516
+
517
+
518
+ def _compute_adjusted_padding(input_size, output_size, kernel_size, stride, padding, dilation=1):
519
+ """
520
+ Taken from: https://github.com/google/jax/pull/5772/commits
521
+
522
+ Computes adjusted padding for desired ConvTranspose `output_size`.
523
+ Ported from DeepMind Haiku.
524
+ """
525
+ kernel_size = (kernel_size - 1) * dilation + 1
526
+ if padding == 'VALID':
527
+ expected_input_size = (output_size - kernel_size + stride) // stride
528
+ if input_size != expected_input_size:
529
+ raise ValueError(f'The expected input size with the current set of input '
530
+ f'parameters is {expected_input_size} which doesn\'t '
531
+ f'match the actual input size {input_size}.')
532
+ padding_before = 0
533
+ elif padding == 'SAME':
534
+ expected_input_size = (output_size + stride - 1) // stride
535
+ if input_size != expected_input_size:
536
+ raise ValueError(f'The expected input size with the current set of input '
537
+ f'parameters is {expected_input_size} which doesn\'t '
538
+ f'match the actual input size {input_size}.')
539
+ padding_needed = max(0, (input_size - 1) * stride + kernel_size - output_size)
540
+ padding_before = padding_needed // 2
541
+ else:
542
+ padding_before = padding[0] # type: ignore[assignment]
543
+
544
+ expanded_input_size = (input_size - 1) * stride + 1
545
+ padded_out_size = output_size + kernel_size - 1
546
+ pad_before = kernel_size - 1 - padding_before
547
+ pad_after = padded_out_size - expanded_input_size - pad_before
548
+ return (pad_before, pad_after)
549
+
550
+
551
+ def _flip_axes(x, axes):
552
+ """
553
+ Taken from: https://github.com/google/jax/blob/master/jax/_src/lax/lax.py
554
+
555
+ Flip ndarray 'x' along each axis specified in axes tuple.
556
+ """
557
+ for axis in axes:
558
+ x = jnp.flip(x, axis)
559
+ return x
560
+
561
+
562
+ def gradient_based_conv_transpose(lhs,
563
+ rhs,
564
+ strides,
565
+ padding,
566
+ output_padding,
567
+ output_shape=None,
568
+ dilation=None,
569
+ dimension_numbers=None,
570
+ transpose_kernel=True,
571
+ feature_group_count=1,
572
+ precision=None):
573
+ """
574
+ Taken from: https://github.com/google/jax/pull/5772/commits
575
+
576
+ Convenience wrapper for calculating the N-d transposed convolution.
577
+ Much like `conv_transpose`, this function calculates transposed convolutions
578
+ via fractionally strided convolution rather than calculating the gradient
579
+ (transpose) of a forward convolution. However, the latter is more common
580
+ among deep learning frameworks, such as TensorFlow, PyTorch, and Keras.
581
+ This function provides the same set of APIs to help reproduce results in these frameworks.
582
+ Args:
583
+ lhs: a rank `n+2` dimensional input array.
584
+ rhs: a rank `n+2` dimensional array of kernel weights.
585
+ strides: sequence of `n` integers, amounts to strides of the corresponding forward convolution.
586
+ padding: `"SAME"`, `"VALID"`, or a sequence of `n` integer 2-tuples that controls
587
+ the before-and-after padding for each `n` spatial dimension of
588
+ the corresponding forward convolution.
589
+ output_padding: A sequence of integers specifying the amount of padding along
590
+ each spacial dimension of the output tensor, used to disambiguate the output shape of
591
+ transposed convolutions when the stride is larger than 1.
592
+ (see a detailed description at https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html)
593
+ The amount of output padding along a given dimension must
594
+ be lower than the stride along that same dimension.
595
+ If set to `None` (default), the output shape is inferred.
596
+ If both `output_padding` and `output_shape` are specified, they have to be mutually compatible.
597
+ output_shape: Output shape of the spatial dimensions of a transpose
598
+ convolution. Can be `None` or an iterable of `n` integers. If a `None` value is given (default),
599
+ the shape is automatically calculated.
600
+ Similar to `output_padding`, `output_shape` is also for disambiguating the output shape
601
+ when stride > 1 (see also
602
+ https://www.tensorflow.org/api_docs/python/tf/nn/conv2d_transpose)
603
+ If both `output_padding` and `output_shape` are specified, they have to be mutually compatible.
604
+ dilation: `None`, or a sequence of `n` integers, giving the
605
+ dilation factor to apply in each spatial dimension of `rhs`. Dilated convolution
606
+ is also known as atrous convolution.
607
+ dimension_numbers: tuple of dimension descriptors as in lax.conv_general_dilated. Defaults to tensorflow convention.
608
+ transpose_kernel: if `True` flips spatial axes and swaps the input/output
609
+ channel axes of the kernel. This makes the output of this function identical
610
+ to the gradient-derived functions like keras.layers.Conv2DTranspose and
611
+ torch.nn.ConvTranspose2d applied to the same kernel.
612
+ Although for typical use in neural nets this is unnecessary
613
+ and makes input/output channel specification confusing, you need to set this to `True`
614
+ in order to match the behavior in many deep learning frameworks, such as TensorFlow, Keras, and PyTorch.
615
+ precision: Optional. Either ``None``, which means the default precision for
616
+ the backend, a ``lax.Precision`` enum value (``Precision.DEFAULT``,
617
+ ``Precision.HIGH`` or ``Precision.HIGHEST``) or a tuple of two
618
+ ``lax.Precision`` enums indicating precision of ``lhs``` and ``rhs``.
619
+ Returns:
620
+ Transposed N-d convolution.
621
+ """
622
+ assert len(lhs.shape) == len(rhs.shape) and len(lhs.shape) >= 2
623
+ ndims = len(lhs.shape)
624
+ one = (1,) * (ndims - 2)
625
+ # Set dimensional layout defaults if not specified.
626
+ if dimension_numbers is None:
627
+ if ndims == 2:
628
+ dimension_numbers = ('NC', 'IO', 'NC')
629
+ elif ndims == 3:
630
+ dimension_numbers = ('NHC', 'HIO', 'NHC')
631
+ elif ndims == 4:
632
+ dimension_numbers = ('NHWC', 'HWIO', 'NHWC')
633
+ elif ndims == 5:
634
+ dimension_numbers = ('NHWDC', 'HWDIO', 'NHWDC')
635
+ else:
636
+ raise ValueError('No 4+ dimensional dimension_number defaults.')
637
+ dn = jax.lax.conv_dimension_numbers(lhs.shape, rhs.shape, dimension_numbers)
638
+ k_shape = np.take(rhs.shape, dn.rhs_spec)
639
+ k_sdims = k_shape[2:] # type: ignore[index]
640
+ i_shape = np.take(lhs.shape, dn.lhs_spec)
641
+ i_sdims = i_shape[2:] # type: ignore[index]
642
+
643
+ # Calculate correct output shape given padding and strides.
644
+ if dilation is None:
645
+ dilation = (1,) * (rhs.ndim - 2)
646
+
647
+ if output_padding is None:
648
+ output_padding = [None] * (rhs.ndim - 2) # type: ignore[list-item]
649
+
650
+ if isinstance(padding, str):
651
+ if padding in {'SAME', 'VALID'}:
652
+ padding = [padding] * (rhs.ndim - 2) # type: ignore[list-item]
653
+ else:
654
+ raise ValueError(f"`padding` must be 'VALID' or 'SAME'. Passed: {padding}.")
655
+
656
+ inferred_output_shape = tuple(map(_deconv_output_length, i_sdims, k_sdims, padding, output_padding, strides, dilation))
657
+
658
+ if output_shape is None:
659
+ output_shape = inferred_output_shape # type: ignore[assignment]
660
+ else:
661
+ if not output_shape == inferred_output_shape:
662
+ raise ValueError(f'`output_padding` and `output_shape` are not compatible.'
663
+ f'Inferred output shape from `output_padding`: {inferred_output_shape}, '
664
+ f'but got `output_shape` {output_shape}')
665
+
666
+ pads = tuple(map(_compute_adjusted_padding, i_sdims, output_shape, k_sdims, strides, padding, dilation))
667
+
668
+ if transpose_kernel:
669
+ # flip spatial dims and swap input / output channel axes
670
+ rhs = _flip_axes(rhs, np.array(dn.rhs_spec)[2:])
671
+ rhs = np.swapaxes(rhs, dn.rhs_spec[0], dn.rhs_spec[1])
672
+ return jax.lax.conv_general_dilated(lhs, rhs, one, pads, strides, dilation, dn, feature_group_count, precision=precision)
673
+
674
+
stylegan2/utils.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import requests
3
+ import os
4
+ import tempfile
5
+
6
+
7
+ def download(ckpt_dir, url):
8
+ name = url[url.rfind('/') + 1 : url.rfind('?')]
9
+ if ckpt_dir is None:
10
+ ckpt_dir = tempfile.gettempdir()
11
+ ckpt_dir = os.path.join(ckpt_dir, 'flaxmodels')
12
+ ckpt_file = os.path.join(ckpt_dir, name)
13
+ if not os.path.exists(ckpt_file):
14
+ print(f'Downloading: \"{url[:url.rfind("?")]}\" to {ckpt_file}')
15
+ if not os.path.exists(ckpt_dir):
16
+ os.makedirs(ckpt_dir)
17
+
18
+ response = requests.get(url, stream=True)
19
+ total_size_in_bytes = int(response.headers.get('content-length', 0))
20
+ progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
21
+
22
+ # first create temp file, in case the download fails
23
+ ckpt_file_temp = os.path.join(ckpt_dir, name + '.temp')
24
+ with open(ckpt_file_temp, 'wb') as file:
25
+ for data in response.iter_content(chunk_size=1024):
26
+ progress_bar.update(len(data))
27
+ file.write(data)
28
+ progress_bar.close()
29
+
30
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
31
+ print('An error occured while downloading, please try again.')
32
+ if os.path.exists(ckpt_file_temp):
33
+ os.remove(ckpt_file_temp)
34
+ else:
35
+ # if download was successful, rename the temp file
36
+ os.rename(ckpt_file_temp, ckpt_file)
37
+ return ckpt_file
training.py ADDED
@@ -0,0 +1,382 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import flax
4
+ from flax.optim import dynamic_scale as dynamic_scale_lib
5
+ from flax.core import frozen_dict
6
+ import optax
7
+ import numpy as np
8
+ import functools
9
+ import wandb
10
+ import time
11
+
12
+ import stylegan2
13
+ import data_pipeline
14
+ import checkpoint
15
+ import training_utils
16
+ import training_steps
17
+ from fid import FID
18
+
19
+ import logging
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def tree_shape(item):
25
+ return jax.tree_map(lambda c: c.shape, item)
26
+
27
+
28
+ def train_and_evaluate(config):
29
+ num_devices = jax.device_count() # 8
30
+ num_local_devices = jax.local_device_count() # 4
31
+ num_workers = jax.process_count()
32
+
33
+ # --------------------------------------
34
+ # Data
35
+ # --------------------------------------
36
+ ds_train, dataset_info = data_pipeline.get_data(data_dir=config.data_dir,
37
+ img_size=config.resolution,
38
+ img_channels=config.img_channels,
39
+ num_classes=config.c_dim,
40
+ num_local_devices=num_local_devices,
41
+ batch_size=config.batch_size)
42
+
43
+ # --------------------------------------
44
+ # Seeding and Precision
45
+ # --------------------------------------
46
+ rng = jax.random.PRNGKey(config.random_seed)
47
+
48
+ if config.mixed_precision:
49
+ dtype = jnp.float16
50
+ elif config.bf16:
51
+ dtype = jnp.bfloat16
52
+ else:
53
+ dtype = jnp.float32
54
+ logger.info(f'Running on dtype {dtype}')
55
+
56
+ platform = jax.local_devices()[0].platform
57
+ if config.mixed_precision and platform == 'gpu':
58
+ dynamic_scale_G_main = dynamic_scale_lib.DynamicScale()
59
+ dynamic_scale_D_main = dynamic_scale_lib.DynamicScale()
60
+ dynamic_scale_G_reg = dynamic_scale_lib.DynamicScale()
61
+ dynamic_scale_D_reg = dynamic_scale_lib.DynamicScale()
62
+ clip_conv = 256
63
+ num_fp16_res = 4
64
+ else:
65
+ dynamic_scale_G_main = None
66
+ dynamic_scale_D_main = None
67
+ dynamic_scale_G_reg = None
68
+ dynamic_scale_D_reg = None
69
+ clip_conv = None
70
+ num_fp16_res = 0
71
+
72
+ # --------------------------------------
73
+ # Initialize Models
74
+ # --------------------------------------
75
+ logger.info('Initialize models...')
76
+
77
+ rng, init_rng = jax.random.split(rng)
78
+
79
+ # Generator initialization for training
80
+ start_mn = time.time()
81
+ logger.info("Creating MappingNetwork...")
82
+ mapping_net = stylegan2.MappingNetwork(z_dim=config.z_dim,
83
+ c_dim=config.c_dim,
84
+ w_dim=config.w_dim,
85
+ num_ws=int(np.log2(config.resolution)) * 2 - 3,
86
+ num_layers=8,
87
+ dtype=dtype)
88
+
89
+ mapping_net_vars = mapping_net.init(init_rng,
90
+ jnp.ones((1, config.z_dim)),
91
+ jnp.ones((1, config.c_dim)))
92
+
93
+ mapping_net_params, moving_stats = mapping_net_vars['params'], mapping_net_vars['moving_stats']
94
+
95
+ logger.info(f"MappingNetwork took {time.time() - start_mn:.2f}s")
96
+
97
+ logger.info("Creating SynthesisNetwork...")
98
+ start_sn = time.time()
99
+ synthesis_net = stylegan2.SynthesisNetwork(resolution=config.resolution,
100
+ num_channels=config.img_channels,
101
+ w_dim=config.w_dim,
102
+ fmap_base=config.fmap_base,
103
+ num_fp16_res=num_fp16_res,
104
+ clip_conv=clip_conv,
105
+ dtype=dtype)
106
+
107
+ synthesis_net_vars = synthesis_net.init(init_rng,
108
+ jnp.ones((1, mapping_net.num_ws, config.w_dim)))
109
+ synthesis_net_params, noise_consts = synthesis_net_vars['params'], synthesis_net_vars['noise_consts']
110
+
111
+ logger.info(f"SynthesisNetwork took {time.time() - start_sn:.2f}s")
112
+
113
+ params_G = frozen_dict.FrozenDict(
114
+ {'mapping': mapping_net_params,
115
+ 'synthesis': synthesis_net_params}
116
+ )
117
+
118
+ # Discriminator initialization for training
119
+ logger.info("Creating Discriminator...")
120
+ start_d = time.time()
121
+ discriminator = stylegan2.Discriminator(resolution=config.resolution,
122
+ num_channels=config.img_channels,
123
+ c_dim=config.c_dim,
124
+ mbstd_group_size=config.mbstd_group_size,
125
+ num_fp16_res=num_fp16_res,
126
+ clip_conv=clip_conv,
127
+ dtype=dtype)
128
+ rng, init_rng = jax.random.split(rng)
129
+ params_D = discriminator.init(init_rng,
130
+ jnp.ones((1, config.resolution, config.resolution, config.img_channels)),
131
+ jnp.ones((1, config.c_dim)))
132
+ logger.info(f"Discriminator took {time.time() - start_d:.2f}s")
133
+
134
+ # Exponential average Generator initialization
135
+ logger.info("Creating Generator EMA...")
136
+ start_g = time.time()
137
+ generator_ema = stylegan2.Generator(resolution=config.resolution,
138
+ num_channels=config.img_channels,
139
+ z_dim=config.z_dim,
140
+ c_dim=config.c_dim,
141
+ w_dim=config.w_dim,
142
+ num_ws=int(np.log2(config.resolution)) * 2 - 3,
143
+ num_mapping_layers=8,
144
+ fmap_base=config.fmap_base,
145
+ num_fp16_res=num_fp16_res,
146
+ clip_conv=clip_conv,
147
+ dtype=dtype)
148
+
149
+ params_ema_G = generator_ema.init(init_rng,
150
+ jnp.ones((1, config.z_dim)),
151
+ jnp.ones((1, config.c_dim)))
152
+ logger.info(f"Took {time.time() - start_g:.2f}s")
153
+
154
+ # --------------------------------------
155
+ # Initialize States and Optimizers
156
+ # --------------------------------------
157
+ logger.info('Initialize states...')
158
+ tx_G = optax.adam(learning_rate=config.learning_rate, b1=0.0, b2=0.99)
159
+ tx_D = optax.adam(learning_rate=config.learning_rate, b1=0.0, b2=0.99)
160
+
161
+ state_G = training_utils.TrainStateG.create(apply_fn=None,
162
+ apply_mapping=mapping_net.apply,
163
+ apply_synthesis=synthesis_net.apply,
164
+ params=params_G,
165
+ moving_stats=moving_stats,
166
+ noise_consts=noise_consts,
167
+ tx=tx_G,
168
+ dynamic_scale_main=dynamic_scale_G_main,
169
+ dynamic_scale_reg=dynamic_scale_G_reg,
170
+ epoch=0)
171
+
172
+ state_D = training_utils.TrainStateD.create(apply_fn=discriminator.apply,
173
+ params=params_D,
174
+ tx=tx_D,
175
+ dynamic_scale_main=dynamic_scale_D_main,
176
+ dynamic_scale_reg=dynamic_scale_D_reg,
177
+ epoch=0)
178
+
179
+ # Copy over the parameters from the training generator to the ema generator
180
+ params_ema_G = training_utils.update_generator_ema(state_G, params_ema_G, config, ema_beta=0)
181
+
182
+ # Running mean of path length for path length regularization
183
+ pl_mean = jnp.zeros((), dtype=dtype)
184
+
185
+ step = 0
186
+ epoch_offset = 0
187
+ best_fid_score = np.inf
188
+ ckpt_path = None
189
+
190
+ if config.resume_run_id is not None:
191
+ # Resume training from existing checkpoint
192
+ ckpt_path = checkpoint.get_latest_checkpoint(config.ckpt_dir)
193
+ logger.info(f'Resume training from checkpoint: {ckpt_path}')
194
+ ckpt = checkpoint.load_checkpoint(ckpt_path)
195
+ step = ckpt['step']
196
+ epoch_offset = ckpt['epoch']
197
+ best_fid_score = ckpt['fid_score']
198
+ pl_mean = ckpt['pl_mean']
199
+ state_G = ckpt['state_G']
200
+ state_D = ckpt['state_D']
201
+ params_ema_G = ckpt['params_ema_G']
202
+ config = ckpt['config']
203
+ elif config.load_from_pkl is not None:
204
+ # Load checkpoint and start new run
205
+ ckpt_path = config.load_from_pkl
206
+ logger.info(f'Load model state from from : {ckpt_path}')
207
+ ckpt = checkpoint.load_checkpoint(ckpt_path)
208
+ pl_mean = ckpt['pl_mean']
209
+ state_G = ckpt['state_G']
210
+ state_D = ckpt['state_D']
211
+ params_ema_G = ckpt['params_ema_G']
212
+
213
+ # Replicate states across devices
214
+ pl_mean = flax.jax_utils.replicate(pl_mean)
215
+ state_G = flax.jax_utils.replicate(state_G)
216
+ state_D = flax.jax_utils.replicate(state_D)
217
+
218
+ # --------------------------------------
219
+ # Precompile train and eval steps
220
+ # --------------------------------------
221
+ logger.info('Precompile training steps...')
222
+ p_main_step_G = jax.pmap(training_steps.main_step_G, axis_name='batch')
223
+ p_regul_step_G = jax.pmap(functools.partial(training_steps.regul_step_G, config=config), axis_name='batch')
224
+
225
+ p_main_step_D = jax.pmap(training_steps.main_step_D, axis_name='batch')
226
+ p_regul_step_D = jax.pmap(functools.partial(training_steps.regul_step_D, config=config), axis_name='batch')
227
+
228
+ # --------------------------------------
229
+ # Training
230
+ # --------------------------------------
231
+ logger.info('Start training...')
232
+ fid_metric = FID(generator_ema, ds_train, config)
233
+
234
+ # Dict to collect training statistics / losses
235
+ metrics = {}
236
+ num_imgs_processed = 0
237
+ num_steps_per_epoch = dataset_info['num_examples'] // (config.batch_size * num_devices)
238
+ effective_batch_size = config.batch_size * num_devices
239
+ if config.wandb and jax.process_index() == 0:
240
+ # do some more logging
241
+ wandb.config.effective_batch_size = effective_batch_size
242
+ wandb.config.num_steps_per_epoch = num_steps_per_epoch
243
+ wandb.config.num_workers = num_workers
244
+ wandb.config.device_count = num_devices
245
+ wandb.config.num_examples = dataset_info['num_examples']
246
+ wandb.config.vm_name = training_utils.get_vm_name()
247
+
248
+ for epoch in range(epoch_offset, config.num_epochs):
249
+ if config.wandb and jax.process_index() == 0:
250
+ wandb.log({'training/epochs': epoch}, step=step)
251
+
252
+ for batch in data_pipeline.prefetch(ds_train, config.num_prefetch):
253
+ assert batch['image'].shape[1] == config.batch_size, f"Mismatched batch (batch size: {config.batch_size}, this batch: {batch['image'].shape[1]})"
254
+
255
+ # pbar.update(num_devices * config.batch_size)
256
+ iteration_start_time = time.time()
257
+
258
+ if config.c_dim == 0:
259
+ # No labels in the dataset
260
+ batch['label'] = None
261
+
262
+ # Create two latent noise vectors and combine them for the style mixing regularization
263
+ rng, key = jax.random.split(rng)
264
+ z_latent1 = jax.random.normal(key, (num_local_devices, config.batch_size, config.z_dim), dtype)
265
+ rng, key = jax.random.split(rng)
266
+ z_latent2 = jax.random.normal(key, (num_local_devices, config.batch_size, config.z_dim), dtype)
267
+
268
+ # Split PRNGs across devices
269
+ rkey = jax.random.split(key, num=num_local_devices)
270
+ mixing_prob = flax.jax_utils.replicate(config.mixing_prob)
271
+
272
+ # --------------------------------------
273
+ # Update Discriminator
274
+ # --------------------------------------
275
+ time_d_start = time.time()
276
+ state_D, metrics = p_main_step_D(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rkey)
277
+ time_d_end = time.time()
278
+ if step % config.D_reg_interval == 0:
279
+ state_D, metrics = p_regul_step_D(state_D, batch, metrics)
280
+
281
+ # --------------------------------------
282
+ # Update Generator
283
+ # --------------------------------------
284
+ time_g_start = time.time()
285
+ state_G, metrics = p_main_step_G(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rkey)
286
+ if step % config.G_reg_interval == 0:
287
+ H, W = batch['image'].shape[-3], batch['image'].shape[-2]
288
+ rng, key = jax.random.split(rng)
289
+ pl_noise = jax.random.normal(key, batch['image'].shape, dtype=dtype) / np.sqrt(H * W)
290
+ state_G, metrics, pl_mean = p_regul_step_G(state_G, batch, z_latent1, pl_noise, pl_mean, metrics,
291
+ rng=rkey)
292
+
293
+ params_ema_G = training_utils.update_generator_ema(flax.jax_utils.unreplicate(state_G),
294
+ params_ema_G,
295
+ config)
296
+ time_g_end = time.time()
297
+
298
+ # --------------------------------------
299
+ # Logging and Checkpointing
300
+ # --------------------------------------
301
+ if step % config.save_every == 0 and config.disable_fid:
302
+ # If FID evaluation is disabled, a checkpoint will be saved every 'save_every' steps.
303
+ if jax.process_index() == 0:
304
+ logger.info('Saving checkpoint...')
305
+ checkpoint.save_checkpoint(config.ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step,
306
+ epoch)
307
+
308
+ num_imgs_processed += num_devices * config.batch_size
309
+ if step % config.eval_fid_every == 0 and not config.disable_fid:
310
+ # If FID evaluation is enabled, only save a checkpoint if FID score is better.
311
+ if jax.process_index() == 0:
312
+ logger.info('Computing FID...')
313
+ fid_score = fid_metric.compute_fid(params_ema_G).item()
314
+ if config.wandb:
315
+ wandb.log({'training/gen/fid': fid_score}, step=step)
316
+ logger.info(f'Computed FID: {fid_score:.2f}')
317
+ if fid_score < best_fid_score:
318
+ best_fid_score = fid_score
319
+ logger.info(f'New best FID score ({best_fid_score:.3f}). Saving checkpoint...')
320
+ ts = time.time()
321
+ checkpoint.save_checkpoint(config.ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=fid_score)
322
+ te = time.time()
323
+ logger.info(f'... successfully saved checkpoint in {(te-ts)/60:.1f}min')
324
+
325
+ sec_per_kimg = (time.time() - iteration_start_time) / (num_devices * config.batch_size / 1000.0)
326
+ time_taken_g = time_g_end - time_g_start
327
+ time_taken_d = time_d_end - time_d_start
328
+ time_taken_per_step = time.time() - iteration_start_time
329
+ g_loss = jnp.mean(metrics['G_loss']).item()
330
+ d_loss = jnp.mean(metrics['D_loss']).item()
331
+
332
+ if config.wandb and jax.process_index() == 0:
333
+ # wandb logging - happens every step
334
+ wandb.log({'training/gen/loss': jnp.mean(metrics['G_loss']).item()}, step=step, commit=False)
335
+ wandb.log({'training/dis/loss': jnp.mean(metrics['D_loss']).item()}, step=step, commit=False)
336
+ wandb.log({'training/dis/fake_logits': jnp.mean(metrics['fake_logits']).item()}, step=step, commit=False)
337
+ wandb.log({'training/dis/real_logits': jnp.mean(metrics['real_logits']).item()}, step=step, commit=False)
338
+ wandb.log({'training/time_taken_g': time_taken_g, 'training/time_taken_d': time_taken_d}, step=step, commit=False)
339
+ wandb.log({'training/time_taken_per_step': time_taken_per_step}, step=step, commit=False)
340
+ wandb.log({'training/num_imgs_trained': num_imgs_processed}, step=step, commit=False)
341
+ wandb.log({'training/sec_per_kimg': sec_per_kimg}, step=step)
342
+
343
+ if step % config.log_every == 0:
344
+ # console logging - happens every log_every steps
345
+ logger.info(f'Total steps: {step:>6,} - epoch {epoch:>3,}/{config.num_epochs} @ {step % num_steps_per_epoch:>6,}/{num_steps_per_epoch:,} - G loss: {g_loss:.5f} - D loss: {d_loss:.5f} - sec/kimg: {sec_per_kimg:.2f}s - time per step: {time_taken_per_step:.3f}s')
346
+
347
+ if step % config.generate_samples_every == 0 and config.wandb and jax.process_index() == 0:
348
+ # Generate training images
349
+ train_snapshot = training_utils.get_training_snapshot(
350
+ image_real=flax.jax_utils.unreplicate(batch['image']),
351
+ image_gen=flax.jax_utils.unreplicate(metrics['image_gen']),
352
+ max_num=10
353
+ )
354
+ wandb.log({'training/snapshot': wandb.Image(train_snapshot)}, commit=False, step=step)
355
+
356
+ # Generate evaluation images
357
+ labels = None if config.c_dim == 0 else batch['label'][0]
358
+ image_gen_eval = training_steps.eval_step_G(
359
+ generator_ema, params=params_ema_G,
360
+ z_latent=z_latent1[0],
361
+ labels=labels,
362
+ truncation=1
363
+ )
364
+ image_gen_eval_trunc = training_steps.eval_step_G(
365
+ generator_ema,
366
+ params=params_ema_G,
367
+ z_latent=z_latent1[0],
368
+ labels=labels,
369
+ truncation=0.5
370
+ )
371
+ eval_snapshot = training_utils.get_eval_snapshot(image=image_gen_eval, max_num=10)
372
+ eval_snapshot_trunc = training_utils.get_eval_snapshot(image=image_gen_eval_trunc, max_num=10)
373
+ wandb.log({'eval/snapshot': wandb.Image(eval_snapshot)}, commit=False, step=step)
374
+ wandb.log({'eval/snapshot_trunc': wandb.Image(eval_snapshot_trunc)}, step=step)
375
+
376
+ step += 1
377
+
378
+ # Sync moving stats across devices
379
+ state_G = training_utils.sync_moving_stats(state_G)
380
+
381
+ # Sync moving average of path length mean (Generator regularization)
382
+ pl_mean = jax.pmap(lambda x: jax.lax.pmean(x, axis_name='batch'), axis_name='batch')(pl_mean)
training_steps.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import functools
4
+
5
+
6
+ def main_step_G(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rng):
7
+
8
+ def loss_fn(params):
9
+ w_latent1, new_state_G = state_G.apply_mapping({'params': params['mapping'], 'moving_stats': state_G.moving_stats},
10
+ z_latent1,
11
+ batch['label'],
12
+ mutable=['moving_stats'])
13
+ w_latent2 = state_G.apply_mapping({'params': params['mapping'], 'moving_stats': state_G.moving_stats},
14
+ z_latent2,
15
+ batch['label'],
16
+ skip_w_avg_update=True)
17
+
18
+ # style mixing
19
+ cutoff_rng, layer_select_rng, synth_rng = jax.random.split(rng, num=3)
20
+ num_layers = w_latent1.shape[1]
21
+ layer_idx = jnp.arange(num_layers)[jnp.newaxis, :, jnp.newaxis]
22
+ mixing_cutoff = jax.lax.cond(jax.random.uniform(cutoff_rng, (), minval=0.0, maxval=1.0) < mixing_prob,
23
+ lambda _: jax.random.randint(layer_select_rng, (), 1, num_layers, dtype=jnp.int32),
24
+ lambda _: num_layers,
25
+ operand=None)
26
+ mixing_cond = jnp.broadcast_to(layer_idx < mixing_cutoff, w_latent1.shape)
27
+ w_latent = jnp.where(mixing_cond, w_latent1, w_latent2)
28
+
29
+ image_gen = state_G.apply_synthesis({'params': params['synthesis'], 'noise_consts': state_G.noise_consts},
30
+ w_latent,
31
+ rng=synth_rng)
32
+
33
+ fake_logits = state_D.apply_fn(state_D.params, image_gen, batch['label'])
34
+ loss = jnp.mean(jax.nn.softplus(-fake_logits))
35
+ return loss, (fake_logits, image_gen, new_state_G)
36
+
37
+ dynamic_scale = state_G.dynamic_scale_main
38
+
39
+ if dynamic_scale:
40
+ grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True, axis_name='batch')
41
+ dynamic_scale, is_fin, aux, grads = grad_fn(state_G.params)
42
+ else:
43
+ grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
44
+ aux, grads = grad_fn(state_G.params)
45
+ grads = jax.lax.pmean(grads, axis_name='batch')
46
+
47
+ loss = aux[0]
48
+ _, image_gen, new_state = aux[1]
49
+ metrics['G_loss'] = loss
50
+ metrics['image_gen'] = image_gen
51
+
52
+ new_state_G = state_G.apply_gradients(grads=grads, moving_stats=new_state['moving_stats'])
53
+
54
+ if dynamic_scale:
55
+ new_state_G = new_state_G.replace(opt_state=jax.tree_multimap(functools.partial(jnp.where, is_fin),
56
+ new_state_G.opt_state,
57
+ state_G.opt_state),
58
+ params=jax.tree_multimap(functools.partial(jnp.where, is_fin),
59
+ new_state_G.params,
60
+ state_G.params))
61
+ metrics['G_scale'] = dynamic_scale.scale
62
+
63
+ return new_state_G, metrics
64
+
65
+
66
+ def regul_step_G(state_G, batch, z_latent, pl_noise, pl_mean, metrics, config, rng):
67
+
68
+ def loss_fn(params):
69
+ w_latent, new_state_G = state_G.apply_mapping({'params': params['mapping'], 'moving_stats': state_G.moving_stats},
70
+ z_latent,
71
+ batch['label'],
72
+ mutable=['moving_stats'])
73
+
74
+ pl_grads = jax.grad(lambda *args: jnp.sum(state_G.apply_synthesis(*args) * pl_noise), argnums=1)({'params': params['synthesis'],
75
+ 'noise_consts': state_G.noise_consts},
76
+ w_latent,
77
+ 'random',
78
+ rng)
79
+ pl_lengths = jnp.sqrt(jnp.mean(jnp.sum(jnp.square(pl_grads), axis=2), axis=1))
80
+ pl_mean_new = pl_mean + config.pl_decay * (jnp.mean(pl_lengths) - pl_mean)
81
+ pl_penalty = jnp.square(pl_lengths - pl_mean_new) * config.pl_weight
82
+ loss = jnp.mean(pl_penalty) * config.G_reg_interval
83
+
84
+ return loss, pl_mean_new
85
+
86
+ dynamic_scale = state_G.dynamic_scale_reg
87
+
88
+ if dynamic_scale:
89
+ grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True)
90
+ dynamic_scale, is_fin, aux, grads = grad_fn(state_G.params)
91
+ else:
92
+ grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
93
+ aux, grads = grad_fn(state_G.params)
94
+ grads = jax.lax.pmean(grads, axis_name='batch')
95
+
96
+ loss = aux[0]
97
+ pl_mean_new = aux[1]
98
+
99
+ metrics['G_regul_loss'] = loss
100
+ new_state_G = state_G.apply_gradients(grads=grads)
101
+
102
+ if dynamic_scale:
103
+ new_state_G = new_state_G.replace(opt_state=jax.tree_multimap(functools.partial(jnp.where, is_fin),
104
+ new_state_G.opt_state,
105
+ state_G.opt_state),
106
+ params=jax.tree_multimap(functools.partial(jnp.where, is_fin),
107
+ new_state_G.params,
108
+ state_G.params))
109
+ metrics['G_regul_scale'] = dynamic_scale.scale
110
+
111
+ return new_state_G, metrics, pl_mean_new
112
+
113
+
114
+ def main_step_D(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rng):
115
+
116
+ def loss_fn(params):
117
+ w_latent1 = state_G.apply_mapping({'params': state_G.params['mapping'], 'moving_stats': state_G.moving_stats},
118
+ z_latent1,
119
+ batch['label'],
120
+ train=False)
121
+
122
+ w_latent2 = state_G.apply_mapping({'params': state_G.params['mapping'], 'moving_stats': state_G.moving_stats},
123
+ z_latent2,
124
+ batch['label'],
125
+ train=False)
126
+
127
+ # style mixing
128
+ cutoff_rng, layer_select_rng, synth_rng = jax.random.split(rng, num=3)
129
+ num_layers = w_latent1.shape[1]
130
+ layer_idx = jnp.arange(num_layers)[jnp.newaxis, :, jnp.newaxis]
131
+ mixing_cutoff = jax.lax.cond(jax.random.uniform(cutoff_rng, (), minval=0.0, maxval=1.0) < mixing_prob,
132
+ lambda _: jax.random.randint(layer_select_rng, (), 1, num_layers, dtype=jnp.int32),
133
+ lambda _: num_layers,
134
+ operand=None)
135
+ mixing_cond = jnp.broadcast_to(layer_idx < mixing_cutoff, w_latent1.shape)
136
+ w_latent = jnp.where(mixing_cond, w_latent1, w_latent2)
137
+
138
+ image_gen = state_G.apply_synthesis({'params': state_G.params['synthesis'], 'noise_consts': state_G.noise_consts},
139
+ w_latent,
140
+ rng=synth_rng)
141
+
142
+ fake_logits = state_D.apply_fn(params, image_gen, batch['label'])
143
+ real_logits = state_D.apply_fn(params, batch['image'], batch['label'])
144
+
145
+ loss_fake = jax.nn.softplus(fake_logits)
146
+ loss_real = jax.nn.softplus(-real_logits)
147
+ loss = jnp.mean(loss_fake + loss_real)
148
+
149
+ return loss, (fake_logits, real_logits)
150
+
151
+ dynamic_scale = state_D.dynamic_scale_main
152
+
153
+ if dynamic_scale:
154
+ grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True)
155
+ dynamic_scale, is_fin, aux, grads = grad_fn(state_D.params)
156
+ else:
157
+ grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
158
+ aux, grads = grad_fn(state_D.params)
159
+ grads = jax.lax.pmean(grads, axis_name='batch')
160
+
161
+ loss = aux[0]
162
+ fake_logits, real_logits = aux[1]
163
+ metrics['D_loss'] = loss
164
+ metrics['fake_logits'] = jnp.mean(fake_logits)
165
+ metrics['real_logits'] = jnp.mean(real_logits)
166
+
167
+ new_state_D = state_D.apply_gradients(grads=grads)
168
+
169
+ if dynamic_scale:
170
+ new_state_D = new_state_D.replace(opt_state=jax.tree_multimap(functools.partial(jnp.where, is_fin),
171
+ new_state_D.opt_state,
172
+ state_D.opt_state),
173
+ params=jax.tree_multimap(functools.partial(jnp.where, is_fin),
174
+ new_state_D.params,
175
+ state_D.params))
176
+ metrics['D_scale'] = dynamic_scale.scale
177
+
178
+ return new_state_D, metrics
179
+
180
+
181
+ def regul_step_D(state_D, batch, metrics, config):
182
+
183
+ def loss_fn(params):
184
+ r1_grads = jax.grad(lambda *args: jnp.sum(state_D.apply_fn(*args)), argnums=1)(params, batch['image'], batch['label'])
185
+ r1_penalty = jnp.sum(jnp.square(r1_grads), axis=(1, 2, 3)) * (config.r1_gamma / 2) * config.D_reg_interval
186
+ loss = jnp.mean(r1_penalty)
187
+ return loss, None
188
+
189
+ dynamic_scale = state_D.dynamic_scale_reg
190
+
191
+ if dynamic_scale:
192
+ grad_fn = dynamic_scale.value_and_grad(loss_fn, has_aux=True)
193
+ dynamic_scale, is_fin, aux, grads = grad_fn(state_D.params)
194
+ else:
195
+ grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
196
+ aux, grads = grad_fn(state_D.params)
197
+ grads = jax.lax.pmean(grads, axis_name='batch')
198
+
199
+ loss = aux[0]
200
+ metrics['D_regul_loss'] = loss
201
+
202
+ new_state_D = state_D.apply_gradients(grads=grads)
203
+
204
+ if dynamic_scale:
205
+ new_state_D = new_state_D.replace(opt_state=jax.tree_multimap(functools.partial(jnp.where, is_fin),
206
+ new_state_D.opt_state,
207
+ state_D.opt_state),
208
+ params=jax.tree_multimap(functools.partial(jnp.where, is_fin),
209
+ new_state_D.params,
210
+ state_D.params))
211
+ metrics['D_regul_scale'] = dynamic_scale.scale
212
+
213
+ return new_state_D, metrics
214
+
215
+
216
+ def eval_step_G(generator, params, z_latent, labels, truncation):
217
+ image_gen = generator.apply(params, z_latent, labels, truncation_psi=truncation, train=False, noise_mode='const')
218
+ return image_gen
219
+
training_utils.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from jaxlib.xla_extension import DeviceArray
4
+ import flax
5
+ from flax.optim import dynamic_scale as dynamic_scale_lib
6
+ from flax.core import frozen_dict
7
+ from flax.training import train_state
8
+ from flax import struct
9
+ import numpy as np
10
+ from PIL import Image
11
+ from urllib.request import Request, urlopen
12
+ import urllib.error
13
+ from typing import Any, Callable
14
+
15
+
16
+ def sync_moving_stats(state):
17
+ """
18
+ Sync moving statistics across devices.
19
+
20
+ Args:
21
+ state (train_state.TrainState): Training state.
22
+
23
+ Returns:
24
+ (train_state.TrainState): Updated training state.
25
+ """
26
+ cross_replica_mean = jax.pmap(lambda x: jax.lax.pmean(x, 'x'), 'x')
27
+ return state.replace(moving_stats=cross_replica_mean(state.moving_stats))
28
+
29
+
30
+ def update_generator_ema(state_G, params_ema_G, config, ema_beta=None):
31
+ """
32
+ Update exponentially moving average of the generator weights.
33
+ Moving stats and noise constants will be copied over.
34
+
35
+ Args:
36
+ state_G (train_state.TrainState): Generator state.
37
+ params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator.
38
+ config (Any): Config object.
39
+ ema_beta (float): Beta parameter of the ema. If None, will be computed
40
+ from 'ema_nimg' and 'batch_size'.
41
+
42
+ Returns:
43
+ (frozen_dict.FrozenDict): Updates parameters of the ema generator.
44
+ """
45
+ def _update_ema(src, trg, beta):
46
+ for name, src_child in src.items():
47
+ if isinstance(src_child, DeviceArray):
48
+ trg[name] = src[name] + ema_beta * (trg[name] - src[name])
49
+ else:
50
+ _update_ema(src_child, trg[name], beta)
51
+
52
+ if ema_beta is None:
53
+ ema_nimg = config.ema_kimg * 1000
54
+ ema_beta = 0.5 ** (config.batch_size / max(ema_nimg, 1e-8))
55
+
56
+ params_ema_G = params_ema_G.unfreeze()
57
+
58
+ # Copy over moving stats
59
+ params_ema_G['moving_stats']['mapping_network'] = state_G.moving_stats
60
+ params_ema_G['noise_consts']['synthesis_network'] = state_G.noise_consts
61
+
62
+ # Update exponentially moving average of the trainable parameters
63
+ _update_ema(state_G.params['mapping'], params_ema_G['params']['mapping_network'], ema_beta)
64
+ _update_ema(state_G.params['synthesis'], params_ema_G['params']['synthesis_network'], ema_beta)
65
+
66
+ params_ema_G = frozen_dict.freeze(params_ema_G)
67
+ return params_ema_G
68
+
69
+
70
+ class TrainStateG(train_state.TrainState):
71
+ """
72
+ Generator train state for a single Optax optimizer.
73
+
74
+ Attributes:
75
+ apply_mapping (Callable): Apply function of the Mapping Network.
76
+ apply_synthesis (Callable): Apply function of the Synthesis Network.
77
+ dynamic_scale (dynamic_scale_lib.DynamicScale): Dynamic loss scaling for mixed precision gradients.
78
+ epoch (int): Current epoch.
79
+ moving_stats (Any): Moving average of the latent W.
80
+ noise_consts (Any): Noise constants from synthesis layers.
81
+ """
82
+ apply_mapping: Callable = struct.field(pytree_node=False)
83
+ apply_synthesis: Callable = struct.field(pytree_node=False)
84
+ dynamic_scale_main: dynamic_scale_lib.DynamicScale
85
+ dynamic_scale_reg: dynamic_scale_lib.DynamicScale
86
+ epoch: int
87
+ moving_stats: Any=None
88
+ noise_consts: Any=None
89
+
90
+
91
+ class TrainStateD(train_state.TrainState):
92
+ """
93
+ Discriminator train state for a single Optax optimizer.
94
+
95
+ Attributes:
96
+ dynamic_scale (dynamic_scale_lib.DynamicScale): Dynamic loss scaling for mixed precision gradients.
97
+ epoch (int): Current epoch.
98
+ """
99
+ dynamic_scale_main: dynamic_scale_lib.DynamicScale
100
+ dynamic_scale_reg: dynamic_scale_lib.DynamicScale
101
+ epoch: int
102
+
103
+
104
+ def get_training_snapshot(image_real, image_gen, max_num=10):
105
+ """
106
+ Creates a snapshot of generated images and real images.
107
+
108
+ Args:
109
+ images_real (DeviceArray): Batch of real images, shape [B, H, W, C].
110
+ images_gen (DeviceArray): Batch of generated images, shape [B, H, W, C].
111
+ max_num (int): Maximum number of images used for snapshot.
112
+
113
+ Returns:
114
+ (PIL.Image): Training snapshot. Top row: generated images, bottom row: real images.
115
+ """
116
+ if image_real.shape[0] > max_num:
117
+ image_real = image_real[:max_num]
118
+ if image_gen.shape[0] > max_num:
119
+ image_gen = image_gen[:max_num]
120
+
121
+ image_real = jnp.split(image_real, image_real.shape[0], axis=0)
122
+ image_gen = jnp.split(image_gen, image_gen.shape[0], axis=0)
123
+
124
+ image_real = [jnp.squeeze(x, axis=0) for x in image_real]
125
+ image_gen = [jnp.squeeze(x, axis=0) for x in image_gen]
126
+
127
+ image_real = jnp.concatenate(image_real, axis=1)
128
+ image_gen = jnp.concatenate(image_gen, axis=1)
129
+
130
+ image_gen = (image_gen - np.min(image_gen)) / (np.max(image_gen) - np.min(image_gen))
131
+ image_real = (image_real - np.min(image_real)) / (np.max(image_real) - np.min(image_real))
132
+ image = jnp.concatenate((image_gen, image_real), axis=0)
133
+
134
+ image = np.uint8(image * 255)
135
+ if image.shape[-1] == 1:
136
+ image = np.repeat(image, 3, axis=-1)
137
+ return Image.fromarray(image)
138
+
139
+
140
+ def get_eval_snapshot(image, max_num=10):
141
+ """
142
+ Creates a snapshot of generated images.
143
+
144
+ Args:
145
+ image (DeviceArray): Generated images, shape [B, H, W, C].
146
+
147
+ Returns:
148
+ (PIL.Image): Eval snapshot.
149
+ """
150
+ if image.shape[0] > max_num:
151
+ image = image[:max_num]
152
+
153
+ image = jnp.split(image, image.shape[0], axis=0)
154
+ image = [jnp.squeeze(x, axis=0) for x in image]
155
+ image = jnp.concatenate(image, axis=1)
156
+ image = (image - np.min(image)) / (np.max(image) - np.min(image))
157
+ image = np.uint8(image * 255)
158
+ if image.shape[-1] == 1:
159
+ image = np.repeat(image, 3, axis=-1)
160
+ return Image.fromarray(image)
161
+
162
+
163
+ def get_vm_name():
164
+ gcp_metadata_url = "http://metadata.google.internal/computeMetadata/v1/instance/attributes/instance-id"
165
+ req = Request(gcp_metadata_url)
166
+ req.add_header('Metadata-Flavor', 'Google')
167
+ instance_id = None
168
+ try:
169
+ with urlopen(req) as url:
170
+ instance_id = url.read().decode()
171
+ except urllib.error.URLError:
172
+ # metadata.google.internal not reachable: use dev
173
+ pass
174
+ return instance_id