# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. """BatchNorm (BN) utility functions and custom batch-size BN implementations""" from functools import partial import torch import torch.distributed as dist import torch.nn as nn from torch.autograd.function import Function import timesformer.utils.distributed as du def get_norm(cfg): """ Args: cfg (CfgNode): model building configs, details are in the comments of the config file. Returns: nn.Module: the normalization layer. """ if cfg.BN.NORM_TYPE == "batchnorm": return nn.BatchNorm3d elif cfg.BN.NORM_TYPE == "sub_batchnorm": return partial(SubBatchNorm3d, num_splits=cfg.BN.NUM_SPLITS) elif cfg.BN.NORM_TYPE == "sync_batchnorm": return partial( NaiveSyncBatchNorm3d, num_sync_devices=cfg.BN.NUM_SYNC_DEVICES ) else: raise NotImplementedError( "Norm type {} is not supported".format(cfg.BN.NORM_TYPE) ) class SubBatchNorm3d(nn.Module): """ The standard BN layer computes stats across all examples in a GPU. In some cases it is desirable to compute stats across only a subset of examples (e.g., in multigrid training https://arxiv.org/abs/1912.00998). SubBatchNorm3d splits the batch dimension into N splits, and run BN on each of them separately (so that the stats are computed on each subset of examples (1/N of batch) independently. During evaluation, it aggregates the stats from all splits into one BN. """ def __init__(self, num_splits, **args): """ Args: num_splits (int): number of splits. args (list): other arguments. """ super(SubBatchNorm3d, self).__init__() self.num_splits = num_splits num_features = args["num_features"] # Keep only one set of weight and bias. if args.get("affine", True): self.affine = True args["affine"] = False self.weight = torch.nn.Parameter(torch.ones(num_features)) self.bias = torch.nn.Parameter(torch.zeros(num_features)) else: self.affine = False self.bn = nn.BatchNorm3d(**args) args["num_features"] = num_features * num_splits self.split_bn = nn.BatchNorm3d(**args) def _get_aggregated_mean_std(self, means, stds, n): """ Calculate the aggregated mean and stds. Args: means (tensor): mean values. stds (tensor): standard deviations. n (int): number of sets of means and stds. """ mean = means.view(n, -1).sum(0) / n std = ( stds.view(n, -1).sum(0) / n + ((means.view(n, -1) - mean) ** 2).view(n, -1).sum(0) / n ) return mean.detach(), std.detach() def aggregate_stats(self): """ Synchronize running_mean, and running_var. Call this before eval. """ if self.split_bn.track_running_stats: ( self.bn.running_mean.data, self.bn.running_var.data, ) = self._get_aggregated_mean_std( self.split_bn.running_mean, self.split_bn.running_var, self.num_splits, ) def forward(self, x): if self.training: n, c, t, h, w = x.shape x = x.view(n // self.num_splits, c * self.num_splits, t, h, w) x = self.split_bn(x) x = x.view(n, c, t, h, w) else: x = self.bn(x) if self.affine: x = x * self.weight.view((-1, 1, 1, 1)) x = x + self.bias.view((-1, 1, 1, 1)) return x class GroupGather(Function): """ GroupGather performs all gather on each of the local process/ GPU groups. """ @staticmethod def forward(ctx, input, num_sync_devices, num_groups): """ Perform forwarding, gathering the stats across different process/ GPU group. """ ctx.num_sync_devices = num_sync_devices ctx.num_groups = num_groups input_list = [ torch.zeros_like(input) for k in range(du.get_local_size()) ] dist.all_gather( input_list, input, async_op=False, group=du._LOCAL_PROCESS_GROUP ) inputs = torch.stack(input_list, dim=0) if num_groups > 1: rank = du.get_local_rank() group_idx = rank // num_sync_devices inputs = inputs[ group_idx * num_sync_devices : (group_idx + 1) * num_sync_devices ] inputs = torch.sum(inputs, dim=0) return inputs @staticmethod def backward(ctx, grad_output): """ Perform backwarding, gathering the gradients across different process/ GPU group. """ grad_output_list = [ torch.zeros_like(grad_output) for k in range(du.get_local_size()) ] dist.all_gather( grad_output_list, grad_output, async_op=False, group=du._LOCAL_PROCESS_GROUP, ) grads = torch.stack(grad_output_list, dim=0) if ctx.num_groups > 1: rank = du.get_local_rank() group_idx = rank // ctx.num_sync_devices grads = grads[ group_idx * ctx.num_sync_devices : (group_idx + 1) * ctx.num_sync_devices ] grads = torch.sum(grads, dim=0) return grads, None, None class NaiveSyncBatchNorm3d(nn.BatchNorm3d): def __init__(self, num_sync_devices, **args): """ Naive version of Synchronized 3D BatchNorm. Args: num_sync_devices (int): number of device to sync. args (list): other arguments. """ self.num_sync_devices = num_sync_devices if self.num_sync_devices > 0: assert du.get_local_size() % self.num_sync_devices == 0, ( du.get_local_size(), self.num_sync_devices, ) self.num_groups = du.get_local_size() // self.num_sync_devices else: self.num_sync_devices = du.get_local_size() self.num_groups = 1 super(NaiveSyncBatchNorm3d, self).__init__(**args) def forward(self, input): if du.get_local_size() == 1 or not self.training: return super().forward(input) assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs" C = input.shape[1] mean = torch.mean(input, dim=[0, 2, 3, 4]) meansqr = torch.mean(input * input, dim=[0, 2, 3, 4]) vec = torch.cat([mean, meansqr], dim=0) vec = GroupGather.apply(vec, self.num_sync_devices, self.num_groups) * ( 1.0 / self.num_sync_devices ) mean, meansqr = torch.split(vec, C) var = meansqr - mean * mean self.running_mean += self.momentum * (mean.detach() - self.running_mean) self.running_var += self.momentum * (var.detach() - self.running_var) invstd = torch.rsqrt(var + self.eps) scale = self.weight * invstd bias = self.bias - mean * scale scale = scale.reshape(1, -1, 1, 1, 1) bias = bias.reshape(1, -1, 1, 1, 1) return input * scale + bias