Transformers
English
falcon
custom_code
text-generation-inference
erfanzar commited on
Commit
f78acae
1 Parent(s): 6af0a7c

Create model.py

Browse files
Files changed (1) hide show
  1. model.py +488 -0
model.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ from flax import linen as nn
4
+ from flax.core import FrozenDict
5
+ from typing import Optional, Dict, Union, Tuple
6
+ from transformers import FlaxPreTrainedModel, PretrainedConfig
7
+ from jax import numpy as jnp
8
+ import jax
9
+ from jax.interpreters import pxla
10
+ from jax.experimental.pjit import pjit, PartitionSpec, with_sharding_constraint as wsc
11
+ from transformers.modeling_flax_outputs import FlaxCausalLMOutput, FlaxBaseModelOutput
12
+ from jax.random import split, PRNGKey
13
+ from functools import partial
14
+ from einops import rearrange
15
+
16
+ ACT2FN = {
17
+ "gelu": partial(nn.gelu, approximate=False),
18
+ "relu": nn.relu,
19
+ "silu": nn.swish,
20
+ "swish": nn.swish,
21
+ "gelu_new": partial(nn.gelu, approximate=True),
22
+
23
+ }
24
+
25
+
26
+ def get_names_from_parition_spec(partition_specs):
27
+ names = set()
28
+ if isinstance(partition_specs, dict):
29
+ partition_specs = partition_specs.values()
30
+ for item in partition_specs:
31
+ if item is None:
32
+ continue
33
+ elif isinstance(item, str):
34
+ names.add(item)
35
+ else:
36
+ names.update(get_names_from_parition_spec(item))
37
+
38
+ return list(names)
39
+
40
+
41
+ def names_in_mesh(*names):
42
+ return set(names) <= set(pxla.thread_resources.env.physical_mesh.axis_names)
43
+
44
+
45
+ def with_sharding_constraint(x, partition_specs):
46
+ axis_names = get_names_from_parition_spec(partition_specs)
47
+ if names_in_mesh(*axis_names):
48
+ x = wsc(x, partition_specs)
49
+ return x
50
+
51
+
52
+ class FalconConfig(PretrainedConfig):
53
+ model_type = "falcon"
54
+ attribute_map = {
55
+ "num_hidden_layers": "n_layer",
56
+ "num_attention_heads": "n_head",
57
+ }
58
+
59
+ def __init__(
60
+ self,
61
+ vocab_size=250880,
62
+ hidden_size=64,
63
+ n_layer=2,
64
+ n_head=8,
65
+ layer_norm_epsilon=1e-5,
66
+ initializer_range=0.02,
67
+ use_cache=True,
68
+ bos_token_id=1,
69
+ eos_token_id=2,
70
+ apply_residual_connection_post_layernorm=False,
71
+ hidden_dropout=0.0,
72
+ attention_dropout=0.0,
73
+ multi_query=False,
74
+ alibi=False,
75
+ bias=False,
76
+ parallel_attn=False,
77
+ max_seq_len=2048,
78
+ **kwargs,
79
+ ):
80
+ self.vocab_size = vocab_size
81
+ n_embed = kwargs.pop("n_embed", None)
82
+ self.hidden_size = hidden_size if n_embed is None else n_embed
83
+ self.n_layer = n_layer
84
+ self.n_head = n_head
85
+ self.layer_norm_epsilon = layer_norm_epsilon
86
+ self.initializer_range = initializer_range
87
+ self.max_seq_len = max_seq_len
88
+ self.use_cache = use_cache
89
+ self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
90
+ self.hidden_dropout = hidden_dropout
91
+ self.attention_dropout = attention_dropout
92
+ self.bos_token_id = bos_token_id
93
+ self.eos_token_id = eos_token_id
94
+ self.multi_query = multi_query
95
+ self.alibi = alibi
96
+ self.bias = bias
97
+ self.parallel_attn = parallel_attn
98
+
99
+ super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
100
+
101
+ @property
102
+ def head_dim(self):
103
+ return self.hidden_size // self.n_head
104
+
105
+ @property
106
+ def rotary(self):
107
+ return not self.alibi
108
+
109
+ @staticmethod
110
+ def get_partition_rules(fully_fsdp: bool = False):
111
+ return (
112
+ ('wte/embedding', PartitionSpec('fsdp', 'mp')),
113
+ ('self_attention/w_qkv/(kernel|bias)', PartitionSpec('fsdp', 'mp')),
114
+ ('self_attention/wo/(kernel|bias)', PartitionSpec('fsdp', 'mp')),
115
+ ('mlp/down/(kernel|bias)', PartitionSpec('fsdp', 'mp')),
116
+ ('mlp/up/(kernel|bias)', PartitionSpec('mp', 'fsdp')),
117
+ ('lm_head/kernel', PartitionSpec('fsdp', 'mp')),
118
+ ('transformer/ln_f/bias', PartitionSpec('fsdp', 'mp')),
119
+ ('transformer/ln_f/scale', PartitionSpec('fsdp', 'mp')),
120
+ ('transformer/post_attention_layernorm/scale', PartitionSpec('mp', 'fsdp')),
121
+ ('transformer/post_attention_layernorm/bias', PartitionSpec('mp', 'fsdp')),
122
+ ('.*', PartitionSpec('fsdp', 'mp'))
123
+ ) if not fully_fsdp else (
124
+ ('wte/embedding', PartitionSpec('fsdp')),
125
+ ('self_attention/w_qkv/(kernel|bias)', PartitionSpec('fsdp')),
126
+ ('self_attention/wo/(kernel|bias)', PartitionSpec('fsdp')),
127
+ ('mlp/down/(kernel|bias)', PartitionSpec('fsdp')),
128
+ ('mlp/up/(kernel|bias)', PartitionSpec('fsdp')),
129
+ ('lm_head/kernel', PartitionSpec('fsdp')),
130
+ ('transformer/ln_f/bias', PartitionSpec('fsdp')),
131
+ ('transformer/ln_f/scale', PartitionSpec('fsdp')),
132
+ ('transformer/post_attention_layernorm/scale', PartitionSpec('fsdp')),
133
+ ('transformer/post_attention_layernorm/bias', PartitionSpec('fsdp')),
134
+ ('.*', PartitionSpec('fsdp'))
135
+ )
136
+
137
+ @staticmethod
138
+ def get_mesh_names():
139
+ return 'dp', 'fsdp', 'mp'
140
+
141
+
142
+ def build_alibi(max_length, num_attention_heads, alibi_max: int = 8):
143
+ w_range = jnp.arange(1 - max_length, 1).reshape(1, 1, 1, max_length)
144
+ cp2 = 2 ** math.ceil(math.log2(num_attention_heads))
145
+ h_range = jnp.arange(1, 1 + num_attention_heads, ).reshape(1, -1, 1, 1)
146
+ h_range = jnp.matmul(h_range, jnp.asarray(alibi_max / cp2).reshape(1, 1))
147
+ slop = 1 / jnp.power(2, h_range)
148
+ if cp2 != num_attention_heads:
149
+ slop = jnp.concatenate([slop[1::2], slop[::2]], axis=-1)[:num_attention_heads]
150
+ alibi = (w_range * slop).reshape(1, num_attention_heads, 1, max_length)
151
+ return alibi
152
+
153
+
154
+ def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
155
+ dtype: jnp.dtype = jnp.bfloat16) -> jnp.ndarray:
156
+ freqs = 1.0 / (theta ** (jnp.arange(0, dim, 2)[: (dim // 2)].astype(dtype) / dim))
157
+ t = jnp.arange(end) # type: ignore
158
+ freqs = jnp.outer(t, freqs).astype(dtype)
159
+ sin, cos = jnp.sin(freqs), jnp.cos(freqs)
160
+ freqs_cis = jnp.complex64(cos + 1j * sin)
161
+ return jnp.asarray(freqs_cis)
162
+
163
+
164
+ def apply_rotary_emb(
165
+ xq: jnp.ndarray,
166
+ xk: jnp.ndarray,
167
+ freqs_cis: jnp.ndarray,
168
+ dtype: jnp.dtype = jnp.bfloat16,
169
+ ) -> Tuple[jnp.ndarray, jnp.ndarray]:
170
+ reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
171
+ reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
172
+
173
+ xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
174
+ xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
175
+
176
+ freqs_cis = jnp.reshape(freqs_cis, (*freqs_cis.shape[:2], 1, *freqs_cis.shape[2:]))
177
+
178
+ xq_out = xq_ * freqs_cis
179
+ xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1)
180
+
181
+ xk_out = xk_ * freqs_cis
182
+ xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)
183
+
184
+ return xq_out.astype(dtype), xk_out.astype(dtype)
185
+
186
+
187
+ class FlaxFalconAttention(nn.Module):
188
+ config: FalconConfig
189
+ dtype: jnp.dtype = jnp.float32
190
+ param_dtype: jnp.dtype = jnp.float32
191
+ precision: Optional[Union[jax.lax.Precision, str]] = None
192
+
193
+ def setup(self) -> None:
194
+ head_dim = self.config.hidden_size // self.config.n_head
195
+ self.w_qkv = nn.Dense(
196
+ features=self.config.hidden_size * 3,
197
+ dtype=self.dtype,
198
+ param_dtype=self.param_dtype,
199
+ use_bias=self.config.bias
200
+ )
201
+ self.factor_scale = 1 / math.sqrt(head_dim)
202
+ self.wo = nn.Dense(
203
+ features=self.config.hidden_size,
204
+ dtype=self.dtype,
205
+ param_dtype=self.param_dtype,
206
+ use_bias=self.config.bias
207
+ )
208
+ self.head_dim = head_dim
209
+ if not self.config.alibi:
210
+ self.freq = precompute_freqs_cis(head_dim, self.config.max_seq_len, dtype=self.dtype)
211
+
212
+ def __call__(self,
213
+ hidden_states: jnp.DeviceArray,
214
+ alibi: jnp.DeviceArray = None,
215
+ attention_mask: jnp.DeviceArray = None,
216
+ ):
217
+ b, s, d = hidden_states.shape
218
+ q, k, v = jnp.split(self.w_qkv(hidden_states), 3, -1)
219
+ q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
220
+ k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
221
+ v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
222
+ k = rearrange(k, 'b s (h d) -> b s h d', h=self.config.n_head)
223
+ q = rearrange(q, 'b s (h d) -> b s h d', h=self.config.n_head)
224
+ v = rearrange(v, 'b s (h d) -> b s h d', h=self.config.n_head)
225
+ if not self.config.alibi:
226
+ freq = self.freq[:s].reshape(1, s, -1)
227
+ q, k = apply_rotary_emb(q, k, freq, self.dtype)
228
+ attn = jnp.einsum('...qhd,...khd->...hqk', q, k, precision=self.precision)
229
+ attn = with_sharding_constraint(attn, PartitionSpec(("dp", "fsdp"), "mp", None, None))
230
+
231
+ if alibi is not None:
232
+ attn += attn
233
+ attn = attn * self.factor_scale
234
+ if attention_mask is not None:
235
+ attn += attention_mask
236
+ attn = jax.nn.softmax(attn, axis=-1)
237
+ attn = jnp.einsum('...hqk,...khd->...qhd', attn, v, precision=self.precision).reshape((b, s, d))
238
+ return self.wo(attn)
239
+
240
+
241
+ class FlaxFalconMlp(nn.Module):
242
+ config: FalconConfig
243
+ dtype: jnp.dtype = jnp.float32
244
+ param_dtype: jnp.dtype = jnp.float32
245
+ precision: Optional[Union[jax.lax.Precision, str]] = None
246
+
247
+ def setup(self) -> None:
248
+ self.up = nn.Dense(
249
+ features=self.config.hidden_size * 4,
250
+ dtype=self.dtype,
251
+ param_dtype=self.param_dtype,
252
+ use_bias=self.config.bias
253
+ )
254
+ self.down = nn.Dense(
255
+ features=self.config.hidden_size,
256
+ dtype=self.dtype,
257
+ param_dtype=self.param_dtype,
258
+ use_bias=self.config.bias
259
+ )
260
+
261
+ def __call__(self, x):
262
+ return self.down(nn.gelu(self.up(x)))
263
+
264
+
265
+ class FlaxFalconBlock(nn.Module):
266
+ config: FalconConfig
267
+ dtype: jnp.dtype = jnp.float32
268
+ param_dtype: jnp.dtype = jnp.float32
269
+ precision: Optional[Union[jax.lax.Precision, str]] = None
270
+
271
+ def setup(self) -> None:
272
+ config = self.config
273
+ self.input_layernorm = nn.LayerNorm(epsilon=config.layer_norm_epsilon,
274
+ dtype=self.dtype)
275
+ if not config.parallel_attn:
276
+ self.post_attention_layernorm = nn.LayerNorm(epsilon=config.layer_norm_epsilon,
277
+ dtype=self.dtype)
278
+
279
+ self.mlp = FlaxFalconMlp(
280
+ config=config,
281
+ dtype=self.dtype,
282
+ param_dtype=self.param_dtype,
283
+ precision=self.precision
284
+ )
285
+ self.self_attention = FlaxFalconAttention(
286
+ config=config,
287
+ dtype=self.dtype,
288
+ param_dtype=self.param_dtype,
289
+ precision=self.precision
290
+ )
291
+
292
+ def __call__(self,
293
+ hidden_states: jnp.DeviceArray,
294
+ alibi: jnp.DeviceArray,
295
+ attention_mask: jnp.DeviceArray,
296
+ ):
297
+ residual = hidden_states
298
+ hidden_states = self.input_layernorm(hidden_states)
299
+
300
+ attn = self.self_attention(
301
+ hidden_states=hidden_states,
302
+ attention_mask=attention_mask,
303
+ alibi=alibi
304
+ )
305
+ if not self.config.parallel_attn:
306
+ residual = attn + residual
307
+ hidden_states = self.post_attention_layernorm(residual)
308
+
309
+ mlp_out = self.mlp(hidden_states)
310
+ if self.config.parallel_attn:
311
+ mlp_out += attn
312
+ return mlp_out + residual
313
+
314
+
315
+ class FlaxFalconCollection(nn.Module):
316
+ config: FalconConfig
317
+ dtype: jnp.dtype = jnp.float32
318
+ param_dtype: jnp.dtype = jnp.float32
319
+ precision: Optional[Union[jax.lax.Precision, str]] = None
320
+
321
+ def setup(self) -> None:
322
+ self.blocks = [
323
+ FlaxFalconBlock(
324
+ config=self.config,
325
+ dtype=self.dtype,
326
+ param_dtype=self.param_dtype,
327
+ precision=self.precision,
328
+ name=str(i)
329
+ )
330
+ for i in range(
331
+ self.config.n_layer
332
+ )
333
+ ]
334
+
335
+ def __call__(self,
336
+ hidden_states: jnp.DeviceArray,
337
+ alibi: jnp.DeviceArray,
338
+ attention_mask: jnp.DeviceArray,
339
+
340
+ ):
341
+ for b in self.blocks:
342
+ hidden_states = b(
343
+
344
+ attention_mask=attention_mask,
345
+ hidden_states=hidden_states,
346
+ alibi=alibi
347
+ )
348
+ return hidden_states
349
+
350
+
351
+ class FlaxFalconModule(nn.Module):
352
+ config: FalconConfig
353
+ dtype: jnp.dtype = jnp.float32
354
+ param_dtype: jnp.dtype = jnp.float32
355
+ precision: Optional[Union[jax.lax.Precision, str]] = None
356
+
357
+ def setup(self) -> None:
358
+ self.wte = nn.Embed(
359
+ num_embeddings=self.config.vocab_size,
360
+ features=self.config.hidden_size,
361
+ dtype=self.dtype,
362
+ param_dtype=self.param_dtype
363
+ )
364
+ self.h = FlaxFalconCollection(
365
+ config=self.config,
366
+ dtype=self.dtype,
367
+ param_dtype=self.param_dtype,
368
+ precision=self.precision
369
+ )
370
+ self.ln_f = nn.LayerNorm(dtype=self.dtype, param_dtype=self.param_dtype, epsilon=self.config.layer_norm_epsilon)
371
+
372
+ def __call__(self,
373
+ input_ids: jnp.int32 = None,
374
+ attention_mask: Optional[jnp.DeviceArray] = None,
375
+ use_cache: Optional[bool] = None,
376
+ return_dict: Optional[bool] = None,
377
+ ):
378
+ batch, seq_len = input_ids.shape
379
+ hidden_states = self.wte(
380
+ inputs=input_ids.astype(jnp.int32)
381
+ )
382
+ if attention_mask is None:
383
+ attention_mask = jnp.ones(
384
+ (batch, seq_len)
385
+ )
386
+
387
+ alibi = build_alibi(seq_len, self.config
388
+ .n_head, 8) if self.config.alibi else None
389
+ causal_mask = nn.make_causal_mask(
390
+ input_ids,
391
+ )
392
+
393
+ mv = jnp.finfo(hidden_states).min
394
+ attention_mask = jnp.where(attention_mask == 1, 0, mv) + jnp.where(causal_mask == 1, 0, mv)
395
+
396
+ causal_mask += attention_mask
397
+ output = self.ln_f(self.h(
398
+ hidden_states=hidden_states,
399
+ attention_mask=attention_mask,
400
+ alibi=alibi
401
+ ))
402
+
403
+ if return_dict:
404
+ return FlaxBaseModelOutput(
405
+ last_hidden_state=output,
406
+ )
407
+ else:
408
+ return output,
409
+
410
+
411
+ class FlaxFalconPretrainedModel(FlaxPreTrainedModel):
412
+ module_class: nn.Module = None
413
+ config_class = FalconConfig
414
+
415
+ def __init__(self, config, _do_init=False, dtype: jnp.dtype = jnp.float32, param_dtype: jnp.dtype = jnp.float32,
416
+ input_shape: Tuple = (1, 12)):
417
+ module = self.module_class(config=config, dtype=dtype, param_dtype=param_dtype)
418
+ super().__init__(_do_init=_do_init, module=module, config=config, dtype=dtype, input_shape=input_shape)
419
+
420
+ def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> Dict:
421
+ if params is None:
422
+ params = self.module.init(
423
+ rngs=rng,
424
+ input_ids=jnp.ones(input_shape),
425
+ attention_mask=jnp.ones(input_shape)
426
+ )
427
+ return params['params']
428
+
429
+ def __call__(self, input_ids,
430
+ attention_mask=None,
431
+ params: FrozenDict = None,
432
+ add_params_field: bool = False,
433
+ return_dict: bool = True):
434
+ params = {'params': params or self.params} if add_params_field else params or self.params
435
+ predict = self.module.apply(
436
+ params,
437
+ input_ids=jnp.asarray(input_ids, dtype=jnp.int32),
438
+ attention_mask=jnp.asarray(attention_mask,
439
+ dtype=jnp.int32) if attention_mask is not None else attention_mask,
440
+ return_dict=return_dict
441
+ )
442
+ return predict
443
+
444
+ def prepare_inputs_for_generation(self, input_ids, attention_mask=None):
445
+ return {
446
+ 'input_ids': input_ids,
447
+ 'attention_mask': attention_mask
448
+ }
449
+
450
+
451
+ class FlaxFalconModel(FlaxFalconPretrainedModel):
452
+ module_class = FlaxFalconModule
453
+
454
+
455
+ class FlaxFalconForCausalLMModule(nn.Module):
456
+ config: FalconConfig
457
+ dtype: jnp.dtype = jnp.float32
458
+ param_dtype: jnp.dtype = jnp.float32
459
+ precision: Optional[Union[jax.lax.Precision, str]] = None
460
+
461
+ def setup(self) -> None:
462
+ self.transformer = FlaxFalconModule(
463
+ config=self.config,
464
+ dtype=self.dtype,
465
+ param_dtype=self.param_dtype,
466
+ precision=self.precision
467
+ )
468
+ self.lm_head = nn.Dense(
469
+ self.config.vocab_size,
470
+ use_bias=False
471
+ )
472
+
473
+ def __call__(self, input_ids, attention_mask, return_dict: bool = False):
474
+ output = self.lm_head(self.transformer(
475
+ input_ids=input_ids,
476
+ attention_mask=attention_mask,
477
+ return_dict=True
478
+ ).last_hidden_state)
479
+ if return_dict:
480
+ return FlaxCausalLMOutput(
481
+ logits=output
482
+ )
483
+ else:
484
+ return output,
485
+
486
+
487
+ class FlaxFalconForCausalLM(FlaxFalconPretrainedModel):
488
+ module_class = FlaxFalconForCausalLMModule