Transformers
English
falcon
custom_code
text-generation-inference
erfanzar commited on
Commit
e48934c
1 Parent(s): b545422

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +24 -8
model.py CHANGED
@@ -193,7 +193,8 @@ class FlaxFalconAttention(nn.Module):
193
  def setup(self) -> None:
194
  head_dim = self.config.hidden_size // self.config.n_head
195
  self.w_qkv = nn.Dense(
196
- features=self.config.hidden_size * 3,
 
197
  dtype=self.dtype,
198
  param_dtype=self.param_dtype,
199
  use_bias=self.config.bias
@@ -206,6 +207,7 @@ class FlaxFalconAttention(nn.Module):
206
  use_bias=self.config.bias
207
  )
208
  self.head_dim = head_dim
 
209
  if not self.config.alibi:
210
  self.freq = precompute_freqs_cis(head_dim, self.config.max_seq_len, dtype=self.dtype)
211
 
@@ -215,13 +217,25 @@ class FlaxFalconAttention(nn.Module):
215
  attention_mask: jnp.DeviceArray = None,
216
  ):
217
  b, s, d = hidden_states.shape
218
- q, k, v = jnp.split(self.w_qkv(hidden_states), 3, -1)
219
- q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
220
- k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
221
- v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
222
- k = rearrange(k, 'b s (h d) -> b s h d', h=self.config.n_head)
223
- q = rearrange(q, 'b s (h d) -> b s h d', h=self.config.n_head)
224
- v = rearrange(v, 'b s (h d) -> b s h d', h=self.config.n_head)
 
 
 
 
 
 
 
 
 
 
 
 
225
  if not self.config.alibi:
226
  freq = self.freq[:s].reshape(1, s, -1)
227
  q, k = apply_rotary_emb(q, k, freq, self.dtype)
@@ -231,8 +245,10 @@ class FlaxFalconAttention(nn.Module):
231
  if alibi is not None:
232
  attn += attn
233
  attn = attn * self.factor_scale
 
234
  if attention_mask is not None:
235
  attn += attention_mask
 
236
  attn = jax.nn.softmax(attn, axis=-1)
237
  attn = jnp.einsum('...hqk,...khd->...qhd', attn, v, precision=self.precision).reshape((b, s, d))
238
  return self.wo(attn)
 
193
  def setup(self) -> None:
194
  head_dim = self.config.hidden_size // self.config.n_head
195
  self.w_qkv = nn.Dense(
196
+ features=3 * self.config.hidden_size if not self.config.multi_query else (
197
+ self.config.hidden_size + 2 * head_dim),
198
  dtype=self.dtype,
199
  param_dtype=self.param_dtype,
200
  use_bias=self.config.bias
 
207
  use_bias=self.config.bias
208
  )
209
  self.head_dim = head_dim
210
+ assert self.head_dim * self.config.n_head == self.config.hidden_size
211
  if not self.config.alibi:
212
  self.freq = precompute_freqs_cis(head_dim, self.config.max_seq_len, dtype=self.dtype)
213
 
 
217
  attention_mask: jnp.DeviceArray = None,
218
  ):
219
  b, s, d = hidden_states.shape
220
+ qkv = self.w_qkv(hidden_states)
221
+ if not self.config.multi_query:
222
+ q, k, v = jnp.split(qkv, 3, -1)
223
+ q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
224
+ k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
225
+ v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, 'mp'))
226
+ k = rearrange(k, 'b s (h d) -> b s h d', h=self.config.n_head)
227
+ q = rearrange(q, 'b s (h d) -> b s h d', h=self.config.n_head)
228
+ v = rearrange(v, 'b s (h d) -> b s h d', h=self.config.n_head)
229
+ else:
230
+ qkv = qkv.reshape(
231
+ b, s, self.config.n_head + 2, -1
232
+ )
233
+ q, k, v = qkv[..., :-2, :], qkv[..., [-2], :], qkv[..., [-1], :]
234
+
235
+ q = with_sharding_constraint(q, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
236
+ k = with_sharding_constraint(k, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
237
+ v = with_sharding_constraint(v, PartitionSpec(('dp', 'fsdp'), None, None, 'mp'))
238
+
239
  if not self.config.alibi:
240
  freq = self.freq[:s].reshape(1, s, -1)
241
  q, k = apply_rotary_emb(q, k, freq, self.dtype)
 
245
  if alibi is not None:
246
  attn += attn
247
  attn = attn * self.factor_scale
248
+
249
  if attention_mask is not None:
250
  attn += attention_mask
251
+
252
  attn = jax.nn.softmax(attn, axis=-1)
253
  attn = jnp.einsum('...hqk,...khd->...qhd', attn, v, precision=self.precision).reshape((b, s, d))
254
  return self.wo(attn)