Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
# Adapted from https://github.com/facebookresearch/pycls/blob/f8cd962737e33ce9e19b3083a33551da95c2d9c0/pycls/core/net.py # noqa: E501 | |
# Original licence: Copyright (c) 2019 Facebook, Inc under the Apache License 2.0 # noqa: E501 | |
import itertools | |
import logging | |
from typing import List, Optional, Sequence, Union | |
import mmengine | |
import torch | |
import torch.nn as nn | |
from mmengine.hooks import Hook | |
from mmengine.logging import print_log | |
from mmengine.model import is_model_wrapper | |
from mmengine.runner import EpochBasedTrainLoop, IterBasedTrainLoop, Runner | |
from mmengine.utils import ProgressBar | |
from torch.functional import Tensor | |
from torch.nn import GroupNorm | |
from torch.nn.modules.batchnorm import _BatchNorm | |
from torch.nn.modules.instancenorm import _InstanceNorm | |
from torch.utils.data import DataLoader | |
from mmpretrain.registry import HOOKS | |
DATA_BATCH = Optional[Sequence[dict]] | |
def scaled_all_reduce(tensors: List[Tensor], num_gpus: int) -> List[Tensor]: | |
"""Performs the scaled all_reduce operation on the provided tensors. | |
The input tensors are modified in-place. Currently supports only the sum | |
reduction operator. The reduced values are scaled by the inverse size of | |
the process group. | |
Args: | |
tensors (List[torch.Tensor]): The tensors to process. | |
num_gpus (int): The number of gpus to use | |
Returns: | |
List[torch.Tensor]: The processed tensors. | |
""" | |
# There is no need for reduction in the single-proc case | |
if num_gpus == 1: | |
return tensors | |
# Queue the reductions | |
reductions = [] | |
for tensor in tensors: | |
reduction = torch.distributed.all_reduce(tensor, async_op=True) | |
reductions.append(reduction) | |
# Wait for reductions to finish | |
for reduction in reductions: | |
reduction.wait() | |
# Scale the results | |
for tensor in tensors: | |
tensor.mul_(1.0 / num_gpus) | |
return tensors | |
def update_bn_stats( | |
model: nn.Module, | |
loader: DataLoader, | |
num_samples: int = 8192, | |
logger: Optional[Union[logging.Logger, str]] = None) -> None: | |
"""Computes precise BN stats on training data. | |
Args: | |
model (nn.module): The model whose bn stats will be recomputed. | |
loader (DataLoader): PyTorch dataloader._dataloader | |
num_samples (int): The number of samples to update the bn stats. | |
Defaults to 8192. | |
logger (logging.Logger or str, optional): If the type of logger is | |
``logging.Logger``, we directly use logger to log messages. | |
Some special loggers are: | |
- "silent": No message will be printed. | |
- "current": Use latest created logger to log message. | |
- other str: Instance name of logger. The corresponding logger | |
will log message if it has been created, otherwise will raise a | |
`ValueError`. | |
- None: The `print()` method will be used to print log messages. | |
""" | |
if is_model_wrapper(model): | |
model = model.module | |
# get dist info | |
rank, world_size = mmengine.dist.get_dist_info() | |
# Compute the number of mini-batches to use, if the size of dataloader is | |
# less than num_iters, use all the samples in dataloader. | |
num_iter = num_samples // (loader.batch_size * world_size) | |
num_iter = min(num_iter, len(loader)) | |
# Retrieve the BN layers | |
bn_layers = [ | |
m for m in model.modules() | |
if m.training and isinstance(m, (_BatchNorm)) | |
] | |
if len(bn_layers) == 0: | |
print_log('No BN found in model', logger=logger, level=logging.WARNING) | |
return | |
print_log( | |
f'{len(bn_layers)} BN found, run {num_iter} iters...', logger=logger) | |
# Finds all the other norm layers with training=True. | |
other_norm_layers = [ | |
m for m in model.modules() | |
if m.training and isinstance(m, (_InstanceNorm, GroupNorm)) | |
] | |
if len(other_norm_layers) > 0: | |
print_log( | |
'IN/GN stats will not be updated in PreciseHook.', | |
logger=logger, | |
level=logging.INFO) | |
# Initialize BN stats storage for computing | |
# mean(mean(batch)) and mean(var(batch)) | |
running_means = [torch.zeros_like(bn.running_mean) for bn in bn_layers] | |
running_vars = [torch.zeros_like(bn.running_var) for bn in bn_layers] | |
# Remember momentum values | |
momentums = [bn.momentum for bn in bn_layers] | |
# Set momentum to 1.0 to compute BN stats that reflect the current batch | |
for bn in bn_layers: | |
bn.momentum = 1.0 | |
# Average the BN stats for each BN layer over the batches | |
if rank == 0: | |
prog_bar = ProgressBar(num_iter) | |
for data in itertools.islice(loader, num_iter): | |
data = model.data_preprocessor(data, False) | |
model(**data) | |
for i, bn in enumerate(bn_layers): | |
running_means[i] += bn.running_mean / num_iter | |
running_vars[i] += bn.running_var / num_iter | |
if rank == 0: | |
prog_bar.update() | |
# Sync BN stats across GPUs (no reduction if 1 GPU used) | |
running_means = scaled_all_reduce(running_means, world_size) | |
running_vars = scaled_all_reduce(running_vars, world_size) | |
# Set BN stats and restore original momentum values | |
for i, bn in enumerate(bn_layers): | |
bn.running_mean = running_means[i] | |
bn.running_var = running_vars[i] | |
bn.momentum = momentums[i] | |
class PreciseBNHook(Hook): | |
"""Precise BN hook. | |
Recompute and update the batch norm stats to make them more precise. During | |
training both BN stats and the weight are changing after every iteration, | |
so the running average can not precisely reflect the actual stats of the | |
current model. | |
With this hook, the BN stats are recomputed with fixed weights, to make the | |
running average more precise. Specifically, it computes the true average of | |
per-batch mean/variance instead of the running average. See Sec. 3 of the | |
paper `Rethinking Batch in BatchNorm <https://arxiv.org/abs/2105.07576>` | |
for details. | |
This hook will update BN stats, so it should be executed before | |
``CheckpointHook`` and ``EMAHook``, generally set its priority to | |
"ABOVE_NORMAL". | |
Args: | |
num_samples (int): The number of samples to update the bn stats. | |
Defaults to 8192. | |
interval (int): Perform precise bn interval. If the train loop is | |
`EpochBasedTrainLoop` or `by_epoch=True`, its unit is 'epoch'; if the | |
train loop is `IterBasedTrainLoop` or `by_epoch=False`, its unit is | |
'iter'. Defaults to 1. | |
""" | |
def __init__(self, num_samples: int = 8192, interval: int = 1) -> None: | |
assert interval > 0 and num_samples > 0, "'interval' and " \ | |
"'num_samples' must be bigger than 0." | |
self.interval = interval | |
self.num_samples = num_samples | |
def _perform_precise_bn(self, runner: Runner) -> None: | |
"""perform precise bn.""" | |
print_log( | |
f'Running Precise BN for {self.num_samples} samples...', | |
logger=runner.logger) | |
update_bn_stats( | |
runner.model, | |
runner.train_loop.dataloader, | |
self.num_samples, | |
logger=runner.logger) | |
print_log('Finish Precise BN, BN stats updated.', logger=runner.logger) | |
def after_train_epoch(self, runner: Runner) -> None: | |
"""Calculate prcise BN and broadcast BN stats across GPUs. | |
Args: | |
runner (obj:`Runner`): The runner of the training process. | |
""" | |
# if use `EpochBasedTrainLoop``, do perform precise every | |
# `self.interval` epochs. | |
if isinstance(runner.train_loop, | |
EpochBasedTrainLoop) and self.every_n_epochs( | |
runner, self.interval): | |
self._perform_precise_bn(runner) | |
def after_train_iter(self, | |
runner, | |
batch_idx: int, | |
data_batch: DATA_BATCH = None, | |
outputs: Optional[dict] = None) -> None: | |
"""Calculate prcise BN and broadcast BN stats across GPUs. | |
Args: | |
runner (obj:`Runner`): The runner of the training process. | |
batch_idx (int): The index of the current batch in the train loop. | |
data_batch (Sequence[dict], optional): Data from dataloader. | |
Defaults to None. | |
""" | |
# if use `IterBasedTrainLoop``, do perform precise every | |
# `self.interval` iters. | |
if isinstance(runner.train_loop, | |
IterBasedTrainLoop) and self.every_n_train_iters( | |
runner, self.interval): | |
self._perform_precise_bn(runner) | |