Update model.py
Browse files
model.py
CHANGED
@@ -139,16 +139,24 @@ class FalconConfig(PretrainedConfig):
|
|
139 |
return 'dp', 'fsdp', 'mp'
|
140 |
|
141 |
|
142 |
-
def
|
143 |
-
|
144 |
cp2 = 2 ** math.ceil(math.log2(num_attention_heads))
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
148 |
if cp2 != num_attention_heads:
|
149 |
-
|
150 |
-
|
151 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
|
154 |
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
|
@@ -400,8 +408,8 @@ class FlaxFalconModule(nn.Module):
|
|
400 |
(batch, seq_len)
|
401 |
)
|
402 |
|
403 |
-
alibi =
|
404 |
-
|
405 |
causal_mask = nn.make_causal_mask(
|
406 |
input_ids,
|
407 |
)
|
|
|
139 |
return 'dp', 'fsdp', 'mp'
|
140 |
|
141 |
|
142 |
+
def built_bloom_alibi(attention_mask, num_attention_heads):
|
143 |
+
b, s = attention_mask.shape
|
144 |
cp2 = 2 ** math.ceil(math.log2(num_attention_heads))
|
145 |
+
base = jnp.asarray(
|
146 |
+
2 ** (- (2 ** -(math.log2(cp2) - 3))), dtype=jnp.float32
|
147 |
+
)
|
148 |
+
powers = jnp.arange(1, 1 + cp2, dtype=jnp.float32)
|
149 |
+
slops = jnp.power(base, powers)
|
150 |
if cp2 != num_attention_heads:
|
151 |
+
extra_base = jnp.asarray(
|
152 |
+
2 ** (-(2 ** (math.log2(2 * cp2) - 3))), dtype=jnp.float32
|
153 |
+
)
|
154 |
+
num_rem_heads = min(cp2, num_attention_heads - cp2)
|
155 |
+
extra_power = jnp.arange(1, 1 + 2 * num_rem_heads, 2, dtype=jnp.dtype)
|
156 |
+
slops = jnp.concatenate([slops, jnp.power(extra_base, extra_power)], axis=0)
|
157 |
+
arange_tensor = (((jnp.cumsum(attention_mask, axis=-1)) - 1) * attention_mask)[:, jnp.newaxis, :]
|
158 |
+
alibi = slops[..., jnp.newaxis] * arange_tensor
|
159 |
+
return alibi.reshape(b, num_attention_heads, 1, s)
|
160 |
|
161 |
|
162 |
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
|
|
|
408 |
(batch, seq_len)
|
409 |
)
|
410 |
|
411 |
+
alibi = built_bloom_alibi(attention_mask, self.config
|
412 |
+
.n_head).astype(hidden_states.dtype) if self.config.alibi else None
|
413 |
causal_mask = nn.make_causal_mask(
|
414 |
input_ids,
|
415 |
)
|