Update model.py
Browse files
model.py
CHANGED
@@ -141,7 +141,7 @@ class FalconConfig(PretrainedConfig):
|
|
141 |
|
142 |
def built_bloom_alibi(attention_mask, num_attention_heads):
|
143 |
b, s = attention_mask.shape
|
144 |
-
cp2 = 2 ** math.
|
145 |
base = jnp.asarray(
|
146 |
2 ** (- (2 ** -(math.log2(cp2) - 3))), dtype=jnp.float32
|
147 |
)
|
@@ -155,8 +155,8 @@ def built_bloom_alibi(attention_mask, num_attention_heads):
|
|
155 |
extra_power = jnp.arange(1, 1 + 2 * num_rem_heads, 2, dtype=jnp.dtype)
|
156 |
slops = jnp.concatenate([slops, jnp.power(extra_base, extra_power)], axis=0)
|
157 |
arange_tensor = (((jnp.cumsum(attention_mask, axis=-1)) - 1) * attention_mask)[:, jnp.newaxis, :]
|
158 |
-
alibi = slops[..., jnp.newaxis] * arange_tensor
|
159 |
-
return alibi.reshape(b, num_attention_heads, 1, s)
|
160 |
|
161 |
|
162 |
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
|
@@ -225,6 +225,7 @@ class FlaxFalconAttention(nn.Module):
|
|
225 |
attention_mask: jnp.DeviceArray = None,
|
226 |
):
|
227 |
b, s, d = hidden_states.shape
|
|
|
228 |
qkv = self.w_qkv(hidden_states)
|
229 |
if not self.config.multi_query:
|
230 |
q, k, v = jnp.split(qkv, 3, -1)
|
|
|
141 |
|
142 |
def built_bloom_alibi(attention_mask, num_attention_heads):
|
143 |
b, s = attention_mask.shape
|
144 |
+
cp2 = 2 ** math.floor(math.log2(num_attention_heads))
|
145 |
base = jnp.asarray(
|
146 |
2 ** (- (2 ** -(math.log2(cp2) - 3))), dtype=jnp.float32
|
147 |
)
|
|
|
155 |
extra_power = jnp.arange(1, 1 + 2 * num_rem_heads, 2, dtype=jnp.dtype)
|
156 |
slops = jnp.concatenate([slops, jnp.power(extra_base, extra_power)], axis=0)
|
157 |
arange_tensor = (((jnp.cumsum(attention_mask, axis=-1)) - 1) * attention_mask)[:, jnp.newaxis, :]
|
158 |
+
alibi = slops[..., jnp.newaxis].astype(jnp.bfloat16) * arange_tensor
|
159 |
+
return alibi.reshape(b , num_attention_heads, 1, s)
|
160 |
|
161 |
|
162 |
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
|
|
|
225 |
attention_mask: jnp.DeviceArray = None,
|
226 |
):
|
227 |
b, s, d = hidden_states.shape
|
228 |
+
|
229 |
qkv = self.w_qkv(hidden_states)
|
230 |
if not self.config.multi_query:
|
231 |
q, k, v = jnp.split(qkv, 3, -1)
|