Transformers
English
falcon
custom_code
text-generation-inference
erfanzar commited on
Commit
4c54116
1 Parent(s): 9f6cc93

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +4 -3
model.py CHANGED
@@ -141,7 +141,7 @@ class FalconConfig(PretrainedConfig):
141
 
142
  def built_bloom_alibi(attention_mask, num_attention_heads):
143
  b, s = attention_mask.shape
144
- cp2 = 2 ** math.ceil(math.log2(num_attention_heads))
145
  base = jnp.asarray(
146
  2 ** (- (2 ** -(math.log2(cp2) - 3))), dtype=jnp.float32
147
  )
@@ -155,8 +155,8 @@ def built_bloom_alibi(attention_mask, num_attention_heads):
155
  extra_power = jnp.arange(1, 1 + 2 * num_rem_heads, 2, dtype=jnp.dtype)
156
  slops = jnp.concatenate([slops, jnp.power(extra_base, extra_power)], axis=0)
157
  arange_tensor = (((jnp.cumsum(attention_mask, axis=-1)) - 1) * attention_mask)[:, jnp.newaxis, :]
158
- alibi = slops[..., jnp.newaxis] * arange_tensor
159
- return alibi.reshape(b, num_attention_heads, 1, s)
160
 
161
 
162
  def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
@@ -225,6 +225,7 @@ class FlaxFalconAttention(nn.Module):
225
  attention_mask: jnp.DeviceArray = None,
226
  ):
227
  b, s, d = hidden_states.shape
 
228
  qkv = self.w_qkv(hidden_states)
229
  if not self.config.multi_query:
230
  q, k, v = jnp.split(qkv, 3, -1)
 
141
 
142
  def built_bloom_alibi(attention_mask, num_attention_heads):
143
  b, s = attention_mask.shape
144
+ cp2 = 2 ** math.floor(math.log2(num_attention_heads))
145
  base = jnp.asarray(
146
  2 ** (- (2 ** -(math.log2(cp2) - 3))), dtype=jnp.float32
147
  )
 
155
  extra_power = jnp.arange(1, 1 + 2 * num_rem_heads, 2, dtype=jnp.dtype)
156
  slops = jnp.concatenate([slops, jnp.power(extra_base, extra_power)], axis=0)
157
  arange_tensor = (((jnp.cumsum(attention_mask, axis=-1)) - 1) * attention_mask)[:, jnp.newaxis, :]
158
+ alibi = slops[..., jnp.newaxis].astype(jnp.bfloat16) * arange_tensor
159
+ return alibi.reshape(b , num_attention_heads, 1, s)
160
 
161
 
162
  def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
 
225
  attention_mask: jnp.DeviceArray = None,
226
  ):
227
  b, s, d = hidden_states.shape
228
+
229
  qkv = self.w_qkv(hidden_states)
230
  if not self.config.multi_query:
231
  q, k, v = jnp.split(qkv, 3, -1)