yuta0306
first commit
565faca
raw
history blame
No virus
2.59 kB
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
class LayerNorm(nn.Module):
"""LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False"""
def __init__(self, ndim, bias):
super().__init__()
self.weight = nn.Parameter(torch.ones(ndim))
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
def forward(self, input):
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
class RMSNorm(torch.nn.Module):
def __init__(self, ndim: int, eps: float):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(ndim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x) * self.weight
class SwiGLU(nn.Module):
def __init__(self, in_dim, out_dim, bias) -> None:
super().__init__()
self.w1 = nn.Linear(in_dim, out_dim, bias=bias)
self.w3 = nn.Linear(in_dim, out_dim, bias=bias)
def forward(self, x):
return F.silu(self.w1(x)) * self.w3(x)
class MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.non_linearity = config.nonlinearity_type
hidden_dim = 4 * config.n_embd
if config.nonlinearity_type == "gelu":
self.c_fc = nn.Linear(config.n_embd, hidden_dim, bias=config.bias)
self.gelu = nn.GELU()
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
elif config.nonlinearity_type == "swiglu":
if config.swiglu_multiple_of is None:
raise Exception("SwiGLU requires swiglu_multiple_of to be set")
hidden_dim = int(2 * hidden_dim / 3)
hidden_dim = config.swiglu_multiple_of * math.ceil(hidden_dim / config.swiglu_multiple_of)
# set name to `c_proj` so that the right initialisation gets applied to it in GPT.__init__()
self.swiglu = SwiGLU(config.n_embd, hidden_dim, bias=config.bias)
self.c_proj = nn.Linear(hidden_dim, config.n_embd, bias=config.bias)
else:
raise Exception(f"Unknown nonlinearity type: {config.nonlinearity_type}")
self.dropout = nn.Dropout(config.dropout)
def forward(self, x):
if self.non_linearity == "gelu":
x = self.c_fc(x)
x = self.gelu(x)
elif self.non_linearity == "swiglu":
x = self.swiglu(x)
x = self.c_proj(x)
x = self.dropout(x)
return x