erfanzar commited on
Commit
885054d
1 Parent(s): 36691a3

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +402 -297
model.py CHANGED
@@ -1,8 +1,5 @@
1
- import os
2
- from typing import Any, Dict, List, Optional, Tuple, Union
3
-
4
- from flax.linen import remat
5
-
6
  import jax
7
  import jax.numpy as jnp
8
  from jax import lax
@@ -11,93 +8,63 @@ import flax.linen as nn
11
  from flax.linen.attention import dot_product_attention_weights
12
  from flax.traverse_util import flatten_dict, unflatten_dict
13
  from flax.linen import partitioning as nn_partitioning
14
-
15
  from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
16
  from flax.linen import combine_masks, make_causal_mask
17
-
18
  from transformers.configuration_utils import PretrainedConfig
19
  from transformers.modeling_flax_utils import FlaxPreTrainedModel
20
-
21
- from jax.interpreters import pxla
22
  from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput, FlaxSequenceClassifierOutput
23
-
24
- from jax.experimental.pjit import with_sharding_constraint as wsc
25
-
26
-
27
- def get_names_from_parition_spec(partition_specs):
28
- names = set()
29
- if isinstance(partition_specs, dict):
30
- partition_specs = partition_specs.values()
31
- for item in partition_specs:
32
- if item is None:
33
- continue
34
- elif isinstance(item, str):
35
- names.add(item)
36
- else:
37
- names.update(get_names_from_parition_spec(item))
38
-
39
- return list(names)
40
-
41
-
42
- def names_in_mesh(*names):
43
- return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names)
44
-
45
-
46
- def with_sharding_constraint(x, partition_specs):
47
- axis_names = get_names_from_parition_spec(partition_specs)
48
- if names_in_mesh(*axis_names):
49
- x = wsc(x, partition_specs)
50
- return x
51
-
52
-
53
- def get_gradient_checkpoint_policy(name):
54
- return {
55
- 'everything_saveable': jax.checkpoint_policies.everything_saveable,
56
- 'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
57
- 'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
58
- 'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
59
- }[name]
60
 
61
 
62
  class LlamaConfig(PretrainedConfig):
63
- model_type = "Llama"
64
 
65
  def __init__(
66
  self,
67
- vocab_size=32000,
68
- hidden_size=4096,
69
- intermediate_size=11008,
70
- num_hidden_layers=32,
71
- num_attention_heads=32,
72
- max_sequence_length=2048,
73
- rms_norm_eps=1e-6,
74
- initializer_range=0.02,
75
- use_cache=True,
76
- bos_token_id=0,
77
- eos_token_id=1,
78
- resid_pdrop=0.0,
79
- embd_pdrop=0.0,
80
- attn_pdrop=0.0,
81
- tie_word_embeddings=False,
82
- gradient_checkpointing='nothing_saveable',
83
- fcm_min_ratio=0.0,
84
- fcm_max_ratio=0.0,
 
 
85
  use_pjit_attention_force: bool = True,
86
- rope_scaling=None,
 
 
 
 
 
87
  **kwargs,
88
  ):
89
- if rope_scaling is None:
90
- rope_scaling = {
91
- "factor": 8.0,
92
- "type": "linear"
93
- }
94
  self.vocab_size = vocab_size
 
 
95
  self.hidden_size = hidden_size
96
  self.initializer_range = initializer_range
97
  self.intermediate_size = intermediate_size
98
  self.num_hidden_layers = num_hidden_layers
 
99
  self.num_attention_heads = num_attention_heads
100
- self.max_sequence_length = max_sequence_length
101
  self.rms_norm_eps = rms_norm_eps
102
  self.use_cache = use_cache
103
  self.resid_pdrop = resid_pdrop
@@ -108,6 +75,12 @@ class LlamaConfig(PretrainedConfig):
108
  self.fcm_min_ratio = fcm_min_ratio
109
  self.fcm_max_ratio = fcm_max_ratio
110
  self.rope_scaling = rope_scaling
 
 
 
 
 
 
111
  super().__init__(
112
  # pad_token_id=pad_token_id,
113
  bos_token_id=bos_token_id,
@@ -120,40 +93,73 @@ class LlamaConfig(PretrainedConfig):
120
  def get_partition_rules(fully_fsdp: bool = True):
121
  return (
122
 
123
- ("transformer/wte/embedding", PS("dp", "fsdp")),
124
 
125
- ("attention/(wq|wk|wv)/kernel", PS("fsdp", "dp")),
126
- ("attention/wo/kernel", PS("dp", "fsdp")),
127
 
128
- ("feed_forward/w1/kernel", PS("fsdp", "dp")),
129
- ("feed_forward/w2/kernel", PS("dp", "fsdp")),
130
- ("feed_forward/w3/kernel", PS("fsdp", "dp")),
131
 
132
- ("attention_norm/kernel", PS(None)),
133
- ("ffn_norm/kernel", PS(None)),
134
 
135
- ("transformer/ln_f/kernel", PS(None)),
136
  ("lm_head/kernel", PS("fsdp", "dp")),
137
  ('.*', PS(None)),
138
  ) if not fully_fsdp else (
139
 
140
- ("transformer/wte/embedding", PS("fsdp")),
141
 
142
- ("attention/(wq|wk|wv)/kernel", PS("fsdp")),
143
- ("attention/wo/kernel", PS("fsdp")),
144
 
145
- ("feed_forward/w1/kernel", PS("fsdp")),
146
- ("feed_forward/w2/kernel", PS("fsdp")),
147
- ("feed_forward/w3/kernel", PS("fsdp")),
148
 
149
- ("attention_norm/kernel", PS(None)),
150
- ("ffn_norm/kernel", PS(None)),
151
 
152
- ("transformer/ln_f/kernel", PS(None)),
153
  ("lm_head/kernel", PS("fsdp")),
154
- ('.*', PS(None)),
155
  )
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  @staticmethod
158
  def get_weight_decay_exclusions():
159
  return tuple()
@@ -163,14 +169,41 @@ class LlamaConfig(PretrainedConfig):
163
  return ('params', 'dropout', 'fcm')
164
 
165
 
166
- remat = nn_partitioning.remat
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
 
169
  class RMSNorm(nn.Module):
170
  dim: int
171
  eps: float = 1e-6
172
- dtype: jnp.dtype = jnp.bfloat16
173
- param_dtype: jnp.dtype = jnp.bfloat16
174
 
175
  def setup(self) -> None:
176
  self.weight = self.param(
@@ -184,124 +217,65 @@ class RMSNorm(nn.Module):
184
  return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps)
185
 
186
  def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
187
- x = x.astype(jnp.promote_types(self.dtype, jnp.bfloat16))
188
  output = self._norm(x).astype(self.dtype)
189
  weight = jnp.asarray(self.weight, self.dtype)
190
  return output * weight
191
 
192
 
193
- def rotate_half(x):
194
- x1 = x[..., : x.shape[-1] // 2]
195
- x2 = x[..., x.shape[-1] // 2:]
196
- return jnp.concatenate([-x2, x1], axis=-1)
197
-
198
-
199
- def precompute_freqs_cis(
200
- method: str,
201
- dim: int, end: int, theta: float = 10000.0,
202
- scaling_factor: float = 8.,
203
- dtype: jnp.dtype = jnp.bfloat16) -> jnp.ndarray:
204
- if method == 'linear':
205
- freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
206
- elif method == 'dynamic':
207
- base = theta * (
208
- (scaling_factor * end / end) - (scaling_factor - 1)
209
- ) ** (dim / (dim - 2))
210
- freqs = 1.0 / (base ** (jnp.arange(0, dim, 2) / dim))
211
- else:
212
- raise ValueError(f'unknown {method} method for precompute_freqs_cis')
213
- t = jnp.arange(end) # type: ignore
214
- freqs = jnp.outer(t, freqs).astype(dtype)
215
- sin, cos = jnp.sin(freqs), jnp.cos(freqs)
216
- freqs_cis = jnp.complex64(cos + 1j * sin)
217
- return jnp.asarray(freqs_cis)
218
-
219
-
220
- def apply_rotary_emb(
221
- xq: jnp.ndarray,
222
- xk: jnp.ndarray,
223
- freqs_cis: jnp.ndarray,
224
- dtype: jnp.dtype = jnp.bfloat16,
225
- ) -> Tuple[jnp.ndarray, jnp.ndarray]:
226
- reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
227
- reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
228
-
229
- xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
230
- xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
231
-
232
- freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:]))
233
-
234
- xq_out = xq_ * freqs_cis
235
- xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1)
236
-
237
- xk_out = xk_ * freqs_cis
238
- xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)
239
-
240
- return xq_out.astype(dtype), xk_out.astype(dtype)
241
-
242
-
243
  class FlaxLlamaAttention(nn.Module):
244
  config: LlamaConfig
245
- dtype: jnp.dtype = jnp.bfloat16
246
- param_dtype: jnp.dtype = jnp.bfloat16
247
  precision: Optional[Union[jax.lax.Precision, str]] = None
248
 
249
  def setup(self):
250
  config = self.config
251
- self.embed_dim = config.hidden_size
252
- self.num_heads = config.num_attention_heads
253
- self.head_dim = self.embed_dim // self.num_heads
254
 
255
- self.wq = nn.Dense(
 
 
256
  config.num_attention_heads * self.head_dim,
257
  dtype=self.dtype,
258
  param_dtype=self.param_dtype,
259
  use_bias=False,
260
  kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
261
- precision=self.precision,
262
  )
263
- self.wk = nn.Dense(
264
- config.num_attention_heads * self.head_dim,
265
  dtype=self.dtype,
266
  param_dtype=self.param_dtype,
267
  use_bias=False,
268
  kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
269
- precision=self.precision,
270
  )
271
- self.wv = nn.Dense(
272
- config.num_attention_heads * self.head_dim,
273
  dtype=self.dtype,
274
  param_dtype=self.param_dtype,
275
  use_bias=False,
276
  kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
277
- precision=self.precision,
278
  )
279
- self.wo = nn.Dense(
280
  config.hidden_size,
281
  dtype=self.dtype,
282
  param_dtype=self.param_dtype,
283
  use_bias=False,
284
  kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
285
- precision=self.precision,
286
  )
287
 
288
- self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
289
-
290
- self.causal_mask = make_causal_mask(jnp.ones((1, config.max_sequence_length), dtype="bool"), dtype="bool")
291
 
292
- self.freqs_cis = precompute_freqs_cis(
293
- method=self.config.rope_scaling['type'],
294
- scaling_factor=float(self.config.rope_scaling['factor']),
295
- dim=self.head_dim,
296
- end=config.max_sequence_length * 2,
297
- dtype=self.dtype,
298
- )
299
-
300
- def _split_heads(self, hidden_states):
301
- return hidden_states.reshape(hidden_states.shape[:2] + (self.num_heads, self.head_dim))
302
 
303
  def _merge_heads(self, hidden_states):
304
- return hidden_states.reshape(hidden_states.shape[:2] + (self.embed_dim,))
305
 
306
  @nn.compact
307
  def _concatenate_to_cache(self, key, value, query, attention_mask):
@@ -320,6 +294,7 @@ class FlaxLlamaAttention(nn.Module):
320
  cached_value.value = value
321
  num_updated_cache_vectors = query.shape[1]
322
  cache_index.value = cache_index.value + num_updated_cache_vectors
 
323
  pad_mask = jnp.broadcast_to(
324
  jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
325
  tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
@@ -327,44 +302,69 @@ class FlaxLlamaAttention(nn.Module):
327
  attention_mask = combine_masks(pad_mask, attention_mask)
328
  return key, value, attention_mask
329
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  def __call__(
331
  self,
332
- hidden_states,
333
- attention_mask,
334
- position_ids,
 
 
335
  deterministic: bool = True,
336
  init_cache: bool = False,
337
  output_attentions: bool = False,
338
  fcm_mask=None,
339
  ):
340
- xq, xk, xv = self.wq(hidden_states), self.wk(hidden_states), self.wv(hidden_states)
341
- if self.config.use_pjit_attention_force:
342
- xq = with_sharding_constraint(xq, PS(("dp", "fsdp"), None, "mp"))
343
- xk = with_sharding_constraint(xk, PS(("dp", "fsdp"), None, "mp"))
344
- xv = with_sharding_constraint(xv, PS(("dp", "fsdp"), None, "mp"))
345
-
346
- xq = self._split_heads(xq)
347
- xk = self._split_heads(xk)
348
- xv = self._split_heads(xv)
349
 
350
- freqs_cis = jnp.take(self.freqs_cis, position_ids, axis=0)
351
-
352
- xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, dtype=self.dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
 
354
- query_length, key_length = xq.shape[1], xk.shape[1]
355
 
356
  if self.has_variable("cache", "cached_key"):
357
  mask_shift = self.variables["cache"]["cache_index"]
358
  max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
359
  causal_mask = lax.dynamic_slice(
360
- self.causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
361
  )
362
  else:
363
- causal_mask = self.causal_mask[:, :, :query_length, :key_length]
364
 
365
  batch_size = hidden_states.shape[0]
366
  causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
367
-
368
  attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
369
  attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
370
 
@@ -373,30 +373,58 @@ class FlaxLlamaAttention(nn.Module):
373
  dropout_rng = self.make_rng("dropout")
374
 
375
  if self.has_variable("cache", "cached_key") or init_cache:
376
- xk, xv, attention_mask = self._concatenate_to_cache(xk, xv, xq, attention_mask)
 
377
 
378
  attention_bias = lax.select(
379
  attention_mask > 0,
380
  jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
381
  jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
382
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
383
 
384
- attn_weights = dot_product_attention_weights(
385
- xq,
386
- xk,
387
- bias=attention_bias,
388
- dropout_rng=dropout_rng,
389
- dropout_rate=self.config.attn_pdrop,
390
- deterministic=deterministic,
391
- dtype=jnp.promote_types(self.dtype, jnp.bfloat16),
392
- precision=self.precision,
393
- )
394
- if self.config.use_pjit_attention_force:
395
- attn_weights = with_sharding_constraint(attn_weights, PS(("dp", "fsdp"), "mp", None, None))
396
 
397
- attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, xv, precision=self.precision)
398
- attn_output = self._merge_heads(attn_output)
399
- attn_output = self.wo(attn_output)
400
  attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
401
  outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
402
  return outputs
@@ -404,14 +432,14 @@ class FlaxLlamaAttention(nn.Module):
404
 
405
  class FlaxLlamaMLP(nn.Module):
406
  config: LlamaConfig
407
- dtype: jnp.dtype = jnp.bfloat16
408
- param_dtype: jnp.dtype = jnp.bfloat16
409
  precision: Optional[Union[jax.lax.Precision, str]] = None
410
 
411
  def setup(self) -> None:
412
  config = self.config
413
 
414
- self.w1 = nn.Dense(
415
  config.intermediate_size,
416
  dtype=self.dtype,
417
  param_dtype=self.param_dtype,
@@ -419,7 +447,7 @@ class FlaxLlamaMLP(nn.Module):
419
  kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
420
  precision=self.precision,
421
  )
422
- self.w2 = nn.Dense(
423
  config.hidden_size,
424
  dtype=self.dtype,
425
  param_dtype=self.param_dtype,
@@ -427,7 +455,7 @@ class FlaxLlamaMLP(nn.Module):
427
  kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
428
  precision=self.precision,
429
  )
430
- self.w3 = nn.Dense(
431
  config.intermediate_size,
432
  dtype=self.dtype,
433
  param_dtype=self.param_dtype,
@@ -438,69 +466,116 @@ class FlaxLlamaMLP(nn.Module):
438
  self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
439
 
440
  def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
441
- x = self.w2(nn.silu(self.w1(x)) * self.w3(x))
442
  x = self.dropout(x, deterministic=deterministic)
443
  return x
444
 
445
 
446
  class FlaxLlamaBlock(nn.Module):
447
  config: LlamaConfig
448
- dtype: jnp.dtype = jnp.bfloat16
449
- param_dtype: jnp.dtype = jnp.bfloat16
450
  precision: Optional[Union[jax.lax.Precision, str]] = None
451
 
452
  def setup(self) -> None:
453
- self.attention = FlaxLlamaAttention(
 
 
 
 
 
 
 
454
  self.config,
455
  dtype=self.dtype,
456
  param_dtype=self.param_dtype,
457
- precision=self.precision,
458
  )
459
- self.feed_forward = FlaxLlamaMLP(
 
 
 
 
 
 
 
 
460
  self.config,
461
  dtype=self.dtype,
462
  param_dtype=self.param_dtype,
463
  precision=self.precision,
464
  )
465
- self.attention_norm = RMSNorm(
466
  self.config.hidden_size,
467
  eps=self.config.rms_norm_eps,
468
  dtype=self.dtype,
469
  param_dtype=self.param_dtype,
470
  )
471
- self.ffn_norm = RMSNorm(
472
  self.config.hidden_size,
473
  eps=self.config.rms_norm_eps,
474
  dtype=self.dtype,
475
  param_dtype=self.param_dtype,
 
476
  )
477
 
478
  def __call__(
479
  self,
480
- hidden_states,
481
- attention_mask=None,
482
- position_ids=None,
 
 
483
  deterministic: bool = True,
484
  init_cache: bool = False,
485
  output_attentions: bool = False,
486
  fcm_mask: Optional[jnp.ndarray] = None,
487
  ):
488
- attn_outputs = self.attention(
489
- self.attention_norm(hidden_states),
490
- attention_mask=attention_mask,
491
- position_ids=position_ids,
492
- deterministic=deterministic,
493
- init_cache=init_cache,
494
- output_attentions=output_attentions,
495
- fcm_mask=fcm_mask,
 
 
496
  )
497
  attn_output = attn_outputs[0]
498
  hidden_states = hidden_states + attn_output
499
 
500
- feed_forward_hidden_states = self.feed_forward(
501
- self.ffn_norm(hidden_states),
502
- deterministic=deterministic,
503
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
504
  hidden_states = hidden_states + feed_forward_hidden_states
505
 
506
  return (hidden_states,) + attn_outputs[1:]
@@ -508,7 +583,7 @@ class FlaxLlamaBlock(nn.Module):
508
 
509
  class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
510
  config_class = LlamaConfig
511
- base_model_prefix = "transformer"
512
  module_class: nn.Module = None
513
 
514
  def __init__(
@@ -516,7 +591,7 @@ class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
516
  config: LlamaConfig,
517
  input_shape: Tuple = (1, 1),
518
  seed: int = 0,
519
- dtype: jnp.dtype = jnp.bfloat16,
520
  _do_init: bool = True,
521
  **kwargs,
522
  ):
@@ -571,9 +646,9 @@ class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
571
 
572
  def __call__(
573
  self,
574
- input_ids,
575
- attention_mask=None,
576
- position_ids=None,
577
  params: dict = None,
578
  past_key_values: dict = None,
579
  dropout_rng: jax.random.PRNGKey = None,
@@ -581,8 +656,10 @@ class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
581
  output_attentions: Optional[bool] = None,
582
  output_hidden_states: Optional[bool] = None,
583
  return_dict: Optional[bool] = None,
 
584
  add_params_field: bool = False
585
  ):
 
586
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
587
  output_hidden_states = (
588
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -591,6 +668,11 @@ class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
591
 
592
  batch_size, sequence_length = input_ids.shape
593
 
 
 
 
 
 
594
  if position_ids is None:
595
  if past_key_values is not None:
596
  raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
@@ -622,6 +704,7 @@ class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
622
  output_attentions,
623
  output_hidden_states,
624
  return_dict,
 
625
  rngs=rngs,
626
  mutable=mutable,
627
  )
@@ -639,29 +722,24 @@ class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
639
 
640
  class FlaxLlamaBlockCollection(nn.Module):
641
  config: LlamaConfig
642
- dtype: jnp.dtype = jnp.bfloat16
643
- param_dtype: jnp.dtype = jnp.bfloat16
644
  precision: Optional[Union[jax.lax.Precision, str]] = None
645
 
646
  def setup(self):
647
- block = FlaxLlamaBlock
648
-
649
- if self.config.gradient_checkpointing != '':
650
- block = remat(
651
- block, static_argnums=(3, 4, 5),
652
- policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing)
653
- )
654
-
655
  self.blocks = [
656
- block(self.config, name=str(i), dtype=self.dtype, param_dtype=self.param_dtype, precision=self.precision)
 
657
  for i in range(self.config.num_hidden_layers)
658
  ]
659
 
660
  def __call__(
661
  self,
662
- hidden_states,
663
- attention_mask=None,
664
- position_ids=None,
 
 
665
  deterministic: bool = True,
666
  init_cache: bool = False,
667
  output_attentions: bool = False,
@@ -693,20 +771,21 @@ class FlaxLlamaBlockCollection(nn.Module):
693
  all_hidden_states += (hidden_states,)
694
 
695
  layer_outputs = block(
696
- hidden_states,
697
- attention_mask,
698
- position_ids,
699
- deterministic,
700
- init_cache,
701
- output_attentions,
702
- fcm_mask,
 
 
703
  )
704
  hidden_states = layer_outputs[0]
705
 
706
  if output_attentions:
707
  all_attentions += (layer_outputs[1],)
708
 
709
- # this contains possible `None` values - `FlaxGPTJModule` will filter them out
710
  outputs = (hidden_states, all_hidden_states, all_attentions)
711
 
712
  return outputs
@@ -714,14 +793,13 @@ class FlaxLlamaBlockCollection(nn.Module):
714
 
715
  class FlaxLlamaModule(nn.Module):
716
  config: LlamaConfig
717
- dtype: jnp.dtype = jnp.bfloat16
718
- param_dtype: jnp.dtype = jnp.bfloat16
719
  precision: Optional[Union[jax.lax.Precision, str]] = None
720
 
721
  def setup(self):
722
- self.embed_dim = self.config.hidden_size
723
 
724
- self.wte = nn.Embed(
725
  self.config.vocab_size,
726
  self.config.hidden_size,
727
  embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
@@ -729,30 +807,47 @@ class FlaxLlamaModule(nn.Module):
729
  param_dtype=self.param_dtype,
730
  )
731
  self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
732
- self.h = FlaxLlamaBlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype,
733
- precision=self.precision)
734
- self.ln_f = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype,
735
  param_dtype=self.param_dtype)
 
 
 
 
 
 
736
 
737
  def __call__(
738
  self,
739
- input_ids,
740
- attention_mask,
741
- position_ids,
742
- deterministic=True,
 
743
  init_cache: bool = False,
744
  output_attentions: bool = False,
745
  output_hidden_states: bool = False,
746
  return_dict: bool = True,
 
747
  ):
748
- input_embeds = self.wte(input_ids.astype("i4"))
 
749
 
 
 
 
 
 
 
750
  hidden_states = self.dropout(input_embeds, deterministic=deterministic)
751
 
752
- outputs = self.h(
753
- hidden_states,
754
- attention_mask,
 
755
  position_ids=position_ids,
 
756
  deterministic=deterministic,
757
  init_cache=init_cache,
758
  output_attentions=output_attentions,
@@ -761,7 +856,7 @@ class FlaxLlamaModule(nn.Module):
761
  )
762
 
763
  hidden_states = outputs[0]
764
- hidden_states = self.ln_f(hidden_states)
765
 
766
  if output_hidden_states:
767
  all_hidden_states = outputs[1] + (hidden_states,)
@@ -785,12 +880,16 @@ class FlaxLlamaModel(FlaxLlamaPreTrainedModel):
785
 
786
  class FlaxLlamaForCausalLMModule(nn.Module):
787
  config: LlamaConfig
788
- dtype: jnp.dtype = jnp.bfloat16
789
- param_dtype: jnp.dtype = jnp.bfloat16
790
  precision: Optional[Union[jax.lax.Precision, str]] = None
791
 
792
  def setup(self):
793
- self.transformer = FlaxLlamaModule(self.config, dtype=self.dtype)
 
 
 
 
794
  self.lm_head = nn.Dense(
795
  self.config.vocab_size,
796
  dtype=self.dtype,
@@ -802,14 +901,15 @@ class FlaxLlamaForCausalLMModule(nn.Module):
802
 
803
  def __call__(
804
  self,
805
- input_ids,
806
- attention_mask=None,
807
- position_ids=None,
808
  deterministic: bool = True,
809
  init_cache: bool = False,
810
  output_attentions: bool = False,
811
  output_hidden_states: bool = False,
812
  return_dict: bool = True,
 
813
  ):
814
  batch_size, seq_length = input_ids.shape
815
  if attention_mask is None:
@@ -819,7 +919,7 @@ class FlaxLlamaForCausalLMModule(nn.Module):
819
  jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
820
  (batch_size, seq_length)
821
  )
822
- outputs = self.transformer(
823
  input_ids,
824
  attention_mask,
825
  position_ids,
@@ -828,16 +928,19 @@ class FlaxLlamaForCausalLMModule(nn.Module):
828
  output_attentions=output_attentions,
829
  output_hidden_states=output_hidden_states,
830
  return_dict=return_dict,
 
831
  )
832
 
833
  hidden_states = outputs[0]
834
 
835
  if self.config.tie_word_embeddings:
836
- shared_kernel = self.transformer.variables["params"]["wte"]["embedding"].T
837
  lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
838
  else:
839
  lm_logits = self.lm_head(hidden_states)
840
 
 
 
841
  if not return_dict:
842
  return (lm_logits,) + outputs[1:]
843
 
@@ -847,7 +950,7 @@ class FlaxLlamaForCausalLMModule(nn.Module):
847
  class FlaxLlamaForCausalLM(FlaxLlamaPreTrainedModel):
848
  module_class = FlaxLlamaForCausalLMModule
849
 
850
- def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[jnp.DeviceArray] = None):
851
  batch_size, seq_length = input_ids.shape
852
 
853
  past_key_values = self.init_cache(batch_size, max_length)
@@ -873,12 +976,12 @@ class FlaxLlamaForCausalLM(FlaxLlamaPreTrainedModel):
873
  class FlaxLlamaForSequenceClassificationModule(nn.Module):
874
  num_classes: int
875
  config: LlamaConfig
876
- dtype: jnp.dtype = jnp.bfloat16
877
- param_dtype: jnp.dtype = jnp.bfloat16
878
  precision: Optional[Union[jax.lax.Precision, str]] = None
879
 
880
  def setup(self):
881
- self.transformer = FlaxLlamaModule(self.config, dtype=self.dtype)
882
  self.classifier = nn.Dense(
883
  self.num_classes,
884
  dtype=self.dtype,
@@ -890,14 +993,15 @@ class FlaxLlamaForSequenceClassificationModule(nn.Module):
890
 
891
  def __call__(
892
  self,
893
- input_ids,
894
- attention_mask=None,
895
- position_ids=None,
896
  deterministic: bool = True,
897
  init_cache: bool = False,
898
  output_attentions: bool = False,
899
  output_hidden_states: bool = False,
900
  return_dict: bool = True,
 
901
  ):
902
  batch_size, seq_length = input_ids.shape
903
  if attention_mask is None:
@@ -907,7 +1011,7 @@ class FlaxLlamaForSequenceClassificationModule(nn.Module):
907
  jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
908
  (batch_size, seq_length)
909
  )
910
- outputs = self.transformer(
911
  input_ids,
912
  attention_mask,
913
  position_ids,
@@ -916,6 +1020,7 @@ class FlaxLlamaForSequenceClassificationModule(nn.Module):
916
  output_attentions=output_attentions,
917
  output_hidden_states=output_hidden_states,
918
  return_dict=return_dict,
 
919
  )
920
 
921
  hidden_states = outputs[0]
@@ -930,4 +1035,4 @@ class FlaxLlamaForSequenceClassificationModule(nn.Module):
930
 
931
 
932
  class FlaxLlamaForSequenceClassification(FlaxLlamaPreTrainedModel):
933
- module_class = FlaxLlamaForSequenceClassificationModule
 
1
+ from typing import Dict, Optional, Tuple, Union
2
+ from einops import einops
 
 
 
3
  import jax
4
  import jax.numpy as jnp
5
  from jax import lax
 
8
  from flax.linen.attention import dot_product_attention_weights
9
  from flax.traverse_util import flatten_dict, unflatten_dict
10
  from flax.linen import partitioning as nn_partitioning
 
11
  from flax.core.frozen_dict import FrozenDict, freeze, unfreeze
12
  from flax.linen import combine_masks, make_causal_mask
 
13
  from transformers.configuration_utils import PretrainedConfig
14
  from transformers.modeling_flax_utils import FlaxPreTrainedModel
 
 
15
  from transformers.modeling_flax_outputs import FlaxBaseModelOutput, FlaxCausalLMOutput, FlaxSequenceClassifierOutput
16
+ from fjformer.attention import blockwise_dot_product_attention
17
+ from ..flax_modelling_utils import with_sharding_constraint, \
18
+ get_gradient_checkpoint_policy, repeat_kv_bnsh, apply_rotary_pos_emb, precompute_freq_cis
19
+ import chex
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
 
22
  class LlamaConfig(PretrainedConfig):
23
+ model_type = "llama"
24
 
25
  def __init__(
26
  self,
27
+ vocab_size: int = 32000,
28
+ hidden_size: int = 4096,
29
+ intermediate_size: int = 11008,
30
+ num_hidden_layers: int = 32,
31
+ num_attention_heads: int = 32,
32
+ number_rep_kv: int = 1,
33
+ num_key_value_heads: Optional[int] = None,
34
+ max_position_embeddings: int = 2048,
35
+ rms_norm_eps: float = 1e-6,
36
+ initializer_range: float = 0.02,
37
+ use_cache: bool = True,
38
+ bos_token_id: int = 0,
39
+ eos_token_id: int = 1,
40
+ resid_pdrop: float = 0.0,
41
+ embd_pdrop: float = 0.0,
42
+ attn_pdrop: float = 0.0,
43
+ tie_word_embeddings: bool = False,
44
+ gradient_checkpointing: str = 'nothing_saveable',
45
+ fcm_min_ratio: float = -1,
46
+ fcm_max_ratio: float = -1,
47
  use_pjit_attention_force: bool = True,
48
+ rope_scaling: Dict[str, Union[str, float]] = None,
49
+ use_flash_attention: bool = False,
50
+ use_sacn_mlp: bool = False,
51
+ flash_attn_query_chunk_size: int = 1024,
52
+ flash_attn_key_chunk_size: int = 1024,
53
+ scan_mlp_chunk_size: int = 1024,
54
  **kwargs,
55
  ):
56
+ num_key_value_heads = num_key_value_heads or number_rep_kv * num_attention_heads
57
+ self.num_key_value_heads = num_key_value_heads
 
 
 
58
  self.vocab_size = vocab_size
59
+
60
+ self.number_rep_kv = number_rep_kv
61
  self.hidden_size = hidden_size
62
  self.initializer_range = initializer_range
63
  self.intermediate_size = intermediate_size
64
  self.num_hidden_layers = num_hidden_layers
65
+
66
  self.num_attention_heads = num_attention_heads
67
+ self.max_position_embeddings = max_position_embeddings
68
  self.rms_norm_eps = rms_norm_eps
69
  self.use_cache = use_cache
70
  self.resid_pdrop = resid_pdrop
 
75
  self.fcm_min_ratio = fcm_min_ratio
76
  self.fcm_max_ratio = fcm_max_ratio
77
  self.rope_scaling = rope_scaling
78
+ self.use_flash_attention = use_flash_attention
79
+ self.use_sacn_mlp = use_sacn_mlp
80
+ self.flash_attn_key_chunk_size = flash_attn_key_chunk_size
81
+ self.flash_attn_query_chunk_size = flash_attn_query_chunk_size
82
+ self.scan_mlp_chunk_size = scan_mlp_chunk_size
83
+
84
  super().__init__(
85
  # pad_token_id=pad_token_id,
86
  bos_token_id=bos_token_id,
 
93
  def get_partition_rules(fully_fsdp: bool = True):
94
  return (
95
 
96
+ ("model/embed_tokens/embedding", PS("dp", "fsdp")),
97
 
98
+ ("self_attn/(q_proj|k_proj|v_proj)/kernel", PS("fsdp", "dp")),
99
+ ("self_attn/o_proj/kernel", PS("dp", "fsdp")),
100
 
101
+ ("mlp/gate_proj/kernel", PS("fsdp", "dp")),
102
+ ("mlp/down_proj/kernel", PS("dp", "fsdp")),
103
+ ("mlp/up_proj/kernel", PS("fsdp", "dp")),
104
 
105
+ ("input_layernorm/kernel", PS(None)),
106
+ ("post_attention_layernorm/kernel", PS(None)),
107
 
108
+ ("model/norm/kernel", PS(None)),
109
  ("lm_head/kernel", PS("fsdp", "dp")),
110
  ('.*', PS(None)),
111
  ) if not fully_fsdp else (
112
 
113
+ ("model/embed_tokens/embedding", PS("fsdp")),
114
 
115
+ ("self_attn/(q_proj|k_proj|v_proj)/kernel", PS("fsdp")),
116
+ ("self_attn/o_proj/kernel", PS("fsdp")),
117
 
118
+ ("mlp/gate_proj/kernel", PS("fsdp")),
119
+ ("mlp/down_proj/kernel", PS("fsdp")),
120
+ ("mlp/up_proj/kernel", PS("fsdp")),
121
 
122
+ ("input_layernorm/kernel", PS(None)),
123
+ ("post_attention_layernorm/kernel", PS(None)),
124
 
125
+ ("model/norm/kernel", PS(None)),
126
  ("lm_head/kernel", PS("fsdp")),
127
+ ('.*', PS('fsdp')),
128
  )
129
 
130
+ def add_jax_args(self,
131
+ resid_pdrop: float = 0.0,
132
+ embd_pdrop: float = 0.0,
133
+ attn_pdrop: float = 0.0,
134
+ tie_word_embeddings: bool = False,
135
+ gradient_checkpointing: str = 'nothing_saveable',
136
+ fcm_min_ratio: float = 0.0,
137
+ fcm_max_ratio: float = 0.0,
138
+ use_pjit_attention_force: bool = True,
139
+ use_flash_attention: bool = False,
140
+ use_sacn_mlp: bool = False,
141
+ flash_attn_query_chunk_size: int = 1024,
142
+ flash_attn_key_chunk_size: int = 1024,
143
+ scan_mlp_chunk_size: int = 1024,
144
+ number_rep_kv: int = 1,
145
+ ):
146
+ self.use_flash_attention = use_flash_attention
147
+ self.embd_pdrop = embd_pdrop
148
+ self.number_rep_kv = number_rep_kv
149
+ self.resid_pdrop = resid_pdrop
150
+
151
+ self.attn_pdrop = attn_pdrop
152
+ self.tie_word_embeddings = tie_word_embeddings
153
+ self.gradient_checkpointing = gradient_checkpointing
154
+ self.fcm_min_ratio = fcm_min_ratio
155
+ self.fcm_max_ratio = fcm_max_ratio
156
+ self.use_pjit_attention_force = use_pjit_attention_force
157
+
158
+ self.use_sacn_mlp = use_sacn_mlp
159
+ self.flash_attn_query_chunk_size = flash_attn_query_chunk_size
160
+ self.flash_attn_key_chunk_size = flash_attn_key_chunk_size
161
+ self.scan_mlp_chunk_size = scan_mlp_chunk_size
162
+
163
  @staticmethod
164
  def get_weight_decay_exclusions():
165
  return tuple()
 
169
  return ('params', 'dropout', 'fcm')
170
 
171
 
172
+ re_mat = nn_partitioning.remat
173
+
174
+
175
+ class FlaxLlamaEmbedding(nn.Module):
176
+ dtype: jnp.dtype = jnp.float32
177
+
178
+ def __call__(self, query, key, freq_cis, position_ids):
179
+ sin, cos = freq_cis
180
+
181
+ sin = sin[position_ids][:, None, :, :]
182
+ cos = cos[position_ids][:, None, :, :]
183
+
184
+ key = apply_rotary_pos_emb(key, sin, cos)
185
+ query = apply_rotary_pos_emb(query, sin, cos)
186
+
187
+ return query.astype(self.dtype), key.astype(self.dtype)
188
+
189
+
190
+ def repeat_kv(x: chex.Array, n_rep: int) -> chex.Array:
191
+ bs, s, n_kv_heads, head_dim = x.shape
192
+ if n_rep == 1:
193
+ return x
194
+ x = x[:, :, jnp.newaxis, :, :]
195
+ x = jnp.repeat(x, n_rep, axis=2)
196
+
197
+ return x.reshape(bs, s,
198
+ n_kv_heads * n_rep,
199
+ head_dim)
200
 
201
 
202
  class RMSNorm(nn.Module):
203
  dim: int
204
  eps: float = 1e-6
205
+ dtype: jnp.dtype = jnp.float32
206
+ param_dtype: jnp.dtype = jnp.float32
207
 
208
  def setup(self) -> None:
209
  self.weight = self.param(
 
217
  return x * jax.lax.rsqrt(jnp.square(x).mean(-1, keepdims=True) + self.eps)
218
 
219
  def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
220
+ x = x.astype(jnp.promote_types(self.dtype, jnp.float32))
221
  output = self._norm(x).astype(self.dtype)
222
  weight = jnp.asarray(self.weight, self.dtype)
223
  return output * weight
224
 
225
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  class FlaxLlamaAttention(nn.Module):
227
  config: LlamaConfig
228
+ dtype: jnp.dtype = jnp.float32
229
+ param_dtype: jnp.dtype = jnp.float32
230
  precision: Optional[Union[jax.lax.Precision, str]] = None
231
 
232
  def setup(self):
233
  config = self.config
234
+ self.hidden_size = config.hidden_size
235
+ self.head_dim = self.config.hidden_size // self.config.num_attention_heads
236
+ self.number_of_reps = self.config.num_attention_heads // self.config.num_key_value_heads
237
 
238
+ if self.number_of_reps == 1:
239
+ assert self.config.num_attention_heads == self.config.num_key_value_heads
240
+ self.q_proj = nn.Dense(
241
  config.num_attention_heads * self.head_dim,
242
  dtype=self.dtype,
243
  param_dtype=self.param_dtype,
244
  use_bias=False,
245
  kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
246
+ precision=self.precision
247
  )
248
+ self.k_proj = nn.Dense(
249
+ config.num_key_value_heads * self.head_dim,
250
  dtype=self.dtype,
251
  param_dtype=self.param_dtype,
252
  use_bias=False,
253
  kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
254
+ precision=self.precision
255
  )
256
+ self.v_proj = nn.Dense(
257
+ config.num_key_value_heads * self.head_dim,
258
  dtype=self.dtype,
259
  param_dtype=self.param_dtype,
260
  use_bias=False,
261
  kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
262
+ precision=self.precision
263
  )
264
+ self.o_proj = nn.Dense(
265
  config.hidden_size,
266
  dtype=self.dtype,
267
  param_dtype=self.param_dtype,
268
  use_bias=False,
269
  kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
270
+ precision=self.precision
271
  )
272
 
273
+ self.rotary = FlaxLlamaEmbedding(self.dtype)
 
 
274
 
275
+ self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)
 
 
 
 
 
 
 
 
 
276
 
277
  def _merge_heads(self, hidden_states):
278
+ return hidden_states.reshape(hidden_states.shape[:2] + (self.hidden_size,))
279
 
280
  @nn.compact
281
  def _concatenate_to_cache(self, key, value, query, attention_mask):
 
294
  cached_value.value = value
295
  num_updated_cache_vectors = query.shape[1]
296
  cache_index.value = cache_index.value + num_updated_cache_vectors
297
+
298
  pad_mask = jnp.broadcast_to(
299
  jnp.arange(max_length) < cur_index + num_updated_cache_vectors,
300
  tuple(batch_dims) + (1, num_updated_cache_vectors, max_length),
 
302
  attention_mask = combine_masks(pad_mask, attention_mask)
303
  return key, value, attention_mask
304
 
305
+ @staticmethod
306
+ def _t(query, key, value):
307
+ return jnp.transpose(query, (0, 2, 1, 3)), jnp.transpose(key, (0, 2, 1, 3)), jnp.transpose(value, (0, 2, 1, 3))
308
+
309
+ def apply_rotary(self, batch_size, sequence_length, query, key, value, freq_cis, position_ids):
310
+ query = query.reshape(batch_size, sequence_length, self.config.num_attention_heads, self.head_dim)
311
+ key = key.reshape(batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)
312
+ value = value.reshape(batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)
313
+
314
+ query, key, value = self._t(query, key, value)
315
+ query, key = self.rotary(position_ids=position_ids, query=query, key=key, freq_cis=freq_cis)
316
+ key = repeat_kv_bnsh(key, self.number_of_reps)
317
+ value = repeat_kv_bnsh(value, self.number_of_reps)
318
+ return self._t(query, key, value)
319
+
320
  def __call__(
321
  self,
322
+ hidden_states: chex.Array,
323
+ freq_cis: chex.Array,
324
+ attention_mask: chex.Array,
325
+ position_ids: chex.Array,
326
+ causal_mask: chex.Array,
327
  deterministic: bool = True,
328
  init_cache: bool = False,
329
  output_attentions: bool = False,
330
  fcm_mask=None,
331
  ):
 
 
 
 
 
 
 
 
 
332
 
333
+ batch_size, sequence_length = hidden_states.shape[:2]
334
+ query_state, key_state, value_state = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(
335
+ hidden_states)
336
+ if self.config.use_pjit_attention_force:
337
+ query_state = with_sharding_constraint(query_state, PS(("dp", "fsdp"), None, "mp"))
338
+ key_state = with_sharding_constraint(key_state, PS(("dp", "fsdp"), None, "mp"))
339
+ value_state = with_sharding_constraint(value_state, PS(("dp", "fsdp"), None, "mp"))
340
+
341
+ query_state = query_state.reshape(batch_size, sequence_length, self.config.num_attention_heads, self.head_dim)
342
+ key_state = key_state.reshape(batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)
343
+ value_state = value_state.reshape(batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)
344
+
345
+ query_state, key_state, value_state = self.apply_rotary(
346
+ query=query_state,
347
+ key=key_state,
348
+ value=value_state,
349
+ position_ids=position_ids,
350
+ freq_cis=freq_cis,
351
+ batch_size=batch_size,
352
+ sequence_length=sequence_length
353
+ )
354
 
355
+ query_length, key_length = query_state.shape[1], key_state.shape[1]
356
 
357
  if self.has_variable("cache", "cached_key"):
358
  mask_shift = self.variables["cache"]["cache_index"]
359
  max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
360
  causal_mask = lax.dynamic_slice(
361
+ causal_mask, (0, 0, mask_shift, 0), (1, 1, query_length, max_decoder_length)
362
  )
363
  else:
364
+ causal_mask = causal_mask[:, :, :query_length, :key_length]
365
 
366
  batch_size = hidden_states.shape[0]
367
  causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
 
368
  attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
369
  attention_mask = combine_masks(attention_mask, causal_mask, fcm_mask)
370
 
 
373
  dropout_rng = self.make_rng("dropout")
374
 
375
  if self.has_variable("cache", "cached_key") or init_cache:
376
+ key_state, value_state, attention_mask = self._concatenate_to_cache(key_state, value_state, query_state,
377
+ attention_mask)
378
 
379
  attention_bias = lax.select(
380
  attention_mask > 0,
381
  jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
382
  jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
383
  )
384
+ if self.config.use_flash_attention and not (self.has_variable("cache", "cached_key") or init_cache):
385
+ attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
386
+ attention_bias = lax.select(
387
+ attention_mask > 0,
388
+ jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
389
+ jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
390
+ )
391
+ attn_weights = None
392
+ attn_output = blockwise_dot_product_attention(
393
+ query_state,
394
+ key_state,
395
+ value_state,
396
+ bias=attention_bias,
397
+ deterministic=deterministic,
398
+ dropout_rng=dropout_rng,
399
+ attn_pdrop=self.config.attn_pdrop,
400
+ causal=True,
401
+ query_chunk_size=self.config.scan_query_chunk_size,
402
+ key_chunk_size=self.config.scan_key_chunk_size,
403
+ dtype=self.dtype,
404
+ policy=get_gradient_checkpoint_policy('nothing_saveable'),
405
+ precision=self.precision,
406
+ float32_logits=True,
407
+ )
408
+ if self.config.use_pjit_attention_force:
409
+ attn_output = with_sharding_constraint(attn_output, PS(("dp", "fsdp"), None, "mp", None))
410
+ attn_output = self._merge_heads(attn_output)
411
+ else:
412
+ attn_weights = dot_product_attention_weights(
413
+ query=query_state,
414
+ key=key_state,
415
+ bias=attention_bias,
416
+ dtype=jnp.promote_types(self.dtype, jnp.float32),
417
+ deterministic=deterministic,
418
+ dropout_rate=self.config.attn_pdrop,
419
+ precision=self.precision,
420
+ )
421
+ if self.config.use_pjit_attention_force:
422
+ attn_weights = with_sharding_constraint(attn_weights, PS(("dp", "fsdp"), "mp", None, None))
423
 
424
+ attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_state)
425
+ attn_output = self._merge_heads(attn_output)
 
 
 
 
 
 
 
 
 
 
426
 
427
+ attn_output = self.o_proj(attn_output)
 
 
428
  attn_output = self.resid_dropout(attn_output, deterministic=deterministic)
429
  outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
430
  return outputs
 
432
 
433
  class FlaxLlamaMLP(nn.Module):
434
  config: LlamaConfig
435
+ dtype: jnp.dtype = jnp.float32
436
+ param_dtype: jnp.dtype = jnp.float32
437
  precision: Optional[Union[jax.lax.Precision, str]] = None
438
 
439
  def setup(self) -> None:
440
  config = self.config
441
 
442
+ self.gate_proj = nn.Dense(
443
  config.intermediate_size,
444
  dtype=self.dtype,
445
  param_dtype=self.param_dtype,
 
447
  kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
448
  precision=self.precision,
449
  )
450
+ self.down_proj = nn.Dense(
451
  config.hidden_size,
452
  dtype=self.dtype,
453
  param_dtype=self.param_dtype,
 
455
  kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
456
  precision=self.precision,
457
  )
458
+ self.up_proj = nn.Dense(
459
  config.intermediate_size,
460
  dtype=self.dtype,
461
  param_dtype=self.param_dtype,
 
466
  self.dropout = nn.Dropout(rate=self.config.resid_pdrop)
467
 
468
  def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
469
+ x = self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x))
470
  x = self.dropout(x, deterministic=deterministic)
471
  return x
472
 
473
 
474
  class FlaxLlamaBlock(nn.Module):
475
  config: LlamaConfig
476
+ dtype: jnp.dtype = jnp.float32
477
+ param_dtype: jnp.dtype = jnp.float32
478
  precision: Optional[Union[jax.lax.Precision, str]] = None
479
 
480
  def setup(self) -> None:
481
+ attn_block = FlaxLlamaAttention
482
+ if self.config.gradient_checkpointing != '':
483
+ attn_block = re_mat(
484
+ FlaxLlamaAttention, static_argnums=(5, 6, 7),
485
+ policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing)
486
+ )
487
+
488
+ self.self_attn = attn_block(
489
  self.config,
490
  dtype=self.dtype,
491
  param_dtype=self.param_dtype,
492
+ precision=self.precision
493
  )
494
+ mlp_block = FlaxLlamaMLP
495
+
496
+ if self.config.gradient_checkpointing != '':
497
+ mlp_block = re_mat(
498
+ FlaxLlamaMLP, static_argnums=(1,),
499
+ policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing)
500
+ )
501
+
502
+ self.mlp = mlp_block(
503
  self.config,
504
  dtype=self.dtype,
505
  param_dtype=self.param_dtype,
506
  precision=self.precision,
507
  )
508
+ self.input_layernorm = RMSNorm(
509
  self.config.hidden_size,
510
  eps=self.config.rms_norm_eps,
511
  dtype=self.dtype,
512
  param_dtype=self.param_dtype,
513
  )
514
+ self.post_attention_layernorm = RMSNorm(
515
  self.config.hidden_size,
516
  eps=self.config.rms_norm_eps,
517
  dtype=self.dtype,
518
  param_dtype=self.param_dtype,
519
+
520
  )
521
 
522
  def __call__(
523
  self,
524
+ hidden_states: chex.Array,
525
+ freq_cis: chex.Array,
526
+ attention_mask: chex.Array,
527
+ position_ids: chex.Array,
528
+ causal_mask: chex.Array,
529
  deterministic: bool = True,
530
  init_cache: bool = False,
531
  output_attentions: bool = False,
532
  fcm_mask: Optional[jnp.ndarray] = None,
533
  ):
534
+ attn_outputs = self.self_attn(
535
+ self.input_layernorm(hidden_states),
536
+ freq_cis,
537
+ attention_mask,
538
+ position_ids,
539
+ causal_mask,
540
+ deterministic,
541
+ init_cache,
542
+ output_attentions,
543
+ fcm_mask,
544
  )
545
  attn_output = attn_outputs[0]
546
  hidden_states = hidden_states + attn_output
547
 
548
+ feed_forward_input = self.post_attention_layernorm(hidden_states)
549
+
550
+ if self.config.use_sacn_mlp:
551
+ feed_forward_input = einops.rearrange(
552
+ feed_forward_input,
553
+ '... (b s) d -> ... b s d',
554
+ b=self.config.scan_mlp_chunk_size
555
+ )
556
+
557
+ def mlp_forward(mlp, carry, x):
558
+ return None, mlp(x, deterministic)
559
+
560
+ scan_axis = feed_forward_input.ndim - 3
561
+
562
+ _, feed_forward_hidden_states = nn.scan(
563
+ mlp_forward,
564
+ variable_broadcast="params",
565
+ split_rngs={"params": False, "dropout": True},
566
+ in_axes=scan_axis,
567
+ out_axes=scan_axis,
568
+ )(self.mlp, None, feed_forward_input)
569
+ feed_forward_hidden_states = einops.rearrange(
570
+ feed_forward_hidden_states,
571
+ '... b s d -> ... (b s) d'
572
+ )
573
+ else:
574
+ feed_forward_hidden_states = self.mlp(
575
+ feed_forward_input,
576
+ deterministic,
577
+ )
578
+
579
  hidden_states = hidden_states + feed_forward_hidden_states
580
 
581
  return (hidden_states,) + attn_outputs[1:]
 
583
 
584
  class FlaxLlamaPreTrainedModel(FlaxPreTrainedModel):
585
  config_class = LlamaConfig
586
+ base_model_prefix = "model"
587
  module_class: nn.Module = None
588
 
589
  def __init__(
 
591
  config: LlamaConfig,
592
  input_shape: Tuple = (1, 1),
593
  seed: int = 0,
594
+ dtype: jnp.dtype = jnp.float32,
595
  _do_init: bool = True,
596
  **kwargs,
597
  ):
 
646
 
647
  def __call__(
648
  self,
649
+ input_ids: chex.Array,
650
+ attention_mask: chex.Array = None,
651
+ position_ids: chex.Array = None,
652
  params: dict = None,
653
  past_key_values: dict = None,
654
  dropout_rng: jax.random.PRNGKey = None,
 
656
  output_attentions: Optional[bool] = None,
657
  output_hidden_states: Optional[bool] = None,
658
  return_dict: Optional[bool] = None,
659
+ extra_embedding: Optional[Union[jnp.ndarray, None]] = None,
660
  add_params_field: bool = False
661
  ):
662
+
663
  output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
664
  output_hidden_states = (
665
  output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 
668
 
669
  batch_size, sequence_length = input_ids.shape
670
 
671
+ assert sequence_length <= self.config.max_position_embeddings, (f'Position out of range '
672
+ f'(Model Support '
673
+ f'{self.config.max_position_embeddings} got'
674
+ f' {sequence_length})')
675
+
676
  if position_ids is None:
677
  if past_key_values is not None:
678
  raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")
 
704
  output_attentions,
705
  output_hidden_states,
706
  return_dict,
707
+ extra_embedding,
708
  rngs=rngs,
709
  mutable=mutable,
710
  )
 
722
 
723
  class FlaxLlamaBlockCollection(nn.Module):
724
  config: LlamaConfig
725
+ dtype: jnp.dtype = jnp.float32
726
+ param_dtype: jnp.dtype = jnp.float32
727
  precision: Optional[Union[jax.lax.Precision, str]] = None
728
 
729
  def setup(self):
 
 
 
 
 
 
 
 
730
  self.blocks = [
731
+ FlaxLlamaBlock(self.config, name=str(i), dtype=self.dtype, param_dtype=self.param_dtype,
732
+ precision=self.precision)
733
  for i in range(self.config.num_hidden_layers)
734
  ]
735
 
736
  def __call__(
737
  self,
738
+ hidden_states: chex.Array,
739
+ freq_cis: chex.Array,
740
+ attention_mask: chex.Array,
741
+ position_ids: chex.Array,
742
+ causal_mask: chex.Array,
743
  deterministic: bool = True,
744
  init_cache: bool = False,
745
  output_attentions: bool = False,
 
771
  all_hidden_states += (hidden_states,)
772
 
773
  layer_outputs = block(
774
+ hidden_states=hidden_states,
775
+ freq_cis=freq_cis,
776
+ attention_mask=attention_mask,
777
+ position_ids=position_ids,
778
+ causal_mask=causal_mask,
779
+ deterministic=deterministic,
780
+ init_cache=init_cache,
781
+ output_attentions=output_attentions,
782
+ fcm_mask=fcm_mask,
783
  )
784
  hidden_states = layer_outputs[0]
785
 
786
  if output_attentions:
787
  all_attentions += (layer_outputs[1],)
788
 
 
789
  outputs = (hidden_states, all_hidden_states, all_attentions)
790
 
791
  return outputs
 
793
 
794
  class FlaxLlamaModule(nn.Module):
795
  config: LlamaConfig
796
+ dtype: jnp.dtype = jnp.float32
797
+ param_dtype: jnp.dtype = jnp.float32
798
  precision: Optional[Union[jax.lax.Precision, str]] = None
799
 
800
  def setup(self):
 
801
 
802
+ self.embed_tokens = nn.Embed(
803
  self.config.vocab_size,
804
  self.config.hidden_size,
805
  embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
 
807
  param_dtype=self.param_dtype,
808
  )
809
  self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
810
+ self.layers = FlaxLlamaBlockCollection(self.config, dtype=self.dtype, param_dtype=self.param_dtype,
811
+ precision=self.precision)
812
+ self.norm = RMSNorm(self.config.hidden_size, eps=self.config.rms_norm_eps, dtype=self.dtype,
813
  param_dtype=self.param_dtype)
814
+ config = self.config
815
+ self.causal_mask = make_causal_mask(jnp.ones((1, config.max_position_embeddings)))
816
+ self.freq_cis = precompute_freq_cis(
817
+ max_position_embedding=config.max_position_embeddings,
818
+ head_dim=config.hidden_size // config.num_attention_heads
819
+ )
820
 
821
  def __call__(
822
  self,
823
+ input_ids: chex.Array,
824
+ attention_mask: chex.Array,
825
+ position_ids: chex.Array,
826
+ deterministic: bool = True,
827
+ input_embeds: chex.Array = None,
828
  init_cache: bool = False,
829
  output_attentions: bool = False,
830
  output_hidden_states: bool = False,
831
  return_dict: bool = True,
832
+ extra_embedding: Optional[Union[jnp.ndarray, None]] = None
833
  ):
834
+ if input_embeds is None:
835
+ input_embeds = self.embed_tokens(input_ids.astype("i4"))
836
 
837
+ batch_size, sequence_length = input_ids.shape
838
+ assert sequence_length <= self.config.max_position_embeddings, (f'Position out of range '
839
+ f'(Model Support '
840
+ f'{self.config.max_position_embeddings} got'
841
+ f' {sequence_length})')
842
+ input_embeds = input_embeds + extra_embedding if extra_embedding is not None else input_embeds
843
  hidden_states = self.dropout(input_embeds, deterministic=deterministic)
844
 
845
+ outputs = self.layers(
846
+ hidden_states=hidden_states,
847
+ freq_cis=self.freq_cis,
848
+ attention_mask=attention_mask,
849
  position_ids=position_ids,
850
+ causal_mask=self.causal_mask,
851
  deterministic=deterministic,
852
  init_cache=init_cache,
853
  output_attentions=output_attentions,
 
856
  )
857
 
858
  hidden_states = outputs[0]
859
+ hidden_states = self.norm(hidden_states)
860
 
861
  if output_hidden_states:
862
  all_hidden_states = outputs[1] + (hidden_states,)
 
880
 
881
  class FlaxLlamaForCausalLMModule(nn.Module):
882
  config: LlamaConfig
883
+ dtype: jnp.dtype = jnp.float32
884
+ param_dtype: jnp.dtype = jnp.float32
885
  precision: Optional[Union[jax.lax.Precision, str]] = None
886
 
887
  def setup(self):
888
+ self.model = FlaxLlamaModule(self.config,
889
+ dtype=self.dtype,
890
+ param_dtype=self.param_dtype,
891
+ precision=self.precision,
892
+ )
893
  self.lm_head = nn.Dense(
894
  self.config.vocab_size,
895
  dtype=self.dtype,
 
901
 
902
  def __call__(
903
  self,
904
+ input_ids: chex.Array,
905
+ attention_mask: chex.Array = None,
906
+ position_ids: chex.Array = None,
907
  deterministic: bool = True,
908
  init_cache: bool = False,
909
  output_attentions: bool = False,
910
  output_hidden_states: bool = False,
911
  return_dict: bool = True,
912
+ extra_embedding: Optional[Union[jnp.ndarray, None]] = None
913
  ):
914
  batch_size, seq_length = input_ids.shape
915
  if attention_mask is None:
 
919
  jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
920
  (batch_size, seq_length)
921
  )
922
+ outputs = self.model(
923
  input_ids,
924
  attention_mask,
925
  position_ids,
 
928
  output_attentions=output_attentions,
929
  output_hidden_states=output_hidden_states,
930
  return_dict=return_dict,
931
+ extra_embedding=extra_embedding
932
  )
933
 
934
  hidden_states = outputs[0]
935
 
936
  if self.config.tie_word_embeddings:
937
+ shared_kernel = self.model.variables["params"]["embed_tokens"]["embedding"].T
938
  lm_logits = self.lm_head.apply({"params": {"kernel": shared_kernel}}, hidden_states)
939
  else:
940
  lm_logits = self.lm_head(hidden_states)
941
 
942
+ lm_logits = lm_logits.astype(jnp.float32)
943
+
944
  if not return_dict:
945
  return (lm_logits,) + outputs[1:]
946
 
 
950
  class FlaxLlamaForCausalLM(FlaxLlamaPreTrainedModel):
951
  module_class = FlaxLlamaForCausalLMModule
952
 
953
+ def prepare_inputs_for_generation(self, input_ids, max_length, attention_mask: Optional[chex.Array] = None):
954
  batch_size, seq_length = input_ids.shape
955
 
956
  past_key_values = self.init_cache(batch_size, max_length)
 
976
  class FlaxLlamaForSequenceClassificationModule(nn.Module):
977
  num_classes: int
978
  config: LlamaConfig
979
+ dtype: jnp.dtype = jnp.float32
980
+ param_dtype: jnp.dtype = jnp.float32
981
  precision: Optional[Union[jax.lax.Precision, str]] = None
982
 
983
  def setup(self):
984
+ self.model = FlaxLlamaModule(self.config, dtype=self.dtype)
985
  self.classifier = nn.Dense(
986
  self.num_classes,
987
  dtype=self.dtype,
 
993
 
994
  def __call__(
995
  self,
996
+ input_ids: chex.Array,
997
+ attention_mask: chex.Array = None,
998
+ position_ids: chex.Array = None,
999
  deterministic: bool = True,
1000
  init_cache: bool = False,
1001
  output_attentions: bool = False,
1002
  output_hidden_states: bool = False,
1003
  return_dict: bool = True,
1004
+ extra_embedding: Optional[Union[jnp.ndarray, None]] = None
1005
  ):
1006
  batch_size, seq_length = input_ids.shape
1007
  if attention_mask is None:
 
1011
  jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
1012
  (batch_size, seq_length)
1013
  )
1014
+ outputs = self.model(
1015
  input_ids,
1016
  attention_mask,
1017
  position_ids,
 
1020
  output_attentions=output_attentions,
1021
  output_hidden_states=output_hidden_states,
1022
  return_dict=return_dict,
1023
+ extra_embedding=extra_embedding
1024
  )
1025
 
1026
  hidden_states = outputs[0]
 
1035
 
1036
 
1037
  class FlaxLlamaForSequenceClassification(FlaxLlamaPreTrainedModel):
1038
+ module_class = FlaxLlamaForSequenceClassificationModule