# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.distributed as dist import torch.nn.functional as F from torch.autograd import Function from torch.autograd.function import once_differentiable from torch.nn.modules.module import Module from torch.nn.parameter import Parameter from annotator.uniformer.mmcv.cnn import NORM_LAYERS from ..utils import ext_loader ext_module = ext_loader.load_ext('_ext', [ 'sync_bn_forward_mean', 'sync_bn_forward_var', 'sync_bn_forward_output', 'sync_bn_backward_param', 'sync_bn_backward_data' ]) class SyncBatchNormFunction(Function): @staticmethod def symbolic(g, input, running_mean, running_var, weight, bias, momentum, eps, group, group_size, stats_mode): return g.op( 'mmcv::MMCVSyncBatchNorm', input, running_mean, running_var, weight, bias, momentum_f=momentum, eps_f=eps, group_i=group, group_size_i=group_size, stats_mode=stats_mode) @staticmethod def forward(self, input, running_mean, running_var, weight, bias, momentum, eps, group, group_size, stats_mode): self.momentum = momentum self.eps = eps self.group = group self.group_size = group_size self.stats_mode = stats_mode assert isinstance( input, (torch.HalfTensor, torch.FloatTensor, torch.cuda.HalfTensor, torch.cuda.FloatTensor)), \ f'only support Half or Float Tensor, but {input.type()}' output = torch.zeros_like(input) input3d = input.flatten(start_dim=2) output3d = output.view_as(input3d) num_channels = input3d.size(1) # ensure mean/var/norm/std are initialized as zeros # ``torch.empty()`` does not guarantee that mean = torch.zeros( num_channels, dtype=torch.float, device=input3d.device) var = torch.zeros( num_channels, dtype=torch.float, device=input3d.device) norm = torch.zeros_like( input3d, dtype=torch.float, device=input3d.device) std = torch.zeros( num_channels, dtype=torch.float, device=input3d.device) batch_size = input3d.size(0) if batch_size > 0: ext_module.sync_bn_forward_mean(input3d, mean) batch_flag = torch.ones([1], device=mean.device, dtype=mean.dtype) else: # skip updating mean and leave it as zeros when the input is empty batch_flag = torch.zeros([1], device=mean.device, dtype=mean.dtype) # synchronize mean and the batch flag vec = torch.cat([mean, batch_flag]) if self.stats_mode == 'N': vec *= batch_size if self.group_size > 1: dist.all_reduce(vec, group=self.group) total_batch = vec[-1].detach() mean = vec[:num_channels] if self.stats_mode == 'default': mean = mean / self.group_size elif self.stats_mode == 'N': mean = mean / total_batch.clamp(min=1) else: raise NotImplementedError # leave var as zeros when the input is empty if batch_size > 0: ext_module.sync_bn_forward_var(input3d, mean, var) if self.stats_mode == 'N': var *= batch_size if self.group_size > 1: dist.all_reduce(var, group=self.group) if self.stats_mode == 'default': var /= self.group_size elif self.stats_mode == 'N': var /= total_batch.clamp(min=1) else: raise NotImplementedError # if the total batch size over all the ranks is zero, # we should not update the statistics in the current batch update_flag = total_batch.clamp(max=1) momentum = update_flag * self.momentum ext_module.sync_bn_forward_output( input3d, mean, var, weight, bias, running_mean, running_var, norm, std, output3d, eps=self.eps, momentum=momentum, group_size=self.group_size) self.save_for_backward(norm, std, weight) return output @staticmethod @once_differentiable def backward(self, grad_output): norm, std, weight = self.saved_tensors grad_weight = torch.zeros_like(weight) grad_bias = torch.zeros_like(weight) grad_input = torch.zeros_like(grad_output) grad_output3d = grad_output.flatten(start_dim=2) grad_input3d = grad_input.view_as(grad_output3d) batch_size = grad_input3d.size(0) if batch_size > 0: ext_module.sync_bn_backward_param(grad_output3d, norm, grad_weight, grad_bias) # all reduce if self.group_size > 1: dist.all_reduce(grad_weight, group=self.group) dist.all_reduce(grad_bias, group=self.group) grad_weight /= self.group_size grad_bias /= self.group_size if batch_size > 0: ext_module.sync_bn_backward_data(grad_output3d, weight, grad_weight, grad_bias, norm, std, grad_input3d) return grad_input, None, None, grad_weight, grad_bias, \ None, None, None, None, None @NORM_LAYERS.register_module(name='MMSyncBN') class SyncBatchNorm(Module): """Synchronized Batch Normalization. Args: num_features (int): number of features/chennels in input tensor eps (float, optional): a value added to the denominator for numerical stability. Defaults to 1e-5. momentum (float, optional): the value used for the running_mean and running_var computation. Defaults to 0.1. affine (bool, optional): whether to use learnable affine parameters. Defaults to True. track_running_stats (bool, optional): whether to track the running mean and variance during training. When set to False, this module does not track such statistics, and initializes statistics buffers ``running_mean`` and ``running_var`` as ``None``. When these buffers are ``None``, this module always uses batch statistics in both training and eval modes. Defaults to True. group (int, optional): synchronization of stats happen within each process group individually. By default it is synchronization across the whole world. Defaults to None. stats_mode (str, optional): The statistical mode. Available options includes ``'default'`` and ``'N'``. Defaults to 'default'. When ``stats_mode=='default'``, it computes the overall statistics using those from each worker with equal weight, i.e., the statistics are synchronized and simply divied by ``group``. This mode will produce inaccurate statistics when empty tensors occur. When ``stats_mode=='N'``, it compute the overall statistics using the total number of batches in each worker ignoring the number of group, i.e., the statistics are synchronized and then divied by the total batch ``N``. This mode is beneficial when empty tensors occur during training, as it average the total mean by the real number of batch. """ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True, group=None, stats_mode='default'): super(SyncBatchNorm, self).__init__() self.num_features = num_features self.eps = eps self.momentum = momentum self.affine = affine self.track_running_stats = track_running_stats group = dist.group.WORLD if group is None else group self.group = group self.group_size = dist.get_world_size(group) assert stats_mode in ['default', 'N'], \ f'"stats_mode" only accepts "default" and "N", got "{stats_mode}"' self.stats_mode = stats_mode if self.affine: self.weight = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) if self.track_running_stats: self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long)) else: self.register_buffer('running_mean', None) self.register_buffer('running_var', None) self.register_buffer('num_batches_tracked', None) self.reset_parameters() def reset_running_stats(self): if self.track_running_stats: self.running_mean.zero_() self.running_var.fill_(1) self.num_batches_tracked.zero_() def reset_parameters(self): self.reset_running_stats() if self.affine: self.weight.data.uniform_() # pytorch use ones_() self.bias.data.zero_() def forward(self, input): if input.dim() < 2: raise ValueError( f'expected at least 2D input, got {input.dim()}D input') if self.momentum is None: exponential_average_factor = 0.0 else: exponential_average_factor = self.momentum if self.training and self.track_running_stats: if self.num_batches_tracked is not None: self.num_batches_tracked += 1 if self.momentum is None: # use cumulative moving average exponential_average_factor = 1.0 / float( self.num_batches_tracked) else: # use exponential moving average exponential_average_factor = self.momentum if self.training or not self.track_running_stats: return SyncBatchNormFunction.apply( input, self.running_mean, self.running_var, self.weight, self.bias, exponential_average_factor, self.eps, self.group, self.group_size, self.stats_mode) else: return F.batch_norm(input, self.running_mean, self.running_var, self.weight, self.bias, False, exponential_average_factor, self.eps) def __repr__(self): s = self.__class__.__name__ s += f'({self.num_features}, ' s += f'eps={self.eps}, ' s += f'momentum={self.momentum}, ' s += f'affine={self.affine}, ' s += f'track_running_stats={self.track_running_stats}, ' s += f'group_size={self.group_size},' s += f'stats_mode={self.stats_mode})' return s