Update model.py
Browse files
model.py
CHANGED
@@ -75,6 +75,8 @@ class FalconConfig(PretrainedConfig):
|
|
75 |
bias=False,
|
76 |
parallel_attn=False,
|
77 |
max_seq_len=2048,
|
|
|
|
|
78 |
**kwargs,
|
79 |
):
|
80 |
self.vocab_size = vocab_size
|
@@ -90,10 +92,12 @@ class FalconConfig(PretrainedConfig):
|
|
90 |
self.hidden_dropout = hidden_dropout
|
91 |
self.attention_dropout = attention_dropout
|
92 |
self.bos_token_id = bos_token_id
|
|
|
93 |
self.eos_token_id = eos_token_id
|
94 |
self.multi_query = multi_query
|
95 |
self.alibi = alibi
|
96 |
self.bias = bias
|
|
|
97 |
self.parallel_attn = parallel_attn
|
98 |
|
99 |
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
@@ -149,7 +153,7 @@ def built_bloom_alibi(attention_mask, num_attention_heads):
|
|
149 |
slops = jnp.power(base, powers)
|
150 |
if cp2 != num_attention_heads:
|
151 |
extra_base = jnp.asarray(
|
152 |
-
2 ** (-(2 ** (math.log2(2 * cp2) - 3))), dtype=jnp.float32
|
153 |
)
|
154 |
num_rem_heads = min(cp2, num_attention_heads - cp2)
|
155 |
extra_power = jnp.arange(1, 1 + 2 * num_rem_heads, 2, dtype=jnp.dtype)
|
@@ -229,9 +233,10 @@ class FlaxFalconAttention(nn.Module):
|
|
229 |
qkv = self.w_qkv(hidden_states)
|
230 |
if not self.config.multi_query:
|
231 |
q, k, v = jnp.split(qkv, 3, -1)
|
232 |
-
|
233 |
-
|
234 |
-
|
|
|
235 |
k = rearrange(k, 'b s (h d) -> b s h d', h=self.config.n_head)
|
236 |
q = rearrange(q, 'b s (h d) -> b s h d', h=self.config.n_head)
|
237 |
v = rearrange(v, 'b s (h d) -> b s h d', h=self.config.n_head)
|
@@ -240,16 +245,17 @@ class FlaxFalconAttention(nn.Module):
|
|
240 |
b, s, self.config.n_head + 2, -1
|
241 |
)
|
242 |
q, k, v = qkv[..., :-2, :], qkv[..., [-2], :], qkv[..., [-1], :]
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
|
248 |
if not self.config.alibi:
|
249 |
freq = self.freq[:s].reshape(1, s, -1)
|
250 |
q, k = apply_rotary_emb(q, k, freq, self.dtype)
|
251 |
attn = jnp.einsum('...qhd,...khd->...hqk', q, k, precision=self.precision)
|
252 |
-
|
|
|
253 |
|
254 |
if alibi is not None:
|
255 |
attn += alibi
|
@@ -337,6 +343,15 @@ class FlaxFalconBlock(nn.Module):
|
|
337 |
return mlp_out + residual
|
338 |
|
339 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
class FlaxFalconCollection(nn.Module):
|
341 |
config: FalconConfig
|
342 |
dtype: jnp.dtype = jnp.float32
|
@@ -344,8 +359,14 @@ class FlaxFalconCollection(nn.Module):
|
|
344 |
precision: Optional[Union[jax.lax.Precision, str]] = None
|
345 |
|
346 |
def setup(self) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
self.blocks = [
|
348 |
-
|
349 |
config=self.config,
|
350 |
dtype=self.dtype,
|
351 |
param_dtype=self.param_dtype,
|
|
|
75 |
bias=False,
|
76 |
parallel_attn=False,
|
77 |
max_seq_len=2048,
|
78 |
+
use_pjit_attention_force: bool = False,
|
79 |
+
gradient_checkpointing: str = 'nothing_saveable',
|
80 |
**kwargs,
|
81 |
):
|
82 |
self.vocab_size = vocab_size
|
|
|
92 |
self.hidden_dropout = hidden_dropout
|
93 |
self.attention_dropout = attention_dropout
|
94 |
self.bos_token_id = bos_token_id
|
95 |
+
self.use_pjit_attention_force = use_pjit_attention_force
|
96 |
self.eos_token_id = eos_token_id
|
97 |
self.multi_query = multi_query
|
98 |
self.alibi = alibi
|
99 |
self.bias = bias
|
100 |
+
self.gradient_checkpointing = gradient_checkpointing
|
101 |
self.parallel_attn = parallel_attn
|
102 |
|
103 |
super().__init__(bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs)
|
|
|
153 |
slops = jnp.power(base, powers)
|
154 |
if cp2 != num_attention_heads:
|
155 |
extra_base = jnp.asarray(
|
156 |
+
2 ** (-(2 ** -(math.log2(2 * cp2) - 3))), dtype=jnp.float32
|
157 |
)
|
158 |
num_rem_heads = min(cp2, num_attention_heads - cp2)
|
159 |
extra_power = jnp.arange(1, 1 + 2 * num_rem_heads, 2, dtype=jnp.dtype)
|
|
|
233 |
qkv = self.w_qkv(hidden_states)
|
234 |
if not self.config.multi_query:
|
235 |
q, k, v = jnp.split(qkv, 3, -1)
|
236 |
+
if self.config.use_pjit_attention_force:
|
237 |
+
q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
|
238 |
+
k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
|
239 |
+
v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
|
240 |
k = rearrange(k, 'b s (h d) -> b s h d', h=self.config.n_head)
|
241 |
q = rearrange(q, 'b s (h d) -> b s h d', h=self.config.n_head)
|
242 |
v = rearrange(v, 'b s (h d) -> b s h d', h=self.config.n_head)
|
|
|
245 |
b, s, self.config.n_head + 2, -1
|
246 |
)
|
247 |
q, k, v = qkv[..., :-2, :], qkv[..., [-2], :], qkv[..., [-1], :]
|
248 |
+
if self.config.use_pjit_attention_force:
|
249 |
+
q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
|
250 |
+
k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
|
251 |
+
v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
|
252 |
|
253 |
if not self.config.alibi:
|
254 |
freq = self.freq[:s].reshape(1, s, -1)
|
255 |
q, k = apply_rotary_emb(q, k, freq, self.dtype)
|
256 |
attn = jnp.einsum('...qhd,...khd->...hqk', q, k, precision=self.precision)
|
257 |
+
if self.config.use_pjit_attention_force:
|
258 |
+
attn = with_sharding_constraint(attn, PartitionSpec(("dp", "fsdp"), "mp", None, None))
|
259 |
|
260 |
if alibi is not None:
|
261 |
attn += alibi
|
|
|
343 |
return mlp_out + residual
|
344 |
|
345 |
|
346 |
+
def get_gradient_checkpoint_policy(name):
|
347 |
+
return {
|
348 |
+
'everything_saveable': jax.checkpoint_policies.everything_saveable,
|
349 |
+
'nothing_saveable': jax.checkpoint_policies.nothing_saveable,
|
350 |
+
'checkpoint_dots': jax.checkpoint_policies.checkpoint_dots,
|
351 |
+
'checkpoint_dots_with_no_batch_dims': jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims,
|
352 |
+
}[name]
|
353 |
+
|
354 |
+
|
355 |
class FlaxFalconCollection(nn.Module):
|
356 |
config: FalconConfig
|
357 |
dtype: jnp.dtype = jnp.float32
|
|
|
359 |
precision: Optional[Union[jax.lax.Precision, str]] = None
|
360 |
|
361 |
def setup(self) -> None:
|
362 |
+
block = FlaxFalconBlock
|
363 |
+
if self.config.gradient_checkpointing != '':
|
364 |
+
block = nn.remat(
|
365 |
+
block,
|
366 |
+
policy=get_gradient_checkpoint_policy(self.config.gradient_checkpointing)
|
367 |
+
)
|
368 |
self.blocks = [
|
369 |
+
block(
|
370 |
config=self.config,
|
371 |
dtype=self.dtype,
|
372 |
param_dtype=self.param_dtype,
|