Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os.path as osp | |
| import warnings | |
| from annotator.uniformer.mmcv.fileio import FileClient | |
| from ..dist_utils import allreduce_params, master_only | |
| from .hook import HOOKS, Hook | |
| class CheckpointHook(Hook): | |
| """Save checkpoints periodically. | |
| Args: | |
| interval (int): The saving period. If ``by_epoch=True``, interval | |
| indicates epochs, otherwise it indicates iterations. | |
| Default: -1, which means "never". | |
| by_epoch (bool): Saving checkpoints by epoch or by iteration. | |
| Default: True. | |
| save_optimizer (bool): Whether to save optimizer state_dict in the | |
| checkpoint. It is usually used for resuming experiments. | |
| Default: True. | |
| out_dir (str, optional): The root directory to save checkpoints. If not | |
| specified, ``runner.work_dir`` will be used by default. If | |
| specified, the ``out_dir`` will be the concatenation of ``out_dir`` | |
| and the last level directory of ``runner.work_dir``. | |
| `Changed in version 1.3.16.` | |
| max_keep_ckpts (int, optional): The maximum checkpoints to keep. | |
| In some cases we want only the latest few checkpoints and would | |
| like to delete old ones to save the disk space. | |
| Default: -1, which means unlimited. | |
| save_last (bool, optional): Whether to force the last checkpoint to be | |
| saved regardless of interval. Default: True. | |
| sync_buffer (bool, optional): Whether to synchronize buffers in | |
| different gpus. Default: False. | |
| file_client_args (dict, optional): Arguments to instantiate a | |
| FileClient. See :class:`mmcv.fileio.FileClient` for details. | |
| Default: None. | |
| `New in version 1.3.16.` | |
| .. warning:: | |
| Before v1.3.16, the ``out_dir`` argument indicates the path where the | |
| checkpoint is stored. However, since v1.3.16, ``out_dir`` indicates the | |
| root directory and the final path to save checkpoint is the | |
| concatenation of ``out_dir`` and the last level directory of | |
| ``runner.work_dir``. Suppose the value of ``out_dir`` is "/path/of/A" | |
| and the value of ``runner.work_dir`` is "/path/of/B", then the final | |
| path will be "/path/of/A/B". | |
| """ | |
| def __init__(self, | |
| interval=-1, | |
| by_epoch=True, | |
| save_optimizer=True, | |
| out_dir=None, | |
| max_keep_ckpts=-1, | |
| save_last=True, | |
| sync_buffer=False, | |
| file_client_args=None, | |
| **kwargs): | |
| self.interval = interval | |
| self.by_epoch = by_epoch | |
| self.save_optimizer = save_optimizer | |
| self.out_dir = out_dir | |
| self.max_keep_ckpts = max_keep_ckpts | |
| self.save_last = save_last | |
| self.args = kwargs | |
| self.sync_buffer = sync_buffer | |
| self.file_client_args = file_client_args | |
| def before_run(self, runner): | |
| if not self.out_dir: | |
| self.out_dir = runner.work_dir | |
| self.file_client = FileClient.infer_client(self.file_client_args, | |
| self.out_dir) | |
| # if `self.out_dir` is not equal to `runner.work_dir`, it means that | |
| # `self.out_dir` is set so the final `self.out_dir` is the | |
| # concatenation of `self.out_dir` and the last level directory of | |
| # `runner.work_dir` | |
| if self.out_dir != runner.work_dir: | |
| basename = osp.basename(runner.work_dir.rstrip(osp.sep)) | |
| self.out_dir = self.file_client.join_path(self.out_dir, basename) | |
| runner.logger.info((f'Checkpoints will be saved to {self.out_dir} by ' | |
| f'{self.file_client.name}.')) | |
| # disable the create_symlink option because some file backends do not | |
| # allow to create a symlink | |
| if 'create_symlink' in self.args: | |
| if self.args[ | |
| 'create_symlink'] and not self.file_client.allow_symlink: | |
| self.args['create_symlink'] = False | |
| warnings.warn( | |
| ('create_symlink is set as True by the user but is changed' | |
| 'to be False because creating symbolic link is not ' | |
| f'allowed in {self.file_client.name}')) | |
| else: | |
| self.args['create_symlink'] = self.file_client.allow_symlink | |
| def after_train_epoch(self, runner): | |
| if not self.by_epoch: | |
| return | |
| # save checkpoint for following cases: | |
| # 1. every ``self.interval`` epochs | |
| # 2. reach the last epoch of training | |
| if self.every_n_epochs( | |
| runner, self.interval) or (self.save_last | |
| and self.is_last_epoch(runner)): | |
| runner.logger.info( | |
| f'Saving checkpoint at {runner.epoch + 1} epochs') | |
| if self.sync_buffer: | |
| allreduce_params(runner.model.buffers()) | |
| self._save_checkpoint(runner) | |
| def _save_checkpoint(self, runner): | |
| """Save the current checkpoint and delete unwanted checkpoint.""" | |
| runner.save_checkpoint( | |
| self.out_dir, save_optimizer=self.save_optimizer, **self.args) | |
| if runner.meta is not None: | |
| if self.by_epoch: | |
| cur_ckpt_filename = self.args.get( | |
| 'filename_tmpl', 'epoch_{}.pth').format(runner.epoch + 1) | |
| else: | |
| cur_ckpt_filename = self.args.get( | |
| 'filename_tmpl', 'iter_{}.pth').format(runner.iter + 1) | |
| runner.meta.setdefault('hook_msgs', dict()) | |
| runner.meta['hook_msgs']['last_ckpt'] = self.file_client.join_path( | |
| self.out_dir, cur_ckpt_filename) | |
| # remove other checkpoints | |
| if self.max_keep_ckpts > 0: | |
| if self.by_epoch: | |
| name = 'epoch_{}.pth' | |
| current_ckpt = runner.epoch + 1 | |
| else: | |
| name = 'iter_{}.pth' | |
| current_ckpt = runner.iter + 1 | |
| redundant_ckpts = range( | |
| current_ckpt - self.max_keep_ckpts * self.interval, 0, | |
| -self.interval) | |
| filename_tmpl = self.args.get('filename_tmpl', name) | |
| for _step in redundant_ckpts: | |
| ckpt_path = self.file_client.join_path( | |
| self.out_dir, filename_tmpl.format(_step)) | |
| if self.file_client.isfile(ckpt_path): | |
| self.file_client.remove(ckpt_path) | |
| else: | |
| break | |
| def after_train_iter(self, runner): | |
| if self.by_epoch: | |
| return | |
| # save checkpoint for following cases: | |
| # 1. every ``self.interval`` iterations | |
| # 2. reach the last iteration of training | |
| if self.every_n_iters( | |
| runner, self.interval) or (self.save_last | |
| and self.is_last_iter(runner)): | |
| runner.logger.info( | |
| f'Saving checkpoint at {runner.iter + 1} iterations') | |
| if self.sync_buffer: | |
| allreduce_params(runner.model.buffers()) | |
| self._save_checkpoint(runner) | |