updt triton_flash_attn_fn import

#21
by vchiley - opened
Files changed (1) hide show
  1. attention.py +11 -3
attention.py CHANGED
@@ -87,9 +87,17 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None
87
 
88
  def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
89
  try:
90
- from flash_attn import flash_attn_triton
91
  except:
92
- raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
 
 
 
 
 
 
 
 
93
  check_valid_inputs(query, key, value)
94
  if dropout_p:
95
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
@@ -108,7 +116,7 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
108
  key = key.expand(*key.shape[:2], n_heads, key.size(-1))
109
  value = value.expand(*value.shape[:2], n_heads, value.size(-1))
110
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
111
- attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
112
  output = attn_output.view(*attn_output.shape[:2], -1)
113
  return (output, None)
114
 
 
87
 
88
  def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
89
  try:
90
+ from .flash_attn_triton import flash_attn_func
91
  except:
92
+ _installed = False
93
+ if version.parse(torch.__version__) < version.parse('2.0.0'):
94
+ _installed = True
95
+ try:
96
+ from flash_attn.flash_attn_triton import flash_attn_func
97
+ except:
98
+ _installed = False
99
+ if not _installed:
100
+ raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
101
  check_valid_inputs(query, key, value)
102
  if dropout_p:
103
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
 
116
  key = key.expand(*key.shape[:2], n_heads, key.size(-1))
117
  value = value.expand(*value.shape[:2], n_heads, value.size(-1))
118
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
119
+ attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
120
  output = attn_output.view(*attn_output.shape[:2], -1)
121
  return (output, None)
122