Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from mmengine.dist import all_reduce_params, is_distributed | |
from mmengine.registry import HOOKS | |
from .hook import Hook | |
class SyncBuffersHook(Hook): | |
"""Synchronize model buffers such as running_mean and running_var in BN at | |
the end of each epoch.""" | |
priority = 'NORMAL' | |
def __init__(self) -> None: | |
self.distributed = is_distributed() | |
# A flag to mark whether synchronization has been done in | |
# after_train_epoch | |
self.called_in_train = False | |
def before_val_epoch(self, runner) -> None: | |
"""All-reduce model buffers before each validation epoch. | |
Synchronize the buffers before each validation if they have not been | |
synchronized at the end of the previous training epoch. This method | |
will be called when using IterBasedTrainLoop. | |
Args: | |
runner (Runner): The runner of the training process. | |
""" | |
if self.distributed: | |
if not self.called_in_train: | |
all_reduce_params(runner.model.buffers(), op='mean') | |
self.called_in_train = False | |
def after_train_epoch(self, runner) -> None: | |
"""All-reduce model buffers at the end of each epoch. | |
Args: | |
runner (Runner): The runner of the training process. | |
""" | |
if self.distributed: | |
all_reduce_params(runner.model.buffers(), op='mean') | |
self.called_in_train = True | |