import torch import torch.nn as nn import torch.nn.functional as functional try: from queue import Queue except ImportError: from Queue import Queue from .functions import * class ABN(nn.Module): """Activated Batch Normalization This gathers a `BatchNorm2d` and an activation function in a single module """ def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): """Creates an Activated Batch Normalization module Parameters ---------- num_features : int Number of feature channels in the input and output. eps : float Small constant to prevent numerical issues. momentum : float Momentum factor applied to compute running statistics as. affine : bool If `True` apply learned scale and shift transformation after normalization. activation : str Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. slope : float Negative slope for the `leaky_relu` activation. """ super(ABN, self).__init__() self.num_features = num_features self.affine = affine self.eps = eps self.momentum = momentum self.activation = activation self.slope = slope if self.affine: self.weight = nn.Parameter(torch.ones(num_features)) self.bias = nn.Parameter(torch.zeros(num_features)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.constant_(self.running_mean, 0) nn.init.constant_(self.running_var, 1) if self.affine: nn.init.constant_(self.weight, 1) nn.init.constant_(self.bias, 0) def forward(self, x): x = functional.batch_norm(x, self.running_mean, self.running_var, self.weight, self.bias, self.training, self.momentum, self.eps) if self.activation == ACT_RELU: return functional.relu(x, inplace=True) elif self.activation == ACT_LEAKY_RELU: return functional.leaky_relu(x, negative_slope=self.slope, inplace=True) elif self.activation == ACT_ELU: return functional.elu(x, inplace=True) else: return x def __repr__(self): rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ ' affine={affine}, activation={activation}' if self.activation == "leaky_relu": rep += ', slope={slope})' else: rep += ')' return rep.format(name=self.__class__.__name__, **self.__dict__) class InPlaceABN(ABN): """InPlace Activated Batch Normalization""" def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, activation="leaky_relu", slope=0.01): """Creates an InPlace Activated Batch Normalization module Parameters ---------- num_features : int Number of feature channels in the input and output. eps : float Small constant to prevent numerical issues. momentum : float Momentum factor applied to compute running statistics as. affine : bool If `True` apply learned scale and shift transformation after normalization. activation : str Name of the activation functions, one of: `leaky_relu`, `elu` or `none`. slope : float Negative slope for the `leaky_relu` activation. """ super(InPlaceABN, self).__init__(num_features, eps, momentum, affine, activation, slope) def forward(self, x): x, _, _ = inplace_abn(x, self.weight, self.bias, self.running_mean, self.running_var, self.training, self.momentum, self.eps, self.activation, self.slope) return x class InPlaceABNSync(ABN): """InPlace Activated Batch Normalization with cross-GPU synchronization This assumes that it will be replicated across GPUs using the same mechanism as in `nn.DistributedDataParallel`. """ def forward(self, x): x, _, _ = inplace_abn_sync(x, self.weight, self.bias, self.running_mean, self.running_var, self.training, self.momentum, self.eps, self.activation, self.slope) return x def __repr__(self): rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ ' affine={affine}, activation={activation}' if self.activation == "leaky_relu": rep += ', slope={slope})' else: rep += ')' return rep.format(name=self.__class__.__name__, **self.__dict__)