File size: 3,042 Bytes
81170fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import flax
import dill as pickle
import os
import builtins
from jax._src.lib import xla_client
import tensorflow as tf


# Hack: this is the module reported by this object.
# https://github.com/google/jax/issues/8505
builtins.bfloat16 = xla_client.bfloat16


def pickle_dump(obj, filename):
    """ Wrapper to dump an object to a file."""
    with tf.io.gfile.GFile(filename, "wb") as f:
        f.write(pickle.dumps(obj))


def pickle_load(filename):
    """ Wrapper to load an object from a file."""
    with tf.io.gfile.GFile(filename, 'rb') as f:
        pickled = pickle.loads(f.read())
    return pickled


def save_checkpoint(ckpt_dir, state_G, state_D, params_ema_G, pl_mean, config, step, epoch, fid_score=None, keep=2):
    """
    Saves checkpoint.

    Args:
        ckpt_dir (str): Path to the directory, where checkpoints are saved.
        state_G (train_state.TrainState): Generator state.
        state_D (train_state.TrainState): Discriminator state.
        params_ema_G (frozen_dict.FrozenDict): Parameters of the ema generator.
        pl_mean (array): Moving average of the path length (generator regularization).
        config (argparse.Namespace): Configuration.
        step (int): Current step.
        epoch (int): Current epoch.
        fid_score (float): FID score corresponding to the checkpoint.
        keep (int): Number of checkpoints to keep.
    """
    state_dict = {'state_G': flax.jax_utils.unreplicate(state_G),
                  'state_D': flax.jax_utils.unreplicate(state_D),
                  'params_ema_G': params_ema_G,
                  'pl_mean': flax.jax_utils.unreplicate(pl_mean),
                  'config': config,
                  'fid_score': fid_score,
                  'step': step,
                  'epoch': epoch}

    pickle_dump(state_dict, os.path.join(ckpt_dir, f'ckpt_{step}.pickle'))
    ckpts = tf.io.gfile.glob(os.path.join(ckpt_dir, '*.pickle'))
    if len(ckpts) > keep:
        modified_times = {}
        for ckpt in ckpts:
            stats = tf.io.gfile.stat(ckpt)
            modified_times[ckpt] = stats.mtime_nsec
        oldest_ckpt = sorted(modified_times, key=modified_times.get)[0]
        tf.io.gfile.remove(oldest_ckpt)


def load_checkpoint(filename):
    """
    Loads checkpoints.

    Args:
        filename (str): Path to the checkpoint file.

    Returns:
        (dict): Checkpoint.
    """
    state_dict = pickle_load(filename)
    return state_dict


def get_latest_checkpoint(ckpt_dir):
    """
    Returns the path of the latest checkpoint.

    Args:
        ckpt_dir (str): Path to the directory, where checkpoints are saved.

    Returns:
        (str): Path to latest checkpoint (if it exists).
    """
    ckpts = tf.io.gfile.glob(os.path.join(ckpt_dir, '*.pickle'))
    if len(ckpts) == 0:
        return None

    modified_times = {}
    for ckpt in ckpts:
        stats = tf.io.gfile.stat(ckpt)
        modified_times[ckpt] = stats.mtime_nsec
    latest_ckpt = sorted(modified_times, key=modified_times.get)[-1]
    return latest_ckpt