|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from mmengine.dist import all_reduce_params, is_distributed |
|
|
from mmengine.registry import HOOKS |
|
|
from .hook import Hook |
|
|
|
|
|
|
|
|
@HOOKS.register_module() |
|
|
class SyncBuffersHook(Hook): |
|
|
"""Synchronize model buffers such as running_mean and running_var in BN at |
|
|
the end of each epoch.""" |
|
|
|
|
|
priority = 'NORMAL' |
|
|
|
|
|
def __init__(self) -> None: |
|
|
self.distributed = is_distributed() |
|
|
|
|
|
|
|
|
self.called_in_train = False |
|
|
|
|
|
def before_val_epoch(self, runner) -> None: |
|
|
"""All-reduce model buffers before each validation epoch. |
|
|
|
|
|
Synchronize the buffers before each validation if they have not been |
|
|
synchronized at the end of the previous training epoch. This method |
|
|
will be called when using IterBasedTrainLoop. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
""" |
|
|
if self.distributed: |
|
|
if not self.called_in_train: |
|
|
all_reduce_params(runner.model.buffers(), op='mean') |
|
|
self.called_in_train = False |
|
|
|
|
|
def after_train_epoch(self, runner) -> None: |
|
|
"""All-reduce model buffers at the end of each epoch. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
""" |
|
|
if self.distributed: |
|
|
all_reduce_params(runner.model.buffers(), op='mean') |
|
|
self.called_in_train = True |
|
|
|