emozilla commited on
Commit
27d2cea
1 Parent(s): b9eb078

add Birchlabs MPT changes

Browse files
Files changed (4) hide show
  1. attention.py +356 -115
  2. blocks.py +16 -12
  3. is_torch_version.py +56 -0
  4. modeling_mpt.py +154 -90
attention.py CHANGED
@@ -1,131 +1,234 @@
1
-
2
- 'Attention layers.'
3
  import math
4
  import warnings
5
- from typing import Optional
6
  import torch
7
  import torch.nn as nn
8
  from einops import rearrange
9
  from packaging import version
10
  from torch import nn
 
11
  from .norm import LPLayerNorm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
14
- if (original_is_causal and (num_query_tokens != num_key_tokens)):
15
- if (num_query_tokens != 1):
16
  raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.')
17
  else:
18
  return False
19
  return original_is_causal
20
 
21
- def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
23
- k = rearrange(key, 'b s (h d) -> b h d s', h=(1 if multiquery else n_heads))
24
- v = rearrange(value, 'b s (h d) -> b h s d', h=(1 if multiquery else n_heads))
25
  min_val = torch.finfo(q.dtype).min
26
  (b, _, s_q, d) = q.shape
27
- s_k = k.size((- 1))
28
- if (softmax_scale is None):
29
- softmax_scale = (1 / math.sqrt(d))
30
- attn_weight = (q.matmul(k) * softmax_scale)
31
- if (attn_bias is not None):
32
- if (((attn_bias.size((- 1)) != 1) and (attn_bias.size((- 1)) != s_k)) or ((attn_bias.size((- 2)) != 1) and (attn_bias.size((- 2)) != s_q))):
33
  raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
34
- attn_weight = (attn_weight + attn_bias)
35
- if (key_padding_mask is not None):
36
- if (attn_bias is not None):
37
- warnings.warn((((('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ') + 'unneccessary computation/memory usage. Consider integrating ') + 'into attn_bias once and passing that to each attention ') + 'module instead.'))
38
- attn_weight = attn_weight.masked_fill((~ key_padding_mask.view((b, 1, 1, s_k))), min_val)
39
  if is_causal:
40
  s = max(s_q, s_k)
41
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
42
  causal_mask = causal_mask.tril()
43
  causal_mask = causal_mask.to(torch.bool)
44
- causal_mask = (~ causal_mask)
45
- causal_mask = causal_mask[(- s_q):, (- s_k):]
46
  attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
47
- attn_weight = torch.softmax(attn_weight, dim=(- 1))
48
  if dropout_p:
49
  attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
50
  out = attn_weight.matmul(v)
51
  out = rearrange(out, 'b h s d -> b s (h d)')
52
  if needs_weights:
53
- return (out, attn_weight)
54
- return (out, None)
55
 
56
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
57
  for tensor in tensors:
58
- if (tensor.dtype not in valid_dtypes):
59
  raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')
60
- if (not tensor.is_cuda):
61
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
62
 
63
- def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  try:
65
  from flash_attn import bert_padding, flash_attn_interface
66
  except:
67
  raise RuntimeError('Please install flash-attn==1.0.3.post0')
68
  check_valid_inputs(query, key, value)
69
- if (attn_bias is not None):
70
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
71
  (batch_size, seqlen) = query.shape[:2]
72
- if (key_padding_mask is None):
73
  key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
74
- query_padding_mask = key_padding_mask[:, (- query.size(1)):]
75
  (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
76
  query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
77
  (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
78
- key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=(1 if multiquery else n_heads))
79
  (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
80
- value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=(1 if multiquery else n_heads))
81
  if multiquery:
82
- key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size((- 1)))
83
- value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size((- 1)))
84
- dropout_p = (dropout_p if training else 0.0)
85
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
86
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
87
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
88
- return (output, None)
89
 
90
- def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  try:
92
  from .flash_attn_triton import flash_attn_func
93
  except:
94
  _installed = False
95
- if (version.parse(torch.__version__) < version.parse('2.0.0')):
96
  _installed = True
97
  try:
98
  from flash_attn.flash_attn_triton import flash_attn_func
99
  except:
100
  _installed = False
101
- if (not _installed):
102
  raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
103
  check_valid_inputs(query, key, value)
104
  if dropout_p:
105
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
106
  if needs_weights:
107
  raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')
108
- if (key_padding_mask is not None):
109
- warnings.warn((((('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ') + 'unnecessary computation/memory usage. Consider integrating ') + 'into attn_bias once and passing that to each attention ') + 'module instead.'))
110
  (b_size, s_k) = key_padding_mask.shape[:2]
111
- if (attn_bias is None):
112
  attn_bias = query.new_zeros(b_size, 1, 1, s_k)
113
- attn_bias = attn_bias.masked_fill((~ key_padding_mask.view((b_size, 1, 1, s_k))), torch.finfo(query.dtype).min)
114
  query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
115
- key = rearrange(key, 'b s (h d) -> b s h d', h=(1 if multiquery else n_heads))
116
- value = rearrange(value, 'b s (h d) -> b s h d', h=(1 if multiquery else n_heads))
117
  if multiquery:
118
- key = key.expand(*key.shape[:2], n_heads, key.size((- 1)))
119
- value = value.expand(*value.shape[:2], n_heads, value.size((- 1)))
120
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
121
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
122
- output = attn_output.view(*attn_output.shape[:2], (- 1))
123
- return (output, None)
 
 
 
124
 
125
- class MultiheadAttention(nn.Module):
126
- 'Multi-head self attention.\n\n Using torch or triton attention implemetation enables user to also use\n additive bias.\n '
 
 
 
127
 
128
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
129
  super().__init__()
130
  self.attn_impl = attn_impl
131
  self.clip_qkv = clip_qkv
@@ -133,148 +236,286 @@ class MultiheadAttention(nn.Module):
133
  self.d_model = d_model
134
  self.n_heads = n_heads
135
  self.softmax_scale = softmax_scale
136
- if (self.softmax_scale is None):
137
- self.softmax_scale = (1 / math.sqrt((self.d_model / self.n_heads)))
138
  self.attn_dropout_p = attn_pdrop
139
- self.Wqkv = nn.Linear(self.d_model, (3 * self.d_model), device=device)
140
- fuse_splits = (d_model, (2 * d_model))
141
  self.Wqkv._fused = (0, fuse_splits)
142
  if self.qk_ln:
143
- layernorm_class = (LPLayerNorm if low_precision_layernorm else nn.LayerNorm)
144
  self.q_ln = layernorm_class(self.d_model, device=device)
145
  self.k_ln = layernorm_class(self.d_model, device=device)
146
- if (self.attn_impl == 'flash'):
147
  self.attn_fn = flash_attn_fn
148
- elif (self.attn_impl == 'triton'):
149
  self.attn_fn = triton_flash_attn_fn
150
- if verbose:
151
- warnings.warn(((('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ') + 'alloc retries which hurts performance. If encountered, we recommend ') + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.'))
152
- elif (self.attn_impl == 'torch'):
153
  self.attn_fn = scaled_multihead_dot_product_attention
154
- if (torch.cuda.is_available() and verbose):
155
- warnings.warn((('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ') + 'we recommend using `attn_impl: triton`.'))
156
  else:
157
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
158
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
159
  self.out_proj._is_residual = True
160
 
161
- def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
 
 
 
 
 
 
 
 
162
  qkv = self.Wqkv(x)
163
  if self.clip_qkv:
164
- qkv.clamp_(min=(- self.clip_qkv), max=self.clip_qkv)
165
  (query, key, value) = qkv.chunk(3, dim=2)
166
  key_padding_mask = attention_mask
167
  if self.qk_ln:
168
  dtype = query.dtype
169
  query = self.q_ln(query).to(dtype)
170
  key = self.k_ln(key).to(dtype)
171
- if (past_key_value is not None):
172
- if (len(past_key_value) != 0):
173
  key = torch.cat([past_key_value[0], key], dim=1)
174
  value = torch.cat([past_key_value[1], value], dim=1)
175
- past_key_value = (key, value)
176
- if (attn_bias is not None):
177
- attn_bias = attn_bias[:, :, (- query.size(1)):, (- key.size(1)):]
178
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
179
- return (self.out_proj(context), attn_weights, past_key_value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- class MultiQueryAttention(nn.Module):
182
- 'Multi-Query self attention.\n\n Using torch or triton attention implemetation enables user to also use\n additive bias.\n '
183
 
184
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
 
 
 
 
185
  super().__init__()
186
  self.attn_impl = attn_impl
187
  self.clip_qkv = clip_qkv
188
  self.qk_ln = qk_ln
189
  self.d_model = d_model
190
  self.n_heads = n_heads
191
- self.head_dim = (d_model // n_heads)
192
  self.softmax_scale = softmax_scale
193
- if (self.softmax_scale is None):
194
- self.softmax_scale = (1 / math.sqrt(self.head_dim))
195
  self.attn_dropout_p = attn_pdrop
196
- self.Wqkv = nn.Linear(d_model, (d_model + (2 * self.head_dim)), device=device)
197
- fuse_splits = (d_model, (d_model + self.head_dim))
198
  self.Wqkv._fused = (0, fuse_splits)
199
  if self.qk_ln:
200
- layernorm_class = (LPLayerNorm if low_precision_layernorm else nn.LayerNorm)
201
  self.q_ln = layernorm_class(d_model, device=device)
202
  self.k_ln = layernorm_class(self.head_dim, device=device)
203
- if (self.attn_impl == 'flash'):
204
  self.attn_fn = flash_attn_fn
205
- elif (self.attn_impl == 'triton'):
206
  self.attn_fn = triton_flash_attn_fn
207
- if verbose:
208
- warnings.warn(((('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ') + 'alloc retries which hurts performance. If encountered, we recommend ') + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.'))
209
- elif (self.attn_impl == 'torch'):
210
  self.attn_fn = scaled_multihead_dot_product_attention
211
- if (torch.cuda.is_available() and verbose):
212
- warnings.warn((('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ') + 'we recommend using `attn_impl: triton`.'))
213
  else:
214
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
215
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
216
  self.out_proj._is_residual = True
217
 
218
- def forward(self, x, past_key_value=None, attn_bias=None, attention_mask=None, is_causal=True, needs_weights=False):
 
 
 
 
 
 
 
 
219
  qkv = self.Wqkv(x)
220
  if self.clip_qkv:
221
- qkv.clamp_(min=(- self.clip_qkv), max=self.clip_qkv)
222
  (query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)
223
  key_padding_mask = attention_mask
224
  if self.qk_ln:
225
  dtype = query.dtype
226
  query = self.q_ln(query).to(dtype)
227
  key = self.k_ln(key).to(dtype)
228
- if (past_key_value is not None):
229
- if (len(past_key_value) != 0):
230
  key = torch.cat([past_key_value[0], key], dim=1)
231
  value = torch.cat([past_key_value[1], value], dim=1)
232
- past_key_value = (key, value)
233
- if (attn_bias is not None):
234
- attn_bias = attn_bias[:, :, (- query.size(1)):, (- key.size(1)):]
235
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
236
- return (self.out_proj(context), attn_weights, past_key_value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
239
- if (attn_impl == 'flash'):
240
  return None
241
- elif (attn_impl in ['torch', 'triton']):
242
  if alibi:
243
- if ((prefix_lm or (not causal)) or use_sequence_id):
244
  return (1, n_heads, seq_len, seq_len)
245
  return (1, n_heads, 1, seq_len)
246
- elif (prefix_lm or use_sequence_id):
247
  return (1, 1, seq_len, seq_len)
248
  return None
249
  else:
250
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
251
 
252
  def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
253
- if (attn_impl == 'flash'):
254
  return None
255
- elif (attn_impl in ['torch', 'triton']):
256
  if alibi:
257
  (device, dtype) = (attn_bias.device, attn_bias.dtype)
258
- attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=(not causal), alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
259
  return attn_bias
260
  else:
261
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
262
 
263
  def gen_slopes(n_heads, alibi_bias_max=8, device=None):
264
- _n_heads = (2 ** math.ceil(math.log2(n_heads)))
265
- m = torch.arange(1, (_n_heads + 1), dtype=torch.float32, device=device)
266
- m = m.mul((alibi_bias_max / _n_heads))
267
- slopes = (1.0 / torch.pow(2, m))
268
- if (_n_heads != n_heads):
269
  slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
270
  return slopes.view(1, n_heads, 1, 1)
271
 
272
  def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None):
273
- alibi_bias = torch.arange((1 - seq_len), 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
274
  if full:
275
- alibi_bias = (alibi_bias - torch.arange((1 - seq_len), 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1))
276
- alibi_bias = alibi_bias.abs().mul((- 1))
277
  slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
278
- alibi_bias = (alibi_bias * slopes)
279
  return alibi_bias.to(dtype=dtype)
280
- ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
 
1
+ """Attention layers."""
 
2
  import math
3
  import warnings
4
+ from typing import Optional, Dict, Any, NamedTuple, Protocol, Tuple, Union
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
8
  from packaging import version
9
  from torch import nn
10
+ from torch.utils.checkpoint import checkpoint
11
  from .norm import LPLayerNorm
12
+ from .is_torch_version import is_torch_version
13
+
14
+ class PastKeyValue(NamedTuple):
15
+ key: torch.Tensor
16
+ value: torch.Tensor
17
+
18
+ class AttnFnOutput(NamedTuple):
19
+ attns: torch.Tensor
20
+ attn_probs: Optional[torch.Tensor]
21
+
22
+ class AttnFn(Protocol):
23
+ def __call__(
24
+ self,
25
+ query: torch.Tensor,
26
+ key: torch.Tensor,
27
+ value: torch.Tensor,
28
+ n_heads: int,
29
+ softmax_scale: Optional[float] = None,
30
+ attn_bias: Optional[torch.Tensor] = None,
31
+ key_padding_mask: Optional[torch.ByteTensor] = None,
32
+ is_causal = False,
33
+ dropout_p = 0.0,
34
+ training = False,
35
+ needs_weights = False,
36
+ multiquery = False,
37
+ ) -> AttnFnOutput: ...
38
+
39
+ class AttnFnCheckpointed(Protocol):
40
+ def __call__(
41
+ self,
42
+ query: torch.Tensor,
43
+ key: torch.Tensor,
44
+ value: torch.Tensor,
45
+ n_heads: int,
46
+ softmax_scale: Optional[float],
47
+ attn_bias: Optional[torch.Tensor],
48
+ key_padding_mask: Optional[torch.ByteTensor],
49
+ is_causal: bool,
50
+ dropout_p: float,
51
+ training: bool,
52
+ needs_weights: bool,
53
+ ) -> AttnFnOutput: ...
54
+
55
+ class AttnOutput(NamedTuple):
56
+ projected_context: torch.Tensor
57
+ attn_weights: Optional[torch.Tensor]
58
+ past_key_value: Union[PastKeyValue, Tuple, None]
59
+
60
+ class Attn(Protocol):
61
+ def __call__(
62
+ self,
63
+ x: torch.Tensor,
64
+ past_key_value: Union[PastKeyValue, Tuple, None] = None,
65
+ attn_bias: Optional[torch.Tensor] = None,
66
+ attention_mask: Optional[torch.ByteTensor] = None,
67
+ is_causal = True,
68
+ needs_weights = False,
69
+ ) -> AttnOutput: ...
70
 
71
  def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_causal: bool):
72
+ if original_is_causal and num_query_tokens != num_key_tokens:
73
+ if num_query_tokens != 1:
74
  raise NotImplementedError('MPT does not support query and key with different number of tokens, unless number of query tokens is 1.')
75
  else:
76
  return False
77
  return original_is_causal
78
 
79
+ def scaled_multihead_dot_product_attention(
80
+ query: torch.Tensor,
81
+ key: torch.Tensor,
82
+ value: torch.Tensor,
83
+ n_heads: int,
84
+ softmax_scale: Optional[float] = None,
85
+ attn_bias: Optional[torch.Tensor] = None,
86
+ key_padding_mask: Optional[torch.ByteTensor] = None,
87
+ is_causal = False,
88
+ dropout_p = 0.0,
89
+ training = False,
90
+ needs_weights = False,
91
+ multiquery = False,
92
+ ) -> AttnFnOutput:
93
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
94
+ k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads)
95
+ v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads)
96
  min_val = torch.finfo(q.dtype).min
97
  (b, _, s_q, d) = q.shape
98
+ s_k = k.size(-1)
99
+ if softmax_scale is None:
100
+ softmax_scale = 1 / math.sqrt(d)
101
+ attn_weight = q.matmul(k) * softmax_scale
102
+ if attn_bias is not None:
103
+ if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
104
  raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
105
+ attn_weight = attn_weight + attn_bias
106
+ if key_padding_mask is not None:
107
+ if attn_bias is not None:
108
+ warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
109
+ attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
110
  if is_causal:
111
  s = max(s_q, s_k)
112
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
113
  causal_mask = causal_mask.tril()
114
  causal_mask = causal_mask.to(torch.bool)
115
+ causal_mask = ~causal_mask
116
+ causal_mask = causal_mask[-s_q:, -s_k:]
117
  attn_weight = attn_weight.masked_fill(causal_mask.view(1, 1, s_q, s_k), min_val)
118
+ attn_weight = torch.softmax(attn_weight, dim=-1)
119
  if dropout_p:
120
  attn_weight = torch.nn.functional.dropout(attn_weight, p=dropout_p, training=training, inplace=True)
121
  out = attn_weight.matmul(v)
122
  out = rearrange(out, 'b h s d -> b s (h d)')
123
  if needs_weights:
124
+ return AttnFnOutput(out, attn_weight)
125
+ return AttnFnOutput(out, None)
126
 
127
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
128
  for tensor in tensors:
129
+ if tensor.dtype not in valid_dtypes:
130
  raise TypeError(f'tensor.dtype={tensor.dtype!r} must be in valid_dtypes={valid_dtypes!r}.')
131
+ if not tensor.is_cuda:
132
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
133
 
134
+ def flash_attn_fn(
135
+ query: torch.Tensor,
136
+ key: torch.Tensor,
137
+ value: torch.Tensor,
138
+ n_heads: int,
139
+ softmax_scale: Optional[float] = None,
140
+ attn_bias: Optional[torch.Tensor] = None,
141
+ key_padding_mask: Optional[torch.ByteTensor] = None,
142
+ is_causal = False,
143
+ dropout_p = 0.0,
144
+ training = False,
145
+ needs_weights = False,
146
+ multiquery = False,
147
+ ) -> AttnFnOutput:
148
  try:
149
  from flash_attn import bert_padding, flash_attn_interface
150
  except:
151
  raise RuntimeError('Please install flash-attn==1.0.3.post0')
152
  check_valid_inputs(query, key, value)
153
+ if attn_bias is not None:
154
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
155
  (batch_size, seqlen) = query.shape[:2]
156
+ if key_padding_mask is None:
157
  key_padding_mask = torch.ones_like(key[:, :, 0], dtype=torch.bool)
158
+ query_padding_mask = key_padding_mask[:, -query.size(1):]
159
  (query_unpad, indices_q, cu_seqlens_q, max_seqlen_q) = bert_padding.unpad_input(query, query_padding_mask)
160
  query_unpad = rearrange(query_unpad, 'nnz (h d) -> nnz h d', h=n_heads)
161
  (key_unpad, _, cu_seqlens_k, max_seqlen_k) = bert_padding.unpad_input(key, key_padding_mask)
162
+ key_unpad = rearrange(key_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
163
  (value_unpad, _, _, _) = bert_padding.unpad_input(value, key_padding_mask)
164
+ value_unpad = rearrange(value_unpad, 'nnz (h d) -> nnz h d', h=1 if multiquery else n_heads)
165
  if multiquery:
166
+ key_unpad = key_unpad.expand(key_unpad.size(0), n_heads, key_unpad.size(-1))
167
+ value_unpad = value_unpad.expand(value_unpad.size(0), n_heads, value_unpad.size(-1))
168
+ dropout_p = dropout_p if training else 0.0
169
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
170
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
171
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
172
+ return AttnFnOutput(output, None)
173
 
174
+ def triton_flash_attn_fn(
175
+ query: torch.Tensor,
176
+ key: torch.Tensor,
177
+ value: torch.Tensor,
178
+ n_heads: int,
179
+ softmax_scale: Optional[float] = None,
180
+ attn_bias: Optional[torch.Tensor] = None,
181
+ key_padding_mask: Optional[torch.ByteTensor] = None,
182
+ is_causal = False,
183
+ dropout_p = 0.0,
184
+ training = False,
185
+ needs_weights = False,
186
+ multiquery = False,
187
+ ) -> AttnFnOutput:
188
  try:
189
  from .flash_attn_triton import flash_attn_func
190
  except:
191
  _installed = False
192
+ if version.parse(torch.__version__) < version.parse('2.0.0'):
193
  _installed = True
194
  try:
195
  from flash_attn.flash_attn_triton import flash_attn_func
196
  except:
197
  _installed = False
198
+ if not _installed:
199
  raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
200
  check_valid_inputs(query, key, value)
201
  if dropout_p:
202
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
203
  if needs_weights:
204
  raise NotImplementedError(f'attn_impl: triton cannot return attn weights.')
205
+ if key_padding_mask is not None:
206
+ warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unnecessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
207
  (b_size, s_k) = key_padding_mask.shape[:2]
208
+ if attn_bias is None:
209
  attn_bias = query.new_zeros(b_size, 1, 1, s_k)
210
+ attn_bias = attn_bias.masked_fill(~key_padding_mask.view((b_size, 1, 1, s_k)), torch.finfo(query.dtype).min)
211
  query = rearrange(query, 'b s (h d) -> b s h d', h=n_heads)
212
+ key = rearrange(key, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
213
+ value = rearrange(value, 'b s (h d) -> b s h d', h=1 if multiquery else n_heads)
214
  if multiquery:
215
+ key = key.expand(*key.shape[:2], n_heads, key.size(-1))
216
+ value = value.expand(*value.shape[:2], n_heads, value.size(-1))
217
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
218
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
219
+ output = attn_output.view(*attn_output.shape[:2], -1)
220
+ return AttnFnOutput(output, None)
221
+
222
+ class MultiheadAttention(nn.Module, Attn):
223
+ """Multi-head self attention.
224
 
225
+ Using torch or triton attention implemetation enables user to also use
226
+ additive bias.
227
+ """
228
+ gradient_checkpointing = False
229
+ attn_fn: AttnFn
230
 
231
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
232
  super().__init__()
233
  self.attn_impl = attn_impl
234
  self.clip_qkv = clip_qkv
 
236
  self.d_model = d_model
237
  self.n_heads = n_heads
238
  self.softmax_scale = softmax_scale
239
+ if self.softmax_scale is None:
240
+ self.softmax_scale = 1 / math.sqrt(self.d_model / self.n_heads)
241
  self.attn_dropout_p = attn_pdrop
242
+ self.Wqkv = nn.Linear(self.d_model, 3 * self.d_model, device=device)
243
+ fuse_splits = (d_model, 2 * d_model)
244
  self.Wqkv._fused = (0, fuse_splits)
245
  if self.qk_ln:
246
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
247
  self.q_ln = layernorm_class(self.d_model, device=device)
248
  self.k_ln = layernorm_class(self.d_model, device=device)
249
+ if self.attn_impl == 'flash':
250
  self.attn_fn = flash_attn_fn
251
+ elif self.attn_impl == 'triton':
252
  self.attn_fn = triton_flash_attn_fn
253
+ warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
254
+ elif self.attn_impl == 'torch':
 
255
  self.attn_fn = scaled_multihead_dot_product_attention
256
+ if torch.cuda.is_available():
257
+ warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
258
  else:
259
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
260
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
261
  self.out_proj._is_residual = True
262
 
263
+ def forward(
264
+ self,
265
+ x: torch.Tensor,
266
+ past_key_value: Union[PastKeyValue, Tuple, None] = None,
267
+ attn_bias: Optional[torch.Tensor] = None,
268
+ attention_mask: Optional[torch.ByteTensor] = None,
269
+ is_causal = True,
270
+ needs_weights = False,
271
+ ) -> AttnOutput:
272
  qkv = self.Wqkv(x)
273
  if self.clip_qkv:
274
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
275
  (query, key, value) = qkv.chunk(3, dim=2)
276
  key_padding_mask = attention_mask
277
  if self.qk_ln:
278
  dtype = query.dtype
279
  query = self.q_ln(query).to(dtype)
280
  key = self.k_ln(key).to(dtype)
281
+ if past_key_value is not None:
282
+ if len(past_key_value) != 0:
283
  key = torch.cat([past_key_value[0], key], dim=1)
284
  value = torch.cat([past_key_value[1], value], dim=1)
285
+ past_key_value = PastKeyValue(key, value)
286
+ if attn_bias is not None:
287
+ attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
288
+ if self.training and self.gradient_checkpointing:
289
+ ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version('>=', '1.11.0') else {}
290
+ def create_custom_forward(attn_fn: AttnFn) -> AttnFnCheckpointed:
291
+ def custom_forward(
292
+ query: torch.Tensor,
293
+ key: torch.Tensor,
294
+ value: torch.Tensor,
295
+ n_heads: int,
296
+ softmax_scale: Optional[float],
297
+ attn_bias: Optional[torch.Tensor],
298
+ key_padding_mask: Optional[torch.ByteTensor],
299
+ is_causal: bool,
300
+ dropout_p: float,
301
+ training: bool,
302
+ needs_weights: bool,
303
+ ):
304
+ return attn_fn(
305
+ query,
306
+ key,
307
+ value,
308
+ n_heads,
309
+ softmax_scale,
310
+ attn_bias,
311
+ key_padding_mask,
312
+ is_causal,
313
+ dropout_p,
314
+ training,
315
+ needs_weights,
316
+ False, # multiquery
317
+ )
318
+ return custom_forward
319
+ attn_fn_out: AttnFnOutput = checkpoint(
320
+ create_custom_forward(self.attn_fn),
321
+ query,
322
+ key,
323
+ value,
324
+ self.n_heads,
325
+ self.softmax_scale,
326
+ attn_bias,
327
+ key_padding_mask,
328
+ is_causal,
329
+ self.attn_dropout_p,
330
+ self.training,
331
+ needs_weights,
332
+ **ckpt_kwargs,
333
+ )
334
+ else:
335
+ attn_fn_out: AttnFnOutput = self.attn_fn(
336
+ query,
337
+ key,
338
+ value,
339
+ self.n_heads,
340
+ softmax_scale=self.softmax_scale,
341
+ attn_bias=attn_bias,
342
+ key_padding_mask=key_padding_mask,
343
+ is_causal=is_causal,
344
+ dropout_p=self.attn_dropout_p,
345
+ training=self.training,
346
+ needs_weights=needs_weights,
347
+ )
348
+ context, attn_weights = attn_fn_out
349
+ return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
350
 
351
+ class MultiQueryAttention(nn.Module, Attn):
352
+ """Multi-Query self attention.
353
 
354
+ Using torch or triton attention implemetation enables user to also use
355
+ additive bias.
356
+ """
357
+
358
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
359
  super().__init__()
360
  self.attn_impl = attn_impl
361
  self.clip_qkv = clip_qkv
362
  self.qk_ln = qk_ln
363
  self.d_model = d_model
364
  self.n_heads = n_heads
365
+ self.head_dim = d_model // n_heads
366
  self.softmax_scale = softmax_scale
367
+ if self.softmax_scale is None:
368
+ self.softmax_scale = 1 / math.sqrt(self.head_dim)
369
  self.attn_dropout_p = attn_pdrop
370
+ self.Wqkv = nn.Linear(d_model, d_model + 2 * self.head_dim, device=device)
371
+ fuse_splits = (d_model, d_model + self.head_dim)
372
  self.Wqkv._fused = (0, fuse_splits)
373
  if self.qk_ln:
374
+ layernorm_class = LPLayerNorm if low_precision_layernorm else nn.LayerNorm
375
  self.q_ln = layernorm_class(d_model, device=device)
376
  self.k_ln = layernorm_class(self.head_dim, device=device)
377
+ if self.attn_impl == 'flash':
378
  self.attn_fn = flash_attn_fn
379
+ elif self.attn_impl == 'triton':
380
  self.attn_fn = triton_flash_attn_fn
381
+ warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
382
+ elif self.attn_impl == 'torch':
 
383
  self.attn_fn = scaled_multihead_dot_product_attention
384
+ if torch.cuda.is_available():
385
+ warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
386
  else:
387
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
388
  self.out_proj = nn.Linear(self.d_model, self.d_model, device=device)
389
  self.out_proj._is_residual = True
390
 
391
+ def forward(
392
+ self,
393
+ x: torch.Tensor,
394
+ past_key_value: Union[PastKeyValue, Tuple, None] = None,
395
+ attn_bias: Optional[torch.Tensor] = None,
396
+ attention_mask: Optional[torch.ByteTensor] = None,
397
+ is_causal = True,
398
+ needs_weights = False,
399
+ ) -> AttnOutput:
400
  qkv = self.Wqkv(x)
401
  if self.clip_qkv:
402
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
403
  (query, key, value) = qkv.split([self.d_model, self.head_dim, self.head_dim], dim=2)
404
  key_padding_mask = attention_mask
405
  if self.qk_ln:
406
  dtype = query.dtype
407
  query = self.q_ln(query).to(dtype)
408
  key = self.k_ln(key).to(dtype)
409
+ if past_key_value is not None:
410
+ if len(past_key_value) != 0:
411
  key = torch.cat([past_key_value[0], key], dim=1)
412
  value = torch.cat([past_key_value[1], value], dim=1)
413
+ past_key_value = PastKeyValue(key, value)
414
+ if attn_bias is not None:
415
+ attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
416
+ if self.training and self.gradient_checkpointing:
417
+ ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version('>=', '1.11.0') else {}
418
+ def create_custom_forward(attn_fn: AttnFn) -> AttnFnCheckpointed:
419
+ def custom_forward(
420
+ query: torch.Tensor,
421
+ key: torch.Tensor,
422
+ value: torch.Tensor,
423
+ n_heads: int,
424
+ softmax_scale: Optional[float],
425
+ attn_bias: Optional[torch.Tensor],
426
+ key_padding_mask: Optional[torch.ByteTensor],
427
+ is_causal: bool,
428
+ dropout_p: float,
429
+ training: bool,
430
+ needs_weights: bool,
431
+ ):
432
+ return attn_fn(
433
+ query,
434
+ key,
435
+ value,
436
+ n_heads,
437
+ softmax_scale,
438
+ attn_bias,
439
+ key_padding_mask,
440
+ is_causal,
441
+ dropout_p,
442
+ training,
443
+ needs_weights,
444
+ True, # multiquery
445
+ )
446
+ return custom_forward
447
+ attn_fn_out: AttnFnOutput = checkpoint(
448
+ create_custom_forward(self.attn_fn),
449
+ query,
450
+ key,
451
+ value,
452
+ self.n_heads,
453
+ self.softmax_scale,
454
+ attn_bias,
455
+ key_padding_mask,
456
+ is_causal,
457
+ self.attn_dropout_p,
458
+ self.training,
459
+ needs_weights,
460
+ **ckpt_kwargs,
461
+ )
462
+ else:
463
+ attn_fn_out: AttnFnOutput = self.attn_fn(
464
+ query,
465
+ key,
466
+ value,
467
+ self.n_heads,
468
+ softmax_scale=self.softmax_scale,
469
+ attn_bias=attn_bias,
470
+ key_padding_mask=key_padding_mask,
471
+ is_causal=is_causal,
472
+ dropout_p=self.attn_dropout_p,
473
+ training=self.training,
474
+ needs_weights=needs_weights,
475
+ )
476
+ context, attn_weights = attn_fn_out
477
+ return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
478
 
479
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
480
+ if attn_impl == 'flash':
481
  return None
482
+ elif attn_impl in ['torch', 'triton']:
483
  if alibi:
484
+ if (prefix_lm or not causal) or use_sequence_id:
485
  return (1, n_heads, seq_len, seq_len)
486
  return (1, n_heads, 1, seq_len)
487
+ elif prefix_lm or use_sequence_id:
488
  return (1, 1, seq_len, seq_len)
489
  return None
490
  else:
491
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
492
 
493
  def build_attn_bias(attn_impl, attn_bias, n_heads, seq_len, causal=False, alibi=False, alibi_bias_max=8):
494
+ if attn_impl == 'flash':
495
  return None
496
+ elif attn_impl in ['torch', 'triton']:
497
  if alibi:
498
  (device, dtype) = (attn_bias.device, attn_bias.dtype)
499
+ attn_bias = attn_bias.add(build_alibi_bias(n_heads, seq_len, full=not causal, alibi_bias_max=alibi_bias_max, device=device, dtype=dtype))
500
  return attn_bias
501
  else:
502
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
503
 
504
  def gen_slopes(n_heads, alibi_bias_max=8, device=None):
505
+ _n_heads = 2 ** math.ceil(math.log2(n_heads))
506
+ m = torch.arange(1, _n_heads + 1, dtype=torch.float32, device=device)
507
+ m = m.mul(alibi_bias_max / _n_heads)
508
+ slopes = 1.0 / torch.pow(2, m)
509
+ if _n_heads != n_heads:
510
  slopes = torch.concat([slopes[1::2], slopes[::2]])[:n_heads]
511
  return slopes.view(1, n_heads, 1, 1)
512
 
513
  def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None, dtype=None):
514
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, 1, seq_len)
515
  if full:
516
+ alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.int32, device=device).view(1, 1, seq_len, 1)
517
+ alibi_bias = alibi_bias.abs().mul(-1)
518
  slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
519
+ alibi_bias = alibi_bias * slopes
520
  return alibi_bias.to(dtype=dtype)
521
+ ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
blocks.py CHANGED
@@ -1,42 +1,46 @@
1
-
2
- 'GPT Blocks used for the GPT Model.'
3
- from typing import Dict, Optional, Tuple
4
  import torch
5
  import torch.nn as nn
6
- from .attention import ATTN_CLASS_REGISTRY
7
  from .norm import NORM_CLASS_REGISTRY
8
 
 
 
 
 
9
  class MPTMLP(nn.Module):
10
 
11
  def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
12
  super().__init__()
13
- self.up_proj = nn.Linear(d_model, (expansion_ratio * d_model), device=device)
14
  self.act = nn.GELU(approximate='none')
15
- self.down_proj = nn.Linear((expansion_ratio * d_model), d_model, device=device)
16
  self.down_proj._is_residual = True
17
 
18
  def forward(self, x):
19
  return self.down_proj(self.act(self.up_proj(x)))
20
 
21
  class MPTBlock(nn.Module):
 
22
 
23
- def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs):
24
  del kwargs
25
  super().__init__()
26
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
27
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
28
  self.norm_1 = norm_class(d_model, device=device)
29
- self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device)
30
  self.norm_2 = norm_class(d_model, device=device)
31
  self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
32
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
33
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
34
 
35
- def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[(torch.Tensor, Optional[Tuple[torch.Tensor]])]:
36
  a = self.norm_1(x)
37
  (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
38
- x = (x + self.resid_attn_dropout(b))
39
  m = self.norm_2(x)
40
  n = self.ffn(m)
41
- x = (x + self.resid_ffn_dropout(n))
42
- return (x, past_key_value)
 
1
+ """GPT Blocks used for the GPT Model."""
2
+ from typing import Dict, Optional, Tuple, NamedTuple, Union
 
3
  import torch
4
  import torch.nn as nn
5
+ from .attention import ATTN_CLASS_REGISTRY, Attn, PastKeyValue
6
  from .norm import NORM_CLASS_REGISTRY
7
 
8
+ class MPTBlockOutput(NamedTuple):
9
+ hidden_states: torch.Tensor
10
+ past_key_value: Union[PastKeyValue, Tuple, None]
11
+
12
  class MPTMLP(nn.Module):
13
 
14
  def __init__(self, d_model: int, expansion_ratio: int, device: Optional[str]=None):
15
  super().__init__()
16
+ self.up_proj = nn.Linear(d_model, expansion_ratio * d_model, device=device)
17
  self.act = nn.GELU(approximate='none')
18
+ self.down_proj = nn.Linear(expansion_ratio * d_model, d_model, device=device)
19
  self.down_proj._is_residual = True
20
 
21
  def forward(self, x):
22
  return self.down_proj(self.act(self.up_proj(x)))
23
 
24
  class MPTBlock(nn.Module):
25
+ attn: Attn
26
 
27
+ def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
28
  del kwargs
29
  super().__init__()
30
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
31
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
32
  self.norm_1 = norm_class(d_model, device=device)
33
+ self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
34
  self.norm_2 = norm_class(d_model, device=device)
35
  self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
36
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
37
  self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
38
 
39
+ def forward(self, x: torch.Tensor, past_key_value: Union[PastKeyValue, Tuple, None] = None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> MPTBlockOutput:
40
  a = self.norm_1(x)
41
  (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
42
+ x = x + self.resid_attn_dropout(b)
43
  m = self.norm_2(x)
44
  n = self.ffn(m)
45
+ x = x + self.resid_ffn_dropout(n)
46
+ return MPTBlockOutput(x, past_key_value)
is_torch_version.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import logging
3
+ import operator as op
4
+ from packaging import version
5
+ from packaging.version import Version, parse
6
+ from typing import Union
7
+ import importlib.util
8
+
9
+ # The package importlib_metadata is in a different place, depending on the python version.
10
+ if sys.version_info < (3, 8):
11
+ import importlib_metadata
12
+ else:
13
+ import importlib.metadata as importlib_metadata
14
+
15
+ STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ _torch_available = importlib.util.find_spec("torch") is not None
20
+ if _torch_available:
21
+ try:
22
+ _torch_version = importlib_metadata.version("torch")
23
+ logger.info(f"PyTorch version {_torch_version} available.")
24
+ except importlib_metadata.PackageNotFoundError:
25
+ _torch_available = False
26
+
27
+ # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
28
+ def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
29
+ """
30
+ Args:
31
+ Compares a library version to some requirement using a given operation.
32
+ library_or_version (`str` or `packaging.version.Version`):
33
+ A library name or a version to check.
34
+ operation (`str`):
35
+ A string representation of an operator, such as `">"` or `"<="`.
36
+ requirement_version (`str`):
37
+ The version to compare the library version against
38
+ """
39
+ if operation not in STR_OPERATION_TO_FUNC.keys():
40
+ raise ValueError(f"`operation` must be one of {list(STR_OPERATION_TO_FUNC.keys())}, received {operation}")
41
+ operation = STR_OPERATION_TO_FUNC[operation]
42
+ if isinstance(library_or_version, str):
43
+ library_or_version = parse(importlib_metadata.version(library_or_version))
44
+ return operation(library_or_version, parse(requirement_version))
45
+
46
+ # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
47
+ def is_torch_version(operation: str, version: str):
48
+ """
49
+ Args:
50
+ Compares the current PyTorch version to a given reference with an operation.
51
+ operation (`str`):
52
+ A string representation of an operator, such as `">"` or `"<="`
53
+ version (`str`):
54
+ A string version of PyTorch
55
+ """
56
+ return compare_versions(parse(_torch_version), operation, version)
modeling_mpt.py CHANGED
@@ -1,30 +1,48 @@
 
1
 
2
- 'A simple, flexible implementation of a GPT model.\n\nInspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py\n'
 
3
  import math
4
  import warnings
5
- from typing import List, Optional, Tuple, Union
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
 
9
  from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
10
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
11
- from .attention import attn_bias_shape, build_attn_bias
12
- from .blocks import MPTBlock
 
13
  from .norm import NORM_CLASS_REGISTRY
14
  from .configuration_mpt import MPTConfig
15
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
16
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
17
  from .meta_init_context import init_empty_weights
18
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
19
- try:
20
- from .flash_attn_triton import flash_attn_func
21
- except:
22
- pass
23
- Tokenizer = Union[(PreTrainedTokenizer, PreTrainedTokenizerFast)]
 
 
 
 
 
 
 
 
 
24
 
25
  class MPTPreTrainedModel(PreTrainedModel):
26
  config_class = MPTConfig
27
  base_model_prefix = 'model'
 
 
 
 
 
28
 
29
  class MPTModel(MPTPreTrainedModel):
30
 
@@ -36,37 +54,37 @@ class MPTModel(MPTPreTrainedModel):
36
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
37
  self.alibi = config.attn_config['alibi']
38
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
39
- if (config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys()):
40
  norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
41
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
42
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
43
  self.embedding_fraction = config.embedding_fraction
44
  self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device)
45
- if (not self.alibi):
46
  self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
47
  self.emb_drop = nn.Dropout(config.emb_pdrop)
48
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
49
  self.norm_f = norm_class(config.d_model, device=config.init_device)
50
- if (config.init_device != 'meta'):
51
- print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.')
52
  self.apply(self.param_init_fn)
53
- self.is_causal = (not self.prefix_lm)
54
  self._attn_bias_initialized = False
55
  self.attn_bias = None
56
  self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id)
57
  if config.no_bias:
58
  for module in self.modules():
59
- if (hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter)):
60
  if config.verbose:
61
  warnings.warn(f'Removing bias ({module.bias}) from {module}.')
62
  module.register_parameter('bias', None)
63
- if (config.verbose and (config.verbose > 2)):
64
  print(self)
65
- if ('verbose' not in self.config.init_config):
66
  self.config.init_config['verbose'] = self.config.verbose
67
- if (self.config.init_config['verbose'] > 1):
68
  init_fn_name = self.config.init_config['name']
69
  warnings.warn(f'Using {init_fn_name} initialization.')
 
70
 
71
  def get_input_embeddings(self):
72
  return self.wte
@@ -76,115 +94,157 @@ class MPTModel(MPTPreTrainedModel):
76
 
77
  @torch.no_grad()
78
  def _attn_bias(self, device, dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None):
79
- if (not self._attn_bias_initialized):
80
  if self.attn_bias_shape:
81
  self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
82
  self.attn_bias = build_attn_bias(self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max)
83
  self._attn_bias_initialized = True
84
- if (self.attn_impl == 'flash'):
85
  return (self.attn_bias, attention_mask)
86
- if (self.attn_bias is not None):
87
  self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
88
  attn_bias = self.attn_bias
89
  if self.prefix_lm:
90
  assert isinstance(attn_bias, torch.Tensor)
91
  assert isinstance(prefix_mask, torch.Tensor)
92
  attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
93
- if (self.attn_uses_sequence_id and (sequence_id is not None)):
94
  assert isinstance(attn_bias, torch.Tensor)
95
  attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
96
- if (attention_mask is not None):
97
- s_k = attention_mask.shape[(- 1)]
98
- if (attn_bias is None):
99
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
100
  else:
101
- attn_bias = attn_bias[:, :, :, (- s_k):]
102
- if ((prefix_mask is not None) and (attention_mask.shape != prefix_mask.shape)):
103
- raise ValueError((f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.'))
104
  min_val = torch.finfo(attn_bias.dtype).min
105
- attn_bias = attn_bias.masked_fill((~ attention_mask.view((- 1), 1, 1, s_k)), min_val)
106
  return (attn_bias, None)
107
 
108
  def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
109
- (s_k, s_q) = attn_bias.shape[(- 2):]
110
- if ((s_k != self.config.max_seq_len) or (s_q != self.config.max_seq_len)):
111
- raise ValueError((('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ') + f'but are {s_k} and {s_q}.'))
112
- seq_len = prefix_mask.shape[(- 1)]
113
- if (seq_len > self.config.max_seq_len):
114
  raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
115
  attn_bias = attn_bias[..., :seq_len, :seq_len]
116
  causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
117
- prefix = prefix_mask.view((- 1), 1, 1, seq_len)
118
- cannot_attend = (~ torch.logical_or(causal, prefix.bool()))
119
  min_val = torch.finfo(attn_bias.dtype).min
120
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
121
  return attn_bias
122
 
123
  def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):
124
- seq_len = sequence_id.shape[(- 1)]
125
- if (seq_len > self.config.max_seq_len):
126
  raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
127
  attn_bias = attn_bias[..., :seq_len, :seq_len]
128
- cannot_attend = torch.logical_not(torch.eq(sequence_id.view((- 1), seq_len, 1), sequence_id.view((- 1), 1, seq_len))).unsqueeze(1)
129
  min_val = torch.finfo(attn_bias.dtype).min
130
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
131
  return attn_bias
132
 
133
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
134
- return_dict = (return_dict if (return_dict is not None) else self.config.return_dict)
135
- use_cache = (use_cache if (use_cache is not None) else self.config.use_cache)
136
- if (attention_mask is not None):
 
 
 
 
 
 
137
  attention_mask = attention_mask.bool()
138
- if (prefix_mask is not None):
139
  prefix_mask = prefix_mask.bool()
140
- if (not return_dict):
141
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
142
  if output_attentions:
143
  raise NotImplementedError('output_attentions is not implemented yet for MPT')
144
- if ((attention_mask is not None) and (attention_mask[:, 0].sum() != attention_mask.shape[0]) and self.training):
145
  raise NotImplementedError('MPT does not support training with left padding.')
146
- if (self.prefix_lm and (prefix_mask is None)):
147
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
148
  if self.training:
149
- if (self.attn_uses_sequence_id and (sequence_id is None)):
150
- raise ValueError(('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.'))
151
- elif ((self.attn_uses_sequence_id is False) and (sequence_id is not None)):
152
- warnings.warn(('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.'))
153
  S = input_ids.size(1)
154
- assert (S <= self.config.max_seq_len), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
155
  tok_emb = self.wte(input_ids)
156
  if self.alibi:
157
  x = tok_emb
158
  else:
159
  past_position = 0
160
- if (past_key_values is not None):
161
- if (len(past_key_values) != self.config.n_layers):
162
- raise ValueError((f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).'))
163
  past_position = past_key_values[0][0].size(1)
164
- if ((S + past_position) > self.config.max_seq_len):
165
- raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {(S + 1)}, this model only supports total sequence length <= {self.config.max_seq_len}.')
166
- pos = torch.arange(past_position, (S + past_position), dtype=torch.long, device=input_ids.device).unsqueeze(0)
167
- if (attention_mask is not None):
168
- pos = torch.clamp((pos - torch.cumsum((~ attention_mask).to(torch.int32), dim=1)[:, past_position:]), min=0)
169
  pos_emb = self.wpe(pos)
170
- x = (tok_emb + pos_emb)
171
- if (self.embedding_fraction == 1):
172
  x = self.emb_drop(x)
173
  else:
174
- x_shrunk = ((x * self.embedding_fraction) + (x.detach() * (1 - self.embedding_fraction)))
175
  assert isinstance(self.emb_drop, nn.Module)
176
  x = self.emb_drop(x_shrunk)
177
  (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
178
- if (use_cache and (past_key_values is None)):
179
  past_key_values = [() for _ in range(self.config.n_layers)]
180
- all_hidden_states = (() if output_hidden_states else None)
181
  for (b_idx, block) in enumerate(self.blocks):
182
  if output_hidden_states:
183
- assert (all_hidden_states is not None)
184
- all_hidden_states = (all_hidden_states + (x,))
185
- past_key_value = (past_key_values[b_idx] if (past_key_values is not None) else None)
186
- (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
187
- if (past_key_values is not None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  past_key_values[b_idx] = past_key_value
189
  x = self.norm_f(x)
190
  return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
@@ -203,15 +263,15 @@ class MPTForCausalLM(MPTPreTrainedModel):
203
 
204
  def __init__(self, config: MPTConfig):
205
  super().__init__(config)
206
- if (not config.tie_word_embeddings):
207
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
208
  self.transformer = MPTModel(config)
209
  self.logit_scale = None
210
- if (config.logit_scale is not None):
211
  logit_scale = config.logit_scale
212
  if isinstance(logit_scale, str):
213
- if (logit_scale == 'inv_sqrt_d_model'):
214
- logit_scale = (1 / math.sqrt(config.d_model))
215
  else:
216
  raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
217
  self.logit_scale = logit_scale
@@ -234,20 +294,20 @@ class MPTForCausalLM(MPTPreTrainedModel):
234
  def get_decoder(self):
235
  return self.transformer
236
 
237
- def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
238
- return_dict = (return_dict if (return_dict is not None) else self.config.return_dict)
239
- use_cache = (use_cache if (use_cache is not None) else self.config.use_cache)
240
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
241
  logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
242
- if (self.logit_scale is not None):
243
- if (self.logit_scale == 0):
244
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
245
  logits *= self.logit_scale
246
  loss = None
247
- if (labels is not None):
248
- labels = torch.roll(labels, shifts=(- 1))
249
- labels[:, (- 1)] = (- 100)
250
- loss = F.cross_entropy(logits.view((- 1), logits.size((- 1))), labels.to(logits.device).view((- 1)))
251
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
252
 
253
  def param_init_fn(self, module):
@@ -261,20 +321,20 @@ class MPTForCausalLM(MPTPreTrainedModel):
261
  return isinstance(module, MPTBlock)
262
 
263
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
264
- if (inputs_embeds is not None):
265
  raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
266
  attention_mask = kwargs['attention_mask'].bool()
267
- if (attention_mask[:, (- 1)].sum() != attention_mask.shape[0]):
268
  raise NotImplementedError('MPT does not support generation with right padding.')
269
- if (self.transformer.attn_uses_sequence_id and self.training):
270
  sequence_id = torch.zeros_like(input_ids[:1])
271
  else:
272
  sequence_id = None
273
- if (past_key_values is not None):
274
- input_ids = input_ids[:, (- 1)].unsqueeze((- 1))
275
  if self.transformer.prefix_lm:
276
  prefix_mask = torch.ones_like(attention_mask)
277
- if (kwargs.get('use_cache') == False):
278
  raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
279
  else:
280
  prefix_mask = None
@@ -282,8 +342,12 @@ class MPTForCausalLM(MPTPreTrainedModel):
282
 
283
  @staticmethod
284
  def _reorder_cache(past_key_values, beam_idx):
285
- 'Used by HuggingFace generate when using beam search with kv-caching.\n\n See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133\n for an example in transformers.\n '
 
 
 
 
286
  reordered_past = []
287
  for layer_past in past_key_values:
288
  reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
289
- return reordered_past
 
1
+ """A simple, flexible implementation of a GPT model.
2
 
3
+ Inspired by https://github.com/karpathy/minGPT/blob/master/mingpt/model.py
4
+ """
5
  import math
6
  import warnings
7
+ from typing import Any, List, Optional, Tuple, Union, Protocol, Dict
8
  import torch
9
  import torch.nn as nn
10
  import torch.nn.functional as F
11
+ from torch.utils.checkpoint import checkpoint
12
  from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
13
  from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
14
+ from transformers.utils import logging
15
+ from .attention import attn_bias_shape, build_attn_bias, PastKeyValue, MultiheadAttention, MultiQueryAttention
16
+ from .blocks import MPTBlock, MPTBlockOutput
17
  from .norm import NORM_CLASS_REGISTRY
18
  from .configuration_mpt import MPTConfig
19
  from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
20
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
21
  from .meta_init_context import init_empty_weights
22
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
23
+ from .is_torch_version import is_torch_version
24
+
25
+ Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ class MPTBlockCheckpointedForward(Protocol):
30
+ def __call__(
31
+ x: torch.Tensor,
32
+ past_key_value: Union[PastKeyValue, Tuple, None],
33
+ attn_bias: Optional[torch.Tensor],
34
+ attention_mask: Optional[torch.ByteTensor],
35
+ is_causal: bool,
36
+ ) -> MPTBlockOutput: ...
37
 
38
  class MPTPreTrainedModel(PreTrainedModel):
39
  config_class = MPTConfig
40
  base_model_prefix = 'model'
41
+ _no_split_modules = ['MPTBlock']
42
+ supports_gradient_checkpointing = True
43
+ def _set_gradient_checkpointing(self, module: nn.Module, value=False) -> None:
44
+ if isinstance(module, MPTModel) or isinstance(module, MultiheadAttention) or isinstance(module, MultiQueryAttention):
45
+ module.gradient_checkpointing = value
46
 
47
  class MPTModel(MPTPreTrainedModel):
48
 
 
54
  self.attn_uses_sequence_id = config.attn_config['attn_uses_sequence_id']
55
  self.alibi = config.attn_config['alibi']
56
  self.alibi_bias_max = config.attn_config['alibi_bias_max']
57
+ if config.norm_type.lower() not in NORM_CLASS_REGISTRY.keys():
58
  norm_options = ' | '.join(NORM_CLASS_REGISTRY.keys())
59
  raise NotImplementedError(f'Requested norm type ({config.norm_type}) is not implemented within this repo (Options: {norm_options}).')
60
  norm_class = NORM_CLASS_REGISTRY[config.norm_type.lower()]
61
  self.embedding_fraction = config.embedding_fraction
62
  self.wte = nn.Embedding(config.vocab_size, config.d_model, device=config.init_device)
63
+ if not self.alibi:
64
  self.wpe = nn.Embedding(config.max_seq_len, config.d_model, device=config.init_device)
65
  self.emb_drop = nn.Dropout(config.emb_pdrop)
66
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
67
  self.norm_f = norm_class(config.d_model, device=config.init_device)
68
+ if config.init_device != 'meta':
 
69
  self.apply(self.param_init_fn)
70
+ self.is_causal = not self.prefix_lm
71
  self._attn_bias_initialized = False
72
  self.attn_bias = None
73
  self.attn_bias_shape = attn_bias_shape(self.attn_impl, config.n_heads, config.max_seq_len, self.alibi, prefix_lm=self.prefix_lm, causal=self.is_causal, use_sequence_id=self.attn_uses_sequence_id)
74
  if config.no_bias:
75
  for module in self.modules():
76
+ if hasattr(module, 'bias') and isinstance(module.bias, nn.Parameter):
77
  if config.verbose:
78
  warnings.warn(f'Removing bias ({module.bias}) from {module}.')
79
  module.register_parameter('bias', None)
80
+ if config.verbose and config.verbose > 2:
81
  print(self)
82
+ if 'verbose' not in self.config.init_config:
83
  self.config.init_config['verbose'] = self.config.verbose
84
+ if self.config.init_config['verbose'] > 1:
85
  init_fn_name = self.config.init_config['name']
86
  warnings.warn(f'Using {init_fn_name} initialization.')
87
+ self.gradient_checkpointing = False
88
 
89
  def get_input_embeddings(self):
90
  return self.wte
 
94
 
95
  @torch.no_grad()
96
  def _attn_bias(self, device, dtype, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None):
97
+ if not self._attn_bias_initialized:
98
  if self.attn_bias_shape:
99
  self.attn_bias = torch.zeros(self.attn_bias_shape, device=device, dtype=dtype)
100
  self.attn_bias = build_attn_bias(self.attn_impl, self.attn_bias, self.config.n_heads, self.config.max_seq_len, causal=self.is_causal, alibi=self.alibi, alibi_bias_max=self.alibi_bias_max)
101
  self._attn_bias_initialized = True
102
+ if self.attn_impl == 'flash':
103
  return (self.attn_bias, attention_mask)
104
+ if self.attn_bias is not None:
105
  self.attn_bias = self.attn_bias.to(dtype=dtype, device=device)
106
  attn_bias = self.attn_bias
107
  if self.prefix_lm:
108
  assert isinstance(attn_bias, torch.Tensor)
109
  assert isinstance(prefix_mask, torch.Tensor)
110
  attn_bias = self._apply_prefix_mask(attn_bias, prefix_mask)
111
+ if self.attn_uses_sequence_id and sequence_id is not None:
112
  assert isinstance(attn_bias, torch.Tensor)
113
  attn_bias = self._apply_sequence_id(attn_bias, sequence_id)
114
+ if attention_mask is not None:
115
+ s_k = attention_mask.shape[-1]
116
+ if attn_bias is None:
117
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
118
  else:
119
+ attn_bias = attn_bias[:, :, :, -s_k:]
120
+ if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
121
+ raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
122
  min_val = torch.finfo(attn_bias.dtype).min
123
+ attn_bias = attn_bias.masked_fill(~attention_mask.view(-1, 1, 1, s_k), min_val)
124
  return (attn_bias, None)
125
 
126
  def _apply_prefix_mask(self, attn_bias: torch.Tensor, prefix_mask: torch.Tensor):
127
+ (s_k, s_q) = attn_bias.shape[-2:]
128
+ if s_k != self.config.max_seq_len or s_q != self.config.max_seq_len:
129
+ raise ValueError('attn_bias does not match the expected shape. ' + f'The last two dimensions should both be {self.config.max_length} ' + f'but are {s_k} and {s_q}.')
130
+ seq_len = prefix_mask.shape[-1]
131
+ if seq_len > self.config.max_seq_len:
132
  raise ValueError(f'prefix_mask sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
133
  attn_bias = attn_bias[..., :seq_len, :seq_len]
134
  causal = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool, device=prefix_mask.device)).view(1, 1, seq_len, seq_len)
135
+ prefix = prefix_mask.view(-1, 1, 1, seq_len)
136
+ cannot_attend = ~torch.logical_or(causal, prefix.bool())
137
  min_val = torch.finfo(attn_bias.dtype).min
138
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
139
  return attn_bias
140
 
141
  def _apply_sequence_id(self, attn_bias: torch.Tensor, sequence_id: torch.LongTensor):
142
+ seq_len = sequence_id.shape[-1]
143
+ if seq_len > self.config.max_seq_len:
144
  raise ValueError(f'sequence_id sequence length cannot exceed max_seq_len={self.config.max_seq_len}')
145
  attn_bias = attn_bias[..., :seq_len, :seq_len]
146
+ cannot_attend = torch.logical_not(torch.eq(sequence_id.view(-1, seq_len, 1), sequence_id.view(-1, 1, seq_len))).unsqueeze(1)
147
  min_val = torch.finfo(attn_bias.dtype).min
148
  attn_bias = attn_bias.masked_fill(cannot_attend, min_val)
149
  return attn_bias
150
 
151
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
152
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
153
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
154
+ if self.gradient_checkpointing and self.training:
155
+ if use_cache:
156
+ logger.warning_once(
157
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
158
+ )
159
+ use_cache = False
160
+ if attention_mask is not None:
161
  attention_mask = attention_mask.bool()
162
+ if prefix_mask is not None:
163
  prefix_mask = prefix_mask.bool()
164
+ if not return_dict:
165
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
166
  if output_attentions:
167
  raise NotImplementedError('output_attentions is not implemented yet for MPT')
168
+ if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
169
  raise NotImplementedError('MPT does not support training with left padding.')
170
+ if self.prefix_lm and prefix_mask is None:
171
  raise ValueError('prefix_mask is a required argument when MPT is configured with prefix_lm=True.')
172
  if self.training:
173
+ if self.attn_uses_sequence_id and sequence_id is None:
174
+ raise ValueError('sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.')
175
+ elif self.attn_uses_sequence_id is False and sequence_id is not None:
176
+ warnings.warn('MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + 'This input will be ignored. If you want the model to use `sequence_id`, set attn_uses_sequence_id to True.')
177
  S = input_ids.size(1)
178
+ assert S <= self.config.max_seq_len, f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
179
  tok_emb = self.wte(input_ids)
180
  if self.alibi:
181
  x = tok_emb
182
  else:
183
  past_position = 0
184
+ if past_key_values is not None:
185
+ if len(past_key_values) != self.config.n_layers:
186
+ raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
187
  past_position = past_key_values[0][0].size(1)
188
+ if S + past_position > self.config.max_seq_len:
189
+ raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
190
+ pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
191
+ if attention_mask is not None:
192
+ pos = torch.clamp(pos - torch.cumsum((~attention_mask).to(torch.int32), dim=1)[:, past_position:], min=0)
193
  pos_emb = self.wpe(pos)
194
+ x = tok_emb + pos_emb
195
+ if self.embedding_fraction == 1:
196
  x = self.emb_drop(x)
197
  else:
198
+ x_shrunk = x * self.embedding_fraction + x.detach() * (1 - self.embedding_fraction)
199
  assert isinstance(self.emb_drop, nn.Module)
200
  x = self.emb_drop(x_shrunk)
201
  (attn_bias, attention_mask) = self._attn_bias(device=x.device, dtype=x.dtype, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id)
202
+ if use_cache and past_key_values is None:
203
  past_key_values = [() for _ in range(self.config.n_layers)]
204
+ all_hidden_states = () if output_hidden_states else None
205
  for (b_idx, block) in enumerate(self.blocks):
206
  if output_hidden_states:
207
+ assert all_hidden_states is not None
208
+ all_hidden_states = all_hidden_states + (x,)
209
+ past_key_value = past_key_values[b_idx] if past_key_values is not None else None
210
+ if self.gradient_checkpointing and self.training:
211
+ ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version('>=', '1.11.0') else {}
212
+ def create_custom_forward(module: MPTBlock) -> MPTBlockCheckpointedForward:
213
+ def custom_forward(
214
+ x: torch.Tensor,
215
+ past_key_value: Union[PastKeyValue, Tuple, None],
216
+ attn_bias: Optional[torch.Tensor],
217
+ attention_mask: Optional[torch.ByteTensor],
218
+ is_causal: bool
219
+ ):
220
+ return module.forward(
221
+ x,
222
+ past_key_value,
223
+ attn_bias,
224
+ attention_mask,
225
+ is_causal,
226
+ )
227
+ return custom_forward
228
+ block_out: MPTBlockOutput = checkpoint(
229
+ create_custom_forward(block),
230
+ x,
231
+ past_key_value,
232
+ attn_bias,
233
+ attention_mask,
234
+ self.is_causal,
235
+ **ckpt_kwargs,
236
+ )
237
+ else:
238
+ block_out: MPTBlockOutput = block(
239
+ x,
240
+ past_key_value=past_key_value,
241
+ attn_bias=attn_bias,
242
+ attention_mask=attention_mask,
243
+ is_causal=self.is_causal,
244
+ )
245
+ x, past_key_value = block_out
246
+ del block_out
247
+ if past_key_values is not None:
248
  past_key_values[b_idx] = past_key_value
249
  x = self.norm_f(x)
250
  return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
 
263
 
264
  def __init__(self, config: MPTConfig):
265
  super().__init__(config)
266
+ if not config.tie_word_embeddings:
267
  raise ValueError('MPTForCausalLM only supports tied word embeddings')
268
  self.transformer = MPTModel(config)
269
  self.logit_scale = None
270
+ if config.logit_scale is not None:
271
  logit_scale = config.logit_scale
272
  if isinstance(logit_scale, str):
273
+ if logit_scale == 'inv_sqrt_d_model':
274
+ logit_scale = 1 / math.sqrt(config.d_model)
275
  else:
276
  raise ValueError(f"logit_scale={logit_scale!r} is not recognized as an option; use numeric value or 'inv_sqrt_d_model'.")
277
  self.logit_scale = logit_scale
 
294
  def get_decoder(self):
295
  return self.transformer
296
 
297
+ def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None, *args, **kwargs):
298
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
299
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
300
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
301
  logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
302
+ if self.logit_scale is not None:
303
+ if self.logit_scale == 0:
304
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
305
  logits *= self.logit_scale
306
  loss = None
307
+ if labels is not None:
308
+ labels = torch.roll(labels, shifts=-1)
309
+ labels[:, -1] = -100
310
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
311
  return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
312
 
313
  def param_init_fn(self, module):
 
321
  return isinstance(module, MPTBlock)
322
 
323
  def prepare_inputs_for_generation(self, input_ids, past_key_values=None, inputs_embeds=None, **kwargs):
324
+ if inputs_embeds is not None:
325
  raise NotImplementedError('inputs_embeds is not implemented for MPT yet')
326
  attention_mask = kwargs['attention_mask'].bool()
327
+ if attention_mask[:, -1].sum() != attention_mask.shape[0]:
328
  raise NotImplementedError('MPT does not support generation with right padding.')
329
+ if self.transformer.attn_uses_sequence_id and self.training:
330
  sequence_id = torch.zeros_like(input_ids[:1])
331
  else:
332
  sequence_id = None
333
+ if past_key_values is not None:
334
+ input_ids = input_ids[:, -1].unsqueeze(-1)
335
  if self.transformer.prefix_lm:
336
  prefix_mask = torch.ones_like(attention_mask)
337
+ if kwargs.get('use_cache') == False:
338
  raise NotImplementedError('MPT with prefix_lm=True does not support use_cache=False.')
339
  else:
340
  prefix_mask = None
 
342
 
343
  @staticmethod
344
  def _reorder_cache(past_key_values, beam_idx):
345
+ """Used by HuggingFace generate when using beam search with kv-caching.
346
+
347
+ See https://github.com/huggingface/transformers/blob/3ec7a47664ebe40c40f4b722f6bb1cd30c3821ec/src/transformers/models/gpt2/modeling_gpt2.py#L1122-L1133
348
+ for an example in transformers.
349
+ """
350
  reordered_past = []
351
  for layer_past in past_key_values:
352
  reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
353
+ return reordered_past