| | |
| | from annotator.uniformer.mmcv.utils import Registry, is_method_overridden |
| |
|
| | HOOKS = Registry('hook') |
| |
|
| |
|
| | class Hook: |
| | stages = ('before_run', 'before_train_epoch', 'before_train_iter', |
| | 'after_train_iter', 'after_train_epoch', 'before_val_epoch', |
| | 'before_val_iter', 'after_val_iter', 'after_val_epoch', |
| | 'after_run') |
| |
|
| | def before_run(self, runner): |
| | pass |
| |
|
| | def after_run(self, runner): |
| | pass |
| |
|
| | def before_epoch(self, runner): |
| | pass |
| |
|
| | def after_epoch(self, runner): |
| | pass |
| |
|
| | def before_iter(self, runner): |
| | pass |
| |
|
| | def after_iter(self, runner): |
| | pass |
| |
|
| | def before_train_epoch(self, runner): |
| | self.before_epoch(runner) |
| |
|
| | def before_val_epoch(self, runner): |
| | self.before_epoch(runner) |
| |
|
| | def after_train_epoch(self, runner): |
| | self.after_epoch(runner) |
| |
|
| | def after_val_epoch(self, runner): |
| | self.after_epoch(runner) |
| |
|
| | def before_train_iter(self, runner): |
| | self.before_iter(runner) |
| |
|
| | def before_val_iter(self, runner): |
| | self.before_iter(runner) |
| |
|
| | def after_train_iter(self, runner): |
| | self.after_iter(runner) |
| |
|
| | def after_val_iter(self, runner): |
| | self.after_iter(runner) |
| |
|
| | def every_n_epochs(self, runner, n): |
| | return (runner.epoch + 1) % n == 0 if n > 0 else False |
| |
|
| | def every_n_inner_iters(self, runner, n): |
| | return (runner.inner_iter + 1) % n == 0 if n > 0 else False |
| |
|
| | def every_n_iters(self, runner, n): |
| | return (runner.iter + 1) % n == 0 if n > 0 else False |
| |
|
| | def end_of_epoch(self, runner): |
| | return runner.inner_iter + 1 == len(runner.data_loader) |
| |
|
| | def is_last_epoch(self, runner): |
| | return runner.epoch + 1 == runner._max_epochs |
| |
|
| | def is_last_iter(self, runner): |
| | return runner.iter + 1 == runner._max_iters |
| |
|
| | def get_triggered_stages(self): |
| | trigger_stages = set() |
| | for stage in Hook.stages: |
| | if is_method_overridden(stage, Hook, self): |
| | trigger_stages.add(stage) |
| |
|
| | |
| | |
| | method_stages_map = { |
| | 'before_epoch': ['before_train_epoch', 'before_val_epoch'], |
| | 'after_epoch': ['after_train_epoch', 'after_val_epoch'], |
| | 'before_iter': ['before_train_iter', 'before_val_iter'], |
| | 'after_iter': ['after_train_iter', 'after_val_iter'], |
| | } |
| |
|
| | for method, map_stages in method_stages_map.items(): |
| | if is_method_overridden(method, Hook, self): |
| | trigger_stages.update(map_stages) |
| |
|
| | return [stage for stage in Hook.stages if stage in trigger_stages] |
| |
|