File size: 13,240 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
from ditk import logging
import signal
import sys
import traceback
from typing import Callable
import torch
import torch.utils.data  # torch1.1.0 compatibility
from ding.utils import read_file, save_file

logger = logging.getLogger('default_logger')


def build_checkpoint_helper(cfg):
    """
    Overview:
        Use config to build checkpoint helper.
    Arguments:
        - cfg (:obj:`dict`): ckpt_helper config
    Returns:
        - (:obj:`CheckpointHelper`): checkpoint_helper created by this function
    """
    return CheckpointHelper()


class CheckpointHelper:
    """
    Overview:
        Help to save or load checkpoint by give args.
    Interfaces:
        ``__init__``, ``save``, ``load``, ``_remove_prefix``, ``_add_prefix``, ``_load_matched_model_state_dict``
    """

    def __init__(self):
        pass

    def _remove_prefix(self, state_dict: dict, prefix: str = 'module.') -> dict:
        """
        Overview:
            Remove prefix in state_dict
        Arguments:
            - state_dict (:obj:`dict`): model's state_dict
            - prefix (:obj:`str`): this prefix will be removed in keys
        Returns:
            - new_state_dict (:obj:`dict`): new state_dict after removing prefix
        """
        new_state_dict = {}
        for k, v in state_dict.items():
            if k.startswith(prefix):
                new_k = ''.join(k.split(prefix))
            else:
                new_k = k
            new_state_dict[new_k] = v
        return new_state_dict

    def _add_prefix(self, state_dict: dict, prefix: str = 'module.') -> dict:
        """
        Overview:
            Add prefix in state_dict
        Arguments:
            - state_dict (:obj:`dict`): model's state_dict
            - prefix (:obj:`str`): this prefix will be added in keys
        Returns:
            - (:obj:`dict`): new state_dict after adding prefix
        """
        return {prefix + k: v for k, v in state_dict.items()}

    def save(
            self,
            path: str,
            model: torch.nn.Module,
            optimizer: torch.optim.Optimizer = None,
            last_iter: 'CountVar' = None,  # noqa
            last_epoch: 'CountVar' = None,  # noqa
            last_frame: 'CountVar' = None,  # noqa
            dataset: torch.utils.data.Dataset = None,
            collector_info: torch.nn.Module = None,
            prefix_op: str = None,
            prefix: str = None,
    ) -> None:
        """
        Overview:
            Save checkpoint by given args
        Arguments:
            - path (:obj:`str`): the path of saving checkpoint
            - model (:obj:`torch.nn.Module`): model to be saved
            - optimizer (:obj:`torch.optim.Optimizer`): optimizer obj
            - last_iter (:obj:`CountVar`): iter num, default None
            - last_epoch (:obj:`CountVar`): epoch num, default None
            - last_frame (:obj:`CountVar`): frame num, default None
            - dataset (:obj:`torch.utils.data.Dataset`): dataset, should be replaydataset
            - collector_info (:obj:`torch.nn.Module`): attr of checkpoint, save collector info
            - prefix_op (:obj:`str`): should be ['remove', 'add'], process on state_dict
            - prefix (:obj:`str`): prefix to be processed on state_dict
        """
        checkpoint = {}
        model = model.state_dict()
        if prefix_op is not None:  # remove or add prefix to model.keys()
            prefix_func = {'remove': self._remove_prefix, 'add': self._add_prefix}
            if prefix_op not in prefix_func.keys():
                raise KeyError('invalid prefix_op:{}'.format(prefix_op))
            else:
                model = prefix_func[prefix_op](model, prefix)
        checkpoint['model'] = model

        if optimizer is not None:  # save optimizer
            assert (last_iter is not None or last_epoch is not None)
            checkpoint['last_iter'] = last_iter.val
            if last_epoch is not None:
                checkpoint['last_epoch'] = last_epoch.val
            if last_frame is not None:
                checkpoint['last_frame'] = last_frame.val
            checkpoint['optimizer'] = optimizer.state_dict()

        if dataset is not None:
            checkpoint['dataset'] = dataset.state_dict()
        if collector_info is not None:
            checkpoint['collector_info'] = collector_info.state_dict()
        save_file(path, checkpoint)
        logger.info('save checkpoint in {}'.format(path))

    def _load_matched_model_state_dict(self, model: torch.nn.Module, ckpt_state_dict: dict) -> None:
        """
        Overview:
            Load matched model state_dict, and show mismatch keys between model's state_dict and checkpoint's state_dict
        Arguments:
            - model (:obj:`torch.nn.Module`): model
            - ckpt_state_dict (:obj:`dict`): checkpoint's state_dict
        """
        assert isinstance(model, torch.nn.Module)
        diff = {'miss_keys': [], 'redundant_keys': [], 'mismatch_shape_keys': []}
        model_state_dict = model.state_dict()
        model_keys = set(model_state_dict.keys())
        ckpt_keys = set(ckpt_state_dict.keys())
        diff['miss_keys'] = model_keys - ckpt_keys
        diff['redundant_keys'] = ckpt_keys - model_keys

        intersection_keys = model_keys.intersection(ckpt_keys)
        valid_keys = []
        for k in intersection_keys:
            if model_state_dict[k].shape == ckpt_state_dict[k].shape:
                valid_keys.append(k)
            else:
                diff['mismatch_shape_keys'].append(
                    '{}\tmodel_shape: {}\tckpt_shape: {}'.format(
                        k, model_state_dict[k].shape, ckpt_state_dict[k].shape
                    )
                )
        valid_ckpt_state_dict = {k: v for k, v in ckpt_state_dict.items() if k in valid_keys}
        model.load_state_dict(valid_ckpt_state_dict, strict=False)

        for n, keys in diff.items():
            for k in keys:
                logger.info('{}: {}'.format(n, k))

    def load(
        self,
        load_path: str,
        model: torch.nn.Module,
        optimizer: torch.optim.Optimizer = None,
        last_iter: 'CountVar' = None,  # noqa
        last_epoch: 'CountVar' = None,  # noqa
        last_frame: 'CountVar' = None,  # noqa
        lr_schduler: 'Scheduler' = None,  # noqa
        dataset: torch.utils.data.Dataset = None,
        collector_info: torch.nn.Module = None,
        prefix_op: str = None,
        prefix: str = None,
        strict: bool = True,
        logger_prefix: str = '',
        state_dict_mask: list = [],
    ):
        """
        Overview:
            Load checkpoint by given path
        Arguments:
            - load_path (:obj:`str`): checkpoint's path
            - model (:obj:`torch.nn.Module`): model definition
            - optimizer (:obj:`torch.optim.Optimizer`): optimizer obj
            - last_iter (:obj:`CountVar`): iter num, default None
            - last_epoch (:obj:`CountVar`): epoch num, default None
            - last_frame (:obj:`CountVar`): frame num, default None
            - lr_schduler (:obj:`Schduler`): lr_schduler obj
            - dataset (:obj:`torch.utils.data.Dataset`): dataset, should be replaydataset
            - collector_info (:obj:`torch.nn.Module`): attr of checkpoint, save collector info
            - prefix_op (:obj:`str`): should be ['remove', 'add'], process on state_dict
            - prefix (:obj:`str`): prefix to be processed on state_dict
            - strict (:obj:`bool`): args of model.load_state_dict
            - logger_prefix (:obj:`str`): prefix of logger
            - state_dict_mask (:obj:`list`): A list containing state_dict keys, \
                which shouldn't be loaded into model(after prefix op)

        .. note::

            The checkpoint loaded from load_path is a dict, whose format is like '{'state_dict': OrderedDict(), ...}'
        """
        # TODO save config
        # Note: for reduce first GPU memory cost and compatible for cpu env
        checkpoint = read_file(load_path)
        state_dict = checkpoint['model']
        if prefix_op is not None:
            prefix_func = {'remove': self._remove_prefix, 'add': self._add_prefix}
            if prefix_op not in prefix_func.keys():
                raise KeyError('invalid prefix_op:{}'.format(prefix_op))
            else:
                state_dict = prefix_func[prefix_op](state_dict, prefix)
        if len(state_dict_mask) > 0:
            if strict:
                logger.info(
                    logger_prefix +
                    '[Warning] non-empty state_dict_mask expects strict=False, but finds strict=True in input argument'
                )
                strict = False
            for m in state_dict_mask:
                state_dict_keys = list(state_dict.keys())
                for k in state_dict_keys:
                    if k.startswith(m):
                        state_dict.pop(k)  # ignore return value
        if strict:
            model.load_state_dict(state_dict, strict=True)
        else:
            self._load_matched_model_state_dict(model, state_dict)
        logger.info(logger_prefix + 'load model state_dict in {}'.format(load_path))

        if dataset is not None:
            if 'dataset' in checkpoint.keys():
                dataset.load_state_dict(checkpoint['dataset'])
                logger.info(logger_prefix + 'load online data in {}'.format(load_path))
            else:
                logger.info(logger_prefix + "dataset not in checkpoint, ignore load procedure")

        if optimizer is not None:
            if 'optimizer' in checkpoint.keys():
                optimizer.load_state_dict(checkpoint['optimizer'])
                logger.info(logger_prefix + 'load optimizer in {}'.format(load_path))
            else:
                logger.info(logger_prefix + "optimizer not in checkpoint, ignore load procedure")

        if last_iter is not None:
            if 'last_iter' in checkpoint.keys():
                last_iter.update(checkpoint['last_iter'])
                logger.info(
                    logger_prefix + 'load last_iter in {}, current last_iter is {}'.format(load_path, last_iter.val)
                )
            else:
                logger.info(logger_prefix + "last_iter not in checkpoint, ignore load procedure")

        if collector_info is not None:
            collector_info.load_state_dict(checkpoint['collector_info'])
            logger.info(logger_prefix + 'load collector info in {}'.format(load_path))

        if lr_schduler is not None:
            assert (last_iter is not None)
            raise NotImplementedError


class CountVar(object):
    """
    Overview:
        Number counter
    Interfaces:
        ``__init__``, ``update``, ``add``
    Properties:
        - val (:obj:`int`): the value of the counter
    """

    def __init__(self, init_val: int) -> None:
        """
        Overview:
            Init the var counter
        Arguments:
            - init_val (:obj:`int`): the init value of the counter
        """

        self._val = init_val

    @property
    def val(self) -> int:
        """
        Overview:
            Get the var counter
        """

        return self._val

    def update(self, val: int) -> None:
        """
        Overview:
            Update the var counter
        Arguments:
            - val (:obj:`int`): the update value of the counter
        """
        self._val = val

    def add(self, add_num: int):
        """
        Overview:
            Add the number to counter
        Arguments:
            - add_num (:obj:`int`): the number added to the counter
        """
        self._val += add_num


def auto_checkpoint(func: Callable) -> Callable:
    """
    Overview:
        Create a wrapper to wrap function, and the wrapper will call the save_checkpoint method
        whenever an exception happens.
    Arguments:
        - func(:obj:`Callable`): the function to be wrapped
    Returns:
        - wrapper (:obj:`Callable`): the wrapped function
    """
    dead_signals = ['SIGILL', 'SIGINT', 'SIGKILL', 'SIGQUIT', 'SIGSEGV', 'SIGSTOP', 'SIGTERM', 'SIGBUS']
    all_signals = dead_signals + ['SIGUSR1']

    def register_signal_handler(handler):
        valid_sig = []
        invalid_sig = []
        for sig in all_signals:
            try:
                sig = getattr(signal, sig)
                signal.signal(sig, handler)
                valid_sig.append(sig)
            except Exception:
                invalid_sig.append(sig)
        logger.info('valid sig: ({})\ninvalid sig: ({})'.format(valid_sig, invalid_sig))

    def wrapper(*args, **kwargs):
        handle = args[0]
        assert (hasattr(handle, 'save_checkpoint'))

        def signal_handler(signal_num, frame):
            sig = signal.Signals(signal_num)
            logger.info("SIGNAL: {}({})".format(sig.name, sig.value))
            handle.save_checkpoint('ckpt_interrupt.pth.tar')
            sys.exit(1)

        register_signal_handler(signal_handler)
        try:
            return func(*args, **kwargs)
        except Exception as e:
            handle.save_checkpoint('ckpt_exception.pth.tar')
            traceback.print_exc()

    return wrapper