Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Optional, Sequence | |
| from mmengine.hooks import Hook | |
| from mmpretrain.registry import HOOKS | |
| from mmpretrain.utils import get_ori_model | |
| class DenseCLHook(Hook): | |
| """Hook for DenseCL. | |
| This hook includes ``loss_lambda`` warmup in DenseCL. | |
| Borrowed from the authors' code: `<https://github.com/WXinlong/DenseCL>`_. | |
| Args: | |
| start_iters (int): The number of warmup iterations to set | |
| ``loss_lambda=0``. Defaults to 1000. | |
| """ | |
| def __init__(self, start_iters: int = 1000) -> None: | |
| self.start_iters = start_iters | |
| def before_train(self, runner) -> None: | |
| """Obtain ``loss_lambda`` from algorithm.""" | |
| assert hasattr(get_ori_model(runner.model), 'loss_lambda'), \ | |
| "The runner must have attribute \"loss_lambda\" in DenseCL." | |
| self.loss_lambda = get_ori_model(runner.model).loss_lambda | |
| def before_train_iter(self, | |
| runner, | |
| batch_idx: int, | |
| data_batch: Optional[Sequence[dict]] = None) -> None: | |
| """Adjust ``loss_lambda`` every train iter.""" | |
| assert hasattr(get_ori_model(runner.model), 'loss_lambda'), \ | |
| "The runner must have attribute \"loss_lambda\" in DenseCL." | |
| cur_iter = runner.iter | |
| if cur_iter >= self.start_iters: | |
| get_ori_model(runner.model).loss_lambda = self.loss_lambda | |
| else: | |
| get_ori_model(runner.model).loss_lambda = 0. | |