Transformers
English
falcon
custom_code
text-generation-inference
FlaxFalcon / model.py
erfanzar's picture
Update model.py
4719d39
import math
from flax import linen as nn
from flax.core import FrozenDict
from typing import Optional, Dict, Union, Tuple
from transformers import FlaxPreTrainedModel, PretrainedConfig
from jax import numpy as jnp
import jax
from jax.interpreters import pxla
from jax.experimental.pjit import pjit, PartitionSpec, with_sharding_constraint as wsc
from transformers.modeling_flax_outputs import FlaxCausalLMOutput, FlaxBaseModelOutput
from jax.random import split, PRNGKey
from functools import partial
from einops import rearrange
ACT2FN = {
"gelu": partial(nn.gelu, approximate=False),
"relu": nn.relu,
"silu": nn.swish,
"swish": nn.swish,
"gelu_new": partial(nn.gelu, approximate=True),
}
def get_names_from_parition_spec(partition_specs):
names = set()
if isinstance(partition_specs, dict):
partition_specs = partition_specs.values()
for item in partition_specs:
if item is None:
continue
elif isinstance(item, str):
names.add(item)
else:
names.update(get_names_from_parition_spec(item))
return list(names)
def names_in_mesh(*names):
return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names)
def with_sharding_constraint(x, partition_specs):
axis_names = get_names_from_parition_spec(partition_specs)
if names_in_mesh(*axis_names):
x = wsc(x, partition_specs)
return x
class FalconConfig(PretrainedConfig):
model_type = "falcon"
attribute_map = {
"num_hidden_layers": "n_layer",
"num_attention_heads": "n_head",
}
def __init__(
self,
vocab_size=250880,
hidden_size=64,
n_layer=2,
n_head=8,
layer_norm_epsilon=1e-5,
initializer_range=0.02,
use_cache=True,
bos_token_id=1,
eos_token_id=2,
apply_residual_connection_post_layernorm=False,
hidden_dropout=0.0,
attention_dropout=0.0,
multi_query=False,
alibi=False,
bias=False,
parallel_attn=False,
max_seq_len=2048,
use_pjit_attention_force: bool = False,
gradient_checkpointing: str = 'nothing_saveable',
**kwargs,
):
self.vocab_size = vocab_size
n_embed = kwargs.pop("n_embed", None)
self.hidden_size = hidden_size if n_embed is None else n_embed
self.n_layer = n_layer
self.n_head = n_head
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_range = initializer_range
self.max_seq_len = max_seq_len
self.use_cache = use_cache
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.bos_token_id = bos_token_id
self.use_pjit_attention_force = use_pjit_attention_force
self.eos_token_id = eos_token_id
self.multi_query = multi_query
self.alibi = alibi
self.bias = bias
self.gradient_checkpointing = gradient_checkpointing
self.parallel_attn = parallel_attn
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@property
def head_dim(self):
return self.hidden_size // self.n_head
@property
def rotary(self):
return not self.alibi
@staticmethod
def get_partition_rules(fully_fsdp: bool = False):
return (
('wte/embedding', PartitionSpec('fsdp', 'mp')),
('self_attention/w_qkv/(kernel|bias)', PartitionSpec('fsdp', 'mp')),
('self_attention/wo/(kernel|bias)', PartitionSpec('fsdp', 'mp')),
('mlp/down/(kernel|bias)', PartitionSpec('fsdp', 'mp')),
('mlp/up/(kernel|bias)', PartitionSpec('mp', 'fsdp')),
('lm_head/kernel', PartitionSpec('fsdp', 'mp')),
('transformer/ln_f/bias', PartitionSpec('fsdp', 'mp')),
('transformer/ln_f/scale', PartitionSpec('fsdp', 'mp')),
('transformer/post_attention_layernorm/scale', PartitionSpec('mp', 'fsdp')),
('transformer/post_attention_layernorm/bias', PartitionSpec('mp', 'fsdp')),
('.*', PartitionSpec('fsdp', 'mp'))
) if not fully_fsdp else (
('wte/embedding', PartitionSpec('fsdp')),
('self_attention/w_qkv/(kernel|bias)', PartitionSpec('fsdp')),
('self_attention/wo/(kernel|bias)', PartitionSpec('fsdp')),
('mlp/down/(kernel|bias)', PartitionSpec('fsdp')),
('mlp/up/(kernel|bias)', PartitionSpec('fsdp')),
('lm_head/kernel', PartitionSpec('fsdp')),
('transformer/ln_f/bias', PartitionSpec('fsdp')),
('transformer/ln_f/scale', PartitionSpec('fsdp')),
('transformer/post_attention_layernorm/scale', PartitionSpec('fsdp')),
('transformer/post_attention_layernorm/bias', PartitionSpec('fsdp')),
('.*', PartitionSpec('fsdp'))
)
@staticmethod
def get_mesh_names():
return 'dp', 'fsdp', 'mp'
def built_bloom_alibi(attention_mask, num_attention_heads):
b, s = attention_mask.shape
cp2 = 2 ** math.floor(math.log2(num_attention_heads))
base = jnp.asarray(
2 ** (- (2 ** -(math.log2(cp2) - 3))), dtype=jnp.float32
)
powers = jnp.arange(1, 1 + cp2, dtype=jnp.float32)
slops = jnp.power(base, powers)
if cp2 != num_attention_heads:
extra_base = jnp.asarray(
2 ** (-(2 ** -(math.log2(2 * cp2) - 3))), dtype=jnp.float32
)
num_rem_heads = min(cp2, num_attention_heads - cp2)
extra_power = jnp.arange(1, 1 + 2 * num_rem_heads, 2, dtype=jnp.dtype)
slops = jnp.concatenate([slops, jnp.power(extra_base, extra_power)], axis=0)
arange_tensor = (((jnp.cumsum(attention_mask, axis=-1)) - 1) * attention_mask)[:, jnp.newaxis, :]
alibi = slops[..., jnp.newaxis].astype(jnp.bfloat16) * arange_tensor
return alibi.reshape(b, num_attention_heads, 1, s)
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
dtype: jnp.dtype = jnp.bfloat16) -> jnp.ndarray:
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
t = jnp.arange(end) # type: ignore
freqs = jnp.outer(t, freqs).astype(dtype)
sin, cos = jnp.sin(freqs), jnp.cos(freqs)
freqs_cis = jnp.complex64(cos + 1j * sin)
return jnp.asarray(freqs_cis)
def apply_rotary_emb(
xq: jnp.ndarray,
xk: jnp.ndarray,
freqs_cis: jnp.ndarray,
dtype: jnp.dtype = jnp.bfloat16,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:]))
xq_out = xq_ * freqs_cis
xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1)
xk_out = xk_ * freqs_cis
xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)
return xq_out.astype(dtype), xk_out.astype(dtype)
class FlaxFalconAttention(nn.Module):
config: FalconConfig
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
precision: Optional[Union[jax.lax.Precision, str]] = None
def setup(self) -> None:
head_dim = self.config.hidden_size // self.config.n_head
self.w_qkv = nn.Dense(
features=3 * self.config.hidden_size if not self.config.multi_query else (
self.config.hidden_size + 2 * head_dim),
dtype=self.dtype,
param_dtype=self.param_dtype,
use_bias=self.config.bias
)
self.factor_scale = 1 / math.sqrt(head_dim)
self.wo = nn.Dense(
features=self.config.hidden_size,
dtype=self.dtype,
param_dtype=self.param_dtype,
use_bias=self.config.bias
)
self.head_dim = head_dim
assert self.head_dim * self.config.n_head == self.config.hidden_size
if not self.config.alibi:
self.freq = precompute_freqs_cis(head_dim, self.config.max_seq_len, dtype=self.dtype)
def __call__(self,
hidden_states: jnp.DeviceArray,
alibi: jnp.DeviceArray = None,
attention_mask: jnp.DeviceArray = None,
):
b, s, d = hidden_states.shape
qkv = self.w_qkv(hidden_states)
if not self.config.multi_query:
q, k, v = jnp.split(qkv, 3, -1)
if self.config.use_pjit_attention_force:
q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
k = rearrange(k, 'b s (h d) -> b s h d', h=self.config.n_head)
q = rearrange(q, 'b s (h d) -> b s h d', h=self.config.n_head)
v = rearrange(v, 'b s (h d) -> b s h d', h=self.config.n_head)
else:
qkv = qkv.reshape(
b, s, self.config.n_head + 2, -1
)
q, k, v = qkv[..., :-2, :], qkv[..., [-2], :], qkv[..., [-1], :]
if self.config.use_pjit_attention_force:
q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
if not self.config.alibi:
freq = self.freq[:s].reshape(1, s, -1)
q, k = apply_rotary_emb(q, k, freq, self.dtype)
attn = jnp.einsum('...qhd,...khd->...hqk', q, k, precision=self.precision)
if self.config.use_pjit_attention_force:
attn = with_sharding_constraint(attn, PartitionSpec(("dp", "fsdp"), "mp", None, None))
if alibi is not None:
attn += alibi
attn = attn * self.factor_scale
if attention_mask is not None:
attn += attention_mask
attn = jax.nn.softmax(attn, axis=-1)
attn = jnp.einsum('...hqk,...khd->...qhd', attn, v, precision=self.precision).reshape((b, s, d))
return self.wo(attn)
class FlaxFalconMlp(nn.Module):
config: FalconConfig
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
precision: Optional[Union[jax.lax.Precision, str]] = None
def setup(self) -> None:
self.up = nn.Dense(
features=self.config.hidden_size * 4,
dtype=self.dtype,
param_dtype=self.param_dtype,
use_bias=self.config.bias
)
self.down = nn.Dense(
features=self.config.hidden_size,
dtype=self.dtype,
param_dtype=self.param_dtype,
use_bias=self.config.bias
)
def __call__(self, x):
return self.down(nn.gelu(self.up(x)))
class FlaxFalconBlock(nn.Module):
config: FalconConfig
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
precision: Optional[Union[jax.lax.Precision, str]] = None
def setup(self) -> None:
config = self.config
self.input_layernorm = nn.LayerNorm(epsilon=config.layer_norm_epsilon,
dtype=self.dtype)
if not config.parallel_attn:
self.post_attention_layernorm = nn.LayerNorm(epsilon=config.layer_norm_epsilon,
dtype=self.dtype)
self.mlp = FlaxFalconMlp(
config=config,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision
)
self.self_attention = FlaxFalconAttention(
config=config,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision
)
def __call__(self,
hidden_states: jnp.DeviceArray,
alibi: jnp.DeviceArray,
attention_mask: jnp.DeviceArray,
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
attn = self.self_attention(
hidden_states=hidden_states,
attention_mask=attention_mask,
alibi=alibi
)
if not self.config.parallel_attn:
residual = attn + residual
hidden_states = self.post_attention_layernorm(residual)
mlp_out = self.mlp(hidden_states)
if self.config.parallel_attn:
mlp_out += attn
return mlp_out + residual
def get_gradient_checkpoint_policy(name):
return {
'everything_saveable': jax.checkpoint_policies.everything_saveable,
'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
}[name]
class FlaxFalconCollection(nn.Module):
config: FalconConfig
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
precision: Optional[Union[jax.lax.Precision, str]] = None
def setup(self) -> None:
block = FlaxFalconBlock
if self.config.gradient_checkpointing != '':
block = nn.remat(
block,
policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing)
)
self.blocks = [
block(
config=self.config,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision,
name=str(i)
)
for i in range(
self.config.n_layer
)
]
def __call__(self,
hidden_states: jnp.DeviceArray,
alibi: jnp.DeviceArray,
attention_mask: jnp.DeviceArray,
):
for b in self.blocks:
hidden_states = b(
attention_mask=attention_mask,
hidden_states=hidden_states,
alibi=alibi
)
return hidden_states
class FlaxFalconModule(nn.Module):
config: FalconConfig
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
precision: Optional[Union[jax.lax.Precision, str]] = None
def setup(self) -> None:
self.wte = nn.Embed(
num_embeddings=self.config.vocab_size,
features=self.config.hidden_size,
dtype=self.dtype,
param_dtype=self.param_dtype
)
self.h = FlaxFalconCollection(
config=self.config,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision
)
self.ln_f = nn.LayerNorm(dtype=self.dtype, param_dtype=self.param_dtype, epsilon=self.config.layer_norm_epsilon)
def __call__(self,
input_ids: jnp.int32 = None,
attention_mask: Optional[jnp.DeviceArray] = None,
use_cache: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
batch, seq_len = input_ids.shape
hidden_states = self.wte(
inputs=input_ids.astype(jnp.int32)
)
if attention_mask is None:
attention_mask = jnp.ones(
(batch, seq_len)
)
alibi = built_bloom_alibi(attention_mask, self.config
.n_head).astype(hidden_states.dtype) if self.config.alibi else None
causal_mask = nn.make_causal_mask(
input_ids,
)
mv = jnp.finfo(hidden_states).min
attention_mask = jnp.where(attention_mask == 1, 0, mv) + jnp.where(causal_mask == 1, 0, mv)
causal_mask += attention_mask
output = self.ln_f(self.h(
hidden_states=hidden_states,
attention_mask=attention_mask,
alibi=alibi
))
if return_dict:
return FlaxBaseModelOutput(
last_hidden_state=output,
)
else:
return output,
class FlaxFalconPretrainedModel(FlaxPreTrainedModel):
module_class: nn.Module = None
config_class = FalconConfig
def __init__(self, config, _do_init=False, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32,
input_shape: Tuple = (1, 12)):
module = self.module_class(config=config, dtype=dtype, param_dtype=param_dtype)
super().__init__(_do_init=_do_init, module=module, config=config, dtype=dtype, input_shape=input_shape)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
if params is None:
params = self.module.init(
rngs=rng,
input_ids=jnp.ones(input_shape),
attention_mask=jnp.ones(input_shape)
)
return params['params']
def __call__(self, input_ids,
attention_mask=None,
params: FrozenDict = None,
add_params_field: bool = False,
return_dict: bool = True):
params = {'params': params or self.params} if add_params_field else params or self.params
predict = self.module.apply(
params,
input_ids=jnp.asarray(input_ids, dtype=jnp.int32),
attention_mask=jnp.asarray(attention_mask,
dtype=jnp.int32) if attention_mask is not None else attention_mask,
return_dict=return_dict
)
return predict
def prepare_inputs_for_generation(self, input_ids, attention_mask=None):
return {
'input_ids': input_ids,
'attention_mask': attention_mask
}
class FlaxFalconModel(FlaxFalconPretrainedModel):
module_class = FlaxFalconModule
class FlaxFalconForCausalLMModule(nn.Module):
config: FalconConfig
dtype: jnp.dtype = jnp.float32
param_dtype: jnp.dtype = jnp.float32
precision: Optional[Union[jax.lax.Precision, str]] = None
def setup(self) -> None:
self.transformer = FlaxFalconModule(
config=self.config,
dtype=self.dtype,
param_dtype=self.param_dtype,
precision=self.precision
)
self.lm_head = nn.Dense(
self.config.vocab_size,
use_bias=False
)
def __call__(self, input_ids, attention_mask, return_dict: bool = False):
output = self.lm_head(self.transformer(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True
).last_hidden_state)
if return_dict:
return FlaxCausalLMOutput(
logits=output
)
else:
return output,
class FlaxFalconForCausalLM(FlaxFalconPretrainedModel):
module_class = FlaxFalconForCausalLMModule