efederici commited on
Commit
190c13e
1 Parent(s): fa0590e

Update attention.py

Browse files
Files changed (1) hide show
  1. attention.py +51 -29
attention.py CHANGED
@@ -5,6 +5,7 @@ from typing import Optional
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
 
8
  from torch import nn
9
  from .norm import LPLayerNorm
10
 
@@ -16,25 +17,34 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_cau
16
  return False
17
  return original_is_causal
18
 
19
- def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
20
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
21
- k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads)
22
- v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads)
23
- min_val = torch.finfo(q.dtype).min
 
 
 
 
 
24
  (b, _, s_q, d) = q.shape
25
  s_k = k.size(-1)
26
  if softmax_scale is None:
27
  softmax_scale = 1 / math.sqrt(d)
28
  attn_weight = q.matmul(k) * softmax_scale
29
  if attn_bias is not None:
 
 
 
30
  if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
31
  raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
32
  attn_weight = attn_weight + attn_bias
 
33
  if key_padding_mask is not None:
34
  if attn_bias is not None:
35
  warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
36
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
37
- if is_causal:
38
  s = max(s_q, s_k)
39
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
40
  causal_mask = causal_mask.tril()
@@ -45,11 +55,11 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
45
  attn_weight = torch.softmax(attn_weight, dim=-1)
46
  if dropout_p:
47
  attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
48
- out = attn_weight.matmul(v)
49
  out = rearrange(out, 'b h s d -> b s (h d)')
50
  if needs_weights:
51
- return (out, attn_weight)
52
- return (out, None)
53
 
54
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
55
  for tensor in tensors:
@@ -58,12 +68,21 @@ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
58
  if not tensor.is_cuda:
59
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
60
 
61
- def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
62
  try:
63
  from flash_attn import bert_padding, flash_attn_interface
64
  except:
65
  raise RuntimeError('Please install flash-attn==1.0.3.post0')
66
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
 
 
 
67
  if attn_bias is not None:
68
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
69
  (batch_size, seqlen) = query.shape[:2]
@@ -83,14 +102,31 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None
83
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
84
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
85
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
86
- return (output, None)
87
 
88
- def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
89
  try:
90
  from .flash_attn_triton import flash_attn_func
91
  except:
92
- raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
 
 
 
 
 
 
 
 
93
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
 
 
 
94
  if dropout_p:
95
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
96
  if needs_weights:
@@ -110,7 +146,7 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
110
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
111
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
112
  output = attn_output.view(*attn_output.shape[:2], -1)
113
- return (output, None)
114
 
115
  class MultiheadAttention(nn.Module):
116
  """Multi-head self attention.
@@ -162,14 +198,7 @@ class MultiheadAttention(nn.Module):
162
  dtype = query.dtype
163
  query = self.q_ln(query).to(dtype)
164
  key = self.k_ln(key).to(dtype)
165
- if past_key_value is not None:
166
- if len(past_key_value) != 0:
167
- key = torch.cat([past_key_value[0], key], dim=1)
168
- value = torch.cat([past_key_value[1], value], dim=1)
169
- past_key_value = (key, value)
170
- if attn_bias is not None:
171
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
172
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
173
  return (self.out_proj(context), attn_weights, past_key_value)
174
 
175
  class MultiQueryAttention(nn.Module):
@@ -223,14 +252,7 @@ class MultiQueryAttention(nn.Module):
223
  dtype = query.dtype
224
  query = self.q_ln(query).to(dtype)
225
  key = self.k_ln(key).to(dtype)
226
- if past_key_value is not None:
227
- if len(past_key_value) != 0:
228
- key = torch.cat([past_key_value[0], key], dim=1)
229
- value = torch.cat([past_key_value[1], value], dim=1)
230
- past_key_value = (key, value)
231
- if attn_bias is not None:
232
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
233
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
234
  return (self.out_proj(context), attn_weights, past_key_value)
235
 
236
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
 
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
8
+ from packaging import version
9
  from torch import nn
10
  from .norm import LPLayerNorm
11
 
 
17
  return False
18
  return original_is_causal
19
 
20
+ def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
21
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
22
+ kv_n_heads = 1 if multiquery else n_heads
23
+ k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
24
+ v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
25
+ if past_key_value is not None:
26
+ if len(past_key_value) != 0:
27
+ k = torch.cat([past_key_value[0], k], dim=3)
28
+ v = torch.cat([past_key_value[1], v], dim=2)
29
+ past_key_value = (k, v)
30
  (b, _, s_q, d) = q.shape
31
  s_k = k.size(-1)
32
  if softmax_scale is None:
33
  softmax_scale = 1 / math.sqrt(d)
34
  attn_weight = q.matmul(k) * softmax_scale
35
  if attn_bias is not None:
36
+ _s_q = max(0, attn_bias.size(2) - s_q)
37
+ _s_k = max(0, attn_bias.size(3) - s_k)
38
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
39
  if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
40
  raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
41
  attn_weight = attn_weight + attn_bias
42
+ min_val = torch.finfo(q.dtype).min
43
  if key_padding_mask is not None:
44
  if attn_bias is not None:
45
  warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
46
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
47
+ if is_causal and (not q.size(2) == 1):
48
  s = max(s_q, s_k)
49
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
50
  causal_mask = causal_mask.tril()
 
55
  attn_weight = torch.softmax(attn_weight, dim=-1)
56
  if dropout_p:
57
  attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
58
+ out = attn_weight.to(v.dtype).matmul(v)
59
  out = rearrange(out, 'b h s d -> b s (h d)')
60
  if needs_weights:
61
+ return (out, attn_weight, past_key_value)
62
+ return (out, None, past_key_value)
63
 
64
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
65
  for tensor in tensors:
 
68
  if not tensor.is_cuda:
69
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
70
 
71
+ def flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
72
  try:
73
  from flash_attn import bert_padding, flash_attn_interface
74
  except:
75
  raise RuntimeError('Please install flash-attn==1.0.3.post0')
76
  check_valid_inputs(query, key, value)
77
+ if past_key_value is not None:
78
+ if len(past_key_value) != 0:
79
+ key = torch.cat([past_key_value[0], key], dim=1)
80
+ value = torch.cat([past_key_value[1], value], dim=1)
81
+ past_key_value = (key, value)
82
+ if attn_bias is not None:
83
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
84
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
85
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
86
  if attn_bias is not None:
87
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
88
  (batch_size, seqlen) = query.shape[:2]
 
102
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
103
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
104
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
105
+ return (output, None, past_key_value)
106
 
107
+ def triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
108
  try:
109
  from .flash_attn_triton import flash_attn_func
110
  except:
111
+ _installed = False
112
+ if version.parse(torch.__version__) < version.parse('2.0.0'):
113
+ _installed = True
114
+ try:
115
+ from flash_attn.flash_attn_triton import flash_attn_func
116
+ except:
117
+ _installed = False
118
+ if not _installed:
119
+ raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
120
  check_valid_inputs(query, key, value)
121
+ if past_key_value is not None:
122
+ if len(past_key_value) != 0:
123
+ key = torch.cat([past_key_value[0], key], dim=1)
124
+ value = torch.cat([past_key_value[1], value], dim=1)
125
+ past_key_value = (key, value)
126
+ if attn_bias is not None:
127
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
128
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
129
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
130
  if dropout_p:
131
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
132
  if needs_weights:
 
146
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
147
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
148
  output = attn_output.view(*attn_output.shape[:2], -1)
149
+ return (output, None, past_key_value)
150
 
151
  class MultiheadAttention(nn.Module):
152
  """Multi-head self attention.
 
198
  dtype = query.dtype
199
  query = self.q_ln(query).to(dtype)
200
  key = self.k_ln(key).to(dtype)
201
+ (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
 
 
 
 
 
 
 
202
  return (self.out_proj(context), attn_weights, past_key_value)
203
 
204
  class MultiQueryAttention(nn.Module):
 
252
  dtype = query.dtype
253
  query = self.q_ln(query).to(dtype)
254
  key = self.k_ln(key).to(dtype)
255
+ (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
 
 
 
 
 
 
 
256
  return (self.out_proj(context), attn_weights, past_key_value)
257
 
258
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):