Create model.py
Browse files
model.py
ADDED
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
from flax import linen as nn
|
4 |
+
from flax.core import FrozenDict
|
5 |
+
from typing import Optional, Dict, Union, Tuple
|
6 |
+
from transformers import FlaxPreTrainedModel, PretrainedConfig
|
7 |
+
from jax import numpy as jnp
|
8 |
+
import jax
|
9 |
+
from jax.interpreters import pxla
|
10 |
+
from jax.experimental.pjit import pjit, PartitionSpec, with_sharding_constraint as wsc
|
11 |
+
from transformers.modeling_flax_outputs import FlaxCausalLMOutput, FlaxBaseModelOutput
|
12 |
+
from jax.random import split, PRNGKey
|
13 |
+
from functools import partial
|
14 |
+
from einops import rearrange
|
15 |
+
|
16 |
+
ACT2FN = {
|
17 |
+
"gelu": partial(nn.gelu, approximate=False),
|
18 |
+
"relu": nn.relu,
|
19 |
+
"silu": nn.swish,
|
20 |
+
"swish": nn.swish,
|
21 |
+
"gelu_new": partial(nn.gelu, approximate=True),
|
22 |
+
|
23 |
+
}
|
24 |
+
|
25 |
+
|
26 |
+
def get_names_from_parition_spec(partition_specs):
|
27 |
+
names = set()
|
28 |
+
if isinstance(partition_specs, dict):
|
29 |
+
partition_specs = partition_specs.values()
|
30 |
+
for item in partition_specs:
|
31 |
+
if item is None:
|
32 |
+
continue
|
33 |
+
elif isinstance(item, str):
|
34 |
+
names.add(item)
|
35 |
+
else:
|
36 |
+
names.update(get_names_from_parition_spec(item))
|
37 |
+
|
38 |
+
return list(names)
|
39 |
+
|
40 |
+
|
41 |
+
def names_in_mesh(*names):
|
42 |
+
return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names)
|
43 |
+
|
44 |
+
|
45 |
+
def with_sharding_constraint(x, partition_specs):
|
46 |
+
axis_names = get_names_from_parition_spec(partition_specs)
|
47 |
+
if names_in_mesh(*axis_names):
|
48 |
+
x = wsc(x, partition_specs)
|
49 |
+
return x
|
50 |
+
|
51 |
+
|
52 |
+
class FalconConfig(PretrainedConfig):
|
53 |
+
model_type = "falcon"
|
54 |
+
attribute_map = {
|
55 |
+
"num_hidden_layers": "n_layer",
|
56 |
+
"num_attention_heads": "n_head",
|
57 |
+
}
|
58 |
+
|
59 |
+
def __init__(
|
60 |
+
self,
|
61 |
+
vocab_size=250880,
|
62 |
+
hidden_size=64,
|
63 |
+
n_layer=2,
|
64 |
+
n_head=8,
|
65 |
+
layer_norm_epsilon=1e-5,
|
66 |
+
initializer_range=0.02,
|
67 |
+
use_cache=True,
|
68 |
+
bos_token_id=1,
|
69 |
+
eos_token_id=2,
|
70 |
+
apply_residual_connection_post_layernorm=False,
|
71 |
+
hidden_dropout=0.0,
|
72 |
+
attention_dropout=0.0,
|
73 |
+
multi_query=False,
|
74 |
+
alibi=False,
|
75 |
+
bias=False,
|
76 |
+
parallel_attn=False,
|
77 |
+
max_seq_len=2048,
|
78 |
+
**kwargs,
|
79 |
+
):
|
80 |
+
self.vocab_size = vocab_size
|
81 |
+
n_embed = kwargs.pop("n_embed", None)
|
82 |
+
self.hidden_size = hidden_size if n_embed is None else n_embed
|
83 |
+
self.n_layer = n_layer
|
84 |
+
self.n_head = n_head
|
85 |
+
self.layer_norm_epsilon = layer_norm_epsilon
|
86 |
+
self.initializer_range = initializer_range
|
87 |
+
self.max_seq_len = max_seq_len
|
88 |
+
self.use_cache = use_cache
|
89 |
+
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
|
90 |
+
self.hidden_dropout = hidden_dropout
|
91 |
+
self.attention_dropout = attention_dropout
|
92 |
+
self.bos_token_id = bos_token_id
|
93 |
+
self.eos_token_id = eos_token_id
|
94 |
+
self.multi_query = multi_query
|
95 |
+
self.alibi = alibi
|
96 |
+
self.bias = bias
|
97 |
+
self.parallel_attn = parallel_attn
|
98 |
+
|
99 |
+
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
100 |
+
|
101 |
+
@property
|
102 |
+
def head_dim(self):
|
103 |
+
return self.hidden_size // self.n_head
|
104 |
+
|
105 |
+
@property
|
106 |
+
def rotary(self):
|
107 |
+
return not self.alibi
|
108 |
+
|
109 |
+
@staticmethod
|
110 |
+
def get_partition_rules(fully_fsdp: bool = False):
|
111 |
+
return (
|
112 |
+
('wte/embedding', PartitionSpec('fsdp', 'mp')),
|
113 |
+
('self_attention/w_qkv/(kernel|bias)', PartitionSpec('fsdp', 'mp')),
|
114 |
+
('self_attention/wo/(kernel|bias)', PartitionSpec('fsdp', 'mp')),
|
115 |
+
('mlp/down/(kernel|bias)', PartitionSpec('fsdp', 'mp')),
|
116 |
+
('mlp/up/(kernel|bias)', PartitionSpec('mp', 'fsdp')),
|
117 |
+
('lm_head/kernel', PartitionSpec('fsdp', 'mp')),
|
118 |
+
('transformer/ln_f/bias', PartitionSpec('fsdp', 'mp')),
|
119 |
+
('transformer/ln_f/scale', PartitionSpec('fsdp', 'mp')),
|
120 |
+
('transformer/post_attention_layernorm/scale', PartitionSpec('mp', 'fsdp')),
|
121 |
+
('transformer/post_attention_layernorm/bias', PartitionSpec('mp', 'fsdp')),
|
122 |
+
('.*', PartitionSpec('fsdp', 'mp'))
|
123 |
+
) if not fully_fsdp else (
|
124 |
+
('wte/embedding', PartitionSpec('fsdp')),
|
125 |
+
('self_attention/w_qkv/(kernel|bias)', PartitionSpec('fsdp')),
|
126 |
+
('self_attention/wo/(kernel|bias)', PartitionSpec('fsdp')),
|
127 |
+
('mlp/down/(kernel|bias)', PartitionSpec('fsdp')),
|
128 |
+
('mlp/up/(kernel|bias)', PartitionSpec('fsdp')),
|
129 |
+
('lm_head/kernel', PartitionSpec('fsdp')),
|
130 |
+
('transformer/ln_f/bias', PartitionSpec('fsdp')),
|
131 |
+
('transformer/ln_f/scale', PartitionSpec('fsdp')),
|
132 |
+
('transformer/post_attention_layernorm/scale', PartitionSpec('fsdp')),
|
133 |
+
('transformer/post_attention_layernorm/bias', PartitionSpec('fsdp')),
|
134 |
+
('.*', PartitionSpec('fsdp'))
|
135 |
+
)
|
136 |
+
|
137 |
+
@staticmethod
|
138 |
+
def get_mesh_names():
|
139 |
+
return 'dp', 'fsdp', 'mp'
|
140 |
+
|
141 |
+
|
142 |
+
def build_alibi(max_length, num_attention_heads, alibi_max: int = 8):
|
143 |
+
w_range = jnp.arange(1 - max_length, 1).reshape(1, 1, 1, max_length)
|
144 |
+
cp2 = 2 ** math.ceil(math.log2(num_attention_heads))
|
145 |
+
h_range = jnp.arange(1, 1 + num_attention_heads, ).reshape(1, -1, 1, 1)
|
146 |
+
h_range = jnp.matmul(h_range, jnp.asarray(alibi_max / cp2).reshape(1, 1))
|
147 |
+
slop = 1 / jnp.power(2, h_range)
|
148 |
+
if cp2 != num_attention_heads:
|
149 |
+
slop = jnp.concatenate([slop[1::2], slop[::2]], axis=-1)[:num_attention_heads]
|
150 |
+
alibi = (w_range * slop).reshape(1, num_attention_heads, 1, max_length)
|
151 |
+
return alibi
|
152 |
+
|
153 |
+
|
154 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
|
155 |
+
dtype: jnp.dtype = jnp.bfloat16) -> jnp.ndarray:
|
156 |
+
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
|
157 |
+
t = jnp.arange(end) # type: ignore
|
158 |
+
freqs = jnp.outer(t, freqs).astype(dtype)
|
159 |
+
sin, cos = jnp.sin(freqs), jnp.cos(freqs)
|
160 |
+
freqs_cis = jnp.complex64(cos + 1j * sin)
|
161 |
+
return jnp.asarray(freqs_cis)
|
162 |
+
|
163 |
+
|
164 |
+
def apply_rotary_emb(
|
165 |
+
xq: jnp.ndarray,
|
166 |
+
xk: jnp.ndarray,
|
167 |
+
freqs_cis: jnp.ndarray,
|
168 |
+
dtype: jnp.dtype = jnp.bfloat16,
|
169 |
+
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
170 |
+
reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
|
171 |
+
reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
|
172 |
+
|
173 |
+
xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
|
174 |
+
xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
|
175 |
+
|
176 |
+
freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:]))
|
177 |
+
|
178 |
+
xq_out = xq_ * freqs_cis
|
179 |
+
xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1)
|
180 |
+
|
181 |
+
xk_out = xk_ * freqs_cis
|
182 |
+
xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)
|
183 |
+
|
184 |
+
return xq_out.astype(dtype), xk_out.astype(dtype)
|
185 |
+
|
186 |
+
|
187 |
+
class FlaxFalconAttention(nn.Module):
|
188 |
+
config: FalconConfig
|
189 |
+
dtype: jnp.dtype = jnp.float32
|
190 |
+
param_dtype: jnp.dtype = jnp.float32
|
191 |
+
precision: Optional[Union[jax.lax.Precision, str]] = None
|
192 |
+
|
193 |
+
def setup(self) -> None:
|
194 |
+
head_dim = self.config.hidden_size // self.config.n_head
|
195 |
+
self.w_qkv = nn.Dense(
|
196 |
+
features=self.config.hidden_size * 3,
|
197 |
+
dtype=self.dtype,
|
198 |
+
param_dtype=self.param_dtype,
|
199 |
+
use_bias=self.config.bias
|
200 |
+
)
|
201 |
+
self.factor_scale = 1 / math.sqrt(head_dim)
|
202 |
+
self.wo = nn.Dense(
|
203 |
+
features=self.config.hidden_size,
|
204 |
+
dtype=self.dtype,
|
205 |
+
param_dtype=self.param_dtype,
|
206 |
+
use_bias=self.config.bias
|
207 |
+
)
|
208 |
+
self.head_dim = head_dim
|
209 |
+
if not self.config.alibi:
|
210 |
+
self.freq = precompute_freqs_cis(head_dim, self.config.max_seq_len, dtype=self.dtype)
|
211 |
+
|
212 |
+
def __call__(self,
|
213 |
+
hidden_states: jnp.DeviceArray,
|
214 |
+
alibi: jnp.DeviceArray = None,
|
215 |
+
attention_mask: jnp.DeviceArray = None,
|
216 |
+
):
|
217 |
+
b, s, d = hidden_states.shape
|
218 |
+
q, k, v = jnp.split(self.w_qkv(hidden_states), 3, -1)
|
219 |
+
q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
|
220 |
+
k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
|
221 |
+
v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
|
222 |
+
k = rearrange(k, 'b s (h d) -> b s h d', h=self.config.n_head)
|
223 |
+
q = rearrange(q, 'b s (h d) -> b s h d', h=self.config.n_head)
|
224 |
+
v = rearrange(v, 'b s (h d) -> b s h d', h=self.config.n_head)
|
225 |
+
if not self.config.alibi:
|
226 |
+
freq = self.freq[:s].reshape(1, s, -1)
|
227 |
+
q, k = apply_rotary_emb(q, k, freq, self.dtype)
|
228 |
+
attn = jnp.einsum('...qhd,...khd->...hqk', q, k, precision=self.precision)
|
229 |
+
attn = with_sharding_constraint(attn, PartitionSpec(("dp", "fsdp"), "mp", None, None))
|
230 |
+
|
231 |
+
if alibi is not None:
|
232 |
+
attn += attn
|
233 |
+
attn = attn * self.factor_scale
|
234 |
+
if attention_mask is not None:
|
235 |
+
attn += attention_mask
|
236 |
+
attn = jax.nn.softmax(attn, axis=-1)
|
237 |
+
attn = jnp.einsum('...hqk,...khd->...qhd', attn, v, precision=self.precision).reshape((b, s, d))
|
238 |
+
return self.wo(attn)
|
239 |
+
|
240 |
+
|
241 |
+
class FlaxFalconMlp(nn.Module):
|
242 |
+
config: FalconConfig
|
243 |
+
dtype: jnp.dtype = jnp.float32
|
244 |
+
param_dtype: jnp.dtype = jnp.float32
|
245 |
+
precision: Optional[Union[jax.lax.Precision, str]] = None
|
246 |
+
|
247 |
+
def setup(self) -> None:
|
248 |
+
self.up = nn.Dense(
|
249 |
+
features=self.config.hidden_size * 4,
|
250 |
+
dtype=self.dtype,
|
251 |
+
param_dtype=self.param_dtype,
|
252 |
+
use_bias=self.config.bias
|
253 |
+
)
|
254 |
+
self.down = nn.Dense(
|
255 |
+
features=self.config.hidden_size,
|
256 |
+
dtype=self.dtype,
|
257 |
+
param_dtype=self.param_dtype,
|
258 |
+
use_bias=self.config.bias
|
259 |
+
)
|
260 |
+
|
261 |
+
def __call__(self, x):
|
262 |
+
return self.down(nn.gelu(self.up(x)))
|
263 |
+
|
264 |
+
|
265 |
+
class FlaxFalconBlock(nn.Module):
|
266 |
+
config: FalconConfig
|
267 |
+
dtype: jnp.dtype = jnp.float32
|
268 |
+
param_dtype: jnp.dtype = jnp.float32
|
269 |
+
precision: Optional[Union[jax.lax.Precision, str]] = None
|
270 |
+
|
271 |
+
def setup(self) -> None:
|
272 |
+
config = self.config
|
273 |
+
self.input_layernorm = nn.LayerNorm(epsilon=config.layer_norm_epsilon,
|
274 |
+
dtype=self.dtype)
|
275 |
+
if not config.parallel_attn:
|
276 |
+
self.post_attention_layernorm = nn.LayerNorm(epsilon=config.layer_norm_epsilon,
|
277 |
+
dtype=self.dtype)
|
278 |
+
|
279 |
+
self.mlp = FlaxFalconMlp(
|
280 |
+
config=config,
|
281 |
+
dtype=self.dtype,
|
282 |
+
param_dtype=self.param_dtype,
|
283 |
+
precision=self.precision
|
284 |
+
)
|
285 |
+
self.self_attention = FlaxFalconAttention(
|
286 |
+
config=config,
|
287 |
+
dtype=self.dtype,
|
288 |
+
param_dtype=self.param_dtype,
|
289 |
+
precision=self.precision
|
290 |
+
)
|
291 |
+
|
292 |
+
def __call__(self,
|
293 |
+
hidden_states: jnp.DeviceArray,
|
294 |
+
alibi: jnp.DeviceArray,
|
295 |
+
attention_mask: jnp.DeviceArray,
|
296 |
+
):
|
297 |
+
residual = hidden_states
|
298 |
+
hidden_states = self.input_layernorm(hidden_states)
|
299 |
+
|
300 |
+
attn = self.self_attention(
|
301 |
+
hidden_states=hidden_states,
|
302 |
+
attention_mask=attention_mask,
|
303 |
+
alibi=alibi
|
304 |
+
)
|
305 |
+
if not self.config.parallel_attn:
|
306 |
+
residual = attn + residual
|
307 |
+
hidden_states = self.post_attention_layernorm(residual)
|
308 |
+
|
309 |
+
mlp_out = self.mlp(hidden_states)
|
310 |
+
if self.config.parallel_attn:
|
311 |
+
mlp_out += attn
|
312 |
+
return mlp_out + residual
|
313 |
+
|
314 |
+
|
315 |
+
class FlaxFalconCollection(nn.Module):
|
316 |
+
config: FalconConfig
|
317 |
+
dtype: jnp.dtype = jnp.float32
|
318 |
+
param_dtype: jnp.dtype = jnp.float32
|
319 |
+
precision: Optional[Union[jax.lax.Precision, str]] = None
|
320 |
+
|
321 |
+
def setup(self) -> None:
|
322 |
+
self.blocks = [
|
323 |
+
FlaxFalconBlock(
|
324 |
+
config=self.config,
|
325 |
+
dtype=self.dtype,
|
326 |
+
param_dtype=self.param_dtype,
|
327 |
+
precision=self.precision,
|
328 |
+
name=str(i)
|
329 |
+
)
|
330 |
+
for i in range(
|
331 |
+
self.config.n_layer
|
332 |
+
)
|
333 |
+
]
|
334 |
+
|
335 |
+
def __call__(self,
|
336 |
+
hidden_states: jnp.DeviceArray,
|
337 |
+
alibi: jnp.DeviceArray,
|
338 |
+
attention_mask: jnp.DeviceArray,
|
339 |
+
|
340 |
+
):
|
341 |
+
for b in self.blocks:
|
342 |
+
hidden_states = b(
|
343 |
+
|
344 |
+
attention_mask=attention_mask,
|
345 |
+
hidden_states=hidden_states,
|
346 |
+
alibi=alibi
|
347 |
+
)
|
348 |
+
return hidden_states
|
349 |
+
|
350 |
+
|
351 |
+
class FlaxFalconModule(nn.Module):
|
352 |
+
config: FalconConfig
|
353 |
+
dtype: jnp.dtype = jnp.float32
|
354 |
+
param_dtype: jnp.dtype = jnp.float32
|
355 |
+
precision: Optional[Union[jax.lax.Precision, str]] = None
|
356 |
+
|
357 |
+
def setup(self) -> None:
|
358 |
+
self.wte = nn.Embed(
|
359 |
+
num_embeddings=self.config.vocab_size,
|
360 |
+
features=self.config.hidden_size,
|
361 |
+
dtype=self.dtype,
|
362 |
+
param_dtype=self.param_dtype
|
363 |
+
)
|
364 |
+
self.h = FlaxFalconCollection(
|
365 |
+
config=self.config,
|
366 |
+
dtype=self.dtype,
|
367 |
+
param_dtype=self.param_dtype,
|
368 |
+
precision=self.precision
|
369 |
+
)
|
370 |
+
self.ln_f = nn.LayerNorm(dtype=self.dtype, param_dtype=self.param_dtype, epsilon=self.config.layer_norm_epsilon)
|
371 |
+
|
372 |
+
def __call__(self,
|
373 |
+
input_ids: jnp.int32 = None,
|
374 |
+
attention_mask: Optional[jnp.DeviceArray] = None,
|
375 |
+
use_cache: Optional[bool] = None,
|
376 |
+
return_dict: Optional[bool] = None,
|
377 |
+
):
|
378 |
+
batch, seq_len = input_ids.shape
|
379 |
+
hidden_states = self.wte(
|
380 |
+
inputs=input_ids.astype(jnp.int32)
|
381 |
+
)
|
382 |
+
if attention_mask is None:
|
383 |
+
attention_mask = jnp.ones(
|
384 |
+
(batch, seq_len)
|
385 |
+
)
|
386 |
+
|
387 |
+
alibi = build_alibi(seq_len, self.config
|
388 |
+
.n_head, 8) if self.config.alibi else None
|
389 |
+
causal_mask = nn.make_causal_mask(
|
390 |
+
input_ids,
|
391 |
+
)
|
392 |
+
|
393 |
+
mv = jnp.finfo(hidden_states).min
|
394 |
+
attention_mask = jnp.where(attention_mask == 1, 0, mv) + jnp.where(causal_mask == 1, 0, mv)
|
395 |
+
|
396 |
+
causal_mask += attention_mask
|
397 |
+
output = self.ln_f(self.h(
|
398 |
+
hidden_states=hidden_states,
|
399 |
+
attention_mask=attention_mask,
|
400 |
+
alibi=alibi
|
401 |
+
))
|
402 |
+
|
403 |
+
if return_dict:
|
404 |
+
return FlaxBaseModelOutput(
|
405 |
+
last_hidden_state=output,
|
406 |
+
)
|
407 |
+
else:
|
408 |
+
return output,
|
409 |
+
|
410 |
+
|
411 |
+
class FlaxFalconPretrainedModel(FlaxPreTrainedModel):
|
412 |
+
module_class: nn.Module = None
|
413 |
+
config_class = FalconConfig
|
414 |
+
|
415 |
+
def __init__(self, config, _do_init=False, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32,
|
416 |
+
input_shape: Tuple = (1, 12)):
|
417 |
+
module = self.module_class(config=config, dtype=dtype, param_dtype=param_dtype)
|
418 |
+
super().__init__(_do_init=_do_init, module=module, config=config, dtype=dtype, input_shape=input_shape)
|
419 |
+
|
420 |
+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
|
421 |
+
if params is None:
|
422 |
+
params = self.module.init(
|
423 |
+
rngs=rng,
|
424 |
+
input_ids=jnp.ones(input_shape),
|
425 |
+
attention_mask=jnp.ones(input_shape)
|
426 |
+
)
|
427 |
+
return params['params']
|
428 |
+
|
429 |
+
def __call__(self, input_ids,
|
430 |
+
attention_mask=None,
|
431 |
+
params: FrozenDict = None,
|
432 |
+
add_params_field: bool = False,
|
433 |
+
return_dict: bool = True):
|
434 |
+
params = {'params': params or self.params} if add_params_field else params or self.params
|
435 |
+
predict = self.module.apply(
|
436 |
+
params,
|
437 |
+
input_ids=jnp.asarray(input_ids, dtype=jnp.int32),
|
438 |
+
attention_mask=jnp.asarray(attention_mask,
|
439 |
+
dtype=jnp.int32) if attention_mask is not None else attention_mask,
|
440 |
+
return_dict=return_dict
|
441 |
+
)
|
442 |
+
return predict
|
443 |
+
|
444 |
+
def prepare_inputs_for_generation(self, input_ids, attention_mask=None):
|
445 |
+
return {
|
446 |
+
'input_ids': input_ids,
|
447 |
+
'attention_mask': attention_mask
|
448 |
+
}
|
449 |
+
|
450 |
+
|
451 |
+
class FlaxFalconModel(FlaxFalconPretrainedModel):
|
452 |
+
module_class = FlaxFalconModule
|
453 |
+
|
454 |
+
|
455 |
+
class FlaxFalconForCausalLMModule(nn.Module):
|
456 |
+
config: FalconConfig
|
457 |
+
dtype: jnp.dtype = jnp.float32
|
458 |
+
param_dtype: jnp.dtype = jnp.float32
|
459 |
+
precision: Optional[Union[jax.lax.Precision, str]] = None
|
460 |
+
|
461 |
+
def setup(self) -> None:
|
462 |
+
self.transformer = FlaxFalconModule(
|
463 |
+
config=self.config,
|
464 |
+
dtype=self.dtype,
|
465 |
+
param_dtype=self.param_dtype,
|
466 |
+
precision=self.precision
|
467 |
+
)
|
468 |
+
self.lm_head = nn.Dense(
|
469 |
+
self.config.vocab_size,
|
470 |
+
use_bias=False
|
471 |
+
)
|
472 |
+
|
473 |
+
def __call__(self, input_ids, attention_mask, return_dict: bool = False):
|
474 |
+
output = self.lm_head(self.transformer(
|
475 |
+
input_ids=input_ids,
|
476 |
+
attention_mask=attention_mask,
|
477 |
+
return_dict=True
|
478 |
+
).last_hidden_state)
|
479 |
+
if return_dict:
|
480 |
+
return FlaxCausalLMOutput(
|
481 |
+
logits=output
|
482 |
+
)
|
483 |
+
else:
|
484 |
+
return output,
|
485 |
+
|
486 |
+
|
487 |
+
class FlaxFalconForCausalLM(FlaxFalconPretrainedModel):
|
488 |
+
module_class = FlaxFalconForCausalLMModule
|