Jackmin801 commited on
Commit
6f3de15
1 Parent(s): 4f24e0f

set flash attn as option in config

Browse files
Files changed (3) hide show
  1. configuration_bert.py +4 -0
  2. flash_attn_triton.py +9 -32
  3. modeling_bert.py +17 -6
configuration_bert.py CHANGED
@@ -127,6 +127,8 @@ class JinaBertConfig(PretrainedConfig):
127
  emb_pooler (`str`, *optional*, defaults to `None`):
128
  The function to use for pooling the last layer embeddings to get the sentence embeddings.
129
  Should be one of `None`, `"mean"`.
 
 
130
 
131
  Examples:
132
 
@@ -164,6 +166,7 @@ class JinaBertConfig(PretrainedConfig):
164
  classifier_dropout=None,
165
  feed_forward_type="original",
166
  emb_pooler=None,
 
167
  **kwargs,
168
  ):
169
  super().__init__(pad_token_id=pad_token_id, **kwargs)
@@ -185,6 +188,7 @@ class JinaBertConfig(PretrainedConfig):
185
  self.classifier_dropout = classifier_dropout
186
  self.feed_forward_type = feed_forward_type
187
  self.emb_pooler = emb_pooler
 
188
 
189
 
190
  class JinaBertOnnxConfig(OnnxConfig):
 
127
  emb_pooler (`str`, *optional*, defaults to `None`):
128
  The function to use for pooling the last layer embeddings to get the sentence embeddings.
129
  Should be one of `None`, `"mean"`.
130
+ with_flash (`bool`, *optional*, defaults to `False`):
131
+ Whether to use flash attention. Only works for `triton==2.0.0.dev20230208`
132
 
133
  Examples:
134
 
 
166
  classifier_dropout=None,
167
  feed_forward_type="original",
168
  emb_pooler=None,
169
+ with_flash=False,
170
  **kwargs,
171
  ):
172
  super().__init__(pad_token_id=pad_token_id, **kwargs)
 
188
  self.classifier_dropout = classifier_dropout
189
  self.feed_forward_type = feed_forward_type
190
  self.emb_pooler = emb_pooler
191
+ self.with_flash = with_flash
192
 
193
 
194
  class JinaBertOnnxConfig(OnnxConfig):
flash_attn_triton.py CHANGED
@@ -81,21 +81,11 @@ def _fwd_kernel(
81
  Lse,
82
  TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
83
  softmax_scale,
84
- stride_qb,
85
- stride_qh,
86
- stride_qm,
87
- stride_kb,
88
- stride_kh,
89
- stride_kn,
90
- stride_vb,
91
- stride_vh,
92
- stride_vn,
93
- stride_bb,
94
- stride_bh,
95
- stride_bm,
96
- stride_ob,
97
- stride_oh,
98
- stride_om,
99
  nheads,
100
  seqlen_q,
101
  seqlen_k,
@@ -316,11 +306,6 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
316
  elif bias.shape[2:] == (seqlen_q, seqlen_k):
317
  bias_type = 'matrix'
318
  else:
319
- print(q.shape)
320
- print(k.shape)
321
- print(seqlen_q)
322
- print(seqlen_k)
323
- print(bias.shape)
324
  raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
325
  ' or (seqlen_q, seqlen_k)')
326
  if bias.shape[:2] == (1, nheads):
@@ -359,19 +344,11 @@ def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None):
359
  lse,
360
  tmp,
361
  softmax_scale,
362
- q.stride(0),
363
- q.stride(2),
364
- q.stride(1),
365
- k.stride(0),
366
- k.stride(2),
367
- k.stride(1),
368
- v.stride(0),
369
- v.stride(2),
370
- v.stride(1),
371
  *bias_strides,
372
- o.stride(0),
373
- o.stride(2),
374
- o.stride(1),
375
  nheads,
376
  seqlen_q,
377
  seqlen_k,
 
81
  Lse,
82
  TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug
83
  softmax_scale,
84
+ stride_qb, stride_qh, stride_qm,
85
+ stride_kb, stride_kh, stride_kn,
86
+ stride_vb, stride_vh, stride_vn,
87
+ stride_bb, stride_bh, stride_bm,
88
+ stride_ob, stride_oh, stride_om,
 
 
 
 
 
 
 
 
 
 
89
  nheads,
90
  seqlen_q,
91
  seqlen_k,
 
306
  elif bias.shape[2:] == (seqlen_q, seqlen_k):
307
  bias_type = 'matrix'
308
  else:
 
 
 
 
 
309
  raise RuntimeError('Last 2 dimensions of bias must be (1, seqlen_k)'
310
  ' or (seqlen_q, seqlen_k)')
311
  if bias.shape[:2] == (1, nheads):
 
344
  lse,
345
  tmp,
346
  softmax_scale,
347
+ q.stride(0), q.stride(2), q.stride(1),
348
+ k.stride(0), k.stride(2), k.stride(1),
349
+ v.stride(0), v.stride(2), v.stride(1),
 
 
 
 
 
 
350
  *bias_strides,
351
+ o.stride(0), o.stride(2), o.stride(1),
 
 
352
  nheads,
353
  seqlen_q,
354
  seqlen_k,
modeling_bert.py CHANGED
@@ -55,7 +55,10 @@ from transformers.utils import (
55
  replace_return_docstrings,
56
  )
57
  from .configuration_bert import JinaBertConfig
58
- from .flash_attn_triton import flash_attn_func
 
 
 
59
 
60
  try:
61
  from tqdm.autonotebook import trange
@@ -282,7 +285,7 @@ class JinaBertEmbeddings(nn.Module):
282
 
283
 
284
  class JinaBertSelfAttention(nn.Module):
285
- def __init__(self, config, position_embedding_type=None):
286
  super().__init__()
287
  if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
288
  config, "embedding_size"
@@ -291,6 +294,13 @@ class JinaBertSelfAttention(nn.Module):
291
  f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
292
  f"heads ({config.num_attention_heads})"
293
  )
 
 
 
 
 
 
 
294
 
295
  self.num_attention_heads = config.num_attention_heads
296
  self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
@@ -334,14 +344,15 @@ class JinaBertSelfAttention(nn.Module):
334
  output_attentions: Optional[bool] = False,
335
  bias: Optional[torch.FloatTensor] = None,
336
  ) -> Tuple[torch.Tensor]:
337
- if False:
338
  b, s, h = hidden_states.shape
339
  q = self.query(hidden_states)
340
  k = self.key(hidden_states)
341
  v = self.value(hidden_states)
342
- q = self.transpose_for_scores(q)
343
- k = self.transpose_for_scores(k)
344
- v = self.transpose_for_scores(v)
 
345
  attn = flash_attn_func(q, k, v, bias)
346
  return (attn.view(b, s, h),)
347
  mixed_query_layer = self.query(hidden_states)
 
55
  replace_return_docstrings,
56
  )
57
  from .configuration_bert import JinaBertConfig
58
+ try:
59
+ from .flash_attn_triton import flash_attn_func
60
+ except Exception:
61
+ flash_attn_func = None
62
 
63
  try:
64
  from tqdm.autonotebook import trange
 
285
 
286
 
287
  class JinaBertSelfAttention(nn.Module):
288
+ def __init__(self, config: JinaBertConfig, position_embedding_type=None):
289
  super().__init__()
290
  if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
291
  config, "embedding_size"
 
294
  f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
295
  f"heads ({config.num_attention_heads})"
296
  )
297
+
298
+ self.with_flash = config.with_flash
299
+ if self.with_flash:
300
+ if flash_attn_func is None:
301
+ raise ValueError(
302
+ f"flash_attn_func is None, please install flash_attn_triton"
303
+ )
304
 
305
  self.num_attention_heads = config.num_attention_heads
306
  self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
 
344
  output_attentions: Optional[bool] = False,
345
  bias: Optional[torch.FloatTensor] = None,
346
  ) -> Tuple[torch.Tensor]:
347
+ if self.with_flash:
348
  b, s, h = hidden_states.shape
349
  q = self.query(hidden_states)
350
  k = self.key(hidden_states)
351
  v = self.value(hidden_states)
352
+ # B x S x hidden_dim -> B x S x num_heads x head_dim
353
+ q = q.view(b, s, self.num_attention_heads, self.attention_head_size)
354
+ k = k.view(b, s, self.num_attention_heads, self.attention_head_size)
355
+ v = v.view(b, s, self.num_attention_heads, self.attention_head_size)
356
  attn = flash_attn_func(q, k, v, bias)
357
  return (attn.view(b, s, h),)
358
  mixed_query_layer = self.query(hidden_states)