|
import torch
|
|
|
|
|
|
class RMSNorm(torch.nn.Module):
|
|
"""Root Mean Square Layer Normalization.
|
|
|
|
Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
|
|
https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
|
|
"""
|
|
|
|
def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
|
|
super().__init__()
|
|
self.weight = torch.nn.Parameter(torch.ones(size))
|
|
self.eps = eps
|
|
self.dim = dim
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
dtype = x.dtype
|
|
x = x.float()
|
|
|
|
norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
|
|
x_normed = x * torch.rsqrt(norm_x + self.eps)
|
|
return (self.weight * x_normed).to(dtype=dtype)
|
|
|
|
def reset_parameters(self) -> None:
|
|
torch.nn.init.ones_(self.weight) |