|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import collections |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch.nn.modules.batchnorm import _BatchNorm |
|
from torch.nn.parallel._functions import Broadcast, ReduceAddCoalesced |
|
|
|
from .comm import SyncMaster |
|
|
|
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d'] |
|
|
|
|
|
def _sum_ft(tensor): |
|
"""sum over the first and last dimention""" |
|
return tensor.sum(dim=0).sum(dim=-1) |
|
|
|
|
|
def _unsqueeze_ft(tensor): |
|
"""add new dementions at the front and the tail""" |
|
return tensor.unsqueeze(0).unsqueeze(-1) |
|
|
|
|
|
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size']) |
|
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std']) |
|
|
|
|
|
class _SynchronizedBatchNorm(_BatchNorm): |
|
def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True): |
|
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine) |
|
|
|
self._sync_master = SyncMaster(self._data_parallel_master) |
|
|
|
self._is_parallel = False |
|
self._parallel_id = None |
|
self._slave_pipe = None |
|
|
|
|
|
self._moving_average_fraction = 1. - momentum |
|
self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features)) |
|
self.register_buffer('_tmp_running_var', torch.ones(self.num_features)) |
|
self.register_buffer('_running_iter', torch.ones(1)) |
|
self._tmp_running_mean = self.running_mean.clone() * self._running_iter |
|
self._tmp_running_var = self.running_var.clone() * self._running_iter |
|
|
|
def forward(self, input): |
|
|
|
if not (self._is_parallel and self.training): |
|
return F.batch_norm( |
|
input, self.running_mean, self.running_var, self.weight, self.bias, |
|
self.training, self.momentum, self.eps) |
|
|
|
|
|
input_shape = input.size() |
|
input = input.view(input.size(0), self.num_features, -1) |
|
|
|
|
|
sum_size = input.size(0) * input.size(2) |
|
input_sum = _sum_ft(input) |
|
input_ssum = _sum_ft(input ** 2) |
|
|
|
|
|
if self._parallel_id == 0: |
|
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size)) |
|
else: |
|
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size)) |
|
|
|
|
|
if self.affine: |
|
|
|
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias) |
|
else: |
|
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std) |
|
|
|
|
|
return output.view(input_shape) |
|
|
|
def __data_parallel_replicate__(self, ctx, copy_id): |
|
self._is_parallel = True |
|
self._parallel_id = copy_id |
|
|
|
|
|
if self._parallel_id == 0: |
|
ctx.sync_master = self._sync_master |
|
else: |
|
self._slave_pipe = ctx.sync_master.register_slave(copy_id) |
|
|
|
def _data_parallel_master(self, intermediates): |
|
"""Reduce the sum and square-sum, compute the statistics, and broadcast it.""" |
|
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device()) |
|
|
|
to_reduce = [i[1][:2] for i in intermediates] |
|
to_reduce = [j for i in to_reduce for j in i] |
|
target_gpus = [i[1].sum.get_device() for i in intermediates] |
|
|
|
sum_size = sum([i[1].sum_size for i in intermediates]) |
|
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce) |
|
|
|
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size) |
|
|
|
broadcasted = Broadcast.apply(target_gpus, mean, inv_std) |
|
|
|
outputs = [] |
|
for i, rec in enumerate(intermediates): |
|
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2]))) |
|
|
|
return outputs |
|
|
|
def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0): |
|
"""return *dest* by `dest := dest*alpha + delta*beta + bias`""" |
|
return dest * alpha + delta * beta + bias |
|
|
|
def _compute_mean_std(self, sum_, ssum, size): |
|
"""Compute the mean and standard-deviation with sum and square-sum. This method |
|
also maintains the moving average on the master device.""" |
|
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.' |
|
mean = sum_ / size |
|
sumvar = ssum - sum_ * mean |
|
unbias_var = sumvar / (size - 1) |
|
bias_var = sumvar / size |
|
|
|
self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction) |
|
self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction) |
|
self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction) |
|
|
|
self.running_mean = self._tmp_running_mean / self._running_iter |
|
self.running_var = self._tmp_running_var / self._running_iter |
|
|
|
return mean, bias_var.clamp(self.eps) ** -0.5 |
|
|
|
|
|
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm): |
|
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a |
|
mini-batch. |
|
|
|
.. math:: |
|
|
|
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta |
|
|
|
This module differs from the built-in PyTorch BatchNorm1d as the mean and |
|
standard-deviation are reduced across all devices during training. |
|
|
|
For example, when one uses `nn.DataParallel` to wrap the network during |
|
training, PyTorch's implementation normalize the tensor on each device using |
|
the statistics only on that device, which accelerated the computation and |
|
is also easy to implement, but the statistics might be inaccurate. |
|
Instead, in this synchronized version, the statistics will be computed |
|
over all training samples distributed on multiple devices. |
|
|
|
Note that, for one-GPU or CPU-only case, this module behaves exactly same |
|
as the built-in PyTorch implementation. |
|
|
|
The mean and standard-deviation are calculated per-dimension over |
|
the mini-batches and gamma and beta are learnable parameter vectors |
|
of size C (where C is the input size). |
|
|
|
During training, this layer keeps a running estimate of its computed mean |
|
and variance. The running sum is kept with a default momentum of 0.1. |
|
|
|
During evaluation, this running mean/variance is used for normalization. |
|
|
|
Because the BatchNorm is done over the `C` dimension, computing statistics |
|
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm |
|
|
|
Args: |
|
num_features: num_features from an expected input of size |
|
`batch_size x num_features [x width]` |
|
eps: a value added to the denominator for numerical stability. |
|
Default: 1e-5 |
|
momentum: the value used for the running_mean and running_var |
|
computation. Default: 0.1 |
|
affine: a boolean value that when set to ``True``, gives the layer learnable |
|
affine parameters. Default: ``True`` |
|
|
|
Shape: |
|
- Input: :math:`(N, C)` or :math:`(N, C, L)` |
|
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input) |
|
|
|
Examples: |
|
>>> # With Learnable Parameters |
|
>>> m = SynchronizedBatchNorm1d(100) |
|
>>> # Without Learnable Parameters |
|
>>> m = SynchronizedBatchNorm1d(100, affine=False) |
|
>>> input = torch.autograd.Variable(torch.randn(20, 100)) |
|
>>> output = m(input) |
|
""" |
|
|
|
def _check_input_dim(self, input): |
|
if input.dim() != 2 and input.dim() != 3: |
|
raise ValueError('expected 2D or 3D input (got {}D input)' |
|
.format(input.dim())) |
|
super(SynchronizedBatchNorm1d, self)._check_input_dim(input) |
|
|
|
|
|
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm): |
|
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch |
|
of 3d inputs |
|
|
|
.. math:: |
|
|
|
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta |
|
|
|
This module differs from the built-in PyTorch BatchNorm2d as the mean and |
|
standard-deviation are reduced across all devices during training. |
|
|
|
For example, when one uses `nn.DataParallel` to wrap the network during |
|
training, PyTorch's implementation normalize the tensor on each device using |
|
the statistics only on that device, which accelerated the computation and |
|
is also easy to implement, but the statistics might be inaccurate. |
|
Instead, in this synchronized version, the statistics will be computed |
|
over all training samples distributed on multiple devices. |
|
|
|
Note that, for one-GPU or CPU-only case, this module behaves exactly same |
|
as the built-in PyTorch implementation. |
|
|
|
The mean and standard-deviation are calculated per-dimension over |
|
the mini-batches and gamma and beta are learnable parameter vectors |
|
of size C (where C is the input size). |
|
|
|
During training, this layer keeps a running estimate of its computed mean |
|
and variance. The running sum is kept with a default momentum of 0.1. |
|
|
|
During evaluation, this running mean/variance is used for normalization. |
|
|
|
Because the BatchNorm is done over the `C` dimension, computing statistics |
|
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm |
|
|
|
Args: |
|
num_features: num_features from an expected input of |
|
size batch_size x num_features x height x width |
|
eps: a value added to the denominator for numerical stability. |
|
Default: 1e-5 |
|
momentum: the value used for the running_mean and running_var |
|
computation. Default: 0.1 |
|
affine: a boolean value that when set to ``True``, gives the layer learnable |
|
affine parameters. Default: ``True`` |
|
|
|
Shape: |
|
- Input: :math:`(N, C, H, W)` |
|
- Output: :math:`(N, C, H, W)` (same shape as input) |
|
|
|
Examples: |
|
>>> # With Learnable Parameters |
|
>>> m = SynchronizedBatchNorm2d(100) |
|
>>> # Without Learnable Parameters |
|
>>> m = SynchronizedBatchNorm2d(100, affine=False) |
|
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45)) |
|
>>> output = m(input) |
|
""" |
|
|
|
def _check_input_dim(self, input): |
|
if input.dim() != 4: |
|
raise ValueError('expected 4D input (got {}D input)' |
|
.format(input.dim())) |
|
super(SynchronizedBatchNorm2d, self)._check_input_dim(input) |
|
|
|
|
|
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm): |
|
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch |
|
of 4d inputs |
|
|
|
.. math:: |
|
|
|
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta |
|
|
|
This module differs from the built-in PyTorch BatchNorm3d as the mean and |
|
standard-deviation are reduced across all devices during training. |
|
|
|
For example, when one uses `nn.DataParallel` to wrap the network during |
|
training, PyTorch's implementation normalize the tensor on each device using |
|
the statistics only on that device, which accelerated the computation and |
|
is also easy to implement, but the statistics might be inaccurate. |
|
Instead, in this synchronized version, the statistics will be computed |
|
over all training samples distributed on multiple devices. |
|
|
|
Note that, for one-GPU or CPU-only case, this module behaves exactly same |
|
as the built-in PyTorch implementation. |
|
|
|
The mean and standard-deviation are calculated per-dimension over |
|
the mini-batches and gamma and beta are learnable parameter vectors |
|
of size C (where C is the input size). |
|
|
|
During training, this layer keeps a running estimate of its computed mean |
|
and variance. The running sum is kept with a default momentum of 0.1. |
|
|
|
During evaluation, this running mean/variance is used for normalization. |
|
|
|
Because the BatchNorm is done over the `C` dimension, computing statistics |
|
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm |
|
or Spatio-temporal BatchNorm |
|
|
|
Args: |
|
num_features: num_features from an expected input of |
|
size batch_size x num_features x depth x height x width |
|
eps: a value added to the denominator for numerical stability. |
|
Default: 1e-5 |
|
momentum: the value used for the running_mean and running_var |
|
computation. Default: 0.1 |
|
affine: a boolean value that when set to ``True``, gives the layer learnable |
|
affine parameters. Default: ``True`` |
|
|
|
Shape: |
|
- Input: :math:`(N, C, D, H, W)` |
|
- Output: :math:`(N, C, D, H, W)` (same shape as input) |
|
|
|
Examples: |
|
>>> # With Learnable Parameters |
|
>>> m = SynchronizedBatchNorm3d(100) |
|
>>> # Without Learnable Parameters |
|
>>> m = SynchronizedBatchNorm3d(100, affine=False) |
|
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10)) |
|
>>> output = m(input) |
|
""" |
|
|
|
def _check_input_dim(self, input): |
|
if input.dim() != 5: |
|
raise ValueError('expected 5D input (got {}D input)' |
|
.format(input.dim())) |
|
super(SynchronizedBatchNorm3d, self)._check_input_dim(input) |
|
|