OLMo-Bitnet-1B / initialization.py
emozilla's picture
update inference code
2010c83
import math
from typing import Optional, Union
import torch
import torch.nn as nn
from .config import InitFnType, ModelConfig
from .util import StrEnum
__all__ = ["init_weights", "ModuleType"]
class ModuleType(StrEnum):
in_module = "in"
out_module = "out"
emb = "emb"
final_out = "final_out"
def init_weights(
config: ModelConfig,
module: Union[nn.Linear, nn.Embedding],
d: Optional[int] = None,
layer_id: Optional[int] = None,
std_factor: float = 1.0,
type_of_module: Optional[ModuleType] = None,
) -> None:
"""
Initialize weights of a linear or embedding module.
:param config: The model config.
:param module: The linear or embedding submodule to initialize.
:param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
for fused layers.
:param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
``1 / sqrt(2 * (layer_id + 1))``.
"""
d = d if d is not None else config.d_model
if config.init_fn == InitFnType.normal:
std = config.init_std * std_factor
if config.init_cutoff_factor is not None:
cutoff_value = config.init_cutoff_factor * std
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
else:
nn.init.normal_(module.weight, mean=0.0, std=std)
elif config.init_fn == InitFnType.mitchell:
std = std_factor / math.sqrt(d)
if layer_id is not None:
std = std / math.sqrt(2 * (layer_id + 1))
nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
elif config.init_fn == InitFnType.kaiming_normal:
nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
elif config.init_fn == InitFnType.fan_in:
std = std_factor / math.sqrt(d)
nn.init.normal_(module.weight, mean=0.0, std=std)
elif config.init_fn == InitFnType.full_megatron:
if type_of_module is None:
raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
cutoff_factor = config.init_cutoff_factor
if cutoff_factor is None:
cutoff_factor = 3
if type_of_module == ModuleType.in_module:
# for att_proj (same as QKV), ff_proj
std = config.init_std
elif type_of_module == ModuleType.out_module:
# for attn_out, ff_out
std = config.init_std / math.sqrt(2.0 * config.n_layers)
elif type_of_module == ModuleType.emb:
# positional embeddings (wpe)
# token embeddings (wte)
std = config.init_std
elif type_of_module == ModuleType.final_out:
# final output (ff_out)
std = config.d_model**-0.5
else:
raise RuntimeError(f"Unknown module type '{type_of_module}'")
nn.init.trunc_normal_(
module.weight,
mean=0.0,
std=std,
a=-cutoff_factor * std,
b=cutoff_factor * std,
)
else:
raise NotImplementedError(config.init_fn)
if isinstance(module, nn.Linear):
if module.bias is not None:
nn.init.zeros_(module.bias)
if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
with torch.no_grad():
module.weight.div_(math.sqrt(2 * config.n_layers))