Update model.py
Browse files
model.py
CHANGED
@@ -1,8 +1,5 @@
|
|
1 |
-
import
|
2 |
-
from
|
3 |
-
|
4 |
-
from flax.linen import remat
|
5 |
-
|
6 |
import jax
|
7 |
import jax.numpy as jnp
|
8 |
from jax import lax
|
@@ -11,93 +8,63 @@ import flax.linen as nn
|
|
11 |
from flax.linen.attention import dot_product_attention_weights
|
12 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
13 |
from flax.linen import partitioning as nn_partitioning
|
14 |
-
|
15 |
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
16 |
from flax.linen import combine_masks, make_causal_mask
|
17 |
-
|
18 |
from transformers.configuration_utils import PretrainedConfig
|
19 |
from transformers.modeling_flax_utils import FlaxPreTrainedModel
|
20 |
-
|
21 |
-
from jax.interpreters import pxla
|
22 |
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput, FlaxSequenceClassifierOutput
|
23 |
-
|
24 |
-
from
|
25 |
-
|
26 |
-
|
27 |
-
def get_names_from_parition_spec(partition_specs):
|
28 |
-
names = set()
|
29 |
-
if isinstance(partition_specs, dict):
|
30 |
-
partition_specs = partition_specs.values()
|
31 |
-
for item in partition_specs:
|
32 |
-
if item is None:
|
33 |
-
continue
|
34 |
-
elif isinstance(item, str):
|
35 |
-
names.add(item)
|
36 |
-
else:
|
37 |
-
names.update(get_names_from_parition_spec(item))
|
38 |
-
|
39 |
-
return list(names)
|
40 |
-
|
41 |
-
|
42 |
-
def names_in_mesh(*names):
|
43 |
-
return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names)
|
44 |
-
|
45 |
-
|
46 |
-
def with_sharding_constraint(x, partition_specs):
|
47 |
-
axis_names = get_names_from_parition_spec(partition_specs)
|
48 |
-
if names_in_mesh(*axis_names):
|
49 |
-
x = wsc(x, partition_specs)
|
50 |
-
return x
|
51 |
-
|
52 |
-
|
53 |
-
def get_gradient_checkpoint_policy(name):
|
54 |
-
return {
|
55 |
-
'everything_saveable': jax.checkpoint_policies.everything_saveable,
|
56 |
-
'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
|
57 |
-
'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
|
58 |
-
'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
|
59 |
-
}[name]
|
60 |
|
61 |
|
62 |
class LlamaConfig(PretrainedConfig):
|
63 |
-
model_type = "
|
64 |
|
65 |
def __init__(
|
66 |
self,
|
67 |
-
vocab_size=32000,
|
68 |
-
hidden_size=4096,
|
69 |
-
intermediate_size=11008,
|
70 |
-
num_hidden_layers=32,
|
71 |
-
num_attention_heads=32,
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
85 |
use_pjit_attention_force: bool = True,
|
86 |
-
rope_scaling=None,
|
|
|
|
|
|
|
|
|
|
|
87 |
**kwargs,
|
88 |
):
|
89 |
-
|
90 |
-
|
91 |
-
"factor": 8.0,
|
92 |
-
"type": "linear"
|
93 |
-
}
|
94 |
self.vocab_size = vocab_size
|
|
|
|
|
95 |
self.hidden_size = hidden_size
|
96 |
self.initializer_range = initializer_range
|
97 |
self.intermediate_size = intermediate_size
|
98 |
self.num_hidden_layers = num_hidden_layers
|
|
|
99 |
self.num_attention_heads = num_attention_heads
|
100 |
-
self.
|
101 |
self.rms_norm_eps = rms_norm_eps
|
102 |
self.use_cache = use_cache
|
103 |
self.resid_pdrop = resid_pdrop
|
@@ -108,6 +75,12 @@ class LlamaConfig(PretrainedConfig):
|
|
108 |
self.fcm_min_ratio = fcm_min_ratio
|
109 |
self.fcm_max_ratio = fcm_max_ratio
|
110 |
self.rope_scaling = rope_scaling
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
super().__init__(
|
112 |
# pad_token_id=pad_token_id,
|
113 |
bos_token_id=bos_token_id,
|
@@ -120,40 +93,73 @@ class LlamaConfig(PretrainedConfig):
|
|
120 |
def get_partition_rules(fully_fsdp: bool = True):
|
121 |
return (
|
122 |
|
123 |
-
("
|
124 |
|
125 |
-
("
|
126 |
-
("
|
127 |
|
128 |
-
("
|
129 |
-
("
|
130 |
-
("
|
131 |
|
132 |
-
("
|
133 |
-
("
|
134 |
|
135 |
-
("
|
136 |
("lm_head/kernel", PS("fsdp", "dp")),
|
137 |
('.*', PS(None)),
|
138 |
) if not fully_fsdp else (
|
139 |
|
140 |
-
("
|
141 |
|
142 |
-
("
|
143 |
-
("
|
144 |
|
145 |
-
("
|
146 |
-
("
|
147 |
-
("
|
148 |
|
149 |
-
("
|
150 |
-
("
|
151 |
|
152 |
-
("
|
153 |
("lm_head/kernel", PS("fsdp")),
|
154 |
-
('.*', PS(
|
155 |
)
|
156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
157 |
@staticmethod
|
158 |
def get_weight_decay_exclusions():
|
159 |
return tuple()
|
@@ -163,14 +169,41 @@ class LlamaConfig(PretrainedConfig):
|
|
163 |
return ('params', 'dropout', 'fcm')
|
164 |
|
165 |
|
166 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
|
169 |
class RMSNorm(nn.Module):
|
170 |
dim: int
|
171 |
eps: float = 1e-6
|
172 |
-
dtype: jnp.dtype = jnp.
|
173 |
-
param_dtype: jnp.dtype = jnp.
|
174 |
|
175 |
def setup(self) -> None:
|
176 |
self.weight = self.param(
|
@@ -184,124 +217,65 @@ class RMSNorm(nn.Module):
|
|
184 |
return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps)
|
185 |
|
186 |
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
187 |
-
x = x.astype(jnp.promote_types(self.dtype, jnp.
|
188 |
output = self._norm(x).astype(self.dtype)
|
189 |
weight = jnp.asarray(self.weight, self.dtype)
|
190 |
return output * weight
|
191 |
|
192 |
|
193 |
-
def rotate_half(x):
|
194 |
-
x1 = x[..., : x.shape[-1] // 2]
|
195 |
-
x2 = x[..., x.shape[-1] // 2:]
|
196 |
-
return jnp.concatenate([-x2, x1], axis=-1)
|
197 |
-
|
198 |
-
|
199 |
-
def precompute_freqs_cis(
|
200 |
-
method: str,
|
201 |
-
dim: int, end: int, theta: float = 10000.0,
|
202 |
-
scaling_factor: float = 8.,
|
203 |
-
dtype: jnp.dtype = jnp.bfloat16) -> jnp.ndarray:
|
204 |
-
if method == 'linear':
|
205 |
-
freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
|
206 |
-
elif method == 'dynamic':
|
207 |
-
base = theta * (
|
208 |
-
(scaling_factor * end / end) - (scaling_factor - 1)
|
209 |
-
) ** (dim / (dim - 2))
|
210 |
-
freqs = 1.0 / (base ** (jnp.arange(0, dim, 2) / dim))
|
211 |
-
else:
|
212 |
-
raise ValueError(f'unknown {method} method for precompute_freqs_cis')
|
213 |
-
t = jnp.arange(end) # type: ignore
|
214 |
-
freqs = jnp.outer(t, freqs).astype(dtype)
|
215 |
-
sin, cos = jnp.sin(freqs), jnp.cos(freqs)
|
216 |
-
freqs_cis = jnp.complex64(cos + 1j * sin)
|
217 |
-
return jnp.asarray(freqs_cis)
|
218 |
-
|
219 |
-
|
220 |
-
def apply_rotary_emb(
|
221 |
-
xq: jnp.ndarray,
|
222 |
-
xk: jnp.ndarray,
|
223 |
-
freqs_cis: jnp.ndarray,
|
224 |
-
dtype: jnp.dtype = jnp.bfloat16,
|
225 |
-
) -> Tuple[jnp.ndarray, jnp.ndarray]:
|
226 |
-
reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
|
227 |
-
reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
|
228 |
-
|
229 |
-
xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
|
230 |
-
xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
|
231 |
-
|
232 |
-
freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:]))
|
233 |
-
|
234 |
-
xq_out = xq_ * freqs_cis
|
235 |
-
xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1)
|
236 |
-
|
237 |
-
xk_out = xk_ * freqs_cis
|
238 |
-
xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)
|
239 |
-
|
240 |
-
return xq_out.astype(dtype), xk_out.astype(dtype)
|
241 |
-
|
242 |
-
|
243 |
class FlaxLlamaAttention(nn.Module):
|
244 |
config: LlamaConfig
|
245 |
-
dtype: jnp.dtype = jnp.
|
246 |
-
param_dtype: jnp.dtype = jnp.
|
247 |
precision: Optional[Union[jax.lax.Precision, str]] = None
|
248 |
|
249 |
def setup(self):
|
250 |
config = self.config
|
251 |
-
self.
|
252 |
-
self.
|
253 |
-
self.
|
254 |
|
255 |
-
self.
|
|
|
|
|
256 |
config.num_attention_heads * self.head_dim,
|
257 |
dtype=self.dtype,
|
258 |
param_dtype=self.param_dtype,
|
259 |
use_bias=False,
|
260 |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
261 |
-
precision=self.precision
|
262 |
)
|
263 |
-
self.
|
264 |
-
config.
|
265 |
dtype=self.dtype,
|
266 |
param_dtype=self.param_dtype,
|
267 |
use_bias=False,
|
268 |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
269 |
-
precision=self.precision
|
270 |
)
|
271 |
-
self.
|
272 |
-
config.
|
273 |
dtype=self.dtype,
|
274 |
param_dtype=self.param_dtype,
|
275 |
use_bias=False,
|
276 |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
277 |
-
precision=self.precision
|
278 |
)
|
279 |
-
self.
|
280 |
config.hidden_size,
|
281 |
dtype=self.dtype,
|
282 |
param_dtype=self.param_dtype,
|
283 |
use_bias=False,
|
284 |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
285 |
-
precision=self.precision
|
286 |
)
|
287 |
|
288 |
-
self.
|
289 |
-
|
290 |
-
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_sequence_length), dtype="bool"), dtype="bool")
|
291 |
|
292 |
-
self.
|
293 |
-
method=self.config.rope_scaling['type'],
|
294 |
-
scaling_factor=float(self.config.rope_scaling['factor']),
|
295 |
-
dim=self.head_dim,
|
296 |
-
end=config.max_sequence_length * 2,
|
297 |
-
dtype=self.dtype,
|
298 |
-
)
|
299 |
-
|
300 |
-
def _split_heads(self, hidden_states):
|
301 |
-
return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
|
302 |
|
303 |
def _merge_heads(self, hidden_states):
|
304 |
-
return hidden_states.reshape(hidden_states.shape[:2] + (self.
|
305 |
|
306 |
@nn.compact
|
307 |
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
@@ -320,6 +294,7 @@ class FlaxLlamaAttention(nn.Module):
|
|
320 |
cached_value.value = value
|
321 |
num_updated_cache_vectors = query.shape[1]
|
322 |
cache_index.value = cache_index.value + num_updated_cache_vectors
|
|
|
323 |
pad_mask = jnp.broadcast_to(
|
324 |
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
325 |
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
@@ -327,44 +302,69 @@ class FlaxLlamaAttention(nn.Module):
|
|
327 |
attention_mask = combine_masks(pad_mask, attention_mask)
|
328 |
return key, value, attention_mask
|
329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
def __call__(
|
331 |
self,
|
332 |
-
hidden_states,
|
333 |
-
|
334 |
-
|
|
|
|
|
335 |
deterministic: bool = True,
|
336 |
init_cache: bool = False,
|
337 |
output_attentions: bool = False,
|
338 |
fcm_mask=None,
|
339 |
):
|
340 |
-
xq, xk, xv = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
|
341 |
-
if self.config.use_pjit_attention_force:
|
342 |
-
xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), None, "mp"))
|
343 |
-
xk = with_sharding_constraint(xk, PS(("dp", "fsdp"), None, "mp"))
|
344 |
-
xv = with_sharding_constraint(xv, PS(("dp", "fsdp"), None, "mp"))
|
345 |
-
|
346 |
-
xq = self._split_heads(xq)
|
347 |
-
xk = self._split_heads(xk)
|
348 |
-
xv = self._split_heads(xv)
|
349 |
|
350 |
-
|
351 |
-
|
352 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
353 |
|
354 |
-
query_length, key_length =
|
355 |
|
356 |
if self.has_variable("cache", "cached_key"):
|
357 |
mask_shift = self.variables["cache"]["cache_index"]
|
358 |
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
359 |
causal_mask = lax.dynamic_slice(
|
360 |
-
|
361 |
)
|
362 |
else:
|
363 |
-
causal_mask =
|
364 |
|
365 |
batch_size = hidden_states.shape[0]
|
366 |
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
367 |
-
|
368 |
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
369 |
attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
|
370 |
|
@@ -373,30 +373,58 @@ class FlaxLlamaAttention(nn.Module):
|
|
373 |
dropout_rng = self.make_rng("dropout")
|
374 |
|
375 |
if self.has_variable("cache", "cached_key") or init_cache:
|
376 |
-
|
|
|
377 |
|
378 |
attention_bias = lax.select(
|
379 |
attention_mask > 0,
|
380 |
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
381 |
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
382 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
383 |
|
384 |
-
|
385 |
-
|
386 |
-
xk,
|
387 |
-
bias=attention_bias,
|
388 |
-
dropout_rng=dropout_rng,
|
389 |
-
dropout_rate=self.config.attn_pdrop,
|
390 |
-
deterministic=deterministic,
|
391 |
-
dtype=jnp.promote_types(self.dtype, jnp.bfloat16),
|
392 |
-
precision=self.precision,
|
393 |
-
)
|
394 |
-
if self.config.use_pjit_attention_force:
|
395 |
-
attn_weights = with_sharding_constraint(attn_weights, PS(("dp", "fsdp"), "mp", None, None))
|
396 |
|
397 |
-
attn_output =
|
398 |
-
attn_output = self._merge_heads(attn_output)
|
399 |
-
attn_output = self.wo(attn_output)
|
400 |
attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
|
401 |
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
402 |
return outputs
|
@@ -404,14 +432,14 @@ class FlaxLlamaAttention(nn.Module):
|
|
404 |
|
405 |
class FlaxLlamaMLP(nn.Module):
|
406 |
config: LlamaConfig
|
407 |
-
dtype: jnp.dtype = jnp.
|
408 |
-
param_dtype: jnp.dtype = jnp.
|
409 |
precision: Optional[Union[jax.lax.Precision, str]] = None
|
410 |
|
411 |
def setup(self) -> None:
|
412 |
config = self.config
|
413 |
|
414 |
-
self.
|
415 |
config.intermediate_size,
|
416 |
dtype=self.dtype,
|
417 |
param_dtype=self.param_dtype,
|
@@ -419,7 +447,7 @@ class FlaxLlamaMLP(nn.Module):
|
|
419 |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
420 |
precision=self.precision,
|
421 |
)
|
422 |
-
self.
|
423 |
config.hidden_size,
|
424 |
dtype=self.dtype,
|
425 |
param_dtype=self.param_dtype,
|
@@ -427,7 +455,7 @@ class FlaxLlamaMLP(nn.Module):
|
|
427 |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
428 |
precision=self.precision,
|
429 |
)
|
430 |
-
self.
|
431 |
config.intermediate_size,
|
432 |
dtype=self.dtype,
|
433 |
param_dtype=self.param_dtype,
|
@@ -438,69 +466,116 @@ class FlaxLlamaMLP(nn.Module):
|
|
438 |
self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
|
439 |
|
440 |
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
441 |
-
x = self.
|
442 |
x = self.dropout(x, deterministic=deterministic)
|
443 |
return x
|
444 |
|
445 |
|
446 |
class FlaxLlamaBlock(nn.Module):
|
447 |
config: LlamaConfig
|
448 |
-
dtype: jnp.dtype = jnp.
|
449 |
-
param_dtype: jnp.dtype = jnp.
|
450 |
precision: Optional[Union[jax.lax.Precision, str]] = None
|
451 |
|
452 |
def setup(self) -> None:
|
453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
self.config,
|
455 |
dtype=self.dtype,
|
456 |
param_dtype=self.param_dtype,
|
457 |
-
precision=self.precision
|
458 |
)
|
459 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
460 |
self.config,
|
461 |
dtype=self.dtype,
|
462 |
param_dtype=self.param_dtype,
|
463 |
precision=self.precision,
|
464 |
)
|
465 |
-
self.
|
466 |
self.config.hidden_size,
|
467 |
eps=self.config.rms_norm_eps,
|
468 |
dtype=self.dtype,
|
469 |
param_dtype=self.param_dtype,
|
470 |
)
|
471 |
-
self.
|
472 |
self.config.hidden_size,
|
473 |
eps=self.config.rms_norm_eps,
|
474 |
dtype=self.dtype,
|
475 |
param_dtype=self.param_dtype,
|
|
|
476 |
)
|
477 |
|
478 |
def __call__(
|
479 |
self,
|
480 |
-
hidden_states,
|
481 |
-
|
482 |
-
|
|
|
|
|
483 |
deterministic: bool = True,
|
484 |
init_cache: bool = False,
|
485 |
output_attentions: bool = False,
|
486 |
fcm_mask: Optional[jnp.ndarray] = None,
|
487 |
):
|
488 |
-
attn_outputs = self.
|
489 |
-
self.
|
490 |
-
|
491 |
-
|
492 |
-
|
493 |
-
|
494 |
-
|
495 |
-
|
|
|
|
|
496 |
)
|
497 |
attn_output = attn_outputs[0]
|
498 |
hidden_states = hidden_states + attn_output
|
499 |
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
504 |
hidden_states = hidden_states + feed_forward_hidden_states
|
505 |
|
506 |
return (hidden_states,) + attn_outputs[1:]
|
@@ -508,7 +583,7 @@ class FlaxLlamaBlock(nn.Module):
|
|
508 |
|
509 |
class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
|
510 |
config_class = LlamaConfig
|
511 |
-
base_model_prefix = "
|
512 |
module_class: nn.Module = None
|
513 |
|
514 |
def __init__(
|
@@ -516,7 +591,7 @@ class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
|
|
516 |
config: LlamaConfig,
|
517 |
input_shape: Tuple = (1, 1),
|
518 |
seed: int = 0,
|
519 |
-
dtype: jnp.dtype = jnp.
|
520 |
_do_init: bool = True,
|
521 |
**kwargs,
|
522 |
):
|
@@ -571,9 +646,9 @@ class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
|
|
571 |
|
572 |
def __call__(
|
573 |
self,
|
574 |
-
input_ids,
|
575 |
-
attention_mask=None,
|
576 |
-
position_ids=None,
|
577 |
params: dict = None,
|
578 |
past_key_values: dict = None,
|
579 |
dropout_rng: jax.random.PRNGKey = None,
|
@@ -581,8 +656,10 @@ class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
|
|
581 |
output_attentions: Optional[bool] = None,
|
582 |
output_hidden_states: Optional[bool] = None,
|
583 |
return_dict: Optional[bool] = None,
|
|
|
584 |
add_params_field: bool = False
|
585 |
):
|
|
|
586 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
587 |
output_hidden_states = (
|
588 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
@@ -591,6 +668,11 @@ class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
|
|
591 |
|
592 |
batch_size, sequence_length = input_ids.shape
|
593 |
|
|
|
|
|
|
|
|
|
|
|
594 |
if position_ids is None:
|
595 |
if past_key_values is not None:
|
596 |
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
|
@@ -622,6 +704,7 @@ class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
|
|
622 |
output_attentions,
|
623 |
output_hidden_states,
|
624 |
return_dict,
|
|
|
625 |
rngs=rngs,
|
626 |
mutable=mutable,
|
627 |
)
|
@@ -639,29 +722,24 @@ class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
|
|
639 |
|
640 |
class FlaxLlamaBlockCollection(nn.Module):
|
641 |
config: LlamaConfig
|
642 |
-
dtype: jnp.dtype = jnp.
|
643 |
-
param_dtype: jnp.dtype = jnp.
|
644 |
precision: Optional[Union[jax.lax.Precision, str]] = None
|
645 |
|
646 |
def setup(self):
|
647 |
-
block = FlaxLlamaBlock
|
648 |
-
|
649 |
-
if self.config.gradient_checkpointing != '':
|
650 |
-
block = remat(
|
651 |
-
block, static_argnums=(3, 4, 5),
|
652 |
-
policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing)
|
653 |
-
)
|
654 |
-
|
655 |
self.blocks = [
|
656 |
-
|
|
|
657 |
for i in range(self.config.num_hidden_layers)
|
658 |
]
|
659 |
|
660 |
def __call__(
|
661 |
self,
|
662 |
-
hidden_states,
|
663 |
-
|
664 |
-
|
|
|
|
|
665 |
deterministic: bool = True,
|
666 |
init_cache: bool = False,
|
667 |
output_attentions: bool = False,
|
@@ -693,20 +771,21 @@ class FlaxLlamaBlockCollection(nn.Module):
|
|
693 |
all_hidden_states += (hidden_states,)
|
694 |
|
695 |
layer_outputs = block(
|
696 |
-
hidden_states,
|
697 |
-
|
698 |
-
|
699 |
-
|
700 |
-
|
701 |
-
|
702 |
-
|
|
|
|
|
703 |
)
|
704 |
hidden_states = layer_outputs[0]
|
705 |
|
706 |
if output_attentions:
|
707 |
all_attentions += (layer_outputs[1],)
|
708 |
|
709 |
-
# this contains possible `None` values - `FlaxGPTJModule` will filter them out
|
710 |
outputs = (hidden_states, all_hidden_states, all_attentions)
|
711 |
|
712 |
return outputs
|
@@ -714,14 +793,13 @@ class FlaxLlamaBlockCollection(nn.Module):
|
|
714 |
|
715 |
class FlaxLlamaModule(nn.Module):
|
716 |
config: LlamaConfig
|
717 |
-
dtype: jnp.dtype = jnp.
|
718 |
-
param_dtype: jnp.dtype = jnp.
|
719 |
precision: Optional[Union[jax.lax.Precision, str]] = None
|
720 |
|
721 |
def setup(self):
|
722 |
-
self.embed_dim = self.config.hidden_size
|
723 |
|
724 |
-
self.
|
725 |
self.config.vocab_size,
|
726 |
self.config.hidden_size,
|
727 |
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
@@ -729,30 +807,47 @@ class FlaxLlamaModule(nn.Module):
|
|
729 |
param_dtype=self.param_dtype,
|
730 |
)
|
731 |
self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
|
732 |
-
self.
|
733 |
-
|
734 |
-
self.
|
735 |
param_dtype=self.param_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
736 |
|
737 |
def __call__(
|
738 |
self,
|
739 |
-
input_ids,
|
740 |
-
attention_mask,
|
741 |
-
position_ids,
|
742 |
-
deterministic=True,
|
|
|
743 |
init_cache: bool = False,
|
744 |
output_attentions: bool = False,
|
745 |
output_hidden_states: bool = False,
|
746 |
return_dict: bool = True,
|
|
|
747 |
):
|
748 |
-
input_embeds
|
|
|
749 |
|
|
|
|
|
|
|
|
|
|
|
|
|
750 |
hidden_states = self.dropout(input_embeds, deterministic=deterministic)
|
751 |
|
752 |
-
outputs = self.
|
753 |
-
hidden_states,
|
754 |
-
|
|
|
755 |
position_ids=position_ids,
|
|
|
756 |
deterministic=deterministic,
|
757 |
init_cache=init_cache,
|
758 |
output_attentions=output_attentions,
|
@@ -761,7 +856,7 @@ class FlaxLlamaModule(nn.Module):
|
|
761 |
)
|
762 |
|
763 |
hidden_states = outputs[0]
|
764 |
-
hidden_states = self.
|
765 |
|
766 |
if output_hidden_states:
|
767 |
all_hidden_states = outputs[1] + (hidden_states,)
|
@@ -785,12 +880,16 @@ class FlaxLlamaModel(FlaxLlamaPreTrainedModel):
|
|
785 |
|
786 |
class FlaxLlamaForCausalLMModule(nn.Module):
|
787 |
config: LlamaConfig
|
788 |
-
dtype: jnp.dtype = jnp.
|
789 |
-
param_dtype: jnp.dtype = jnp.
|
790 |
precision: Optional[Union[jax.lax.Precision, str]] = None
|
791 |
|
792 |
def setup(self):
|
793 |
-
self.
|
|
|
|
|
|
|
|
|
794 |
self.lm_head = nn.Dense(
|
795 |
self.config.vocab_size,
|
796 |
dtype=self.dtype,
|
@@ -802,14 +901,15 @@ class FlaxLlamaForCausalLMModule(nn.Module):
|
|
802 |
|
803 |
def __call__(
|
804 |
self,
|
805 |
-
input_ids,
|
806 |
-
attention_mask=None,
|
807 |
-
position_ids=None,
|
808 |
deterministic: bool = True,
|
809 |
init_cache: bool = False,
|
810 |
output_attentions: bool = False,
|
811 |
output_hidden_states: bool = False,
|
812 |
return_dict: bool = True,
|
|
|
813 |
):
|
814 |
batch_size, seq_length = input_ids.shape
|
815 |
if attention_mask is None:
|
@@ -819,7 +919,7 @@ class FlaxLlamaForCausalLMModule(nn.Module):
|
|
819 |
jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
|
820 |
(batch_size, seq_length)
|
821 |
)
|
822 |
-
outputs = self.
|
823 |
input_ids,
|
824 |
attention_mask,
|
825 |
position_ids,
|
@@ -828,16 +928,19 @@ class FlaxLlamaForCausalLMModule(nn.Module):
|
|
828 |
output_attentions=output_attentions,
|
829 |
output_hidden_states=output_hidden_states,
|
830 |
return_dict=return_dict,
|
|
|
831 |
)
|
832 |
|
833 |
hidden_states = outputs[0]
|
834 |
|
835 |
if self.config.tie_word_embeddings:
|
836 |
-
shared_kernel = self.
|
837 |
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
|
838 |
else:
|
839 |
lm_logits = self.lm_head(hidden_states)
|
840 |
|
|
|
|
|
841 |
if not return_dict:
|
842 |
return (lm_logits,) + outputs[1:]
|
843 |
|
@@ -847,7 +950,7 @@ class FlaxLlamaForCausalLMModule(nn.Module):
|
|
847 |
class FlaxLlamaForCausalLM(FlaxLlamaPreTrainedModel):
|
848 |
module_class = FlaxLlamaForCausalLMModule
|
849 |
|
850 |
-
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[
|
851 |
batch_size, seq_length = input_ids.shape
|
852 |
|
853 |
past_key_values = self.init_cache(batch_size, max_length)
|
@@ -873,12 +976,12 @@ class FlaxLlamaForCausalLM(FlaxLlamaPreTrainedModel):
|
|
873 |
class FlaxLlamaForSequenceClassificationModule(nn.Module):
|
874 |
num_classes: int
|
875 |
config: LlamaConfig
|
876 |
-
dtype: jnp.dtype = jnp.
|
877 |
-
param_dtype: jnp.dtype = jnp.
|
878 |
precision: Optional[Union[jax.lax.Precision, str]] = None
|
879 |
|
880 |
def setup(self):
|
881 |
-
self.
|
882 |
self.classifier = nn.Dense(
|
883 |
self.num_classes,
|
884 |
dtype=self.dtype,
|
@@ -890,14 +993,15 @@ class FlaxLlamaForSequenceClassificationModule(nn.Module):
|
|
890 |
|
891 |
def __call__(
|
892 |
self,
|
893 |
-
input_ids,
|
894 |
-
attention_mask=None,
|
895 |
-
position_ids=None,
|
896 |
deterministic: bool = True,
|
897 |
init_cache: bool = False,
|
898 |
output_attentions: bool = False,
|
899 |
output_hidden_states: bool = False,
|
900 |
return_dict: bool = True,
|
|
|
901 |
):
|
902 |
batch_size, seq_length = input_ids.shape
|
903 |
if attention_mask is None:
|
@@ -907,7 +1011,7 @@ class FlaxLlamaForSequenceClassificationModule(nn.Module):
|
|
907 |
jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
|
908 |
(batch_size, seq_length)
|
909 |
)
|
910 |
-
outputs = self.
|
911 |
input_ids,
|
912 |
attention_mask,
|
913 |
position_ids,
|
@@ -916,6 +1020,7 @@ class FlaxLlamaForSequenceClassificationModule(nn.Module):
|
|
916 |
output_attentions=output_attentions,
|
917 |
output_hidden_states=output_hidden_states,
|
918 |
return_dict=return_dict,
|
|
|
919 |
)
|
920 |
|
921 |
hidden_states = outputs[0]
|
@@ -930,4 +1035,4 @@ class FlaxLlamaForSequenceClassificationModule(nn.Module):
|
|
930 |
|
931 |
|
932 |
class FlaxLlamaForSequenceClassification(FlaxLlamaPreTrainedModel):
|
933 |
-
module_class = FlaxLlamaForSequenceClassificationModule
|
|
|
1 |
+
from typing import Dict, Optional, Tuple, Union
|
2 |
+
from einops import einops
|
|
|
|
|
|
|
3 |
import jax
|
4 |
import jax.numpy as jnp
|
5 |
from jax import lax
|
|
|
8 |
from flax.linen.attention import dot_product_attention_weights
|
9 |
from flax.traverse_util import flatten_dict, unflatten_dict
|
10 |
from flax.linen import partitioning as nn_partitioning
|
|
|
11 |
from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
|
12 |
from flax.linen import combine_masks, make_causal_mask
|
|
|
13 |
from transformers.configuration_utils import PretrainedConfig
|
14 |
from transformers.modeling_flax_utils import FlaxPreTrainedModel
|
|
|
|
|
15 |
from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput, FlaxSequenceClassifierOutput
|
16 |
+
from fjformer.attention import blockwise_dot_product_attention
|
17 |
+
from ..flax_modelling_utils import with_sharding_constraint, \
|
18 |
+
get_gradient_checkpoint_policy, repeat_kv_bnsh, apply_rotary_pos_emb, precompute_freq_cis
|
19 |
+
import chex
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
|
22 |
class LlamaConfig(PretrainedConfig):
|
23 |
+
model_type = "llama"
|
24 |
|
25 |
def __init__(
|
26 |
self,
|
27 |
+
vocab_size: int = 32000,
|
28 |
+
hidden_size: int = 4096,
|
29 |
+
intermediate_size: int = 11008,
|
30 |
+
num_hidden_layers: int = 32,
|
31 |
+
num_attention_heads: int = 32,
|
32 |
+
number_rep_kv: int = 1,
|
33 |
+
num_key_value_heads: Optional[int] = None,
|
34 |
+
max_position_embeddings: int = 2048,
|
35 |
+
rms_norm_eps: float = 1e-6,
|
36 |
+
initializer_range: float = 0.02,
|
37 |
+
use_cache: bool = True,
|
38 |
+
bos_token_id: int = 0,
|
39 |
+
eos_token_id: int = 1,
|
40 |
+
resid_pdrop: float = 0.0,
|
41 |
+
embd_pdrop: float = 0.0,
|
42 |
+
attn_pdrop: float = 0.0,
|
43 |
+
tie_word_embeddings: bool = False,
|
44 |
+
gradient_checkpointing: str = 'nothing_saveable',
|
45 |
+
fcm_min_ratio: float = -1,
|
46 |
+
fcm_max_ratio: float = -1,
|
47 |
use_pjit_attention_force: bool = True,
|
48 |
+
rope_scaling: Dict[str, Union[str, float]] = None,
|
49 |
+
use_flash_attention: bool = False,
|
50 |
+
use_sacn_mlp: bool = False,
|
51 |
+
flash_attn_query_chunk_size: int = 1024,
|
52 |
+
flash_attn_key_chunk_size: int = 1024,
|
53 |
+
scan_mlp_chunk_size: int = 1024,
|
54 |
**kwargs,
|
55 |
):
|
56 |
+
num_key_value_heads = num_key_value_heads or number_rep_kv * num_attention_heads
|
57 |
+
self.num_key_value_heads = num_key_value_heads
|
|
|
|
|
|
|
58 |
self.vocab_size = vocab_size
|
59 |
+
|
60 |
+
self.number_rep_kv = number_rep_kv
|
61 |
self.hidden_size = hidden_size
|
62 |
self.initializer_range = initializer_range
|
63 |
self.intermediate_size = intermediate_size
|
64 |
self.num_hidden_layers = num_hidden_layers
|
65 |
+
|
66 |
self.num_attention_heads = num_attention_heads
|
67 |
+
self.max_position_embeddings = max_position_embeddings
|
68 |
self.rms_norm_eps = rms_norm_eps
|
69 |
self.use_cache = use_cache
|
70 |
self.resid_pdrop = resid_pdrop
|
|
|
75 |
self.fcm_min_ratio = fcm_min_ratio
|
76 |
self.fcm_max_ratio = fcm_max_ratio
|
77 |
self.rope_scaling = rope_scaling
|
78 |
+
self.use_flash_attention = use_flash_attention
|
79 |
+
self.use_sacn_mlp = use_sacn_mlp
|
80 |
+
self.flash_attn_key_chunk_size = flash_attn_key_chunk_size
|
81 |
+
self.flash_attn_query_chunk_size = flash_attn_query_chunk_size
|
82 |
+
self.scan_mlp_chunk_size = scan_mlp_chunk_size
|
83 |
+
|
84 |
super().__init__(
|
85 |
# pad_token_id=pad_token_id,
|
86 |
bos_token_id=bos_token_id,
|
|
|
93 |
def get_partition_rules(fully_fsdp: bool = True):
|
94 |
return (
|
95 |
|
96 |
+
("model/embed_tokens/embedding", PS("dp", "fsdp")),
|
97 |
|
98 |
+
("self_attn/(q_proj|k_proj|v_proj)/kernel", PS("fsdp", "dp")),
|
99 |
+
("self_attn/o_proj/kernel", PS("dp", "fsdp")),
|
100 |
|
101 |
+
("mlp/gate_proj/kernel", PS("fsdp", "dp")),
|
102 |
+
("mlp/down_proj/kernel", PS("dp", "fsdp")),
|
103 |
+
("mlp/up_proj/kernel", PS("fsdp", "dp")),
|
104 |
|
105 |
+
("input_layernorm/kernel", PS(None)),
|
106 |
+
("post_attention_layernorm/kernel", PS(None)),
|
107 |
|
108 |
+
("model/norm/kernel", PS(None)),
|
109 |
("lm_head/kernel", PS("fsdp", "dp")),
|
110 |
('.*', PS(None)),
|
111 |
) if not fully_fsdp else (
|
112 |
|
113 |
+
("model/embed_tokens/embedding", PS("fsdp")),
|
114 |
|
115 |
+
("self_attn/(q_proj|k_proj|v_proj)/kernel", PS("fsdp")),
|
116 |
+
("self_attn/o_proj/kernel", PS("fsdp")),
|
117 |
|
118 |
+
("mlp/gate_proj/kernel", PS("fsdp")),
|
119 |
+
("mlp/down_proj/kernel", PS("fsdp")),
|
120 |
+
("mlp/up_proj/kernel", PS("fsdp")),
|
121 |
|
122 |
+
("input_layernorm/kernel", PS(None)),
|
123 |
+
("post_attention_layernorm/kernel", PS(None)),
|
124 |
|
125 |
+
("model/norm/kernel", PS(None)),
|
126 |
("lm_head/kernel", PS("fsdp")),
|
127 |
+
('.*', PS('fsdp')),
|
128 |
)
|
129 |
|
130 |
+
def add_jax_args(self,
|
131 |
+
resid_pdrop: float = 0.0,
|
132 |
+
embd_pdrop: float = 0.0,
|
133 |
+
attn_pdrop: float = 0.0,
|
134 |
+
tie_word_embeddings: bool = False,
|
135 |
+
gradient_checkpointing: str = 'nothing_saveable',
|
136 |
+
fcm_min_ratio: float = 0.0,
|
137 |
+
fcm_max_ratio: float = 0.0,
|
138 |
+
use_pjit_attention_force: bool = True,
|
139 |
+
use_flash_attention: bool = False,
|
140 |
+
use_sacn_mlp: bool = False,
|
141 |
+
flash_attn_query_chunk_size: int = 1024,
|
142 |
+
flash_attn_key_chunk_size: int = 1024,
|
143 |
+
scan_mlp_chunk_size: int = 1024,
|
144 |
+
number_rep_kv: int = 1,
|
145 |
+
):
|
146 |
+
self.use_flash_attention = use_flash_attention
|
147 |
+
self.embd_pdrop = embd_pdrop
|
148 |
+
self.number_rep_kv = number_rep_kv
|
149 |
+
self.resid_pdrop = resid_pdrop
|
150 |
+
|
151 |
+
self.attn_pdrop = attn_pdrop
|
152 |
+
self.tie_word_embeddings = tie_word_embeddings
|
153 |
+
self.gradient_checkpointing = gradient_checkpointing
|
154 |
+
self.fcm_min_ratio = fcm_min_ratio
|
155 |
+
self.fcm_max_ratio = fcm_max_ratio
|
156 |
+
self.use_pjit_attention_force = use_pjit_attention_force
|
157 |
+
|
158 |
+
self.use_sacn_mlp = use_sacn_mlp
|
159 |
+
self.flash_attn_query_chunk_size = flash_attn_query_chunk_size
|
160 |
+
self.flash_attn_key_chunk_size = flash_attn_key_chunk_size
|
161 |
+
self.scan_mlp_chunk_size = scan_mlp_chunk_size
|
162 |
+
|
163 |
@staticmethod
|
164 |
def get_weight_decay_exclusions():
|
165 |
return tuple()
|
|
|
169 |
return ('params', 'dropout', 'fcm')
|
170 |
|
171 |
|
172 |
+
re_mat = nn_partitioning.remat
|
173 |
+
|
174 |
+
|
175 |
+
class FlaxLlamaEmbedding(nn.Module):
|
176 |
+
dtype: jnp.dtype = jnp.float32
|
177 |
+
|
178 |
+
def __call__(self, query, key, freq_cis, position_ids):
|
179 |
+
sin, cos = freq_cis
|
180 |
+
|
181 |
+
sin = sin[position_ids][:, None, :, :]
|
182 |
+
cos = cos[position_ids][:, None, :, :]
|
183 |
+
|
184 |
+
key = apply_rotary_pos_emb(key, sin, cos)
|
185 |
+
query = apply_rotary_pos_emb(query, sin, cos)
|
186 |
+
|
187 |
+
return query.astype(self.dtype), key.astype(self.dtype)
|
188 |
+
|
189 |
+
|
190 |
+
def repeat_kv(x: chex.Array, n_rep: int) -> chex.Array:
|
191 |
+
bs, s, n_kv_heads, head_dim = x.shape
|
192 |
+
if n_rep == 1:
|
193 |
+
return x
|
194 |
+
x = x[:, :, jnp.newaxis, :, :]
|
195 |
+
x = jnp.repeat(x, n_rep, axis=2)
|
196 |
+
|
197 |
+
return x.reshape(bs, s,
|
198 |
+
n_kv_heads * n_rep,
|
199 |
+
head_dim)
|
200 |
|
201 |
|
202 |
class RMSNorm(nn.Module):
|
203 |
dim: int
|
204 |
eps: float = 1e-6
|
205 |
+
dtype: jnp.dtype = jnp.float32
|
206 |
+
param_dtype: jnp.dtype = jnp.float32
|
207 |
|
208 |
def setup(self) -> None:
|
209 |
self.weight = self.param(
|
|
|
217 |
return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps)
|
218 |
|
219 |
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
|
220 |
+
x = x.astype(jnp.promote_types(self.dtype, jnp.float32))
|
221 |
output = self._norm(x).astype(self.dtype)
|
222 |
weight = jnp.asarray(self.weight, self.dtype)
|
223 |
return output * weight
|
224 |
|
225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
226 |
class FlaxLlamaAttention(nn.Module):
|
227 |
config: LlamaConfig
|
228 |
+
dtype: jnp.dtype = jnp.float32
|
229 |
+
param_dtype: jnp.dtype = jnp.float32
|
230 |
precision: Optional[Union[jax.lax.Precision, str]] = None
|
231 |
|
232 |
def setup(self):
|
233 |
config = self.config
|
234 |
+
self.hidden_size = config.hidden_size
|
235 |
+
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
236 |
+
self.number_of_reps = self.config.num_attention_heads // self.config.num_key_value_heads
|
237 |
|
238 |
+
if self.number_of_reps == 1:
|
239 |
+
assert self.config.num_attention_heads == self.config.num_key_value_heads
|
240 |
+
self.q_proj = nn.Dense(
|
241 |
config.num_attention_heads * self.head_dim,
|
242 |
dtype=self.dtype,
|
243 |
param_dtype=self.param_dtype,
|
244 |
use_bias=False,
|
245 |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
246 |
+
precision=self.precision
|
247 |
)
|
248 |
+
self.k_proj = nn.Dense(
|
249 |
+
config.num_key_value_heads * self.head_dim,
|
250 |
dtype=self.dtype,
|
251 |
param_dtype=self.param_dtype,
|
252 |
use_bias=False,
|
253 |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
254 |
+
precision=self.precision
|
255 |
)
|
256 |
+
self.v_proj = nn.Dense(
|
257 |
+
config.num_key_value_heads * self.head_dim,
|
258 |
dtype=self.dtype,
|
259 |
param_dtype=self.param_dtype,
|
260 |
use_bias=False,
|
261 |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
262 |
+
precision=self.precision
|
263 |
)
|
264 |
+
self.o_proj = nn.Dense(
|
265 |
config.hidden_size,
|
266 |
dtype=self.dtype,
|
267 |
param_dtype=self.param_dtype,
|
268 |
use_bias=False,
|
269 |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
270 |
+
precision=self.precision
|
271 |
)
|
272 |
|
273 |
+
self.rotary = FlaxLlamaEmbedding(self.dtype)
|
|
|
|
|
274 |
|
275 |
+
self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
276 |
|
277 |
def _merge_heads(self, hidden_states):
|
278 |
+
return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
|
279 |
|
280 |
@nn.compact
|
281 |
def _concatenate_to_cache(self, key, value, query, attention_mask):
|
|
|
294 |
cached_value.value = value
|
295 |
num_updated_cache_vectors = query.shape[1]
|
296 |
cache_index.value = cache_index.value + num_updated_cache_vectors
|
297 |
+
|
298 |
pad_mask = jnp.broadcast_to(
|
299 |
jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
|
300 |
tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
|
|
|
302 |
attention_mask = combine_masks(pad_mask, attention_mask)
|
303 |
return key, value, attention_mask
|
304 |
|
305 |
+
@staticmethod
|
306 |
+
def _t(query, key, value):
|
307 |
+
return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3))
|
308 |
+
|
309 |
+
def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids):
|
310 |
+
query = query.reshape(batch_size, sequence_length, self.config.num_attention_heads, self.head_dim)
|
311 |
+
key = key.reshape(batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)
|
312 |
+
value = value.reshape(batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)
|
313 |
+
|
314 |
+
query, key, value = self._t(query, key, value)
|
315 |
+
query, key = self.rotary(position_ids=position_ids, query=query, key=key, freq_cis=freq_cis)
|
316 |
+
key = repeat_kv_bnsh(key, self.number_of_reps)
|
317 |
+
value = repeat_kv_bnsh(value, self.number_of_reps)
|
318 |
+
return self._t(query, key, value)
|
319 |
+
|
320 |
def __call__(
|
321 |
self,
|
322 |
+
hidden_states: chex.Array,
|
323 |
+
freq_cis: chex.Array,
|
324 |
+
attention_mask: chex.Array,
|
325 |
+
position_ids: chex.Array,
|
326 |
+
causal_mask: chex.Array,
|
327 |
deterministic: bool = True,
|
328 |
init_cache: bool = False,
|
329 |
output_attentions: bool = False,
|
330 |
fcm_mask=None,
|
331 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
|
333 |
+
batch_size, sequence_length = hidden_states.shape[:2]
|
334 |
+
query_state, key_state, value_state = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(
|
335 |
+
hidden_states)
|
336 |
+
if self.config.use_pjit_attention_force:
|
337 |
+
query_state = with_sharding_constraint(query_state, PS(("dp", "fsdp"), None, "mp"))
|
338 |
+
key_state = with_sharding_constraint(key_state, PS(("dp", "fsdp"), None, "mp"))
|
339 |
+
value_state = with_sharding_constraint(value_state, PS(("dp", "fsdp"), None, "mp"))
|
340 |
+
|
341 |
+
query_state = query_state.reshape(batch_size, sequence_length, self.config.num_attention_heads, self.head_dim)
|
342 |
+
key_state = key_state.reshape(batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)
|
343 |
+
value_state = value_state.reshape(batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)
|
344 |
+
|
345 |
+
query_state, key_state, value_state = self.apply_rotary(
|
346 |
+
query=query_state,
|
347 |
+
key=key_state,
|
348 |
+
value=value_state,
|
349 |
+
position_ids=position_ids,
|
350 |
+
freq_cis=freq_cis,
|
351 |
+
batch_size=batch_size,
|
352 |
+
sequence_length=sequence_length
|
353 |
+
)
|
354 |
|
355 |
+
query_length, key_length = query_state.shape[1], key_state.shape[1]
|
356 |
|
357 |
if self.has_variable("cache", "cached_key"):
|
358 |
mask_shift = self.variables["cache"]["cache_index"]
|
359 |
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
360 |
causal_mask = lax.dynamic_slice(
|
361 |
+
causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
|
362 |
)
|
363 |
else:
|
364 |
+
causal_mask = causal_mask[:, :, :query_length, :key_length]
|
365 |
|
366 |
batch_size = hidden_states.shape[0]
|
367 |
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
|
|
368 |
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
369 |
attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
|
370 |
|
|
|
373 |
dropout_rng = self.make_rng("dropout")
|
374 |
|
375 |
if self.has_variable("cache", "cached_key") or init_cache:
|
376 |
+
key_state, value_state, attention_mask = self._concatenate_to_cache(key_state, value_state, query_state,
|
377 |
+
attention_mask)
|
378 |
|
379 |
attention_bias = lax.select(
|
380 |
attention_mask > 0,
|
381 |
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
382 |
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
383 |
)
|
384 |
+
if self.config.use_flash_attention and not (self.has_variable("cache", "cached_key") or init_cache):
|
385 |
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
386 |
+
attention_bias = lax.select(
|
387 |
+
attention_mask > 0,
|
388 |
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
389 |
+
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
390 |
+
)
|
391 |
+
attn_weights = None
|
392 |
+
attn_output = blockwise_dot_product_attention(
|
393 |
+
query_state,
|
394 |
+
key_state,
|
395 |
+
value_state,
|
396 |
+
bias=attention_bias,
|
397 |
+
deterministic=deterministic,
|
398 |
+
dropout_rng=dropout_rng,
|
399 |
+
attn_pdrop=self.config.attn_pdrop,
|
400 |
+
causal=True,
|
401 |
+
query_chunk_size=self.config.scan_query_chunk_size,
|
402 |
+
key_chunk_size=self.config.scan_key_chunk_size,
|
403 |
+
dtype=self.dtype,
|
404 |
+
policy=get_gradient_checkpoint_policy('nothing_saveable'),
|
405 |
+
precision=self.precision,
|
406 |
+
float32_logits=True,
|
407 |
+
)
|
408 |
+
if self.config.use_pjit_attention_force:
|
409 |
+
attn_output = with_sharding_constraint(attn_output, PS(("dp", "fsdp"), None, "mp", None))
|
410 |
+
attn_output = self._merge_heads(attn_output)
|
411 |
+
else:
|
412 |
+
attn_weights = dot_product_attention_weights(
|
413 |
+
query=query_state,
|
414 |
+
key=key_state,
|
415 |
+
bias=attention_bias,
|
416 |
+
dtype=jnp.promote_types(self.dtype, jnp.float32),
|
417 |
+
deterministic=deterministic,
|
418 |
+
dropout_rate=self.config.attn_pdrop,
|
419 |
+
precision=self.precision,
|
420 |
+
)
|
421 |
+
if self.config.use_pjit_attention_force:
|
422 |
+
attn_weights = with_sharding_constraint(attn_weights, PS(("dp", "fsdp"), "mp", None, None))
|
423 |
|
424 |
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_state)
|
425 |
+
attn_output = self._merge_heads(attn_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
426 |
|
427 |
+
attn_output = self.o_proj(attn_output)
|
|
|
|
|
428 |
attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
|
429 |
outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
|
430 |
return outputs
|
|
|
432 |
|
433 |
class FlaxLlamaMLP(nn.Module):
|
434 |
config: LlamaConfig
|
435 |
+
dtype: jnp.dtype = jnp.float32
|
436 |
+
param_dtype: jnp.dtype = jnp.float32
|
437 |
precision: Optional[Union[jax.lax.Precision, str]] = None
|
438 |
|
439 |
def setup(self) -> None:
|
440 |
config = self.config
|
441 |
|
442 |
+
self.gate_proj = nn.Dense(
|
443 |
config.intermediate_size,
|
444 |
dtype=self.dtype,
|
445 |
param_dtype=self.param_dtype,
|
|
|
447 |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
448 |
precision=self.precision,
|
449 |
)
|
450 |
+
self.down_proj = nn.Dense(
|
451 |
config.hidden_size,
|
452 |
dtype=self.dtype,
|
453 |
param_dtype=self.param_dtype,
|
|
|
455 |
kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
|
456 |
precision=self.precision,
|
457 |
)
|
458 |
+
self.up_proj = nn.Dense(
|
459 |
config.intermediate_size,
|
460 |
dtype=self.dtype,
|
461 |
param_dtype=self.param_dtype,
|
|
|
466 |
self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
|
467 |
|
468 |
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
469 |
+
x = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
|
470 |
x = self.dropout(x, deterministic=deterministic)
|
471 |
return x
|
472 |
|
473 |
|
474 |
class FlaxLlamaBlock(nn.Module):
|
475 |
config: LlamaConfig
|
476 |
+
dtype: jnp.dtype = jnp.float32
|
477 |
+
param_dtype: jnp.dtype = jnp.float32
|
478 |
precision: Optional[Union[jax.lax.Precision, str]] = None
|
479 |
|
480 |
def setup(self) -> None:
|
481 |
+
attn_block = FlaxLlamaAttention
|
482 |
+
if self.config.gradient_checkpointing != '':
|
483 |
+
attn_block = re_mat(
|
484 |
+
FlaxLlamaAttention, static_argnums=(5, 6, 7),
|
485 |
+
policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing)
|
486 |
+
)
|
487 |
+
|
488 |
+
self.self_attn = attn_block(
|
489 |
self.config,
|
490 |
dtype=self.dtype,
|
491 |
param_dtype=self.param_dtype,
|
492 |
+
precision=self.precision
|
493 |
)
|
494 |
+
mlp_block = FlaxLlamaMLP
|
495 |
+
|
496 |
+
if self.config.gradient_checkpointing != '':
|
497 |
+
mlp_block = re_mat(
|
498 |
+
FlaxLlamaMLP, static_argnums=(1,),
|
499 |
+
policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing)
|
500 |
+
)
|
501 |
+
|
502 |
+
self.mlp = mlp_block(
|
503 |
self.config,
|
504 |
dtype=self.dtype,
|
505 |
param_dtype=self.param_dtype,
|
506 |
precision=self.precision,
|
507 |
)
|
508 |
+
self.input_layernorm = RMSNorm(
|
509 |
self.config.hidden_size,
|
510 |
eps=self.config.rms_norm_eps,
|
511 |
dtype=self.dtype,
|
512 |
param_dtype=self.param_dtype,
|
513 |
)
|
514 |
+
self.post_attention_layernorm = RMSNorm(
|
515 |
self.config.hidden_size,
|
516 |
eps=self.config.rms_norm_eps,
|
517 |
dtype=self.dtype,
|
518 |
param_dtype=self.param_dtype,
|
519 |
+
|
520 |
)
|
521 |
|
522 |
def __call__(
|
523 |
self,
|
524 |
+
hidden_states: chex.Array,
|
525 |
+
freq_cis: chex.Array,
|
526 |
+
attention_mask: chex.Array,
|
527 |
+
position_ids: chex.Array,
|
528 |
+
causal_mask: chex.Array,
|
529 |
deterministic: bool = True,
|
530 |
init_cache: bool = False,
|
531 |
output_attentions: bool = False,
|
532 |
fcm_mask: Optional[jnp.ndarray] = None,
|
533 |
):
|
534 |
+
attn_outputs = self.self_attn(
|
535 |
+
self.input_layernorm(hidden_states),
|
536 |
+
freq_cis,
|
537 |
+
attention_mask,
|
538 |
+
position_ids,
|
539 |
+
causal_mask,
|
540 |
+
deterministic,
|
541 |
+
init_cache,
|
542 |
+
output_attentions,
|
543 |
+
fcm_mask,
|
544 |
)
|
545 |
attn_output = attn_outputs[0]
|
546 |
hidden_states = hidden_states + attn_output
|
547 |
|
548 |
+
feed_forward_input = self.post_attention_layernorm(hidden_states)
|
549 |
+
|
550 |
+
if self.config.use_sacn_mlp:
|
551 |
+
feed_forward_input = einops.rearrange(
|
552 |
+
feed_forward_input,
|
553 |
+
'... (b s) d -> ... b s d',
|
554 |
+
b=self.config.scan_mlp_chunk_size
|
555 |
+
)
|
556 |
+
|
557 |
+
def mlp_forward(mlp, carry, x):
|
558 |
+
return None, mlp(x, deterministic)
|
559 |
+
|
560 |
+
scan_axis = feed_forward_input.ndim - 3
|
561 |
+
|
562 |
+
_, feed_forward_hidden_states = nn.scan(
|
563 |
+
mlp_forward,
|
564 |
+
variable_broadcast="params",
|
565 |
+
split_rngs={"params": False, "dropout": True},
|
566 |
+
in_axes=scan_axis,
|
567 |
+
out_axes=scan_axis,
|
568 |
+
)(self.mlp, None, feed_forward_input)
|
569 |
+
feed_forward_hidden_states = einops.rearrange(
|
570 |
+
feed_forward_hidden_states,
|
571 |
+
'... b s d -> ... (b s) d'
|
572 |
+
)
|
573 |
+
else:
|
574 |
+
feed_forward_hidden_states = self.mlp(
|
575 |
+
feed_forward_input,
|
576 |
+
deterministic,
|
577 |
+
)
|
578 |
+
|
579 |
hidden_states = hidden_states + feed_forward_hidden_states
|
580 |
|
581 |
return (hidden_states,) + attn_outputs[1:]
|
|
|
583 |
|
584 |
class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
|
585 |
config_class = LlamaConfig
|
586 |
+
base_model_prefix = "model"
|
587 |
module_class: nn.Module = None
|
588 |
|
589 |
def __init__(
|
|
|
591 |
config: LlamaConfig,
|
592 |
input_shape: Tuple = (1, 1),
|
593 |
seed: int = 0,
|
594 |
+
dtype: jnp.dtype = jnp.float32,
|
595 |
_do_init: bool = True,
|
596 |
**kwargs,
|
597 |
):
|
|
|
646 |
|
647 |
def __call__(
|
648 |
self,
|
649 |
+
input_ids: chex.Array,
|
650 |
+
attention_mask: chex.Array = None,
|
651 |
+
position_ids: chex.Array = None,
|
652 |
params: dict = None,
|
653 |
past_key_values: dict = None,
|
654 |
dropout_rng: jax.random.PRNGKey = None,
|
|
|
656 |
output_attentions: Optional[bool] = None,
|
657 |
output_hidden_states: Optional[bool] = None,
|
658 |
return_dict: Optional[bool] = None,
|
659 |
+
extra_embedding: Optional[Union[jnp.ndarray, None]] = None,
|
660 |
add_params_field: bool = False
|
661 |
):
|
662 |
+
|
663 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
664 |
output_hidden_states = (
|
665 |
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
|
668 |
|
669 |
batch_size, sequence_length = input_ids.shape
|
670 |
|
671 |
+
assert sequence_length <= self.config.max_position_embeddings, (f'Position out of range '
|
672 |
+
f'(Model Support '
|
673 |
+
f'{self.config.max_position_embeddings} got'
|
674 |
+
f' {sequence_length})')
|
675 |
+
|
676 |
if position_ids is None:
|
677 |
if past_key_values is not None:
|
678 |
raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
|
|
|
704 |
output_attentions,
|
705 |
output_hidden_states,
|
706 |
return_dict,
|
707 |
+
extra_embedding,
|
708 |
rngs=rngs,
|
709 |
mutable=mutable,
|
710 |
)
|
|
|
722 |
|
723 |
class FlaxLlamaBlockCollection(nn.Module):
|
724 |
config: LlamaConfig
|
725 |
+
dtype: jnp.dtype = jnp.float32
|
726 |
+
param_dtype: jnp.dtype = jnp.float32
|
727 |
precision: Optional[Union[jax.lax.Precision, str]] = None
|
728 |
|
729 |
def setup(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
730 |
self.blocks = [
|
731 |
+
FlaxLlamaBlock(self.config, name=str(i), dtype=self.dtype, param_dtype=self.param_dtype,
|
732 |
+
precision=self.precision)
|
733 |
for i in range(self.config.num_hidden_layers)
|
734 |
]
|
735 |
|
736 |
def __call__(
|
737 |
self,
|
738 |
+
hidden_states: chex.Array,
|
739 |
+
freq_cis: chex.Array,
|
740 |
+
attention_mask: chex.Array,
|
741 |
+
position_ids: chex.Array,
|
742 |
+
causal_mask: chex.Array,
|
743 |
deterministic: bool = True,
|
744 |
init_cache: bool = False,
|
745 |
output_attentions: bool = False,
|
|
|
771 |
all_hidden_states += (hidden_states,)
|
772 |
|
773 |
layer_outputs = block(
|
774 |
+
hidden_states=hidden_states,
|
775 |
+
freq_cis=freq_cis,
|
776 |
+
attention_mask=attention_mask,
|
777 |
+
position_ids=position_ids,
|
778 |
+
causal_mask=causal_mask,
|
779 |
+
deterministic=deterministic,
|
780 |
+
init_cache=init_cache,
|
781 |
+
output_attentions=output_attentions,
|
782 |
+
fcm_mask=fcm_mask,
|
783 |
)
|
784 |
hidden_states = layer_outputs[0]
|
785 |
|
786 |
if output_attentions:
|
787 |
all_attentions += (layer_outputs[1],)
|
788 |
|
|
|
789 |
outputs = (hidden_states, all_hidden_states, all_attentions)
|
790 |
|
791 |
return outputs
|
|
|
793 |
|
794 |
class FlaxLlamaModule(nn.Module):
|
795 |
config: LlamaConfig
|
796 |
+
dtype: jnp.dtype = jnp.float32
|
797 |
+
param_dtype: jnp.dtype = jnp.float32
|
798 |
precision: Optional[Union[jax.lax.Precision, str]] = None
|
799 |
|
800 |
def setup(self):
|
|
|
801 |
|
802 |
+
self.embed_tokens = nn.Embed(
|
803 |
self.config.vocab_size,
|
804 |
self.config.hidden_size,
|
805 |
embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
|
|
|
807 |
param_dtype=self.param_dtype,
|
808 |
)
|
809 |
self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
|
810 |
+
self.layers = FlaxLlamaBlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype,
|
811 |
+
precision=self.precision)
|
812 |
+
self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype,
|
813 |
param_dtype=self.param_dtype)
|
814 |
+
config = self.config
|
815 |
+
self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings)))
|
816 |
+
self.freq_cis = precompute_freq_cis(
|
817 |
+
max_position_embedding=config.max_position_embeddings,
|
818 |
+
head_dim=config.hidden_size // config.num_attention_heads
|
819 |
+
)
|
820 |
|
821 |
def __call__(
|
822 |
self,
|
823 |
+
input_ids: chex.Array,
|
824 |
+
attention_mask: chex.Array,
|
825 |
+
position_ids: chex.Array,
|
826 |
+
deterministic: bool = True,
|
827 |
+
input_embeds: chex.Array = None,
|
828 |
init_cache: bool = False,
|
829 |
output_attentions: bool = False,
|
830 |
output_hidden_states: bool = False,
|
831 |
return_dict: bool = True,
|
832 |
+
extra_embedding: Optional[Union[jnp.ndarray, None]] = None
|
833 |
):
|
834 |
+
if input_embeds is None:
|
835 |
+
input_embeds = self.embed_tokens(input_ids.astype("i4"))
|
836 |
|
837 |
+
batch_size, sequence_length = input_ids.shape
|
838 |
+
assert sequence_length <= self.config.max_position_embeddings, (f'Position out of range '
|
839 |
+
f'(Model Support '
|
840 |
+
f'{self.config.max_position_embeddings} got'
|
841 |
+
f' {sequence_length})')
|
842 |
+
input_embeds = input_embeds + extra_embedding if extra_embedding is not None else input_embeds
|
843 |
hidden_states = self.dropout(input_embeds, deterministic=deterministic)
|
844 |
|
845 |
+
outputs = self.layers(
|
846 |
+
hidden_states=hidden_states,
|
847 |
+
freq_cis=self.freq_cis,
|
848 |
+
attention_mask=attention_mask,
|
849 |
position_ids=position_ids,
|
850 |
+
causal_mask=self.causal_mask,
|
851 |
deterministic=deterministic,
|
852 |
init_cache=init_cache,
|
853 |
output_attentions=output_attentions,
|
|
|
856 |
)
|
857 |
|
858 |
hidden_states = outputs[0]
|
859 |
+
hidden_states = self.norm(hidden_states)
|
860 |
|
861 |
if output_hidden_states:
|
862 |
all_hidden_states = outputs[1] + (hidden_states,)
|
|
|
880 |
|
881 |
class FlaxLlamaForCausalLMModule(nn.Module):
|
882 |
config: LlamaConfig
|
883 |
+
dtype: jnp.dtype = jnp.float32
|
884 |
+
param_dtype: jnp.dtype = jnp.float32
|
885 |
precision: Optional[Union[jax.lax.Precision, str]] = None
|
886 |
|
887 |
def setup(self):
|
888 |
+
self.model = FlaxLlamaModule(self.config,
|
889 |
+
dtype=self.dtype,
|
890 |
+
param_dtype=self.param_dtype,
|
891 |
+
precision=self.precision,
|
892 |
+
)
|
893 |
self.lm_head = nn.Dense(
|
894 |
self.config.vocab_size,
|
895 |
dtype=self.dtype,
|
|
|
901 |
|
902 |
def __call__(
|
903 |
self,
|
904 |
+
input_ids: chex.Array,
|
905 |
+
attention_mask: chex.Array = None,
|
906 |
+
position_ids: chex.Array = None,
|
907 |
deterministic: bool = True,
|
908 |
init_cache: bool = False,
|
909 |
output_attentions: bool = False,
|
910 |
output_hidden_states: bool = False,
|
911 |
return_dict: bool = True,
|
912 |
+
extra_embedding: Optional[Union[jnp.ndarray, None]] = None
|
913 |
):
|
914 |
batch_size, seq_length = input_ids.shape
|
915 |
if attention_mask is None:
|
|
|
919 |
jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
|
920 |
(batch_size, seq_length)
|
921 |
)
|
922 |
+
outputs = self.model(
|
923 |
input_ids,
|
924 |
attention_mask,
|
925 |
position_ids,
|
|
|
928 |
output_attentions=output_attentions,
|
929 |
output_hidden_states=output_hidden_states,
|
930 |
return_dict=return_dict,
|
931 |
+
extra_embedding=extra_embedding
|
932 |
)
|
933 |
|
934 |
hidden_states = outputs[0]
|
935 |
|
936 |
if self.config.tie_word_embeddings:
|
937 |
+
shared_kernel = self.model.variables["params"]["embed_tokens"]["embedding"].T
|
938 |
lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
|
939 |
else:
|
940 |
lm_logits = self.lm_head(hidden_states)
|
941 |
|
942 |
+
lm_logits = lm_logits.astype(jnp.float32)
|
943 |
+
|
944 |
if not return_dict:
|
945 |
return (lm_logits,) + outputs[1:]
|
946 |
|
|
|
950 |
class FlaxLlamaForCausalLM(FlaxLlamaPreTrainedModel):
|
951 |
module_class = FlaxLlamaForCausalLMModule
|
952 |
|
953 |
+
def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None):
|
954 |
batch_size, seq_length = input_ids.shape
|
955 |
|
956 |
past_key_values = self.init_cache(batch_size, max_length)
|
|
|
976 |
class FlaxLlamaForSequenceClassificationModule(nn.Module):
|
977 |
num_classes: int
|
978 |
config: LlamaConfig
|
979 |
+
dtype: jnp.dtype = jnp.float32
|
980 |
+
param_dtype: jnp.dtype = jnp.float32
|
981 |
precision: Optional[Union[jax.lax.Precision, str]] = None
|
982 |
|
983 |
def setup(self):
|
984 |
+
self.model = FlaxLlamaModule(self.config, dtype=self.dtype)
|
985 |
self.classifier = nn.Dense(
|
986 |
self.num_classes,
|
987 |
dtype=self.dtype,
|
|
|
993 |
|
994 |
def __call__(
|
995 |
self,
|
996 |
+
input_ids: chex.Array,
|
997 |
+
attention_mask: chex.Array = None,
|
998 |
+
position_ids: chex.Array = None,
|
999 |
deterministic: bool = True,
|
1000 |
init_cache: bool = False,
|
1001 |
output_attentions: bool = False,
|
1002 |
output_hidden_states: bool = False,
|
1003 |
return_dict: bool = True,
|
1004 |
+
extra_embedding: Optional[Union[jnp.ndarray, None]] = None
|
1005 |
):
|
1006 |
batch_size, seq_length = input_ids.shape
|
1007 |
if attention_mask is None:
|
|
|
1011 |
jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
|
1012 |
(batch_size, seq_length)
|
1013 |
)
|
1014 |
+
outputs = self.model(
|
1015 |
input_ids,
|
1016 |
attention_mask,
|
1017 |
position_ids,
|
|
|
1020 |
output_attentions=output_attentions,
|
1021 |
output_hidden_states=output_hidden_states,
|
1022 |
return_dict=return_dict,
|
1023 |
+
extra_embedding=extra_embedding
|
1024 |
)
|
1025 |
|
1026 |
hidden_states = outputs[0]
|
|
|
1035 |
|
1036 |
|
1037 |
class FlaxLlamaForSequenceClassification(FlaxLlamaPreTrainedModel):
|
1038 |
+
module_class = FlaxLlamaForSequenceClassificationModule
|