Transformers
English
falcon
custom_code
text-generation-inference
erfanzar commited on
Commit
4719d39
1 Parent(s): 7200664

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +31 -10
model.py CHANGED
@@ -75,6 +75,8 @@ class FalconConfig(PretrainedConfig):
75
  bias=False,
76
  parallel_attn=False,
77
  max_seq_len=2048,
 
 
78
  **kwargs,
79
  ):
80
  self.vocab_size = vocab_size
@@ -90,10 +92,12 @@ class FalconConfig(PretrainedConfig):
90
  self.hidden_dropout = hidden_dropout
91
  self.attention_dropout = attention_dropout
92
  self.bos_token_id = bos_token_id
 
93
  self.eos_token_id = eos_token_id
94
  self.multi_query = multi_query
95
  self.alibi = alibi
96
  self.bias = bias
 
97
  self.parallel_attn = parallel_attn
98
 
99
  super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
@@ -149,7 +153,7 @@ def built_bloom_alibi(attention_mask, num_attention_heads):
149
  slops = jnp.power(base, powers)
150
  if cp2 != num_attention_heads:
151
  extra_base = jnp.asarray(
152
- 2 ** (-(2 ** (math.log2(2 * cp2) - 3))), dtype=jnp.float32
153
  )
154
  num_rem_heads = min(cp2, num_attention_heads - cp2)
155
  extra_power = jnp.arange(1, 1 + 2 * num_rem_heads, 2, dtype=jnp.dtype)
@@ -229,9 +233,10 @@ class FlaxFalconAttention(nn.Module):
229
  qkv = self.w_qkv(hidden_states)
230
  if not self.config.multi_query:
231
  q, k, v = jnp.split(qkv, 3, -1)
232
- q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
233
- k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
234
- v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
 
235
  k = rearrange(k, 'b s (h d) -> b s h d', h=self.config.n_head)
236
  q = rearrange(q, 'b s (h d) -> b s h d', h=self.config.n_head)
237
  v = rearrange(v, 'b s (h d) -> b s h d', h=self.config.n_head)
@@ -240,16 +245,17 @@ class FlaxFalconAttention(nn.Module):
240
  b, s, self.config.n_head + 2, -1
241
  )
242
  q, k, v = qkv[..., :-2, :], qkv[..., [-2], :], qkv[..., [-1], :]
243
-
244
- q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
245
- k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
246
- v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
247
 
248
  if not self.config.alibi:
249
  freq = self.freq[:s].reshape(1, s, -1)
250
  q, k = apply_rotary_emb(q, k, freq, self.dtype)
251
  attn = jnp.einsum('...qhd,...khd->...hqk', q, k, precision=self.precision)
252
- attn = with_sharding_constraint(attn, PartitionSpec(("dp", "fsdp"), "mp", None, None))
 
253
 
254
  if alibi is not None:
255
  attn += alibi
@@ -337,6 +343,15 @@ class FlaxFalconBlock(nn.Module):
337
  return mlp_out + residual
338
 
339
 
 
 
 
 
 
 
 
 
 
340
  class FlaxFalconCollection(nn.Module):
341
  config: FalconConfig
342
  dtype: jnp.dtype = jnp.float32
@@ -344,8 +359,14 @@ class FlaxFalconCollection(nn.Module):
344
  precision: Optional[Union[jax.lax.Precision, str]] = None
345
 
346
  def setup(self) -> None:
 
 
 
 
 
 
347
  self.blocks = [
348
- FlaxFalconBlock(
349
  config=self.config,
350
  dtype=self.dtype,
351
  param_dtype=self.param_dtype,
 
75
  bias=False,
76
  parallel_attn=False,
77
  max_seq_len=2048,
78
+ use_pjit_attention_force: bool = False,
79
+ gradient_checkpointing: str = 'nothing_saveable',
80
  **kwargs,
81
  ):
82
  self.vocab_size = vocab_size
 
92
  self.hidden_dropout = hidden_dropout
93
  self.attention_dropout = attention_dropout
94
  self.bos_token_id = bos_token_id
95
+ self.use_pjit_attention_force = use_pjit_attention_force
96
  self.eos_token_id = eos_token_id
97
  self.multi_query = multi_query
98
  self.alibi = alibi
99
  self.bias = bias
100
+ self.gradient_checkpointing = gradient_checkpointing
101
  self.parallel_attn = parallel_attn
102
 
103
  super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
 
153
  slops = jnp.power(base, powers)
154
  if cp2 != num_attention_heads:
155
  extra_base = jnp.asarray(
156
+ 2 ** (-(2 ** -(math.log2(2 * cp2) - 3))), dtype=jnp.float32
157
  )
158
  num_rem_heads = min(cp2, num_attention_heads - cp2)
159
  extra_power = jnp.arange(1, 1 + 2 * num_rem_heads, 2, dtype=jnp.dtype)
 
233
  qkv = self.w_qkv(hidden_states)
234
  if not self.config.multi_query:
235
  q, k, v = jnp.split(qkv, 3, -1)
236
+ if self.config.use_pjit_attention_force:
237
+ q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
238
+ k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
239
+ v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
240
  k = rearrange(k, 'b s (h d) -> b s h d', h=self.config.n_head)
241
  q = rearrange(q, 'b s (h d) -> b s h d', h=self.config.n_head)
242
  v = rearrange(v, 'b s (h d) -> b s h d', h=self.config.n_head)
 
245
  b, s, self.config.n_head + 2, -1
246
  )
247
  q, k, v = qkv[..., :-2, :], qkv[..., [-2], :], qkv[..., [-1], :]
248
+ if self.config.use_pjit_attention_force:
249
+ q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
250
+ k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
251
+ v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
252
 
253
  if not self.config.alibi:
254
  freq = self.freq[:s].reshape(1, s, -1)
255
  q, k = apply_rotary_emb(q, k, freq, self.dtype)
256
  attn = jnp.einsum('...qhd,...khd->...hqk', q, k, precision=self.precision)
257
+ if self.config.use_pjit_attention_force:
258
+ attn = with_sharding_constraint(attn, PartitionSpec(("dp", "fsdp"), "mp", None, None))
259
 
260
  if alibi is not None:
261
  attn += alibi
 
343
  return mlp_out + residual
344
 
345
 
346
+ def get_gradient_checkpoint_policy(name):
347
+ return {
348
+ 'everything_saveable': jax.checkpoint_policies.everything_saveable,
349
+ 'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
350
+ 'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
351
+ 'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
352
+ }[name]
353
+
354
+
355
  class FlaxFalconCollection(nn.Module):
356
  config: FalconConfig
357
  dtype: jnp.dtype = jnp.float32
 
359
  precision: Optional[Union[jax.lax.Precision, str]] = None
360
 
361
  def setup(self) -> None:
362
+ block = FlaxFalconBlock
363
+ if self.config.gradient_checkpointing != '':
364
+ block = nn.remat(
365
+ block,
366
+ policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing)
367
+ )
368
  self.blocks = [
369
+ block(
370
  config=self.config,
371
  dtype=self.dtype,
372
  param_dtype=self.param_dtype,