Transformers
English
falcon
custom_code
text-generation-inference
erfanzar commited on
Commit
9f6cc93
1 Parent(s): 2dad56b

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +18 -10
model.py CHANGED
@@ -139,16 +139,24 @@ class FalconConfig(PretrainedConfig):
139
  return 'dp', 'fsdp', 'mp'
140
 
141
 
142
- def build_alibi(max_length, num_attention_heads, alibi_max: int = 8):
143
- w_range = jnp.arange(1 - max_length, 1).reshape(1, 1, 1, max_length)
144
  cp2 = 2 ** math.ceil(math.log2(num_attention_heads))
145
- h_range = jnp.arange(1, 1 + num_attention_heads, ).reshape(1, -1, 1, 1)
146
- h_range = jnp.matmul(h_range, jnp.asarray(alibi_max / cp2).reshape(1, 1))
147
- slop = 1 / jnp.power(2, h_range)
 
 
148
  if cp2 != num_attention_heads:
149
- slop = jnp.concatenate([slop[1::2], slop[::2]], axis=-1)[:num_attention_heads]
150
- alibi = (w_range * slop).reshape(1, num_attention_heads, 1, max_length)
151
- return alibi
 
 
 
 
 
 
152
 
153
 
154
  def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
@@ -400,8 +408,8 @@ class FlaxFalconModule(nn.Module):
400
  (batch, seq_len)
401
  )
402
 
403
- alibi = build_alibi(seq_len, self.config
404
- .n_head, 8) if self.config.alibi else None
405
  causal_mask = nn.make_causal_mask(
406
  input_ids,
407
  )
 
139
  return 'dp', 'fsdp', 'mp'
140
 
141
 
142
+ def built_bloom_alibi(attention_mask, num_attention_heads):
143
+ b, s = attention_mask.shape
144
  cp2 = 2 ** math.ceil(math.log2(num_attention_heads))
145
+ base = jnp.asarray(
146
+ 2 ** (- (2 ** -(math.log2(cp2) - 3))), dtype=jnp.float32
147
+ )
148
+ powers = jnp.arange(1, 1 + cp2, dtype=jnp.float32)
149
+ slops = jnp.power(base, powers)
150
  if cp2 != num_attention_heads:
151
+ extra_base = jnp.asarray(
152
+ 2 ** (-(2 ** (math.log2(2 * cp2) - 3))), dtype=jnp.float32
153
+ )
154
+ num_rem_heads = min(cp2, num_attention_heads - cp2)
155
+ extra_power = jnp.arange(1, 1 + 2 * num_rem_heads, 2, dtype=jnp.dtype)
156
+ slops = jnp.concatenate([slops, jnp.power(extra_base, extra_power)], axis=0)
157
+ arange_tensor = (((jnp.cumsum(attention_mask, axis=-1)) - 1) * attention_mask)[:, jnp.newaxis, :]
158
+ alibi = slops[..., jnp.newaxis] * arange_tensor
159
+ return alibi.reshape(b, num_attention_heads, 1, s)
160
 
161
 
162
  def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
 
408
  (batch, seq_len)
409
  )
410
 
411
+ alibi = built_bloom_alibi(attention_mask, self.config
412
+ .n_head).astype(hidden_states.dtype) if self.config.alibi else None
413
  causal_mask = nn.make_causal_mask(
414
  input_ids,
415
  )