Spaces:
Paused
Paused
# Copyright (c) OpenMMLab. All rights reserved. | |
import os.path as osp | |
import platform | |
import shutil | |
import time | |
import warnings | |
import torch | |
from torch.optim import Optimizer | |
import annotator.uniformer.mmcv as mmcv | |
from .base_runner import BaseRunner | |
from .builder import RUNNERS | |
from .checkpoint import save_checkpoint | |
from .hooks import IterTimerHook | |
from .utils import get_host_info | |
class IterLoader: | |
def __init__(self, dataloader): | |
self._dataloader = dataloader | |
self.iter_loader = iter(self._dataloader) | |
self._epoch = 0 | |
def epoch(self): | |
return self._epoch | |
def __next__(self): | |
try: | |
data = next(self.iter_loader) | |
except StopIteration: | |
self._epoch += 1 | |
if hasattr(self._dataloader.sampler, 'set_epoch'): | |
self._dataloader.sampler.set_epoch(self._epoch) | |
time.sleep(2) # Prevent possible deadlock during epoch transition | |
self.iter_loader = iter(self._dataloader) | |
data = next(self.iter_loader) | |
return data | |
def __len__(self): | |
return len(self._dataloader) | |
class IterBasedRunner(BaseRunner): | |
"""Iteration-based Runner. | |
This runner train models iteration by iteration. | |
""" | |
def train(self, data_loader, **kwargs): | |
self.model.train() | |
self.mode = 'train' | |
self.data_loader = data_loader | |
self._epoch = data_loader.epoch | |
data_batch = next(data_loader) | |
self.call_hook('before_train_iter') | |
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) | |
if not isinstance(outputs, dict): | |
raise TypeError('model.train_step() must return a dict') | |
if 'log_vars' in outputs: | |
self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) | |
self.outputs = outputs | |
self.call_hook('after_train_iter') | |
self._inner_iter += 1 | |
self._iter += 1 | |
def val(self, data_loader, **kwargs): | |
self.model.eval() | |
self.mode = 'val' | |
self.data_loader = data_loader | |
data_batch = next(data_loader) | |
self.call_hook('before_val_iter') | |
outputs = self.model.val_step(data_batch, **kwargs) | |
if not isinstance(outputs, dict): | |
raise TypeError('model.val_step() must return a dict') | |
if 'log_vars' in outputs: | |
self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) | |
self.outputs = outputs | |
self.call_hook('after_val_iter') | |
self._inner_iter += 1 | |
def run(self, data_loaders, workflow, max_iters=None, **kwargs): | |
"""Start running. | |
Args: | |
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training | |
and validation. | |
workflow (list[tuple]): A list of (phase, iters) to specify the | |
running order and iterations. E.g, [('train', 10000), | |
('val', 1000)] means running 10000 iterations for training and | |
1000 iterations for validation, iteratively. | |
""" | |
assert isinstance(data_loaders, list) | |
assert mmcv.is_list_of(workflow, tuple) | |
assert len(data_loaders) == len(workflow) | |
if max_iters is not None: | |
warnings.warn( | |
'setting max_iters in run is deprecated, ' | |
'please set max_iters in runner_config', DeprecationWarning) | |
self._max_iters = max_iters | |
assert self._max_iters is not None, ( | |
'max_iters must be specified during instantiation') | |
work_dir = self.work_dir if self.work_dir is not None else 'NONE' | |
self.logger.info('Start running, host: %s, work_dir: %s', | |
get_host_info(), work_dir) | |
self.logger.info('Hooks will be executed in the following order:\n%s', | |
self.get_hook_info()) | |
self.logger.info('workflow: %s, max: %d iters', workflow, | |
self._max_iters) | |
self.call_hook('before_run') | |
iter_loaders = [IterLoader(x) for x in data_loaders] | |
self.call_hook('before_epoch') | |
while self.iter < self._max_iters: | |
for i, flow in enumerate(workflow): | |
self._inner_iter = 0 | |
mode, iters = flow | |
if not isinstance(mode, str) or not hasattr(self, mode): | |
raise ValueError( | |
'runner has no method named "{}" to run a workflow'. | |
format(mode)) | |
iter_runner = getattr(self, mode) | |
for _ in range(iters): | |
if mode == 'train' and self.iter >= self._max_iters: | |
break | |
iter_runner(iter_loaders[i], **kwargs) | |
time.sleep(1) # wait for some hooks like loggers to finish | |
self.call_hook('after_epoch') | |
self.call_hook('after_run') | |
def resume(self, | |
checkpoint, | |
resume_optimizer=True, | |
map_location='default'): | |
"""Resume model from checkpoint. | |
Args: | |
checkpoint (str): Checkpoint to resume from. | |
resume_optimizer (bool, optional): Whether resume the optimizer(s) | |
if the checkpoint file includes optimizer(s). Default to True. | |
map_location (str, optional): Same as :func:`torch.load`. | |
Default to 'default'. | |
""" | |
if map_location == 'default': | |
device_id = torch.cuda.current_device() | |
checkpoint = self.load_checkpoint( | |
checkpoint, | |
map_location=lambda storage, loc: storage.cuda(device_id)) | |
else: | |
checkpoint = self.load_checkpoint( | |
checkpoint, map_location=map_location) | |
self._epoch = checkpoint['meta']['epoch'] | |
self._iter = checkpoint['meta']['iter'] | |
self._inner_iter = checkpoint['meta']['iter'] | |
if 'optimizer' in checkpoint and resume_optimizer: | |
if isinstance(self.optimizer, Optimizer): | |
self.optimizer.load_state_dict(checkpoint['optimizer']) | |
elif isinstance(self.optimizer, dict): | |
for k in self.optimizer.keys(): | |
self.optimizer[k].load_state_dict( | |
checkpoint['optimizer'][k]) | |
else: | |
raise TypeError( | |
'Optimizer should be dict or torch.optim.Optimizer ' | |
f'but got {type(self.optimizer)}') | |
self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}') | |
def save_checkpoint(self, | |
out_dir, | |
filename_tmpl='iter_{}.pth', | |
meta=None, | |
save_optimizer=True, | |
create_symlink=True): | |
"""Save checkpoint to file. | |
Args: | |
out_dir (str): Directory to save checkpoint files. | |
filename_tmpl (str, optional): Checkpoint file template. | |
Defaults to 'iter_{}.pth'. | |
meta (dict, optional): Metadata to be saved in checkpoint. | |
Defaults to None. | |
save_optimizer (bool, optional): Whether save optimizer. | |
Defaults to True. | |
create_symlink (bool, optional): Whether create symlink to the | |
latest checkpoint file. Defaults to True. | |
""" | |
if meta is None: | |
meta = {} | |
elif not isinstance(meta, dict): | |
raise TypeError( | |
f'meta should be a dict or None, but got {type(meta)}') | |
if self.meta is not None: | |
meta.update(self.meta) | |
# Note: meta.update(self.meta) should be done before | |
# meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise | |
# there will be problems with resumed checkpoints. | |
# More details in https://github.com/open-mmlab/mmcv/pull/1108 | |
meta.update(epoch=self.epoch + 1, iter=self.iter) | |
filename = filename_tmpl.format(self.iter + 1) | |
filepath = osp.join(out_dir, filename) | |
optimizer = self.optimizer if save_optimizer else None | |
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) | |
# in some environments, `os.symlink` is not supported, you may need to | |
# set `create_symlink` to False | |
if create_symlink: | |
dst_file = osp.join(out_dir, 'latest.pth') | |
if platform.system() != 'Windows': | |
mmcv.symlink(filename, dst_file) | |
else: | |
shutil.copy(filepath, dst_file) | |
def register_training_hooks(self, | |
lr_config, | |
optimizer_config=None, | |
checkpoint_config=None, | |
log_config=None, | |
momentum_config=None, | |
custom_hooks_config=None): | |
"""Register default hooks for iter-based training. | |
Checkpoint hook, optimizer stepper hook and logger hooks will be set to | |
`by_epoch=False` by default. | |
Default hooks include: | |
+----------------------+-------------------------+ | |
| Hooks | Priority | | |
+======================+=========================+ | |
| LrUpdaterHook | VERY_HIGH (10) | | |
+----------------------+-------------------------+ | |
| MomentumUpdaterHook | HIGH (30) | | |
+----------------------+-------------------------+ | |
| OptimizerStepperHook | ABOVE_NORMAL (40) | | |
+----------------------+-------------------------+ | |
| CheckpointSaverHook | NORMAL (50) | | |
+----------------------+-------------------------+ | |
| IterTimerHook | LOW (70) | | |
+----------------------+-------------------------+ | |
| LoggerHook(s) | VERY_LOW (90) | | |
+----------------------+-------------------------+ | |
| CustomHook(s) | defaults to NORMAL (50) | | |
+----------------------+-------------------------+ | |
If custom hooks have same priority with default hooks, custom hooks | |
will be triggered after default hooks. | |
""" | |
if checkpoint_config is not None: | |
checkpoint_config.setdefault('by_epoch', False) | |
if lr_config is not None: | |
lr_config.setdefault('by_epoch', False) | |
if log_config is not None: | |
for info in log_config['hooks']: | |
info.setdefault('by_epoch', False) | |
super(IterBasedRunner, self).register_training_hooks( | |
lr_config=lr_config, | |
momentum_config=momentum_config, | |
optimizer_config=optimizer_config, | |
checkpoint_config=checkpoint_config, | |
log_config=log_config, | |
timer_config=IterTimerHook(), | |
custom_hooks_config=custom_hooks_config) | |