NEOX / megatron /model /init_functions.py
akswelh's picture
Upload 251 files
d90b3a8 verified
# Copyright (c) 2024, EleutherAI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
try:
import mup
except ImportError:
pass
def init_method_normal(sigma, use_mup_outer=False, mup_init_scale=1.0):
"""Init method based on N(0, sigma)."""
def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.normal_(tensor, mean=0.0, std=sigma)
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)
return init_
def scaled_init_method_normal(
sigma,
num_layers,
use_mup_outer=False,
mup_init_scale=1.0,
num_residuals_per_layer=2,
):
"""Init method based on N(0, sigma/sqrt(2*num_layers).
Also allows for N(0, sigma/sqrt(x*num_layers)) where
x=number of residuals per layer (e.g. 1 for Mamba.)
"""
std = sigma / math.sqrt(num_residuals_per_layer * num_layers)
def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.normal_(tensor, mean=0.0, std=std)
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
# orthogonal init does not support fp16, so have to patch it
def _orthogonal(tensor, gain=1):
if tensor.ndimension() < 2:
raise ValueError("Only tensors with 2 or more dimensions are supported")
rows = tensor.size(0)
cols = tensor.numel() // rows
flattened = tensor.new(rows, cols).normal_(0, 1)
if rows < cols:
flattened.t_()
# Compute the qr factorization
dt = flattened.dtype
flattened = flattened.to(torch.float32) # orthogonal init does not support fp16
q, r = torch.qr(flattened)
q, r = q.to(dtype=dt), r.to(dtype=dt)
# Make Q uniform according to https://arxiv.org/pdf/math-ph/0609050.pdf
d = torch.diag(r, 0)
ph = d.sign()
q *= ph
if rows < cols:
q.t_()
with torch.no_grad():
tensor.view_as(q).copy_(q)
tensor.mul_(gain)
return tensor
def orthogonal_init_method(n_layers=1, use_mup=False, mup_init_scale=1.0):
"""Fills the input Tensor with a (semi) orthogonal matrix, as described in
Exact solutions to the nonlinear dynamics of learning in deep linear neural networks - Saxe, A. et al. (2013)
Optionally scaling by number of layers possible, as introduced in OBST - Nestler et. al. (2021, to be released)"""
if use_mup:
raise ValueError(
"Orthogonal init needs to be patched to support mup. Disable mup or use a different init method to avoid this error"
)
def init_(tensor):
return _orthogonal(tensor, math.sqrt(2 / n_layers))
return init_
def xavier_uniform_init_method(use_mup_outer=False, mup_init_scale=1.0):
"""Fills the input Tensor with values according to the method described in Understanding the difficulty of
training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a uniform distribution."""
def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.xavier_uniform_(tensor)
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.xavier_uniform_(tensor)
return init_
def xavier_normal_init_method(use_mup_outer=False, mup_init_scale=1.0):
"""Fills the input Tensor with values according to the method described in Understanding the difficulty of
training deep feedforward neural networks - Glorot, X. & Bengio, Y. (2010), using a normal distribution."""
def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.xavier_normal_(tensor)
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.xavier_normal_(tensor)
return init_
def small_init_init_method(dim, use_mup_outer=False, mup_init_scale=1.0):
"""Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution."""
std = math.sqrt(2 / (5 * dim))
def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.normal_(tensor, mean=0.0, std=std)
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def wang_init_method(n_layers, dim, use_mup_outer=False, mup_init_scale=1.0):
std = 2 / n_layers / math.sqrt(dim)
def init_(tensor, use_mup=use_mup_outer):
if use_mup:
mup.init.normal_(tensor, mean=0.0, std=std)
with torch.no_grad():
tensor.mul_(mup_init_scale)
return tensor
else:
return torch.nn.init.normal_(tensor, mean=0.0, std=std)
return init_
def get_init_methods(args):
if args.use_mup:
try:
import mup
except ModuleNotFoundError:
print("Please install mup https://github.com/microsoft/mup")
raise Exception
def _get(name):
if name == "normal":
return init_method_normal(
args.init_method_std, args.use_mup, args.mup_init_scale
)
elif name == "scaled_normal":
return scaled_init_method_normal(
args.init_method_std, args.num_layers, args.use_mup, args.mup_init_scale
)
elif name == "orthogonal":
return orthogonal_init_method(args.use_mup, args.mup_init_scale)
elif name == "scaled_orthogonal":
return orthogonal_init_method(
args.num_layers, args.use_mup, args.mup_init_scale
)
elif name == "xavier_uniform":
return xavier_uniform_init_method(args.use_mup, args.mup_init_scale)
elif name == "xavier_normal":
return xavier_normal_init_method(args.use_mup, args.mup_init_scale)
elif name == "wang_init":
return wang_init_method(
args.num_layers, args.hidden_size, args.use_mup, args.mup_init_scale
)
elif name == "small_init":
return small_init_init_method(
args.hidden_size, args.use_mup, args.mup_init_scale
)
elif name == "single_residual_scaled_normal":
# mamba init uses scaled_normal but no need for 2 * num_layers
# since only one residual per layer
return scaled_init_method_normal(
args.init_method_std,
args.num_layers,
args.use_mup,
args.mup_init_scale,
num_residuals_per_layer=1,
)
else:
raise NotImplementedError(f"Unknown init method {name}")
return _get(args.init_method), _get(args.output_layer_init_method)