olm-chat-7b / open_lm /norms.py
henhenhahi111112's picture
Upload folder using huggingface_hub
af6e330 verified
import numbers
from functools import partial
from typing import Union, List
import torch
from torch import Tensor, Size
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
class LayerNorm(nn.Module):
# NOTE: taken from official pytorch implementation and modified
# to allow revoval of gain and bias independently
def __init__(
self,
normalized_shape: Union[int, List[int], Size],
eps: float = 0.00001,
elementwise_gain: bool = True,
elementwise_bias: bool = True,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.elementwise_gain = elementwise_gain
self.elementwise_bias = elementwise_bias
if self.elementwise_gain:
self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter("weight", None)
if self.elementwise_bias:
self.bias = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter("bias", None)
self.reset_parameters()
def reset_parameters(self) -> None:
if self.elementwise_gain:
with torch.no_grad():
self.weight.fill_(1.0)
if self.elementwise_bias:
with torch.no_grad():
self.bias.zero_()
def forward(self, input: Tensor) -> Tensor:
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_gain={elementwise_gain}, "
"elementwise_bias={elementwise_bias}".format(**self.__dict__)
)
class LPLayerNorm(LayerNorm):
"""From MosaicML composer.
See: https://github.com/mosaicml/composer/blob/6acca4c70425455be7280a5459dbf02e1ac5591d/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py#L63
"""
def forward(self, x):
module_device = x.device
downcast_x = _cast_if_autocast_enabled(x)
downcast_weight = _cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
downcast_bias = _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
with torch.autocast(enabled=False, device_type=module_device.type):
return F.layer_norm(
downcast_x,
self.normalized_shape,
downcast_weight,
downcast_bias,
self.eps,
)
def _cast_if_autocast_enabled(tensor):
if torch.is_autocast_enabled():
if tensor.device.type == "cuda":
dtype = torch.get_autocast_gpu_dtype()
elif tensor.device.type == "cpu":
dtype = torch.get_autocast_cpu_dtype()
else:
raise NotImplementedError()
return tensor.to(dtype=dtype)
return tensor
class RmsNorm(nn.Module):
def __init__(
self,
normalized_shape: Union[int, List[int], Size],
eps: float = 1e-6,
device=None,
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
if isinstance(normalized_shape, numbers.Integral):
# mypy error: incompatible types in assignment
normalized_shape = (normalized_shape,) # type: ignore[assignment]
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
self.eps = eps
self.weight = Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
self.reset_parameters()
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float()).type_as(x)
return output * self.weight
def reset_parameters(self) -> None:
with torch.no_grad():
self.weight.fill_(1.0)
def extra_repr(self) -> str:
return "{normalized_shape}, eps={eps} ".format(**self.__dict__)
def get_norm_class(model_norm):
if model_norm == "default_layer_norm":
return torch.nn.LayerNorm
elif model_norm == "lp_layer_norm":
return LPLayerNorm
elif model_norm == "gain_only_lp_layer_norm":
return partial(LPLayerNorm, elementwise_gain=True, elementwise_bias=False)
elif model_norm == "gain_only_layer_norm":
return partial(LayerNorm, elementwise_gain=True, elementwise_bias=False)
elif model_norm == "no_wb_layer_norm":
return partial(LayerNorm, elementwise_gain=False, elementwise_bias=False)
elif model_norm == "rms_norm":
return RmsNorm
else:
raise ValueError(f"Unsupported model-norm: {model_norm}")