Update model.py
Browse files
model.py
CHANGED
@@ -156,7 +156,7 @@ def built_bloom_alibi(attention_mask, num_attention_heads):
|
|
156 |
slops = jnp.concatenate([slops, jnp.power(extra_base, extra_power)], axis=0)
|
157 |
arange_tensor = (((jnp.cumsum(attention_mask, axis=-1)) - 1) * attention_mask)[:, jnp.newaxis, :]
|
158 |
alibi = slops[..., jnp.newaxis].astype(jnp.bfloat16) * arange_tensor
|
159 |
-
return alibi.reshape(b
|
160 |
|
161 |
|
162 |
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
|
@@ -252,7 +252,7 @@ class FlaxFalconAttention(nn.Module):
|
|
252 |
attn = with_sharding_constraint(attn, PartitionSpec(("dp", "fsdp"), "mp", None, None))
|
253 |
|
254 |
if alibi is not None:
|
255 |
-
attn +=
|
256 |
attn = attn * self.factor_scale
|
257 |
|
258 |
if attention_mask is not None:
|
@@ -365,7 +365,6 @@ class FlaxFalconCollection(nn.Module):
|
|
365 |
):
|
366 |
for b in self.blocks:
|
367 |
hidden_states = b(
|
368 |
-
|
369 |
attention_mask=attention_mask,
|
370 |
hidden_states=hidden_states,
|
371 |
alibi=alibi
|
|
|
156 |
slops = jnp.concatenate([slops, jnp.power(extra_base, extra_power)], axis=0)
|
157 |
arange_tensor = (((jnp.cumsum(attention_mask, axis=-1)) - 1) * attention_mask)[:, jnp.newaxis, :]
|
158 |
alibi = slops[..., jnp.newaxis].astype(jnp.bfloat16) * arange_tensor
|
159 |
+
return alibi.reshape(b, num_attention_heads, 1, s)
|
160 |
|
161 |
|
162 |
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
|
|
|
252 |
attn = with_sharding_constraint(attn, PartitionSpec(("dp", "fsdp"), "mp", None, None))
|
253 |
|
254 |
if alibi is not None:
|
255 |
+
attn += alibi
|
256 |
attn = attn * self.factor_scale
|
257 |
|
258 |
if attention_mask is not None:
|
|
|
365 |
):
|
366 |
for b in self.blocks:
|
367 |
hidden_states = b(
|
|
|
368 |
attention_mask=attention_mask,
|
369 |
hidden_states=hidden_states,
|
370 |
alibi=alibi
|