Spaces:
Paused
Paused
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
from torch.autograd import Function | |
from torch.autograd.function import once_differentiable | |
from torch.nn.modules.module import Module | |
from torch.nn.parameter import Parameter | |
from annotator.uniformer.mmcv.cnn import NORM_LAYERS | |
from ..utils import ext_loader | |
ext_module = ext_loader.load_ext('_ext', [ | |
'sync_bn_forward_mean', 'sync_bn_forward_var', 'sync_bn_forward_output', | |
'sync_bn_backward_param', 'sync_bn_backward_data' | |
]) | |
class SyncBatchNormFunction(Function): | |
def symbolic(g, input, running_mean, running_var, weight, bias, momentum, | |
eps, group, group_size, stats_mode): | |
return g.op( | |
'mmcv::MMCVSyncBatchNorm', | |
input, | |
running_mean, | |
running_var, | |
weight, | |
bias, | |
momentum_f=momentum, | |
eps_f=eps, | |
group_i=group, | |
group_size_i=group_size, | |
stats_mode=stats_mode) | |
def forward(self, input, running_mean, running_var, weight, bias, momentum, | |
eps, group, group_size, stats_mode): | |
self.momentum = momentum | |
self.eps = eps | |
self.group = group | |
self.group_size = group_size | |
self.stats_mode = stats_mode | |
assert isinstance( | |
input, (torch.HalfTensor, torch.FloatTensor, | |
torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \ | |
f'only support Half or Float Tensor, but {input.type()}' | |
output = torch.zeros_like(input) | |
input3d = input.flatten(start_dim=2) | |
output3d = output.view_as(input3d) | |
num_channels = input3d.size(1) | |
# ensure mean/var/norm/std are initialized as zeros | |
# ``torch.empty()`` does not guarantee that | |
mean = torch.zeros( | |
num_channels, dtype=torch.float, device=input3d.device) | |
var = torch.zeros( | |
num_channels, dtype=torch.float, device=input3d.device) | |
norm = torch.zeros_like( | |
input3d, dtype=torch.float, device=input3d.device) | |
std = torch.zeros( | |
num_channels, dtype=torch.float, device=input3d.device) | |
batch_size = input3d.size(0) | |
if batch_size > 0: | |
ext_module.sync_bn_forward_mean(input3d, mean) | |
batch_flag = torch.ones([1], device=mean.device, dtype=mean.dtype) | |
else: | |
# skip updating mean and leave it as zeros when the input is empty | |
batch_flag = torch.zeros([1], device=mean.device, dtype=mean.dtype) | |
# synchronize mean and the batch flag | |
vec = torch.cat([mean, batch_flag]) | |
if self.stats_mode == 'N': | |
vec *= batch_size | |
if self.group_size > 1: | |
dist.all_reduce(vec, group=self.group) | |
total_batch = vec[-1].detach() | |
mean = vec[:num_channels] | |
if self.stats_mode == 'default': | |
mean = mean / self.group_size | |
elif self.stats_mode == 'N': | |
mean = mean / total_batch.clamp(min=1) | |
else: | |
raise NotImplementedError | |
# leave var as zeros when the input is empty | |
if batch_size > 0: | |
ext_module.sync_bn_forward_var(input3d, mean, var) | |
if self.stats_mode == 'N': | |
var *= batch_size | |
if self.group_size > 1: | |
dist.all_reduce(var, group=self.group) | |
if self.stats_mode == 'default': | |
var /= self.group_size | |
elif self.stats_mode == 'N': | |
var /= total_batch.clamp(min=1) | |
else: | |
raise NotImplementedError | |
# if the total batch size over all the ranks is zero, | |
# we should not update the statistics in the current batch | |
update_flag = total_batch.clamp(max=1) | |
momentum = update_flag * self.momentum | |
ext_module.sync_bn_forward_output( | |
input3d, | |
mean, | |
var, | |
weight, | |
bias, | |
running_mean, | |
running_var, | |
norm, | |
std, | |
output3d, | |
eps=self.eps, | |
momentum=momentum, | |
group_size=self.group_size) | |
self.save_for_backward(norm, std, weight) | |
return output | |
def backward(self, grad_output): | |
norm, std, weight = self.saved_tensors | |
grad_weight = torch.zeros_like(weight) | |
grad_bias = torch.zeros_like(weight) | |
grad_input = torch.zeros_like(grad_output) | |
grad_output3d = grad_output.flatten(start_dim=2) | |
grad_input3d = grad_input.view_as(grad_output3d) | |
batch_size = grad_input3d.size(0) | |
if batch_size > 0: | |
ext_module.sync_bn_backward_param(grad_output3d, norm, grad_weight, | |
grad_bias) | |
# all reduce | |
if self.group_size > 1: | |
dist.all_reduce(grad_weight, group=self.group) | |
dist.all_reduce(grad_bias, group=self.group) | |
grad_weight /= self.group_size | |
grad_bias /= self.group_size | |
if batch_size > 0: | |
ext_module.sync_bn_backward_data(grad_output3d, weight, | |
grad_weight, grad_bias, norm, std, | |
grad_input3d) | |
return grad_input, None, None, grad_weight, grad_bias, \ | |
None, None, None, None, None | |
class SyncBatchNorm(Module): | |
"""Synchronized Batch Normalization. | |
Args: | |
num_features (int): number of features/chennels in input tensor | |
eps (float, optional): a value added to the denominator for numerical | |
stability. Defaults to 1e-5. | |
momentum (float, optional): the value used for the running_mean and | |
running_var computation. Defaults to 0.1. | |
affine (bool, optional): whether to use learnable affine parameters. | |
Defaults to True. | |
track_running_stats (bool, optional): whether to track the running | |
mean and variance during training. When set to False, this | |
module does not track such statistics, and initializes statistics | |
buffers ``running_mean`` and ``running_var`` as ``None``. When | |
these buffers are ``None``, this module always uses batch | |
statistics in both training and eval modes. Defaults to True. | |
group (int, optional): synchronization of stats happen within | |
each process group individually. By default it is synchronization | |
across the whole world. Defaults to None. | |
stats_mode (str, optional): The statistical mode. Available options | |
includes ``'default'`` and ``'N'``. Defaults to 'default'. | |
When ``stats_mode=='default'``, it computes the overall statistics | |
using those from each worker with equal weight, i.e., the | |
statistics are synchronized and simply divied by ``group``. This | |
mode will produce inaccurate statistics when empty tensors occur. | |
When ``stats_mode=='N'``, it compute the overall statistics using | |
the total number of batches in each worker ignoring the number of | |
group, i.e., the statistics are synchronized and then divied by | |
the total batch ``N``. This mode is beneficial when empty tensors | |
occur during training, as it average the total mean by the real | |
number of batch. | |
""" | |
def __init__(self, | |
num_features, | |
eps=1e-5, | |
momentum=0.1, | |
affine=True, | |
track_running_stats=True, | |
group=None, | |
stats_mode='default'): | |
super(SyncBatchNorm, self).__init__() | |
self.num_features = num_features | |
self.eps = eps | |
self.momentum = momentum | |
self.affine = affine | |
self.track_running_stats = track_running_stats | |
group = dist.group.WORLD if group is None else group | |
self.group = group | |
self.group_size = dist.get_world_size(group) | |
assert stats_mode in ['default', 'N'], \ | |
f'"stats_mode" only accepts "default" and "N", got "{stats_mode}"' | |
self.stats_mode = stats_mode | |
if self.affine: | |
self.weight = Parameter(torch.Tensor(num_features)) | |
self.bias = Parameter(torch.Tensor(num_features)) | |
else: | |
self.register_parameter('weight', None) | |
self.register_parameter('bias', None) | |
if self.track_running_stats: | |
self.register_buffer('running_mean', torch.zeros(num_features)) | |
self.register_buffer('running_var', torch.ones(num_features)) | |
self.register_buffer('num_batches_tracked', | |
torch.tensor(0, dtype=torch.long)) | |
else: | |
self.register_buffer('running_mean', None) | |
self.register_buffer('running_var', None) | |
self.register_buffer('num_batches_tracked', None) | |
self.reset_parameters() | |
def reset_running_stats(self): | |
if self.track_running_stats: | |
self.running_mean.zero_() | |
self.running_var.fill_(1) | |
self.num_batches_tracked.zero_() | |
def reset_parameters(self): | |
self.reset_running_stats() | |
if self.affine: | |
self.weight.data.uniform_() # pytorch use ones_() | |
self.bias.data.zero_() | |
def forward(self, input): | |
if input.dim() < 2: | |
raise ValueError( | |
f'expected at least 2D input, got {input.dim()}D input') | |
if self.momentum is None: | |
exponential_average_factor = 0.0 | |
else: | |
exponential_average_factor = self.momentum | |
if self.training and self.track_running_stats: | |
if self.num_batches_tracked is not None: | |
self.num_batches_tracked += 1 | |
if self.momentum is None: # use cumulative moving average | |
exponential_average_factor = 1.0 / float( | |
self.num_batches_tracked) | |
else: # use exponential moving average | |
exponential_average_factor = self.momentum | |
if self.training or not self.track_running_stats: | |
return SyncBatchNormFunction.apply( | |
input, self.running_mean, self.running_var, self.weight, | |
self.bias, exponential_average_factor, self.eps, self.group, | |
self.group_size, self.stats_mode) | |
else: | |
return F.batch_norm(input, self.running_mean, self.running_var, | |
self.weight, self.bias, False, | |
exponential_average_factor, self.eps) | |
def __repr__(self): | |
s = self.__class__.__name__ | |
s += f'({self.num_features}, ' | |
s += f'eps={self.eps}, ' | |
s += f'momentum={self.momentum}, ' | |
s += f'affine={self.affine}, ' | |
s += f'track_running_stats={self.track_running_stats}, ' | |
s += f'group_size={self.group_size},' | |
s += f'stats_mode={self.stats_mode})' | |
return s | |