olm-chat-7b / open_lm /model.py
henhenhahi111112's picture
Upload folder using huggingface_hub
af6e330 verified
import math
import json
import re
from copy import deepcopy
from pathlib import Path
from dataclasses import dataclass
from typing import Callable
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.checkpoint import checkpoint
import xformers.ops as xops
from huggingface_hub import PyTorchModelHubMixin
from open_lm.attention import get_attn_func, xformers_attn, torch_attn
from open_lm.norms import get_norm_class
from open_lm.positional_embedding.head_rotary import HeadRotaryWithCast
from open_lm.positional_embedding.rotary import RotaryWithCast
from open_lm.positional_embedding.llama_rotary import LLaMARotaryWithCast
from open_lm.positional_embedding.none import identity_with_cast
# from open_lm.moe.mixture_of_experts import MoE
try:
from megablocks.layers.moe import MoE
from megablocks.layers.arguments import Arguments as MoEArgs
except ImportError:
MoE = None
MoEArgs = None
try: # optional import
from mamba_ssm import MambaLMHeadModel
except ImportError:
MambaLMHeadModel = None
# from openclip
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
def _natural_key(string_):
return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())]
def _rescan_model_configs(model_config_paths=None):
global _MODEL_CONFIGS
config_iter = None
if model_config_paths is not None:
config_iter = [
Path(model_config_paths),
]
else:
config_iter = _MODEL_CONFIG_PATHS
config_ext = (".json",)
config_files = []
for config_path in config_iter:
if config_path.is_file() and config_path.suffix in config_ext:
config_files.append(Path(config_path))
elif config_path.is_dir():
for ext in config_ext:
config_files.extend(config_path.glob(f"*{ext}"))
for cf in config_files:
with open(cf, "r") as f:
model_cfg = json.load(f)
_MODEL_CONFIGS[cf.stem] = model_cfg
_MODEL_CONFIGS = {k: v for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0]))}
_rescan_model_configs() # initial populate of model config registry
# args and default params follow llama (except with LayerNorm instead of RmsNorm)
@dataclass
class Params:
dim: int = 512
n_layers: int = 8
n_heads: int = 8
vocab_size: int = -1
norm_eps: float = 1e-5
seq_len: int = 2048
post_embed_norm: bool = False
weight_tying: bool = False
norm_type: nn.Module = nn.LayerNorm
attn_func: Callable = xformers_attn if torch.cuda.is_available() else torch_attn
apply_qk_norm: bool = False
moe_loss_weight: float = 0.1
moe_capacity_factor: float = 1.25
moe_expert_model_parallelism: bool = False
moe_weight_parallelism: bool = False
moe_num_experts: int = 8
moe_top_k: int = 2
moe_freq: int = 0
positional_embedding_type: str = "rotary"
ffn_type: str = "swiglu"
def get_pos_embed(args: Params):
head_dim = args.dim // args.n_heads
if args.positional_embedding_type == "rotary":
return RotaryWithCast(head_dim, args.seq_len)
elif args.positional_embedding_type == "llama_rotary":
return LLaMARotaryWithCast(head_dim, args.n_heads, args.seq_len)
elif args.positional_embedding_type == "head_rotary":
return HeadRotaryWithCast(head_dim, args.seq_len)
elif args.positional_embedding_type == "none":
return identity_with_cast
else:
raise RuntimeError(f"Unknown positional embedding type {args.positional_embedding_type}")
class CustomAttn(nn.Module):
def __init__(self, layer_id, args: Params):
super().__init__()
self.n_heads = args.n_heads
self.head_dim = args.dim // args.n_heads
self.in_proj = nn.Linear(args.dim, 3 * args.n_heads * self.head_dim, bias=False)
self.out_proj = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
self.pos_embed = get_pos_embed(args)
self.attn_fn = args.attn_func
self.apply_qk_norm = args.apply_qk_norm
# initialize norm layers for queries and keys if needed
self.q_norm = (
args.norm_type(
args.n_heads * self.head_dim,
eps=args.norm_eps,
)
if self.apply_qk_norm
else nn.Identity()
)
self.k_norm = (
args.norm_type(
args.n_heads * self.head_dim,
eps=args.norm_eps,
)
if self.apply_qk_norm
else nn.Identity()
)
self.layer_id = layer_id
self.dim = args.dim
self.reset_parameters()
def reset_parameters(self):
# initialize weights by trunc_normal(1/sqrt(fan_in))
std = 1.0 / math.sqrt(self.dim)
torch.nn.init.trunc_normal_(self.in_proj.weight, std=std, a=-3 * std, b=3 * std)
# scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better.
std = std / math.sqrt(2 * (self.layer_id + 1))
torch.nn.init.trunc_normal_(self.out_proj.weight, std=std, a=-3 * std, b=3 * std)
def forward(self, x: torch.Tensor, is_causal=True, past_key_value=None, use_cache=False, attention_mask=None):
batchsize, q_len, _ = x.shape
queries, keys, vals = self.in_proj(x).chunk(3, dim=-1)
queries = self.q_norm(queries)
keys = self.k_norm(keys)
queries = queries.view(batchsize, q_len, self.n_heads, self.head_dim)
keys = keys.view(batchsize, q_len, self.n_heads, self.head_dim)
vals = vals.view(batchsize, q_len, self.n_heads, self.head_dim)
past_length = 0 if past_key_value is None else past_key_value[0].shape[1]
queries, keys, vals = self.pos_embed(queries, keys, vals, offset=past_length)
if past_key_value is not None and use_cache:
keys = torch.cat([past_key_value[0], keys], dim=1)
vals = torch.cat([past_key_value[1], vals], dim=1)
if use_cache:
past_key_value = [keys, vals]
output = self.attn_fn(
queries,
keys,
vals,
is_causal=is_causal,
attention_mask=attention_mask,
)
output = output.view(batchsize, q_len, -1)
return self.out_proj(output), past_key_value
class GemmaMLP(nn.Module):
"""Google's Gemma model MLP (aka GeGLU).
Modified from https://github.com/google/gemma_pytorch/blob/01062c9ef4cf89ac0c985b25a734164ede017d0b/gemma/model.py#L182-L201
"""
def __init__(self, dim: int, hidden_dim: int, layer_id: int):
super().__init__()
self.dim = dim
self.hidden_dim = hidden_dim
self.gate_proj = nn.Linear(dim, hidden_dim)
self.up_proj = nn.Linear(dim, hidden_dim)
self.down_proj = nn.Linear(hidden_dim, dim)
self._layer_id = layer_id
def forward(self, x):
gate = self.gate_proj(x)
gate = F.gelu(gate)
up = self.up_proj(x)
fuse = gate * up
outputs = self.down_proj(fuse)
return outputs
def reset_parameters(self):
std = 1.0 / math.sqrt(self.dim)
torch.nn.init.trunc_normal_(self.gate_proj.weight, std=std, a=-3 * std, b=3 * std)
torch.nn.init.trunc_normal_(self.up_proj.weight, std=std, a=-3 * std, b=3 * std)
std = 1.0 / math.sqrt(self.hidden_dim)
std = std / math.sqrt(2 * (self._layer_id + 1))
torch.nn.init.trunc_normal_(self.down_proj.weight, std=std, a=-3 * std, b=3 * std)
# Same as pseudocode provided from xformers SwiGLU
# https://github.com/facebookresearch/xformers
class SwiGLUTorch(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim, bias=True):
super().__init__()
self.w12 = nn.Linear(in_dim, 2 * hidden_dim, bias=bias)
self.w3 = nn.Linear(hidden_dim, out_dim, bias=bias)
def forward(self, x):
gate, x = self.w12(x).chunk(2, dim=-1)
x = F.silu(gate) * x
return self.w3(x)
class Block(nn.Module):
def __init__(self, layer_id, args: Params):
super().__init__()
self.n_heads = args.n_heads
self.dim = args.dim
self.head_dim = args.dim // args.n_heads
self.attention = CustomAttn(layer_id, args)
self._ffn_type = args.ffn_type
if args.ffn_type == "swiglu":
# this follows llama / lit llama -- go to multiple of 256
self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256)
self.feed_forward = xops.SwiGLU(args.dim, self.hidden_dim, args.dim, bias=False)
elif args.ffn_type == "swiglu_torch":
# this follows llama / lit llama -- go to multiple of 256
self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256)
self.feed_forward = SwiGLUTorch(args.dim, self.hidden_dim, args.dim, bias=False)
elif args.ffn_type == "gelu":
# Follows mosaic mpt7b, but without a bias.
self.hidden_dim = args.dim * 4
self._ff_w1 = nn.Linear(args.dim, self.hidden_dim, bias=False)
self._ff_w2 = nn.Linear(self.hidden_dim, args.dim, bias=False)
self.feed_forward = nn.Sequential(self._ff_w1, nn.GELU(approximate="none"), self._ff_w2)
elif args.ffn_type == "gemma_geglu":
# this follows llama / lit llama -- go to multiple of 256
self.hidden_dim = 256 * ((int(2 * 4 * args.dim / 3) + 256 - 1) // 256)
self.feed_forward = GemmaMLP(args.dim, self.hidden_dim, layer_id)
elif args.ffn_type == "moe":
moe_args = MoEArgs(
hidden_size=args.dim,
ffn_hidden_size=args.dim * 4,
moe_num_experts=args.moe_num_experts,
moe_weight_parallelism=args.moe_weight_parallelism,
moe_expert_model_parallelism=args.moe_expert_model_parallelism,
moe_top_k=args.moe_top_k,
moe_capacity_factor=args.moe_capacity_factor,
moe_loss_weight=args.moe_loss_weight,
device=torch.cuda.current_device(),
bf16=False,
fp16=False,
)
self.feed_forward = MoE(moe_args)
self.layer_id = layer_id
self.attention_norm = args.norm_type(
args.dim,
eps=args.norm_eps,
)
self.ffn_norm = args.norm_type(
args.dim,
eps=args.norm_eps,
)
self.attention.seq_len = args.seq_len
self.reset_parameters()
def reset_parameters(self):
if self._ffn_type == "swiglu" or self._ffn_type == "swiglu_torch":
# initialize weights trunc_normal(1/sqrt(fan_in))
std = 1.0 / math.sqrt(self.dim)
torch.nn.init.trunc_normal_(self.feed_forward.w12.weight, std=std, a=-3 * std, b=3 * std)
# scale init by depth as in https://arxiv.org/abs/1908.11365 -- worked slightly better.
std = 1.0 / math.sqrt(self.hidden_dim)
std = std / math.sqrt(2 * (self.layer_id + 1))
torch.nn.init.trunc_normal_(self.feed_forward.w3.weight, std=std, a=-3 * std, b=3 * std)
elif self._ffn_type == "gelu":
std = 1.0 / math.sqrt(self.dim)
torch.nn.init.trunc_normal_(self._ff_w1.weight, std=std, a=-3 * std, b=3 * std)
std = 1.0 / math.sqrt(self.hidden_dim)
std = std / math.sqrt(2 * (self.layer_id + 1))
torch.nn.init.trunc_normal_(self._ff_w2.weight, std=std, a=-3 * std, b=3 * std)
def forward(self, x, past_key_value=None, use_cache=False, attention_mask=None):
h, past_key_value = self.attention(
self.attention_norm(x),
is_causal=True,
past_key_value=past_key_value,
use_cache=use_cache,
attention_mask=attention_mask,
)
h = x + h
if self._ffn_type == "moe":
ffn_out, _ = self.feed_forward(self.ffn_norm(h))
else:
ffn_out = self.feed_forward(self.ffn_norm(h))
out = h + ffn_out
return out, past_key_value
class Transformer(nn.Module, PyTorchModelHubMixin):
def __init__(self, params):
super().__init__()
# for convenience we often share param names with llama
self.params = params
self.dim = params.dim
self.vocab_size = params.vocab_size
self.n_layers = params.n_layers
self.moe_num_experts = params.moe_num_experts
self.seq_len = params.seq_len
self.post_embed_norm = (
params.norm_type(
params.dim,
eps=params.norm_eps,
)
if params.post_embed_norm
else nn.Identity()
)
self.weight_tying = params.weight_tying
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
self.layers = torch.nn.ModuleList()
ffn_type_ = params.ffn_type
for layer_id in range(params.n_layers):
if params.moe_freq > 0 and layer_id % params.moe_freq == 0:
params.ffn_type = "moe"
else:
params.ffn_type = ffn_type_
self.layers.append(Block(layer_id, params))
# get class for normalization layers
self.norm = params.norm_type(
params.dim,
eps=params.norm_eps,
)
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
if self.weight_tying:
self.tok_embeddings.weight = self.output.weight
self.grad_checkpointing = False
self.reset_parameters()
def reset_parameters(self):
# initialize weight 1/sqrt(dim)
# this is 1/fan_in for output, as is default, and Maciej Kilian tried another option
# for the embed layer (from RWKV paper) but this was better.
std = 1.0 / math.sqrt(self.params.dim)
torch.nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std)
torch.nn.init.trunc_normal_(self.tok_embeddings.weight, std=std, a=-3 * std, b=3 * std)
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable
def forward(self, input_ids=None, inputs_embeds=None, past_key_values=None, use_cache=False, attention_mask=None):
"""
Args:
input
past_key_values
use_cache (bool)
attention_mask (torch.Tensor): Shape (batch_size, sequence_len), indicates tokens that should not be
attended to. attention_mask[s, i] = False indicates that token i should not be attended to by any other
token for sequence s.
"""
if input_ids is not None:
x = self.tok_embeddings(input_ids)
elif inputs_embeds is not None:
x = inputs_embeds
else:
raise ValueError("Either input_ids or inputs_embeds must be provided.")
x = self.post_embed_norm(x)
if past_key_values is None:
past_key_values = [None] * self.n_layers
elif isinstance(past_key_values, tuple):
past_key_values = list(past_key_values)
for i, layer in enumerate(self.layers):
if self.grad_checkpointing:
x, past_key_values[i] = checkpoint(layer, x, past_key_values[i], use_cache, attention_mask)
else:
x, past_key_values[i] = layer(x, past_key_values[i], use_cache=use_cache, attention_mask=attention_mask)
if past_key_values[0] is None:
past_key_values = None
x = self.norm(x)
output = self.output(x)
# follow llama in casting this to float.
return output.float(), x, past_key_values
def get_input_embeddings(self):
return self.tok_embeddings
def get_output_embeddings(self):
return self.output
def create_params(args):
cfg = None
if args.model.endswith(".json"):
_rescan_model_configs(model_config_paths=args.model)
args.model = Path(args.model).stem
# print(f"_MODEL_CONFIGS{_MODEL_CONFIGS}")
if args.model in _MODEL_CONFIGS:
cfg = deepcopy(_MODEL_CONFIGS[args.model])
else:
raise ValueError("Pass a pre-defined open_lm model name or a json config")
# Note: here all the parameters should come from the config file
# but for retro-compatibility, we add new model parameters to the args (with a default value that matches the old version)
# These args are managed separately by the argparser
# If a parameter is in the model config, regardless of the args, we use the config parameters
# If a parameter is not in the model config, we use the args parameter
if "mamba" in args.model:
return {
"d_model": cfg["d_model"],
"n_layer": cfg["n_layer"],
"vocab_size": cfg["vocab_size"],
"seq_len": cfg["seq_len"],
}
else:
return Params(
dim=cfg["hidden_dim"],
n_layers=cfg["n_layers"],
n_heads=cfg["n_heads"],
seq_len=cfg["seq_len"],
vocab_size=cfg["vocab_size"],
post_embed_norm=cfg["post_embed_norm"],
weight_tying=cfg["weight_tying"],
norm_type=get_norm_class(cfg.get("model_norm", args.model_norm)),
attn_func=get_attn_func(
args.attn_name, args.attn_activation, args.attn_seq_scalar, args.attn_seq_scalar_alpha
),
apply_qk_norm=cfg.get("qk_norm", args.qk_norm),
positional_embedding_type=cfg.get("positional_embedding_type", args.positional_embedding_type),
ffn_type=cfg.get("ffn_type", args.ffn_type),
moe_num_experts=cfg.get("moe_num_experts", args.moe_num_experts),
moe_loss_weight=cfg.get("moe_loss_weight", args.moe_loss_weight),
moe_expert_model_parallelism=cfg.get("moe_expert_model_parallelism", args.moe_expert_model_parallelism),
moe_weight_parallelism=cfg.get("moe_weight_parallelism", args.moe_weight_parallelism),
moe_capacity_factor=cfg.get("moe_capacity_factor", args.moe_capacity_factor),
moe_freq=cfg.get("moe_freq", args.moe_freq),
moe_top_k=cfg.get("moe_top_k", args.moe_top_k),
)
class Mamba(nn.Module):
# Experimental architecture, please "pip install mamba-ssm"
# https://arxiv.org/abs/2312.00752
def __init__(self, params):
if MambaLMHeadModel is None:
raise ImportError(
"MambaLMHeadModel is not available. Please install the 'mamba_ssm' package by running 'pip install mamba-ssm'."
)
super().__init__()
self.seq_len = params.pop("seq_len")
self.vocab_size = params["vocab_size"]
self.model = MambaLMHeadModel(**params)
def reset_parameters(self):
return
def forward(self, x):
out = self.model(x).logits
return out, None, None
def create_model(args):
if "mamba" in args.model:
model = Mamba(create_params(args))
return model
else:
model = Transformer(create_params(args))
return model