Spaces:
Running
Running
"""This file contains utilities for initializing neural network parameters.""" | |
import math | |
import warnings | |
from torch import Tensor | |
import torch | |
from typing import Optional as _Optional | |
# These no_grad_* functions are necessary as wrappers around the parts of these | |
# functions that use `with torch.no_grad()`. The JIT doesn't support context | |
# managers, so these need to be implemented as builtins. Using these wrappers | |
# lets us keep those builtins small and re-usable. | |
def _no_grad_uniform_(tensor, a, b, generator=None): | |
with torch.no_grad(): | |
return tensor.uniform_(a, b, generator=generator) | |
def _no_grad_normal_(tensor, mean, std, generator=None): | |
with torch.no_grad(): | |
return tensor.normal_(mean, std, generator=generator) | |
def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None): | |
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf | |
def norm_cdf(x): | |
# Computes standard normal cumulative distribution function | |
return (1. + math.erf(x / math.sqrt(2.))) / 2. | |
if (mean < a - 2 * std) or (mean > b + 2 * std): | |
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " | |
"The distribution of values may be incorrect.", | |
stacklevel=2) | |
with torch.no_grad(): | |
# Values are generated by using a truncated uniform distribution and | |
# then using the inverse CDF for the normal distribution. | |
# Get upper and lower cdf values | |
l = norm_cdf((a - mean) / std) | |
u = norm_cdf((b - mean) / std) | |
# Uniformly fill tensor with values from [l, u], then translate to | |
# [2l-1, 2u-1]. | |
tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator) | |
# Use inverse cdf transform for normal distribution to get truncated | |
# standard normal | |
tensor.erfinv_() | |
# Transform to proper mean, std | |
tensor.mul_(std * math.sqrt(2.)) | |
tensor.add_(mean) | |
# Clamp to ensure it's in the proper range | |
tensor.clamp_(min=a, max=b) | |
return tensor | |
def _no_grad_fill_(tensor, val): | |
with torch.no_grad(): | |
return tensor.fill_(val) | |
def _no_grad_zero_(tensor): | |
with torch.no_grad(): | |
return tensor.zero_() | |
def calculate_gain(nonlinearity, param=None): | |
r"""Return the recommended gain value for the given nonlinearity function. | |
The values are as follows: | |
================= ==================================================== | |
nonlinearity gain | |
================= ==================================================== | |
Linear / Identity :math:`1` | |
Conv{1,2,3}D :math:`1` | |
Sigmoid :math:`1` | |
Tanh :math:`\frac{5}{3}` | |
ReLU :math:`\sqrt{2}` | |
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` | |
SELU :math:`\frac{3}{4}` | |
================= ==================================================== | |
.. warning:: | |
In order to implement `Self-Normalizing Neural Networks`_ , | |
you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``. | |
This gives the initial weights a variance of ``1 / N``, | |
which is necessary to induce a stable fixed point in the forward pass. | |
In contrast, the default gain for ``SELU`` sacrifices the normalization | |
effect for more stable gradient flow in rectangular layers. | |
Args: | |
nonlinearity: the non-linear function (`nn.functional` name) | |
param: optional parameter for the non-linear function | |
Examples: | |
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 | |
.. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html | |
""" | |
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d'] | |
if nonlinearity in linear_fns or nonlinearity == 'sigmoid': | |
return 1 | |
elif nonlinearity == 'tanh': | |
return 5.0 / 3 | |
elif nonlinearity == 'relu': | |
return math.sqrt(2.0) | |
elif nonlinearity == 'leaky_relu': | |
if param is None: | |
negative_slope = 0.01 | |
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float): | |
# True/False are instances of int, hence check above | |
negative_slope = param | |
else: | |
raise ValueError(f"negative_slope {param} not a valid number") | |
return math.sqrt(2.0 / (1 + negative_slope ** 2)) | |
elif nonlinearity == 'selu': | |
return 3.0 / 4 # Value found empirically (https://github.com/pytorch/pytorch/pull/50664) | |
else: | |
raise ValueError(f"Unsupported nonlinearity {nonlinearity}") | |
def uniform_( | |
tensor: Tensor, | |
a: float = 0.0, | |
b: float = 1.0, | |
generator: _Optional[torch.Generator] = None, | |
) -> Tensor: | |
r"""Fill the input Tensor with values drawn from the uniform distribution. | |
:math:`\mathcal{U}(a, b)`. | |
Args: | |
tensor: an n-dimensional `torch.Tensor` | |
a: the lower bound of the uniform distribution | |
b: the upper bound of the uniform distribution | |
generator: the torch Generator to sample from (default: None) | |
Examples: | |
>>> w = torch.empty(3, 5) | |
>>> nn.init.uniform_(w) | |
""" | |
if torch.overrides.has_torch_function_variadic(tensor): | |
return torch.overrides.handle_torch_function( | |
uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator | |
) | |
return _no_grad_uniform_(tensor, a, b, generator) | |
def normal_( | |
tensor: Tensor, | |
mean: float = 0.0, | |
std: float = 1.0, | |
generator: _Optional[torch.Generator] = None, | |
) -> Tensor: | |
r"""Fill the input Tensor with values drawn from the normal distribution. | |
:math:`\mathcal{N}(\text{mean}, \text{std}^2)`. | |
Args: | |
tensor: an n-dimensional `torch.Tensor` | |
mean: the mean of the normal distribution | |
std: the standard deviation of the normal distribution | |
generator: the torch Generator to sample from (default: None) | |
Examples: | |
>>> w = torch.empty(3, 5) | |
>>> nn.init.normal_(w) | |
""" | |
if torch.overrides.has_torch_function_variadic(tensor): | |
return torch.overrides.handle_torch_function( | |
normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator | |
) | |
return _no_grad_normal_(tensor, mean, std, generator) | |
def trunc_normal_( | |
tensor: Tensor, | |
mean: float = 0., | |
std: float = 1., | |
a: float = -2., | |
b: float = 2., | |
generator: _Optional[torch.Generator] = None | |
) -> Tensor: | |
r"""Fill the input Tensor with values drawn from a truncated normal distribution. | |
The values are effectively drawn from the | |
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` | |
with values outside :math:`[a, b]` redrawn until they are within | |
the bounds. The method used for generating the random values works | |
best when :math:`a \leq \text{mean} \leq b`. | |
Args: | |
tensor: an n-dimensional `torch.Tensor` | |
mean: the mean of the normal distribution | |
std: the standard deviation of the normal distribution | |
a: the minimum cutoff value | |
b: the maximum cutoff value | |
generator: the torch Generator to sample from (default: None) | |
Examples: | |
>>> w = torch.empty(3, 5) | |
>>> nn.init.trunc_normal_(w) | |
""" | |
return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator) | |
def constant_(tensor: Tensor, val: float) -> Tensor: | |
r"""Fill the input Tensor with the value :math:`\text{val}`. | |
Args: | |
tensor: an n-dimensional `torch.Tensor` | |
val: the value to fill the tensor with | |
Examples: | |
>>> w = torch.empty(3, 5) | |
>>> nn.init.constant_(w, 0.3) | |
""" | |
if torch.overrides.has_torch_function_variadic(tensor): | |
return torch.overrides.handle_torch_function(constant_, (tensor,), tensor=tensor, val=val) | |
return _no_grad_fill_(tensor, val) | |
def ones_(tensor: Tensor) -> Tensor: | |
r"""Fill the input Tensor with the scalar value `1`. | |
Args: | |
tensor: an n-dimensional `torch.Tensor` | |
Examples: | |
>>> w = torch.empty(3, 5) | |
>>> nn.init.ones_(w) | |
""" | |
return _no_grad_fill_(tensor, 1.) | |
def zeros_(tensor: Tensor) -> Tensor: | |
r"""Fill the input Tensor with the scalar value `0`. | |
Args: | |
tensor: an n-dimensional `torch.Tensor` | |
Examples: | |
>>> w = torch.empty(3, 5) | |
>>> nn.init.zeros_(w) | |
""" | |
return _no_grad_zero_(tensor) | |
def eye_(tensor): | |
r"""Fill the 2-dimensional input `Tensor` with the identity matrix. | |
Preserves the identity of the inputs in `Linear` layers, where as | |
many inputs are preserved as possible. | |
Args: | |
tensor: a 2-dimensional `torch.Tensor` | |
Examples: | |
>>> w = torch.empty(3, 5) | |
>>> nn.init.eye_(w) | |
""" | |
if tensor.ndimension() != 2: | |
raise ValueError("Only tensors with 2 dimensions are supported") | |
with torch.no_grad(): | |
torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad) | |
return tensor | |
def dirac_(tensor, groups=1): | |
r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function. | |
Preserves the identity of the inputs in `Convolutional` | |
layers, where as many input channels are preserved as possible. In case | |
of groups>1, each group of channels preserves identity | |
Args: | |
tensor: a {3, 4, 5}-dimensional `torch.Tensor` | |
groups (int, optional): number of groups in the conv layer (default: 1) | |
Examples: | |
>>> w = torch.empty(3, 16, 5, 5) | |
>>> nn.init.dirac_(w) | |
>>> w = torch.empty(3, 24, 5, 5) | |
>>> nn.init.dirac_(w, 3) | |
""" | |
dimensions = tensor.ndimension() | |
if dimensions not in [3, 4, 5]: | |
raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported") | |
sizes = tensor.size() | |
if sizes[0] % groups != 0: | |
raise ValueError('dim 0 must be divisible by groups') | |
out_chans_per_grp = sizes[0] // groups | |
min_dim = min(out_chans_per_grp, sizes[1]) | |
with torch.no_grad(): | |
tensor.zero_() | |
for g in range(groups): | |
for d in range(min_dim): | |
if dimensions == 3: # Temporal convolution | |
tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1 | |
elif dimensions == 4: # Spatial convolution | |
tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2, | |
tensor.size(3) // 2] = 1 | |
else: # Volumetric convolution | |
tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2, | |
tensor.size(3) // 2, tensor.size(4) // 2] = 1 | |
return tensor | |
def _calculate_fan_in_and_fan_out(tensor): | |
dimensions = tensor.dim() | |
if dimensions < 2: | |
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") | |
num_input_fmaps = tensor.size(1) | |
num_output_fmaps = tensor.size(0) | |
receptive_field_size = 1 | |
if tensor.dim() > 2: | |
# math.prod is not always available, accumulate the product manually | |
# we could use functools.reduce but that is not supported by TorchScript | |
for s in tensor.shape[2:]: | |
receptive_field_size *= s | |
fan_in = num_input_fmaps * receptive_field_size | |
fan_out = num_output_fmaps * receptive_field_size | |
return fan_in, fan_out | |
def xavier_uniform_( | |
tensor: Tensor, gain: float = 1.0, generator: _Optional[torch.Generator] = None | |
) -> Tensor: | |
r"""Fill the input `Tensor` with values using a Xavier uniform distribution. | |
The method is described in `Understanding the difficulty of training | |
deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010). | |
The resulting tensor will have values sampled from | |
:math:`\mathcal{U}(-a, a)` where | |
.. math:: | |
a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} | |
Also known as Glorot initialization. | |
Args: | |
tensor: an n-dimensional `torch.Tensor` | |
gain: an optional scaling factor | |
generator: the torch Generator to sample from (default: None) | |
Examples: | |
>>> w = torch.empty(3, 5) | |
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')) | |
""" | |
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) | |
std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) | |
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation | |
return _no_grad_uniform_(tensor, -a, a, generator) | |
def xavier_normal_( | |
tensor: Tensor, | |
gain: float = 1.0, | |
generator: _Optional[torch.Generator] = None, | |
) -> Tensor: | |
r"""Fill the input `Tensor` with values using a Xavier normal distribution. | |
The method is described in `Understanding the difficulty of training deep feedforward | |
neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor | |
will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where | |
.. math:: | |
\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} | |
Also known as Glorot initialization. | |
Args: | |
tensor: an n-dimensional `torch.Tensor` | |
gain: an optional scaling factor | |
generator: the torch Generator to sample from (default: None) | |
Examples: | |
>>> w = torch.empty(3, 5) | |
>>> nn.init.xavier_normal_(w) | |
""" | |
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) | |
std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) | |
return _no_grad_normal_(tensor, 0., std, generator) | |
def _calculate_correct_fan(tensor, mode): | |
mode = mode.lower() | |
valid_modes = ['fan_in', 'fan_out'] | |
if mode not in valid_modes: | |
raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") | |
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) | |
return fan_in if mode == 'fan_in' else fan_out | |
def kaiming_uniform_( | |
tensor: Tensor, | |
a: float = 0, | |
mode: str = "fan_in", | |
nonlinearity: str = "leaky_relu", | |
generator: _Optional[torch.Generator] = None, | |
): | |
r"""Fill the input `Tensor` with values using a Kaiming uniform distribution. | |
The method is described in `Delving deep into rectifiers: Surpassing | |
human-level performance on ImageNet classification` - He, K. et al. (2015). | |
The resulting tensor will have values sampled from | |
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where | |
.. math:: | |
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} | |
Also known as He initialization. | |
Args: | |
tensor: an n-dimensional `torch.Tensor` | |
a: the negative slope of the rectifier used after this layer (only | |
used with ``'leaky_relu'``) | |
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` | |
preserves the magnitude of the variance of the weights in the | |
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the | |
backwards pass. | |
nonlinearity: the non-linear function (`nn.functional` name), | |
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). | |
generator: the torch Generator to sample from (default: None) | |
Examples: | |
>>> w = torch.empty(3, 5) | |
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') | |
""" | |
if torch.overrides.has_torch_function_variadic(tensor): | |
return torch.overrides.handle_torch_function( | |
kaiming_uniform_, | |
(tensor,), | |
tensor=tensor, | |
a=a, | |
mode=mode, | |
nonlinearity=nonlinearity, | |
generator=generator) | |
if 0 in tensor.shape: | |
warnings.warn("Initializing zero-element tensors is a no-op") | |
return tensor | |
fan = _calculate_correct_fan(tensor, mode) | |
gain = calculate_gain(nonlinearity, a) | |
std = gain / math.sqrt(fan) | |
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation | |
with torch.no_grad(): | |
return tensor.uniform_(-bound, bound, generator=generator) | |
def kaiming_normal_( | |
tensor: Tensor, | |
a: float = 0, | |
mode: str = "fan_in", | |
nonlinearity: str = "leaky_relu", | |
generator: _Optional[torch.Generator] = None, | |
): | |
r"""Fill the input `Tensor` with values using a Kaiming normal distribution. | |
The method is described in `Delving deep into rectifiers: Surpassing | |
human-level performance on ImageNet classification` - He, K. et al. (2015). | |
The resulting tensor will have values sampled from | |
:math:`\mathcal{N}(0, \text{std}^2)` where | |
.. math:: | |
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} | |
Also known as He initialization. | |
Args: | |
tensor: an n-dimensional `torch.Tensor` | |
a: the negative slope of the rectifier used after this layer (only | |
used with ``'leaky_relu'``) | |
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` | |
preserves the magnitude of the variance of the weights in the | |
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the | |
backwards pass. | |
nonlinearity: the non-linear function (`nn.functional` name), | |
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). | |
generator: the torch Generator to sample from (default: None) | |
Examples: | |
>>> w = torch.empty(3, 5) | |
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') | |
""" | |
if 0 in tensor.shape: | |
warnings.warn("Initializing zero-element tensors is a no-op") | |
return tensor | |
fan = _calculate_correct_fan(tensor, mode) | |
gain = calculate_gain(nonlinearity, a) | |
std = gain / math.sqrt(fan) | |
with torch.no_grad(): | |
return tensor.normal_(0, std, generator=generator) | |
def orthogonal_( | |
tensor, | |
gain=1, | |
generator: _Optional[torch.Generator] = None, | |
): | |
r"""Fill the input `Tensor` with a (semi) orthogonal matrix. | |
Described in `Exact solutions to the nonlinear dynamics of learning in deep | |
linear neural networks` - Saxe, A. et al. (2013). The input tensor must have | |
at least 2 dimensions, and for tensors with more than 2 dimensions the | |
trailing dimensions are flattened. | |
Args: | |
tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2` | |
gain: optional scaling factor | |
generator: the torch Generator to sample from (default: None) | |
Examples: | |
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) | |
>>> w = torch.empty(3, 5) | |
>>> nn.init.orthogonal_(w) | |
""" | |
if tensor.ndimension() < 2: | |
raise ValueError("Only tensors with 2 or more dimensions are supported") | |
if tensor.numel() == 0: | |
# no-op | |
return tensor | |
rows = tensor.size(0) | |
cols = tensor.numel() // rows | |
flattened = tensor.new(rows, cols).normal_(0, 1, generator=generator) | |
if rows < cols: | |
flattened.t_() | |
# Compute the qr factorization | |
q, r = torch.linalg.qr(flattened) | |
# Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf | |
d = torch.diag(r, 0) | |
ph = d.sign() | |
q *= ph | |
if rows < cols: | |
q.t_() | |
with torch.no_grad(): | |
tensor.view_as(q).copy_(q) | |
tensor.mul_(gain) | |
return tensor | |
def sparse_( | |
tensor, | |
sparsity, | |
std=0.01, | |
generator: _Optional[torch.Generator] = None, | |
): | |
r"""Fill the 2D input `Tensor` as a sparse matrix. | |
The non-zero elements will be drawn from the normal distribution | |
:math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via | |
Hessian-free optimization` - Martens, J. (2010). | |
Args: | |
tensor: an n-dimensional `torch.Tensor` | |
sparsity: The fraction of elements in each column to be set to zero | |
std: the standard deviation of the normal distribution used to generate | |
the non-zero values | |
generator: the torch Generator to sample from (default: None) | |
Examples: | |
>>> w = torch.empty(3, 5) | |
>>> nn.init.sparse_(w, sparsity=0.1) | |
""" | |
if tensor.ndimension() != 2: | |
raise ValueError("Only tensors with 2 dimensions are supported") | |
rows, cols = tensor.shape | |
num_zeros = int(math.ceil(sparsity * rows)) | |
with torch.no_grad(): | |
tensor.normal_(0, std, generator=generator) | |
for col_idx in range(cols): | |
row_indices = torch.randperm(rows) | |
zero_indices = row_indices[:num_zeros] | |
tensor[zero_indices, col_idx] = 0 | |
return tensor | |
# for backward compatibility | |
def _make_deprecate(meth): | |
new_name = meth.__name__ | |
old_name = new_name[:-1] | |
def deprecated_init(*args, **kwargs): | |
warnings.warn(f"nn.init.{old_name} is now deprecated in favor of nn.init.{new_name}.", stacklevel=2) | |
return meth(*args, **kwargs) | |
deprecated_init.__doc__ = fr""" | |
{old_name}(...) | |
.. warning:: | |
This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`. | |
See :func:`~torch.nn.init.{new_name}` for details.""" | |
deprecated_init.__name__ = old_name | |
return deprecated_init | |
uniform = _make_deprecate(uniform_) | |
normal = _make_deprecate(normal_) | |
constant = _make_deprecate(constant_) | |
eye = _make_deprecate(eye_) | |
dirac = _make_deprecate(dirac_) | |
xavier_uniform = _make_deprecate(xavier_uniform_) | |
xavier_normal = _make_deprecate(xavier_normal_) | |
kaiming_uniform = _make_deprecate(kaiming_uniform_) | |
kaiming_normal = _make_deprecate(kaiming_normal_) | |
orthogonal = _make_deprecate(orthogonal_) | |
sparse = _make_deprecate(sparse_) | |