|
import math |
|
|
|
from flax import linen as nn |
|
from flax.core import FrozenDict |
|
from typing import Optional, Dict, Union, Tuple |
|
from transformers import FlaxPreTrainedModel, PretrainedConfig |
|
from jax import numpy as jnp |
|
import jax |
|
from jax.interpreters import pxla |
|
from jax.experimental.pjit import pjit, PartitionSpec, with_sharding_constraint as wsc |
|
from transformers.modeling_flax_outputs import FlaxCausalLMOutput, FlaxBaseModelOutput |
|
from jax.random import split, PRNGKey |
|
from functools import partial |
|
from einops import rearrange |
|
|
|
ACT2FN = { |
|
"gelu": partial(nn.gelu, approximate=False), |
|
"relu": nn.relu, |
|
"silu": nn.swish, |
|
"swish": nn.swish, |
|
"gelu_new": partial(nn.gelu, approximate=True), |
|
|
|
} |
|
|
|
|
|
def get_names_from_parition_spec(partition_specs): |
|
names = set() |
|
if isinstance(partition_specs, dict): |
|
partition_specs = partition_specs.values() |
|
for item in partition_specs: |
|
if item is None: |
|
continue |
|
elif isinstance(item, str): |
|
names.add(item) |
|
else: |
|
names.update(get_names_from_parition_spec(item)) |
|
|
|
return list(names) |
|
|
|
|
|
def names_in_mesh(*names): |
|
return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names) |
|
|
|
|
|
def with_sharding_constraint(x, partition_specs): |
|
axis_names = get_names_from_parition_spec(partition_specs) |
|
if names_in_mesh(*axis_names): |
|
x = wsc(x, partition_specs) |
|
return x |
|
|
|
|
|
class FalconConfig(PretrainedConfig): |
|
model_type = "falcon" |
|
attribute_map = { |
|
"num_hidden_layers": "n_layer", |
|
"num_attention_heads": "n_head", |
|
} |
|
|
|
def __init__( |
|
self, |
|
vocab_size=250880, |
|
hidden_size=64, |
|
n_layer=2, |
|
n_head=8, |
|
layer_norm_epsilon=1e-5, |
|
initializer_range=0.02, |
|
use_cache=True, |
|
bos_token_id=1, |
|
eos_token_id=2, |
|
apply_residual_connection_post_layernorm=False, |
|
hidden_dropout=0.0, |
|
attention_dropout=0.0, |
|
multi_query=False, |
|
alibi=False, |
|
bias=False, |
|
parallel_attn=False, |
|
max_seq_len=2048, |
|
**kwargs, |
|
): |
|
self.vocab_size = vocab_size |
|
n_embed = kwargs.pop("n_embed", None) |
|
self.hidden_size = hidden_size if n_embed is None else n_embed |
|
self.n_layer = n_layer |
|
self.n_head = n_head |
|
self.layer_norm_epsilon = layer_norm_epsilon |
|
self.initializer_range = initializer_range |
|
self.max_seq_len = max_seq_len |
|
self.use_cache = use_cache |
|
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm |
|
self.hidden_dropout = hidden_dropout |
|
self.attention_dropout = attention_dropout |
|
self.bos_token_id = bos_token_id |
|
self.eos_token_id = eos_token_id |
|
self.multi_query = multi_query |
|
self.alibi = alibi |
|
self.bias = bias |
|
self.parallel_attn = parallel_attn |
|
|
|
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) |
|
|
|
@property |
|
def head_dim(self): |
|
return self.hidden_size // self.n_head |
|
|
|
@property |
|
def rotary(self): |
|
return not self.alibi |
|
|
|
@staticmethod |
|
def get_partition_rules(fully_fsdp: bool = False): |
|
return ( |
|
('wte/embedding', PartitionSpec('fsdp', 'mp')), |
|
('self_attention/w_qkv/(kernel|bias)', PartitionSpec('fsdp', 'mp')), |
|
('self_attention/wo/(kernel|bias)', PartitionSpec('fsdp', 'mp')), |
|
('mlp/down/(kernel|bias)', PartitionSpec('fsdp', 'mp')), |
|
('mlp/up/(kernel|bias)', PartitionSpec('mp', 'fsdp')), |
|
('lm_head/kernel', PartitionSpec('fsdp', 'mp')), |
|
('transformer/ln_f/bias', PartitionSpec('fsdp', 'mp')), |
|
('transformer/ln_f/scale', PartitionSpec('fsdp', 'mp')), |
|
('transformer/post_attention_layernorm/scale', PartitionSpec('mp', 'fsdp')), |
|
('transformer/post_attention_layernorm/bias', PartitionSpec('mp', 'fsdp')), |
|
('.*', PartitionSpec('fsdp', 'mp')) |
|
) if not fully_fsdp else ( |
|
('wte/embedding', PartitionSpec('fsdp')), |
|
('self_attention/w_qkv/(kernel|bias)', PartitionSpec('fsdp')), |
|
('self_attention/wo/(kernel|bias)', PartitionSpec('fsdp')), |
|
('mlp/down/(kernel|bias)', PartitionSpec('fsdp')), |
|
('mlp/up/(kernel|bias)', PartitionSpec('fsdp')), |
|
('lm_head/kernel', PartitionSpec('fsdp')), |
|
('transformer/ln_f/bias', PartitionSpec('fsdp')), |
|
('transformer/ln_f/scale', PartitionSpec('fsdp')), |
|
('transformer/post_attention_layernorm/scale', PartitionSpec('fsdp')), |
|
('transformer/post_attention_layernorm/bias', PartitionSpec('fsdp')), |
|
('.*', PartitionSpec('fsdp')) |
|
) |
|
|
|
@staticmethod |
|
def get_mesh_names(): |
|
return 'dp', 'fsdp', 'mp' |
|
|
|
|
|
def build_alibi(max_length, num_attention_heads, alibi_max: int = 8): |
|
w_range = jnp.arange(1 - max_length, 1).reshape(1, 1, 1, max_length) |
|
cp2 = 2 ** math.ceil(math.log2(num_attention_heads)) |
|
h_range = jnp.arange(1, 1 + num_attention_heads, ).reshape(1, -1, 1, 1) |
|
h_range = jnp.matmul(h_range, jnp.asarray(alibi_max / cp2).reshape(1, 1)) |
|
slop = 1 / jnp.power(2, h_range) |
|
if cp2 != num_attention_heads: |
|
slop = jnp.concatenate([slop[1::2], slop[::2]], axis=-1)[:num_attention_heads] |
|
alibi = (w_range * slop).reshape(1, num_attention_heads, 1, max_length) |
|
return alibi |
|
|
|
|
|
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, |
|
dtype: jnp.dtype = jnp.bfloat16) -> jnp.ndarray: |
|
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim)) |
|
t = jnp.arange(end) |
|
freqs = jnp.outer(t, freqs).astype(dtype) |
|
sin, cos = jnp.sin(freqs), jnp.cos(freqs) |
|
freqs_cis = jnp.complex64(cos + 1j * sin) |
|
return jnp.asarray(freqs_cis) |
|
|
|
|
|
def apply_rotary_emb( |
|
xq: jnp.ndarray, |
|
xk: jnp.ndarray, |
|
freqs_cis: jnp.ndarray, |
|
dtype: jnp.dtype = jnp.bfloat16, |
|
) -> Tuple[jnp.ndarray, jnp.ndarray]: |
|
reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2) |
|
reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2) |
|
|
|
xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1]) |
|
xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1]) |
|
|
|
freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:])) |
|
|
|
xq_out = xq_ * freqs_cis |
|
xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1) |
|
|
|
xk_out = xk_ * freqs_cis |
|
xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1) |
|
|
|
return xq_out.astype(dtype), xk_out.astype(dtype) |
|
|
|
|
|
class FlaxFalconAttention(nn.Module): |
|
config: FalconConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
param_dtype: jnp.dtype = jnp.float32 |
|
precision: Optional[Union[jax.lax.Precision, str]] = None |
|
|
|
def setup(self) -> None: |
|
head_dim = self.config.hidden_size // self.config.n_head |
|
self.w_qkv = nn.Dense( |
|
features=3 * self.config.hidden_size if not self.config.multi_query else ( |
|
self.config.hidden_size + 2 * head_dim), |
|
dtype=self.dtype, |
|
param_dtype=self.param_dtype, |
|
use_bias=self.config.bias |
|
) |
|
self.factor_scale = 1 / math.sqrt(head_dim) |
|
self.wo = nn.Dense( |
|
features=self.config.hidden_size, |
|
dtype=self.dtype, |
|
param_dtype=self.param_dtype, |
|
use_bias=self.config.bias |
|
) |
|
self.head_dim = head_dim |
|
assert self.head_dim * self.config.n_head == self.config.hidden_size |
|
if not self.config.alibi: |
|
self.freq = precompute_freqs_cis(head_dim, self.config.max_seq_len, dtype=self.dtype) |
|
|
|
def __call__(self, |
|
hidden_states: jnp.DeviceArray, |
|
alibi: jnp.DeviceArray = None, |
|
attention_mask: jnp.DeviceArray = None, |
|
): |
|
b, s, d = hidden_states.shape |
|
qkv = self.w_qkv(hidden_states) |
|
if not self.config.multi_query: |
|
q, k, v = jnp.split(qkv, 3, -1) |
|
q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, 'mp')) |
|
k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, 'mp')) |
|
v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, 'mp')) |
|
k = rearrange(k, 'b s (h d) -> b s h d', h=self.config.n_head) |
|
q = rearrange(q, 'b s (h d) -> b s h d', h=self.config.n_head) |
|
v = rearrange(v, 'b s (h d) -> b s h d', h=self.config.n_head) |
|
else: |
|
qkv = qkv.reshape( |
|
b, s, self.config.n_head + 2, -1 |
|
) |
|
q, k, v = qkv[..., :-2, :], qkv[..., [-2], :], qkv[..., [-1], :] |
|
|
|
q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, None, 'mp')) |
|
k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, None, 'mp')) |
|
v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, None, 'mp')) |
|
|
|
if not self.config.alibi: |
|
freq = self.freq[:s].reshape(1, s, -1) |
|
q, k = apply_rotary_emb(q, k, freq, self.dtype) |
|
attn = jnp.einsum('...qhd,...khd->...hqk', q, k, precision=self.precision) |
|
attn = with_sharding_constraint(attn, PartitionSpec(("dp", "fsdp"), "mp", None, None)) |
|
|
|
if alibi is not None: |
|
attn += attn |
|
attn = attn * self.factor_scale |
|
|
|
if attention_mask is not None: |
|
attn += attention_mask |
|
|
|
attn = jax.nn.softmax(attn, axis=-1) |
|
attn = jnp.einsum('...hqk,...khd->...qhd', attn, v, precision=self.precision).reshape((b, s, d)) |
|
return self.wo(attn) |
|
|
|
|
|
class FlaxFalconMlp(nn.Module): |
|
config: FalconConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
param_dtype: jnp.dtype = jnp.float32 |
|
precision: Optional[Union[jax.lax.Precision, str]] = None |
|
|
|
def setup(self) -> None: |
|
self.up = nn.Dense( |
|
features=self.config.hidden_size * 4, |
|
dtype=self.dtype, |
|
param_dtype=self.param_dtype, |
|
use_bias=self.config.bias |
|
) |
|
self.down = nn.Dense( |
|
features=self.config.hidden_size, |
|
dtype=self.dtype, |
|
param_dtype=self.param_dtype, |
|
use_bias=self.config.bias |
|
) |
|
|
|
def __call__(self, x): |
|
return self.down(nn.gelu(self.up(x))) |
|
|
|
|
|
class FlaxFalconBlock(nn.Module): |
|
config: FalconConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
param_dtype: jnp.dtype = jnp.float32 |
|
precision: Optional[Union[jax.lax.Precision, str]] = None |
|
|
|
def setup(self) -> None: |
|
config = self.config |
|
self.input_layernorm = nn.LayerNorm(epsilon=config.layer_norm_epsilon, |
|
dtype=self.dtype) |
|
if not config.parallel_attn: |
|
self.post_attention_layernorm = nn.LayerNorm(epsilon=config.layer_norm_epsilon, |
|
dtype=self.dtype) |
|
|
|
self.mlp = FlaxFalconMlp( |
|
config=config, |
|
dtype=self.dtype, |
|
param_dtype=self.param_dtype, |
|
precision=self.precision |
|
) |
|
self.self_attention = FlaxFalconAttention( |
|
config=config, |
|
dtype=self.dtype, |
|
param_dtype=self.param_dtype, |
|
precision=self.precision |
|
) |
|
|
|
def __call__(self, |
|
hidden_states: jnp.DeviceArray, |
|
alibi: jnp.DeviceArray, |
|
attention_mask: jnp.DeviceArray, |
|
): |
|
residual = hidden_states |
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
|
attn = self.self_attention( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
alibi=alibi |
|
) |
|
if not self.config.parallel_attn: |
|
residual = attn + residual |
|
hidden_states = self.post_attention_layernorm(residual) |
|
|
|
mlp_out = self.mlp(hidden_states) |
|
if self.config.parallel_attn: |
|
mlp_out += attn |
|
return mlp_out + residual |
|
|
|
|
|
class FlaxFalconCollection(nn.Module): |
|
config: FalconConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
param_dtype: jnp.dtype = jnp.float32 |
|
precision: Optional[Union[jax.lax.Precision, str]] = None |
|
|
|
def setup(self) -> None: |
|
self.blocks = [ |
|
FlaxFalconBlock( |
|
config=self.config, |
|
dtype=self.dtype, |
|
param_dtype=self.param_dtype, |
|
precision=self.precision, |
|
name=str(i) |
|
) |
|
for i in range( |
|
self.config.n_layer |
|
) |
|
] |
|
|
|
def __call__(self, |
|
hidden_states: jnp.DeviceArray, |
|
alibi: jnp.DeviceArray, |
|
attention_mask: jnp.DeviceArray, |
|
|
|
): |
|
for b in self.blocks: |
|
hidden_states = b( |
|
|
|
attention_mask=attention_mask, |
|
hidden_states=hidden_states, |
|
alibi=alibi |
|
) |
|
return hidden_states |
|
|
|
|
|
class FlaxFalconModule(nn.Module): |
|
config: FalconConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
param_dtype: jnp.dtype = jnp.float32 |
|
precision: Optional[Union[jax.lax.Precision, str]] = None |
|
|
|
def setup(self) -> None: |
|
self.wte = nn.Embed( |
|
num_embeddings=self.config.vocab_size, |
|
features=self.config.hidden_size, |
|
dtype=self.dtype, |
|
param_dtype=self.param_dtype |
|
) |
|
self.h = FlaxFalconCollection( |
|
config=self.config, |
|
dtype=self.dtype, |
|
param_dtype=self.param_dtype, |
|
precision=self.precision |
|
) |
|
self.ln_f = nn.LayerNorm(dtype=self.dtype, param_dtype=self.param_dtype, epsilon=self.config.layer_norm_epsilon) |
|
|
|
def __call__(self, |
|
input_ids: jnp.int32 = None, |
|
attention_mask: Optional[jnp.DeviceArray] = None, |
|
use_cache: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
): |
|
batch, seq_len = input_ids.shape |
|
hidden_states = self.wte( |
|
inputs=input_ids.astype(jnp.int32) |
|
) |
|
if attention_mask is None: |
|
attention_mask = jnp.ones( |
|
(batch, seq_len) |
|
) |
|
|
|
alibi = build_alibi(seq_len, self.config |
|
.n_head, 8) if self.config.alibi else None |
|
causal_mask = nn.make_causal_mask( |
|
input_ids, |
|
) |
|
|
|
mv = jnp.finfo(hidden_states).min |
|
attention_mask = jnp.where(attention_mask == 1, 0, mv) + jnp.where(causal_mask == 1, 0, mv) |
|
|
|
causal_mask += attention_mask |
|
output = self.ln_f(self.h( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
alibi=alibi |
|
)) |
|
|
|
if return_dict: |
|
return FlaxBaseModelOutput( |
|
last_hidden_state=output, |
|
) |
|
else: |
|
return output, |
|
|
|
|
|
class FlaxFalconPretrainedModel(FlaxPreTrainedModel): |
|
module_class: nn.Module = None |
|
config_class = FalconConfig |
|
|
|
def __init__(self, config, _do_init=False, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32, |
|
input_shape: Tuple = (1, 12)): |
|
module = self.module_class(config=config, dtype=dtype, param_dtype=param_dtype) |
|
super().__init__(_do_init=_do_init, module=module, config=config, dtype=dtype, input_shape=input_shape) |
|
|
|
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict: |
|
if params is None: |
|
params = self.module.init( |
|
rngs=rng, |
|
input_ids=jnp.ones(input_shape), |
|
attention_mask=jnp.ones(input_shape) |
|
) |
|
return params['params'] |
|
|
|
def __call__(self, input_ids, |
|
attention_mask=None, |
|
params: FrozenDict = None, |
|
add_params_field: bool = False, |
|
return_dict: bool = True): |
|
params = {'params': params or self.params} if add_params_field else params or self.params |
|
predict = self.module.apply( |
|
params, |
|
input_ids=jnp.asarray(input_ids, dtype=jnp.int32), |
|
attention_mask=jnp.asarray(attention_mask, |
|
dtype=jnp.int32) if attention_mask is not None else attention_mask, |
|
return_dict=return_dict |
|
) |
|
return predict |
|
|
|
def prepare_inputs_for_generation(self, input_ids, attention_mask=None): |
|
return { |
|
'input_ids': input_ids, |
|
'attention_mask': attention_mask |
|
} |
|
|
|
|
|
class FlaxFalconModel(FlaxFalconPretrainedModel): |
|
module_class = FlaxFalconModule |
|
|
|
|
|
class FlaxFalconForCausalLMModule(nn.Module): |
|
config: FalconConfig |
|
dtype: jnp.dtype = jnp.float32 |
|
param_dtype: jnp.dtype = jnp.float32 |
|
precision: Optional[Union[jax.lax.Precision, str]] = None |
|
|
|
def setup(self) -> None: |
|
self.transformer = FlaxFalconModule( |
|
config=self.config, |
|
dtype=self.dtype, |
|
param_dtype=self.param_dtype, |
|
precision=self.precision |
|
) |
|
self.lm_head = nn.Dense( |
|
self.config.vocab_size, |
|
use_bias=False |
|
) |
|
|
|
def __call__(self, input_ids, attention_mask, return_dict: bool = False): |
|
output = self.lm_head(self.transformer( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
return_dict=True |
|
).last_hidden_state) |
|
if return_dict: |
|
return FlaxCausalLMOutput( |
|
logits=output |
|
) |
|
else: |
|
return output, |
|
|
|
|
|
class FlaxFalconForCausalLM(FlaxFalconPretrainedModel): |
|
module_class = FlaxFalconForCausalLMModule |
|
|