|
import math |
|
import json |
|
import re |
|
from copy import deepcopy |
|
from pathlib import Path |
|
from dataclasses import dataclass |
|
from typing import Callable |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from torch.utils.checkpoint import checkpoint |
|
|
|
import xformers.ops as xops |
|
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
from open_lm.attention import get_attn_func, xformers_attn, torch_attn |
|
from open_lm.norms import get_norm_class |
|
from open_lm.positional_embedding.head_rotary import HeadRotaryWithCast |
|
from open_lm.positional_embedding.rotary import RotaryWithCast |
|
from open_lm.positional_embedding.llama_rotary import LLaMARotaryWithCast |
|
from open_lm.positional_embedding.none import identity_with_cast |
|
|
|
|
|
try: |
|
from megablocks.layers.moe import MoE |
|
from megablocks.layers.arguments import Arguments as MoEArgs |
|
except ImportError: |
|
MoE = None |
|
MoEArgs = None |
|
|
|
try: |
|
from mamba_ssm import MambaLMHeadModel |
|
except ImportError: |
|
MambaLMHeadModel = None |
|
|
|
|
|
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] |
|
_MODEL_CONFIGS = {} |
|
|
|
|
|
def _natural_key(string_): |
|
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] |
|
|
|
|
|
def _rescan_model_configs(model_config_paths=None): |
|
global _MODEL_CONFIGS |
|
|
|
config_iter = None |
|
if model_config_paths is not None: |
|
config_iter = [ |
|
Path(model_config_paths), |
|
] |
|
else: |
|
config_iter = _MODEL_CONFIG_PATHS |
|
|
|
config_ext = (".json",) |
|
config_files = [] |
|
for config_path in config_iter: |
|
if config_path.is_file() and config_path.suffix in config_ext: |
|
config_files.append(Path(config_path)) |
|
elif config_path.is_dir(): |
|
for ext in config_ext: |
|
config_files.extend(config_path.glob(f"*{ext}")) |
|
|
|
for cf in config_files: |
|
with open(cf, "r") as f: |
|
model_cfg = json.load(f) |
|
_MODEL_CONFIGS[cf.stem] = model_cfg |
|
|
|
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))} |
|
|
|
|
|
_rescan_model_configs() |
|
|
|
|
|
|
|
@dataclass |
|
class Params: |
|
dim: int = 512 |
|
n_layers: int = 8 |
|
n_heads: int = 8 |
|
vocab_size: int = -1 |
|
norm_eps: float = 1e-5 |
|
seq_len: int = 2048 |
|
post_embed_norm: bool = False |
|
weight_tying: bool = False |
|
norm_type: nn.Module = nn.LayerNorm |
|
attn_func: Callable = xformers_attn if torch.cuda.is_available() else torch_attn |
|
apply_qk_norm: bool = False |
|
moe_loss_weight: float = 0.1 |
|
moe_capacity_factor: float = 1.25 |
|
moe_expert_model_parallelism: bool = False |
|
moe_weight_parallelism: bool = False |
|
moe_num_experts: int = 8 |
|
moe_top_k: int = 2 |
|
moe_freq: int = 0 |
|
positional_embedding_type: str = "rotary" |
|
ffn_type: str = "swiglu" |
|
|
|
|
|
def get_pos_embed(args: Params): |
|
head_dim = args.dim // args.n_heads |
|
if args.positional_embedding_type == "rotary": |
|
return RotaryWithCast(head_dim, args.seq_len) |
|
elif args.positional_embedding_type == "llama_rotary": |
|
return LLaMARotaryWithCast(head_dim, args.n_heads, args.seq_len) |
|
elif args.positional_embedding_type == "head_rotary": |
|
return HeadRotaryWithCast(head_dim, args.seq_len) |
|
elif args.positional_embedding_type == "none": |
|
return identity_with_cast |
|
else: |
|
raise RuntimeError(f"Unknown positional embedding type {args.positional_embedding_type}") |
|
|
|
|
|
class CustomAttn(nn.Module): |
|
def __init__(self, layer_id, args: Params): |
|
super().__init__() |
|
self.n_heads = args.n_heads |
|
self.head_dim = args.dim // args.n_heads |
|
self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False) |
|
self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) |
|
self.pos_embed = get_pos_embed(args) |
|
self.attn_fn = args.attn_func |
|
self.apply_qk_norm = args.apply_qk_norm |
|
|
|
|
|
self.q_norm = ( |
|
args.norm_type( |
|
args.n_heads * self.head_dim, |
|
eps=args.norm_eps, |
|
) |
|
if self.apply_qk_norm |
|
else nn.Identity() |
|
) |
|
self.k_norm = ( |
|
args.norm_type( |
|
args.n_heads * self.head_dim, |
|
eps=args.norm_eps, |
|
) |
|
if self.apply_qk_norm |
|
else nn.Identity() |
|
) |
|
|
|
self.layer_id = layer_id |
|
self.dim = args.dim |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
|
|
std = 1.0 / math.sqrt(self.dim) |
|
torch.nn.init.trunc_normal_(self.in_proj.weight, std=std, a=-3 * std, b=3 * std) |
|
|
|
std = std / math.sqrt(2 * (self.layer_id + 1)) |
|
torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std) |
|
|
|
def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False, attention_mask=None): |
|
batchsize, q_len, _ = x.shape |
|
queries, keys, vals = self.in_proj(x).chunk(3, dim=-1) |
|
|
|
queries = self.q_norm(queries) |
|
keys = self.k_norm(keys) |
|
|
|
queries = queries.view(batchsize, q_len, self.n_heads, self.head_dim) |
|
keys = keys.view(batchsize, q_len, self.n_heads, self.head_dim) |
|
vals = vals.view(batchsize, q_len, self.n_heads, self.head_dim) |
|
|
|
past_length = 0 if past_key_value is None else past_key_value[0].shape[1] |
|
queries, keys, vals = self.pos_embed(queries, keys, vals, offset=past_length) |
|
|
|
if past_key_value is not None and use_cache: |
|
keys = torch.cat([past_key_value[0], keys], dim=1) |
|
vals = torch.cat([past_key_value[1], vals], dim=1) |
|
|
|
if use_cache: |
|
past_key_value = [keys, vals] |
|
|
|
output = self.attn_fn( |
|
queries, |
|
keys, |
|
vals, |
|
is_causal=is_causal, |
|
attention_mask=attention_mask, |
|
) |
|
|
|
output = output.view(batchsize, q_len, -1) |
|
|
|
return self.out_proj(output), past_key_value |
|
|
|
|
|
class GemmaMLP(nn.Module): |
|
"""Google's Gemma model MLP (aka GeGLU). |
|
|
|
Modified from https://github.com/google/gemma_pytorch/blob/01062c9ef4cf89ac0c985b25a734164ede017d0b/gemma/model.py#L182-L201 |
|
""" |
|
|
|
def __init__(self, dim: int, hidden_dim: int, layer_id: int): |
|
super().__init__() |
|
self.dim = dim |
|
self.hidden_dim = hidden_dim |
|
self.gate_proj = nn.Linear(dim, hidden_dim) |
|
self.up_proj = nn.Linear(dim, hidden_dim) |
|
self.down_proj = nn.Linear(hidden_dim, dim) |
|
self._layer_id = layer_id |
|
|
|
def forward(self, x): |
|
gate = self.gate_proj(x) |
|
gate = F.gelu(gate) |
|
up = self.up_proj(x) |
|
fuse = gate * up |
|
outputs = self.down_proj(fuse) |
|
return outputs |
|
|
|
def reset_parameters(self): |
|
std = 1.0 / math.sqrt(self.dim) |
|
torch.nn.init.trunc_normal_(self.gate_proj.weight, std=std, a=-3 * std, b=3 * std) |
|
torch.nn.init.trunc_normal_(self.up_proj.weight, std=std, a=-3 * std, b=3 * std) |
|
|
|
std = 1.0 / math.sqrt(self.hidden_dim) |
|
std = std / math.sqrt(2 * (self._layer_id + 1)) |
|
torch.nn.init.trunc_normal_(self.down_proj.weight, std=std, a=-3 * std, b=3 * std) |
|
|
|
|
|
|
|
|
|
class SwiGLUTorch(nn.Module): |
|
def __init__(self, in_dim, hidden_dim, out_dim, bias=True): |
|
super().__init__() |
|
self.w12 = nn.Linear(in_dim, 2 * hidden_dim, bias=bias) |
|
self.w3 = nn.Linear(hidden_dim, out_dim, bias=bias) |
|
|
|
def forward(self, x): |
|
gate, x = self.w12(x).chunk(2, dim=-1) |
|
x = F.silu(gate) * x |
|
return self.w3(x) |
|
|
|
|
|
class Block(nn.Module): |
|
def __init__(self, layer_id, args: Params): |
|
super().__init__() |
|
self.n_heads = args.n_heads |
|
self.dim = args.dim |
|
|
|
self.head_dim = args.dim // args.n_heads |
|
self.attention = CustomAttn(layer_id, args) |
|
self._ffn_type = args.ffn_type |
|
if args.ffn_type == "swiglu": |
|
|
|
self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256) |
|
self.feed_forward = xops.SwiGLU(args.dim, self.hidden_dim, args.dim, bias=False) |
|
elif args.ffn_type == "swiglu_torch": |
|
|
|
self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256) |
|
self.feed_forward = SwiGLUTorch(args.dim, self.hidden_dim, args.dim, bias=False) |
|
elif args.ffn_type == "gelu": |
|
|
|
self.hidden_dim = args.dim * 4 |
|
self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False) |
|
self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False) |
|
self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2) |
|
elif args.ffn_type == "gemma_geglu": |
|
|
|
self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256) |
|
self.feed_forward = GemmaMLP(args.dim, self.hidden_dim, layer_id) |
|
elif args.ffn_type == "moe": |
|
moe_args = MoEArgs( |
|
hidden_size=args.dim, |
|
ffn_hidden_size=args.dim * 4, |
|
moe_num_experts=args.moe_num_experts, |
|
moe_weight_parallelism=args.moe_weight_parallelism, |
|
moe_expert_model_parallelism=args.moe_expert_model_parallelism, |
|
moe_top_k=args.moe_top_k, |
|
moe_capacity_factor=args.moe_capacity_factor, |
|
moe_loss_weight=args.moe_loss_weight, |
|
device=torch.cuda.current_device(), |
|
bf16=False, |
|
fp16=False, |
|
) |
|
self.feed_forward = MoE(moe_args) |
|
|
|
self.layer_id = layer_id |
|
self.attention_norm = args.norm_type( |
|
args.dim, |
|
eps=args.norm_eps, |
|
) |
|
self.ffn_norm = args.norm_type( |
|
args.dim, |
|
eps=args.norm_eps, |
|
) |
|
self.attention.seq_len = args.seq_len |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
if self._ffn_type == "swiglu" or self._ffn_type == "swiglu_torch": |
|
|
|
std = 1.0 / math.sqrt(self.dim) |
|
torch.nn.init.trunc_normal_(self.feed_forward.w12.weight, std=std, a=-3 * std, b=3 * std) |
|
|
|
std = 1.0 / math.sqrt(self.hidden_dim) |
|
std = std / math.sqrt(2 * (self.layer_id + 1)) |
|
torch.nn.init.trunc_normal_(self.feed_forward.w3.weight, std=std, a=-3 * std, b=3 * std) |
|
elif self._ffn_type == "gelu": |
|
std = 1.0 / math.sqrt(self.dim) |
|
torch.nn.init.trunc_normal_(self._ff_w1.weight, std=std, a=-3 * std, b=3 * std) |
|
|
|
std = 1.0 / math.sqrt(self.hidden_dim) |
|
std = std / math.sqrt(2 * (self.layer_id + 1)) |
|
torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std) |
|
|
|
def forward(self, x, past_key_value=None, use_cache=False, attention_mask=None): |
|
h, past_key_value = self.attention( |
|
self.attention_norm(x), |
|
is_causal=True, |
|
past_key_value=past_key_value, |
|
use_cache=use_cache, |
|
attention_mask=attention_mask, |
|
) |
|
h = x + h |
|
if self._ffn_type == "moe": |
|
ffn_out, _ = self.feed_forward(self.ffn_norm(h)) |
|
else: |
|
ffn_out = self.feed_forward(self.ffn_norm(h)) |
|
out = h + ffn_out |
|
return out, past_key_value |
|
|
|
|
|
class Transformer(nn.Module, PyTorchModelHubMixin): |
|
def __init__(self, params): |
|
super().__init__() |
|
|
|
self.params = params |
|
self.dim = params.dim |
|
self.vocab_size = params.vocab_size |
|
self.n_layers = params.n_layers |
|
self.moe_num_experts = params.moe_num_experts |
|
self.seq_len = params.seq_len |
|
self.post_embed_norm = ( |
|
params.norm_type( |
|
params.dim, |
|
eps=params.norm_eps, |
|
) |
|
if params.post_embed_norm |
|
else nn.Identity() |
|
) |
|
self.weight_tying = params.weight_tying |
|
|
|
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) |
|
|
|
self.layers = torch.nn.ModuleList() |
|
ffn_type_ = params.ffn_type |
|
for layer_id in range(params.n_layers): |
|
if params.moe_freq > 0 and layer_id % params.moe_freq == 0: |
|
params.ffn_type = "moe" |
|
else: |
|
params.ffn_type = ffn_type_ |
|
self.layers.append(Block(layer_id, params)) |
|
|
|
|
|
self.norm = params.norm_type( |
|
params.dim, |
|
eps=params.norm_eps, |
|
) |
|
self.output = nn.Linear(params.dim, params.vocab_size, bias=False) |
|
if self.weight_tying: |
|
self.tok_embeddings.weight = self.output.weight |
|
self.grad_checkpointing = False |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
|
|
|
|
|
|
std = 1.0 / math.sqrt(self.params.dim) |
|
torch.nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std) |
|
torch.nn.init.trunc_normal_(self.tok_embeddings.weight, std=std, a=-3 * std, b=3 * std) |
|
|
|
@torch.jit.ignore |
|
def set_grad_checkpointing(self, enable=True): |
|
self.grad_checkpointing = enable |
|
|
|
def forward(self, input_ids=None, inputs_embeds=None, past_key_values=None, use_cache=False, attention_mask=None): |
|
""" |
|
Args: |
|
input |
|
past_key_values |
|
use_cache (bool) |
|
attention_mask (torch.Tensor): Shape (batch_size, sequence_len), indicates tokens that should not be |
|
attended to. attention_mask[s, i] = False indicates that token i should not be attended to by any other |
|
token for sequence s. |
|
""" |
|
if input_ids is not None: |
|
x = self.tok_embeddings(input_ids) |
|
elif inputs_embeds is not None: |
|
x = inputs_embeds |
|
else: |
|
raise ValueError("Either input_ids or inputs_embeds must be provided.") |
|
|
|
x = self.post_embed_norm(x) |
|
|
|
if past_key_values is None: |
|
past_key_values = [None] * self.n_layers |
|
elif isinstance(past_key_values, tuple): |
|
past_key_values = list(past_key_values) |
|
for i, layer in enumerate(self.layers): |
|
if self.grad_checkpointing: |
|
x, past_key_values[i] = checkpoint(layer, x, past_key_values[i], use_cache, attention_mask) |
|
else: |
|
x, past_key_values[i] = layer(x, past_key_values[i], use_cache=use_cache, attention_mask=attention_mask) |
|
if past_key_values[0] is None: |
|
past_key_values = None |
|
x = self.norm(x) |
|
output = self.output(x) |
|
|
|
return output.float(), x, past_key_values |
|
|
|
def get_input_embeddings(self): |
|
return self.tok_embeddings |
|
|
|
def get_output_embeddings(self): |
|
return self.output |
|
|
|
|
|
def create_params(args): |
|
cfg = None |
|
|
|
if args.model.endswith(".json"): |
|
_rescan_model_configs(model_config_paths=args.model) |
|
args.model = Path(args.model).stem |
|
|
|
if args.model in _MODEL_CONFIGS: |
|
cfg = deepcopy(_MODEL_CONFIGS[args.model]) |
|
else: |
|
raise ValueError("Pass a pre-defined open_lm model name or a json config") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if "mamba" in args.model: |
|
return { |
|
"d_model": cfg["d_model"], |
|
"n_layer": cfg["n_layer"], |
|
"vocab_size": cfg["vocab_size"], |
|
"seq_len": cfg["seq_len"], |
|
} |
|
else: |
|
return Params( |
|
dim=cfg["hidden_dim"], |
|
n_layers=cfg["n_layers"], |
|
n_heads=cfg["n_heads"], |
|
seq_len=cfg["seq_len"], |
|
vocab_size=cfg["vocab_size"], |
|
post_embed_norm=cfg["post_embed_norm"], |
|
weight_tying=cfg["weight_tying"], |
|
norm_type=get_norm_class(cfg.get("model_norm", args.model_norm)), |
|
attn_func=get_attn_func( |
|
args.attn_name, args.attn_activation, args.attn_seq_scalar, args.attn_seq_scalar_alpha |
|
), |
|
apply_qk_norm=cfg.get("qk_norm", args.qk_norm), |
|
positional_embedding_type=cfg.get("positional_embedding_type", args.positional_embedding_type), |
|
ffn_type=cfg.get("ffn_type", args.ffn_type), |
|
moe_num_experts=cfg.get("moe_num_experts", args.moe_num_experts), |
|
moe_loss_weight=cfg.get("moe_loss_weight", args.moe_loss_weight), |
|
moe_expert_model_parallelism=cfg.get("moe_expert_model_parallelism", args.moe_expert_model_parallelism), |
|
moe_weight_parallelism=cfg.get("moe_weight_parallelism", args.moe_weight_parallelism), |
|
moe_capacity_factor=cfg.get("moe_capacity_factor", args.moe_capacity_factor), |
|
moe_freq=cfg.get("moe_freq", args.moe_freq), |
|
moe_top_k=cfg.get("moe_top_k", args.moe_top_k), |
|
) |
|
|
|
|
|
class Mamba(nn.Module): |
|
|
|
|
|
def __init__(self, params): |
|
if MambaLMHeadModel is None: |
|
raise ImportError( |
|
"MambaLMHeadModel is not available. Please install the 'mamba_ssm' package by running 'pip install mamba-ssm'." |
|
) |
|
|
|
super().__init__() |
|
self.seq_len = params.pop("seq_len") |
|
self.vocab_size = params["vocab_size"] |
|
|
|
self.model = MambaLMHeadModel(**params) |
|
|
|
def reset_parameters(self): |
|
return |
|
|
|
def forward(self, x): |
|
out = self.model(x).logits |
|
return out, None, None |
|
|
|
|
|
def create_model(args): |
|
if "mamba" in args.model: |
|
model = Mamba(create_params(args)) |
|
return model |
|
else: |
|
model = Transformer(create_params(args)) |
|
return model |
|
|