Guanzheng commited on
Commit
91418eb
·
1 Parent(s): 1f5d87f

Update modeling_llama.py

Browse files
Files changed (1) hide show
  1. modeling_llama.py +6 -6
modeling_llama.py CHANGED
@@ -60,14 +60,10 @@ def is_flash_attn_available():
60
  return False
61
 
62
  # Let's add an extra check to see if cuda is available
63
- import torch
64
 
65
  return _is_package_available("flash_attn") and torch.cuda.is_available()
66
 
67
- if is_flash_attn_available():
68
- from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
69
- # from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
70
- from flash_attn.bert_padding import unpad_input, pad_input
71
 
72
 
73
 
@@ -232,7 +228,10 @@ class LlamaAttention(nn.Module):
232
 
233
  attention_mask: [bsz, q_len]
234
  """
235
-
 
 
 
236
  bsz, q_len, *_ = qkv.size()
237
 
238
  if key_padding_mask is None:
@@ -342,6 +341,7 @@ class LlamaAttention(nn.Module):
342
  return attn_output, attn_weights, past_key_value
343
  # use flash attention
344
  elif past_key_value is not None:
 
345
  output = flash_attn_with_kvcache(
346
  query_states.transpose(1, 2),
347
  key_states.transpose(1, 2),
 
60
  return False
61
 
62
  # Let's add an extra check to see if cuda is available
 
63
 
64
  return _is_package_available("flash_attn") and torch.cuda.is_available()
65
 
66
+
 
 
 
67
 
68
 
69
 
 
228
 
229
  attention_mask: [bsz, q_len]
230
  """
231
+ if is_flash_attn_available():
232
+ from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func, flash_attn_qkvpacked_func, flash_attn_with_kvcache
233
+ # from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
234
+ from flash_attn.bert_padding import unpad_input, pad_input
235
  bsz, q_len, *_ = qkv.size()
236
 
237
  if key_padding_mask is None:
 
341
  return attn_output, attn_weights, past_key_value
342
  # use flash attention
343
  elif past_key_value is not None:
344
+ from flash_attn.flash_attn_interface import flash_attn_with_kvcache
345
  output = flash_attn_with_kvcache(
346
  query_states.transpose(1, 2),
347
  key_states.transpose(1, 2),