Transformers
English
falcon
custom_code
text-generation-inference
erfanzar commited on
Commit
5e73374
1 Parent(s): b230d07

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +2 -3
model.py CHANGED
@@ -156,7 +156,7 @@ def built_bloom_alibi(attention_mask, num_attention_heads):
156
  slops = jnp.concatenate([slops, jnp.power(extra_base, extra_power)], axis=0)
157
  arange_tensor = (((jnp.cumsum(attention_mask, axis=-1)) - 1) * attention_mask)[:, jnp.newaxis, :]
158
  alibi = slops[..., jnp.newaxis].astype(jnp.bfloat16) * arange_tensor
159
- return alibi.reshape(b , num_attention_heads, 1, s)
160
 
161
 
162
  def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
@@ -252,7 +252,7 @@ class FlaxFalconAttention(nn.Module):
252
  attn = with_sharding_constraint(attn, PartitionSpec(("dp", "fsdp"), "mp", None, None))
253
 
254
  if alibi is not None:
255
- attn += attn
256
  attn = attn * self.factor_scale
257
 
258
  if attention_mask is not None:
@@ -365,7 +365,6 @@ class FlaxFalconCollection(nn.Module):
365
  ):
366
  for b in self.blocks:
367
  hidden_states = b(
368
-
369
  attention_mask=attention_mask,
370
  hidden_states=hidden_states,
371
  alibi=alibi
 
156
  slops = jnp.concatenate([slops, jnp.power(extra_base, extra_power)], axis=0)
157
  arange_tensor = (((jnp.cumsum(attention_mask, axis=-1)) - 1) * attention_mask)[:, jnp.newaxis, :]
158
  alibi = slops[..., jnp.newaxis].astype(jnp.bfloat16) * arange_tensor
159
+ return alibi.reshape(b, num_attention_heads, 1, s)
160
 
161
 
162
  def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
 
252
  attn = with_sharding_constraint(attn, PartitionSpec(("dp", "fsdp"), "mp", None, None))
253
 
254
  if alibi is not None:
255
+ attn += alibi
256
  attn = attn * self.factor_scale
257
 
258
  if attention_mask is not None:
 
365
  ):
366
  for b in self.blocks:
367
  hidden_states = b(
 
368
  attention_mask=attention_mask,
369
  hidden_states=hidden_states,
370
  alibi=alibi