File size: 6,810 Bytes
a5f8a35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import copy
import pathlib
from typing import Any, Dict, List, Optional

from loguru import logger
import torch
from torch import nn

import virtex.utils.distributed as dist


class CheckpointManager(object):
    r"""
    A helper class to periodically serialize models and other checkpointable
    objects (optimizers, LR schedulers etc., which implement ``state_dict``
    method) during training, and optionally record best performing checkpoint
    based on an observed metric.

    .. note::

        For :class:`~torch.nn.parallel.DistributedDataParallel` objects,
        ``state_dict`` of internal model is serialized.

    .. note::

        The observed metric for keeping best checkpoint is assumed "higher is
        better", flip the sign if otherwise.

    Parameters
    ----------
    serialization_dir: str
        Path to a directory to save checkpoints.
    keep_recent: int, optional (default = 100)
        Number of recent ``k`` checkpoints to keep on disk. Older checkpoints
        will be removed. Set to a very large value for keeping all checkpoints.
    checkpointables: Any
        Keyword arguments with any checkpointable objects, for example: model,
        optimizer, learning rate scheduler.

    Examples
    --------
    >>> model = torch.nn.Linear(10, 2)
    >>> optimizer = torch.optim.Adam(model.parameters())
    >>> ckpt_manager = CheckpointManager("/tmp", model=model, optimizer=optimizer)
    >>> num_epochs = 20
    >>> for epoch in range(num_epochs):
    ...     train(model)
    ...     val_loss = validate(model)
    ...     ckpt_manager.step(- val_loss, epoch)
    """

    def __init__(
        self,
        serialization_dir: str = "/tmp",
        keep_recent: int = 200,
        **checkpointables: Any,
    ):
        self.serialization_dir = pathlib.Path(serialization_dir)
        self.keep_recent = keep_recent

        # Shallow copy, keeps references to tensors as original objects.
        self.checkpointables = copy.copy(checkpointables)

        # Initialize members to hold state dict of best checkpoint and its
        # performance.
        self._best_metric: float = -1e-12
        self._best_ckpt: Dict[str, Any] = {}

        # Keep epoch/iteration numbers of recently saved 'k' checkpoints.
        self._recent_iterations: List[int] = []

    def step(self, iteration: int, metric: Optional[float] = None):
        r"""
        Serialize checkpoint and update best checkpoint based on metric. Keys
        in serialized checkpoint match those in :attr:`checkpointables`.

        Parameters
        ----------
        iteration: int
            Current training iteration. Will be saved with other checkpointables.
        metric: float, optional (default = None)
            Observed metric (higher is better) for keeping track of best
            checkpoint. If this is ``None``, best chckpoint will not be
            recorded/updated.
        """

        checkpointable_state_dict: Dict[str, Any] = self._state_dict()

        # We also checkpoint current iteration.
        checkpointable_state_dict["iteration"] = iteration

        # Update the best checkpoint based on metric, if provided.
        if metric is not None and metric > self._best_metric:
            self._best_metric = metric
            self._best_ckpt = copy.copy(checkpointable_state_dict)

        # Serialize checkpoint corresponding to current iteration.
        torch.save(
            checkpointable_state_dict,
            self.serialization_dir / f"checkpoint_{iteration}.pth",
        )
        if self._best_metric != -1e-12:
            # Serialize best performing checkpoint observed so far.
            torch.save(
                self._best_ckpt, self.serialization_dir / "checkpoint_best.pth"
            )

        # Remove earliest checkpoint if there are more on disk.
        self._recent_iterations.append(iteration)
        if len(self._recent_iterations) > self.keep_recent:
            self.remove_earliest_checkpoint()

    def _state_dict(self):
        r"""Return a dict containing state dict of all checkpointables."""

        __state_dict: Dict[str, Any] = {}
        for key in self.checkpointables:
            if isinstance(
                self.checkpointables[key], nn.parallel.DistributedDataParallel
            ):
                __state_dict[key] = self.checkpointables[key].module.state_dict()
            else:
                __state_dict[key] = self.checkpointables[key].state_dict()

        return __state_dict

    def remove_earliest_checkpoint(self):
        r"""Remove earliest serialized checkpoint from disk."""

        earliest_iteration = self._recent_iterations.pop(0)
        (self.serialization_dir / f"checkpoint_{earliest_iteration}.pth").unlink()

    def load(self, checkpoint_path: str):
        r"""
        Load a serialized checkpoint from a path. This method will try to find
        each of :attr:`checkpointables` in the file and load its state dict.
        Since our checkpointables are held as references, this method does not
        return them.

        Parameters
        ----------
        checkpoint_path: str
            Path to a checkpoint serialized by :meth:`step`.

        Returns
        -------
        int
            Iteration corresponding to the loaded checkpoint. Useful for
            resuming training. This will be -1 in case of best checkpoint,
            or if info does not exist.
        """

        # Each process will log a message after loading checkpoint.
        rank = dist.get_rank()

        logger.info(f"Rank {rank}: Loading checkpoint from {checkpoint_path}")
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        iteration = checkpoint.pop("iteration", -1)

        # Keep flags of all checkpointables to lo which ones were not loaded.
        is_loaded = {key: False for key in self.checkpointables}

        # Load each checkpointable from checkpoint.
        for key in checkpoint:
            if key in self.checkpointables:
                logger.info(f"Rank {rank}: Loading {key} from {checkpoint_path}")

                if isinstance(
                    self.checkpointables[key], nn.parallel.DistributedDataParallel
                ):
                    self.checkpointables[key].module.load_state_dict(checkpoint[key])
                else:
                    self.checkpointables[key].load_state_dict(checkpoint[key])

                is_loaded[key] = True
            else:
                logger.info(f"Rank {rank}: {key} not found in `checkpointables`.")

        not_loaded: List[str] = [key for key in is_loaded if not is_loaded[key]]
        if len(not_loaded) > 0:
            logger.info(
                f"Rank {rank}: Checkpointables not found in file: {not_loaded}"
            )
        return iteration