Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from collections import OrderedDict | |
| from copy import deepcopy | |
| from mmcv.transforms import Compose | |
| from mmengine.hooks import Hook | |
| from mmengine.model import is_model_wrapper | |
| from mmcls.models.utils import RandomBatchAugment | |
| from mmcls.registry import HOOKS, MODEL_WRAPPERS, MODELS | |
| class SwitchRecipeHook(Hook): | |
| """switch recipe during the training loop, including train pipeline, batch | |
| augments and loss currently. | |
| Args: | |
| schedule (list): Every item of the schedule list should be a dict, and | |
| the dict should have ``action_epoch`` and some of | |
| ``train_pipeline``, ``train_augments`` and ``loss`` keys: | |
| - ``action_epoch`` (int): switch training recipe at which epoch. | |
| - ``train_pipeline`` (list, optional): The new data pipeline of the | |
| train dataset. If not specified, keep the original settings. | |
| - ``batch_augments`` (dict | None, optional): The new batch | |
| augmentations of during training. See :mod:`Batch Augmentations | |
| <mmcls.models.utils.batch_augments>` for more details. If None, | |
| disable batch augmentations. If not specified, keep the original | |
| settings. | |
| - ``loss`` (dict, optional): The new loss module config. If not | |
| specified, keep the original settings. | |
| Example: | |
| To use this hook in config files. | |
| .. code:: python | |
| custom_hooks = [ | |
| dict( | |
| type='SwitchRecipeHook', | |
| schedule=[ | |
| dict( | |
| action_epoch=30, | |
| train_pipeline=pipeline_after_30e, | |
| batch_augments=batch_augments_after_30e, | |
| loss=loss_after_30e, | |
| ), | |
| dict( | |
| action_epoch=60, | |
| # Disable batch augmentations after 60e | |
| # and keep other settings. | |
| batch_augments=None, | |
| ), | |
| ] | |
| ) | |
| ] | |
| """ | |
| priority = 'NORMAL' | |
| def __init__(self, schedule): | |
| recipes = {} | |
| for recipe in schedule: | |
| assert 'action_epoch' in recipe, \ | |
| 'Please set `action_epoch` in every item ' \ | |
| 'of the `schedule` in the SwitchRecipeHook.' | |
| recipe = deepcopy(recipe) | |
| if 'train_pipeline' in recipe: | |
| recipe['train_pipeline'] = Compose(recipe['train_pipeline']) | |
| if 'batch_augments' in recipe: | |
| batch_augments = recipe['batch_augments'] | |
| if isinstance(batch_augments, dict): | |
| batch_augments = RandomBatchAugment(**batch_augments) | |
| recipe['batch_augments'] = batch_augments | |
| if 'loss' in recipe: | |
| loss = recipe['loss'] | |
| if isinstance(loss, dict): | |
| loss = MODELS.build(loss) | |
| recipe['loss'] = loss | |
| action_epoch = recipe.pop('action_epoch') | |
| assert action_epoch not in recipes, \ | |
| f'The `action_epoch` {action_epoch} is repeated ' \ | |
| 'in the SwitchRecipeHook.' | |
| recipes[action_epoch] = recipe | |
| self.schedule = OrderedDict(sorted(recipes.items())) | |
| def before_train(self, runner) -> None: | |
| """before run setting. If resume form a checkpoint, do all switch | |
| before the current epoch. | |
| Args: | |
| runner (Runner): The runner of the training, validation or testing | |
| process. | |
| """ | |
| if runner._resume: | |
| for action_epoch, recipe in self.schedule.items(): | |
| if action_epoch >= runner.epoch + 1: | |
| break | |
| self._do_switch(runner, recipe, | |
| f' (resume recipe of epoch {action_epoch})') | |
| def before_train_epoch(self, runner): | |
| """do before train epoch.""" | |
| recipe = self.schedule.get(runner.epoch + 1, None) | |
| if recipe is not None: | |
| self._do_switch(runner, recipe, f' at epoch {runner.epoch + 1}') | |
| def _do_switch(self, runner, recipe, extra_info=''): | |
| """do the switch aug process.""" | |
| if 'batch_augments' in recipe: | |
| self._switch_batch_augments(runner, recipe['batch_augments']) | |
| runner.logger.info(f'Switch batch augments{extra_info}.') | |
| if 'train_pipeline' in recipe: | |
| self._switch_train_pipeline(runner, recipe['train_pipeline']) | |
| runner.logger.info(f'Switch train pipeline{extra_info}.') | |
| if 'loss' in recipe: | |
| self._switch_loss(runner, recipe['loss']) | |
| runner.logger.info(f'Switch loss{extra_info}.') | |
| def _switch_batch_augments(runner, batch_augments): | |
| """switch the train augments.""" | |
| model = runner.model | |
| if is_model_wrapper(model): | |
| model = model.module | |
| model.data_preprocessor.batch_augments = batch_augments | |
| def _switch_train_pipeline(runner, train_pipeline): | |
| """switch the train loader dataset pipeline.""" | |
| def switch_pipeline(dataset, pipeline): | |
| if hasattr(dataset, 'pipeline'): | |
| # for usual dataset | |
| dataset.pipeline = pipeline | |
| elif hasattr(dataset, 'datasets'): | |
| # for concat dataset wrapper | |
| for ds in dataset.datasets: | |
| switch_pipeline(ds, pipeline) | |
| elif hasattr(dataset, 'dataset'): | |
| # for other dataset wrappers | |
| switch_pipeline(dataset.dataset, pipeline) | |
| else: | |
| raise RuntimeError( | |
| 'Cannot access the `pipeline` of the dataset.') | |
| train_loader = runner.train_loop.dataloader | |
| switch_pipeline(train_loader.dataset, train_pipeline) | |
| # To restart the iterator of dataloader when `persistent_workers=True` | |
| train_loader._iterator = None | |
| def _switch_loss(runner, loss_module): | |
| """switch the loss module.""" | |
| model = runner.model | |
| if is_model_wrapper(model, MODEL_WRAPPERS): | |
| model = model.module | |
| if hasattr(model, 'loss_module'): | |
| model.loss_module = loss_module | |
| elif hasattr(model, 'head') and hasattr(model.head, 'loss_module'): | |
| model.head.loss_module = loss_module | |
| else: | |
| raise RuntimeError('Cannot access the `loss_module` of the model.') | |