File size: 18,849 Bytes
81170fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
import jax
import jax.numpy as jnp
import flax
from flax.optim import dynamic_scale as dynamic_scale_lib
from flax.core import frozen_dict
import optax
import numpy as np
import functools
import wandb
import time

import stylegan2
import data_pipeline
import checkpoint
import training_utils
import training_steps
from fid import FID

import logging

logger = logging.getLogger(__name__)


def tree_shape(item):
    return jax.tree_map(lambda c: c.shape, item)


def train_and_evaluate(config):
    num_devices = jax.device_count()   # 8
    num_local_devices = jax.local_device_count()  # 4
    num_workers = jax.process_count()

    # --------------------------------------
    # Data
    # --------------------------------------
    ds_train, dataset_info = data_pipeline.get_data(data_dir=config.data_dir,
                                                    img_size=config.resolution,
                                                    img_channels=config.img_channels,
                                                    num_classes=config.c_dim,
                                                    num_local_devices=num_local_devices,
                                                    batch_size=config.batch_size)

    # --------------------------------------
    # Seeding and Precision
    # --------------------------------------
    rng = jax.random.PRNGKey(config.random_seed)

    if config.mixed_precision:
        dtype = jnp.float16
    elif config.bf16:
        dtype = jnp.bfloat16
    else:
        dtype = jnp.float32
    logger.info(f'Running on dtype {dtype}')

    platform = jax.local_devices()[0].platform
    if config.mixed_precision and platform == 'gpu':
        dynamic_scale_G_main = dynamic_scale_lib.DynamicScale()
        dynamic_scale_D_main = dynamic_scale_lib.DynamicScale()
        dynamic_scale_G_reg = dynamic_scale_lib.DynamicScale()
        dynamic_scale_D_reg = dynamic_scale_lib.DynamicScale()
        clip_conv = 256
        num_fp16_res = 4
    else:
        dynamic_scale_G_main = None
        dynamic_scale_D_main = None
        dynamic_scale_G_reg = None
        dynamic_scale_D_reg = None
        clip_conv = None
        num_fp16_res = 0

    # --------------------------------------
    # Initialize Models
    # --------------------------------------
    logger.info('Initialize models...')

    rng, init_rng = jax.random.split(rng)

    # Generator initialization for training
    start_mn = time.time()
    logger.info("Creating MappingNetwork...")
    mapping_net = stylegan2.MappingNetwork(z_dim=config.z_dim,
                                           c_dim=config.c_dim,
                                           w_dim=config.w_dim,
                                           num_ws=int(np.log2(config.resolution)) * 2 - 3,
                                           num_layers=8,
                                           dtype=dtype)

    mapping_net_vars = mapping_net.init(init_rng,
                                        jnp.ones((1, config.z_dim)),
                                        jnp.ones((1, config.c_dim)))

    mapping_net_params, moving_stats = mapping_net_vars['params'], mapping_net_vars['moving_stats']

    logger.info(f"MappingNetwork took {time.time() - start_mn:.2f}s")

    logger.info("Creating SynthesisNetwork...")
    start_sn = time.time()
    synthesis_net = stylegan2.SynthesisNetwork(resolution=config.resolution,
                                               num_channels=config.img_channels,
                                               w_dim=config.w_dim,
                                               fmap_base=config.fmap_base,
                                               num_fp16_res=num_fp16_res,
                                               clip_conv=clip_conv,
                                               dtype=dtype)

    synthesis_net_vars = synthesis_net.init(init_rng,
                                            jnp.ones((1, mapping_net.num_ws, config.w_dim)))
    synthesis_net_params, noise_consts = synthesis_net_vars['params'], synthesis_net_vars['noise_consts']

    logger.info(f"SynthesisNetwork took {time.time() - start_sn:.2f}s")

    params_G = frozen_dict.FrozenDict(
        {'mapping': mapping_net_params,
         'synthesis': synthesis_net_params}
    )

    # Discriminator initialization for training
    logger.info("Creating Discriminator...")
    start_d = time.time()
    discriminator = stylegan2.Discriminator(resolution=config.resolution,
                                            num_channels=config.img_channels,
                                            c_dim=config.c_dim,
                                            mbstd_group_size=config.mbstd_group_size,
                                            num_fp16_res=num_fp16_res,
                                            clip_conv=clip_conv,
                                            dtype=dtype)
    rng, init_rng = jax.random.split(rng)
    params_D = discriminator.init(init_rng,
                                  jnp.ones((1, config.resolution, config.resolution, config.img_channels)),
                                  jnp.ones((1, config.c_dim)))
    logger.info(f"Discriminator took {time.time() - start_d:.2f}s")

    # Exponential average Generator initialization
    logger.info("Creating Generator EMA...")
    start_g = time.time()
    generator_ema = stylegan2.Generator(resolution=config.resolution,
                                        num_channels=config.img_channels,
                                        z_dim=config.z_dim,
                                        c_dim=config.c_dim,
                                        w_dim=config.w_dim,
                                        num_ws=int(np.log2(config.resolution)) * 2 - 3,
                                        num_mapping_layers=8,
                                        fmap_base=config.fmap_base,
                                        num_fp16_res=num_fp16_res,
                                        clip_conv=clip_conv,
                                        dtype=dtype)

    params_ema_G = generator_ema.init(init_rng,
                                      jnp.ones((1, config.z_dim)),
                                      jnp.ones((1, config.c_dim)))
    logger.info(f"Took {time.time() - start_g:.2f}s")

    # --------------------------------------
    # Initialize States and Optimizers
    # --------------------------------------
    logger.info('Initialize states...')
    tx_G = optax.adam(learning_rate=config.learning_rate, b1=0.0, b2=0.99)
    tx_D = optax.adam(learning_rate=config.learning_rate, b1=0.0, b2=0.99)

    state_G = training_utils.TrainStateG.create(apply_fn=None,
                                                apply_mapping=mapping_net.apply,
                                                apply_synthesis=synthesis_net.apply,
                                                params=params_G,
                                                moving_stats=moving_stats,
                                                noise_consts=noise_consts,
                                                tx=tx_G,
                                                dynamic_scale_main=dynamic_scale_G_main,
                                                dynamic_scale_reg=dynamic_scale_G_reg,
                                                epoch=0)

    state_D = training_utils.TrainStateD.create(apply_fn=discriminator.apply,
                                                params=params_D,
                                                tx=tx_D,
                                                dynamic_scale_main=dynamic_scale_D_main,
                                                dynamic_scale_reg=dynamic_scale_D_reg,
                                                epoch=0)

    # Copy over the parameters from the training generator to the ema generator
    params_ema_G = training_utils.update_generator_ema(state_G, params_ema_G, config, ema_beta=0)

    # Running mean of path length for path length regularization
    pl_mean = jnp.zeros((), dtype=dtype)

    step = 0
    epoch_offset = 0
    best_fid_score = np.inf
    ckpt_path = None

    if config.resume_run_id is not None:
        #  Resume training from existing checkpoint
        ckpt_path = checkpoint.get_latest_checkpoint(config.ckpt_dir)
        logger.info(f'Resume training from checkpoint: {ckpt_path}')
        ckpt = checkpoint.load_checkpoint(ckpt_path)
        step = ckpt['step']
        epoch_offset = ckpt['epoch']
        best_fid_score = ckpt['fid_score']
        pl_mean = ckpt['pl_mean']
        state_G = ckpt['state_G']
        state_D = ckpt['state_D']
        params_ema_G = ckpt['params_ema_G']
        config = ckpt['config']
    elif config.load_from_pkl is not None:
        # Load checkpoint and start new run
        ckpt_path = config.load_from_pkl
        logger.info(f'Load model state from from : {ckpt_path}')
        ckpt = checkpoint.load_checkpoint(ckpt_path)
        pl_mean = ckpt['pl_mean']
        state_G = ckpt['state_G']
        state_D = ckpt['state_D']
        params_ema_G = ckpt['params_ema_G']

    # Replicate states across devices
    pl_mean = flax.jax_utils.replicate(pl_mean)
    state_G = flax.jax_utils.replicate(state_G)
    state_D = flax.jax_utils.replicate(state_D)

    # --------------------------------------
    # Precompile train and eval steps
    # --------------------------------------
    logger.info('Precompile training steps...')
    p_main_step_G = jax.pmap(training_steps.main_step_G, axis_name='batch')
    p_regul_step_G = jax.pmap(functools.partial(training_steps.regul_step_G, config=config), axis_name='batch')

    p_main_step_D = jax.pmap(training_steps.main_step_D, axis_name='batch')
    p_regul_step_D = jax.pmap(functools.partial(training_steps.regul_step_D, config=config), axis_name='batch')

    # --------------------------------------
    # Training
    # --------------------------------------
    logger.info('Start training...')
    fid_metric = FID(generator_ema, ds_train, config)

    # Dict to collect training statistics / losses
    metrics = {}
    num_imgs_processed = 0
    num_steps_per_epoch = dataset_info['num_examples'] // (config.batch_size * num_devices)
    effective_batch_size = config.batch_size * num_devices
    if config.wandb and jax.process_index() == 0:
        # do some more logging
        wandb.config.effective_batch_size = effective_batch_size
        wandb.config.num_steps_per_epoch = num_steps_per_epoch
        wandb.config.num_workers = num_workers
        wandb.config.device_count = num_devices
        wandb.config.num_examples = dataset_info['num_examples']
        wandb.config.vm_name = training_utils.get_vm_name()

    for epoch in range(epoch_offset, config.num_epochs):
        if config.wandb and jax.process_index() == 0:
            wandb.log({'training/epochs': epoch}, step=step)

        for batch in data_pipeline.prefetch(ds_train, config.num_prefetch):
            assert batch['image'].shape[1] == config.batch_size, f"Mismatched batch (batch size: {config.batch_size}, this batch: {batch['image'].shape[1]})"

            # pbar.update(num_devices * config.batch_size)
            iteration_start_time = time.time()

            if config.c_dim == 0:
                # No labels in the dataset
                batch['label'] = None

            # Create two latent noise vectors and combine them for the style mixing regularization
            rng, key = jax.random.split(rng)
            z_latent1 = jax.random.normal(key, (num_local_devices, config.batch_size, config.z_dim), dtype)
            rng, key = jax.random.split(rng)
            z_latent2 = jax.random.normal(key, (num_local_devices, config.batch_size, config.z_dim), dtype)

            # Split PRNGs across devices
            rkey = jax.random.split(key, num=num_local_devices)
            mixing_prob = flax.jax_utils.replicate(config.mixing_prob)

            # --------------------------------------
            # Update Discriminator
            # --------------------------------------
            time_d_start = time.time()
            state_D, metrics = p_main_step_D(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rkey)
            time_d_end = time.time()
            if step % config.D_reg_interval == 0:
                state_D, metrics = p_regul_step_D(state_D, batch, metrics)

            # --------------------------------------
            # Update Generator
            # --------------------------------------
            time_g_start = time.time()
            state_G, metrics = p_main_step_G(state_G, state_D, batch, z_latent1, z_latent2, metrics, mixing_prob, rkey)
            if step % config.G_reg_interval == 0:
                H, W = batch['image'].shape[-3], batch['image'].shape[-2]
                rng, key = jax.random.split(rng)
                pl_noise = jax.random.normal(key, batch['image'].shape, dtype=dtype) / np.sqrt(H * W)
                state_G, metrics, pl_mean = p_regul_step_G(state_G, batch, z_latent1, pl_noise, pl_mean, metrics,
                                                           rng=rkey)

            params_ema_G = training_utils.update_generator_ema(flax.jax_utils.unreplicate(state_G),
                                                               params_ema_G,
                                                               config)
            time_g_end = time.time()

            # --------------------------------------
            # Logging and Checkpointing
            # --------------------------------------
            if step % config.save_every == 0 and config.disable_fid:
                # If FID evaluation is disabled, a checkpoint will be saved every 'save_every' steps.
                if jax.process_index() == 0:
                    logger.info('Saving checkpoint...')
                    checkpoint.save_checkpoint(config.ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step,
                                               epoch)

            num_imgs_processed += num_devices * config.batch_size
            if step % config.eval_fid_every == 0 and not config.disable_fid:
                # If FID evaluation is enabled, only save a checkpoint if FID score is better.
                if jax.process_index() == 0:
                    logger.info('Computing FID...')
                    fid_score = fid_metric.compute_fid(params_ema_G).item()
                    if config.wandb:
                        wandb.log({'training/gen/fid': fid_score}, step=step)
                    logger.info(f'Computed FID: {fid_score:.2f}')
                    if fid_score < best_fid_score:
                        best_fid_score = fid_score
                        logger.info(f'New best FID score ({best_fid_score:.3f}). Saving checkpoint...')
                        ts = time.time()
                        checkpoint.save_checkpoint(config.ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=fid_score)
                        te = time.time()
                        logger.info(f'... successfully saved checkpoint in {(te-ts)/60:.1f}min')

            sec_per_kimg = (time.time() - iteration_start_time) / (num_devices * config.batch_size / 1000.0)
            time_taken_g = time_g_end - time_g_start
            time_taken_d = time_d_end - time_d_start
            time_taken_per_step = time.time() - iteration_start_time
            g_loss = jnp.mean(metrics['G_loss']).item()
            d_loss = jnp.mean(metrics['D_loss']).item()

            if config.wandb and jax.process_index() == 0:
                # wandb logging - happens every step
                wandb.log({'training/gen/loss': jnp.mean(metrics['G_loss']).item()}, step=step, commit=False)
                wandb.log({'training/dis/loss': jnp.mean(metrics['D_loss']).item()}, step=step, commit=False)
                wandb.log({'training/dis/fake_logits': jnp.mean(metrics['fake_logits']).item()}, step=step, commit=False)
                wandb.log({'training/dis/real_logits': jnp.mean(metrics['real_logits']).item()}, step=step, commit=False)
                wandb.log({'training/time_taken_g': time_taken_g, 'training/time_taken_d': time_taken_d}, step=step, commit=False)
                wandb.log({'training/time_taken_per_step': time_taken_per_step}, step=step, commit=False)
                wandb.log({'training/num_imgs_trained': num_imgs_processed}, step=step, commit=False)
                wandb.log({'training/sec_per_kimg': sec_per_kimg}, step=step)

            if step % config.log_every == 0:
                # console logging - happens every log_every steps
                logger.info(f'Total steps: {step:>6,} - epoch {epoch:>3,}/{config.num_epochs} @ {step % num_steps_per_epoch:>6,}/{num_steps_per_epoch:,} - G loss: {g_loss:.5f} - D loss: {d_loss:.5f} - sec/kimg: {sec_per_kimg:.2f}s - time per step: {time_taken_per_step:.3f}s')

            if step % config.generate_samples_every == 0 and config.wandb and jax.process_index() == 0:
                # Generate training images
                train_snapshot = training_utils.get_training_snapshot(
                    image_real=flax.jax_utils.unreplicate(batch['image']),
                    image_gen=flax.jax_utils.unreplicate(metrics['image_gen']),
                    max_num=10
                )
                wandb.log({'training/snapshot': wandb.Image(train_snapshot)}, commit=False, step=step)

                # Generate evaluation images
                labels = None if config.c_dim == 0 else batch['label'][0]
                image_gen_eval = training_steps.eval_step_G(
                    generator_ema, params=params_ema_G,
                    z_latent=z_latent1[0],
                    labels=labels,
                    truncation=1
                )
                image_gen_eval_trunc = training_steps.eval_step_G(
                    generator_ema,
                    params=params_ema_G,
                    z_latent=z_latent1[0],
                    labels=labels,
                    truncation=0.5
                )
                eval_snapshot = training_utils.get_eval_snapshot(image=image_gen_eval, max_num=10)
                eval_snapshot_trunc = training_utils.get_eval_snapshot(image=image_gen_eval_trunc, max_num=10)
                wandb.log({'eval/snapshot': wandb.Image(eval_snapshot)}, commit=False, step=step)
                wandb.log({'eval/snapshot_trunc': wandb.Image(eval_snapshot_trunc)}, step=step)

            step += 1

        # Sync moving stats across devices
        state_G = training_utils.sync_moving_stats(state_G)

        # Sync moving average of path length mean (Generator regularization)
        pl_mean = jax.pmap(lambda x: jax.lax.pmean(x, axis_name='batch'), axis_name='batch')(pl_mean)