add int8 model inference

#1
Files changed (5) hide show
  1. README.md +2 -3
  2. attention.py +61 -37
  3. blocks.py +4 -4
  4. configuration_mpt.py +1 -1
  5. modeling_mpt.py +31 -14
README.md CHANGED
@@ -44,7 +44,7 @@ The following hyperparameters were used during training:
44
  ```shell
45
  import transformers
46
  model = transformers.AutoModelForCausalLM.from_pretrained(
47
- 'Intel/neural-chat-7b-v1.1',
48
  trust_remote_code=True
49
  )
50
  ```
@@ -54,8 +54,7 @@ Follow the instructions [link](https://github.com/intel/intel-extension-for-tran
54
 
55
  ```shell
56
  python run_generation.py \
57
- --model Intel/neural-chat-7b-v1.1 \
58
- --revision c8d4750ac8421303665d6ecc253950c69b56d324 \
59
  --quantize \
60
  --sq \
61
  --alpha 0.95 \
 
44
  ```shell
45
  import transformers
46
  model = transformers.AutoModelForCausalLM.from_pretrained(
47
+ 'Intel/neural-chat-7b-v1-1',
48
  trust_remote_code=True
49
  )
50
  ```
 
54
 
55
  ```shell
56
  python run_generation.py \
57
+ --model Intel/neural-chat-7b-v1-1 \
 
58
  --quantize \
59
  --sq \
60
  --alpha 0.95 \
attention.py CHANGED
@@ -5,6 +5,7 @@ from typing import Optional
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
 
8
  from torch import nn
9
  from .norm import LPLayerNorm
10
 
@@ -16,25 +17,34 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_cau
16
  return False
17
  return original_is_causal
18
 
19
- def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
20
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
21
- k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads)
22
- v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads)
23
- min_val = torch.finfo(q.dtype).min
 
 
 
 
 
24
  (b, _, s_q, d) = q.shape
25
  s_k = k.size(-1)
26
  if softmax_scale is None:
27
  softmax_scale = 1 / math.sqrt(d)
28
  attn_weight = q.matmul(k) * softmax_scale
29
  if attn_bias is not None:
 
 
 
30
  if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
31
  raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
32
  attn_weight = attn_weight + attn_bias
 
33
  if key_padding_mask is not None:
34
  if attn_bias is not None:
35
  warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
36
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
37
- if is_causal:
38
  s = max(s_q, s_k)
39
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
40
  causal_mask = causal_mask.tril()
@@ -48,8 +58,8 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
48
  out = attn_weight.matmul(v)
49
  out = rearrange(out, 'b h s d -> b s (h d)')
50
  if needs_weights:
51
- return (out, attn_weight)
52
- return (out, None)
53
 
54
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
55
  for tensor in tensors:
@@ -58,12 +68,21 @@ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
58
  if not tensor.is_cuda:
59
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
60
 
61
- def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
62
  try:
63
  from flash_attn import bert_padding, flash_attn_interface
64
  except:
65
  raise RuntimeError('Please install flash-attn==1.0.3.post0')
66
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
 
 
 
67
  if attn_bias is not None:
68
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
69
  (batch_size, seqlen) = query.shape[:2]
@@ -83,14 +102,31 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None
83
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
84
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
85
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
86
- return (output, None)
87
 
88
- def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
89
  try:
90
- from flash_attn import flash_attn_triton
91
  except:
92
- raise RuntimeError('Please install flash-attn==1.0.3.post0 and triton==2.0.0.dev20221202')
 
 
 
 
 
 
 
 
93
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
 
 
 
94
  if dropout_p:
95
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
96
  if needs_weights:
@@ -108,9 +144,9 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
108
  key = key.expand(*key.shape[:2], n_heads, key.size(-1))
109
  value = value.expand(*value.shape[:2], n_heads, value.size(-1))
110
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
111
- attn_output = flash_attn_triton.flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
112
  output = attn_output.view(*attn_output.shape[:2], -1)
113
- return (output, None)
114
 
115
  class MultiheadAttention(nn.Module):
116
  """Multi-head self attention.
@@ -119,7 +155,7 @@ class MultiheadAttention(nn.Module):
119
  additive bias.
120
  """
121
 
122
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
123
  super().__init__()
124
  self.attn_impl = attn_impl
125
  self.clip_qkv = clip_qkv
@@ -141,10 +177,11 @@ class MultiheadAttention(nn.Module):
141
  self.attn_fn = flash_attn_fn
142
  elif self.attn_impl == 'triton':
143
  self.attn_fn = triton_flash_attn_fn
144
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
 
145
  elif self.attn_impl == 'torch':
146
  self.attn_fn = scaled_multihead_dot_product_attention
147
- if torch.cuda.is_available():
148
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
149
  else:
150
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
@@ -161,14 +198,7 @@ class MultiheadAttention(nn.Module):
161
  dtype = query.dtype
162
  query = self.q_ln(query).to(dtype)
163
  key = self.k_ln(key).to(dtype)
164
- if past_key_value is not None:
165
- if len(past_key_value) != 0:
166
- key = torch.cat([past_key_value[0], key], dim=1)
167
- value = torch.cat([past_key_value[1], value], dim=1)
168
- past_key_value = (key, value)
169
- if attn_bias is not None:
170
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
171
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
172
  return (self.out_proj(context), attn_weights, past_key_value)
173
 
174
  class MultiQueryAttention(nn.Module):
@@ -178,7 +208,7 @@ class MultiQueryAttention(nn.Module):
178
  additive bias.
179
  """
180
 
181
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
182
  super().__init__()
183
  self.attn_impl = attn_impl
184
  self.clip_qkv = clip_qkv
@@ -201,10 +231,11 @@ class MultiQueryAttention(nn.Module):
201
  self.attn_fn = flash_attn_fn
202
  elif self.attn_impl == 'triton':
203
  self.attn_fn = triton_flash_attn_fn
204
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
 
205
  elif self.attn_impl == 'torch':
206
  self.attn_fn = scaled_multihead_dot_product_attention
207
- if torch.cuda.is_available():
208
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
209
  else:
210
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
@@ -221,14 +252,7 @@ class MultiQueryAttention(nn.Module):
221
  dtype = query.dtype
222
  query = self.q_ln(query).to(dtype)
223
  key = self.k_ln(key).to(dtype)
224
- if past_key_value is not None:
225
- if len(past_key_value) != 0:
226
- key = torch.cat([past_key_value[0], key], dim=1)
227
- value = torch.cat([past_key_value[1], value], dim=1)
228
- past_key_value = (key, value)
229
- if attn_bias is not None:
230
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
231
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
232
  return (self.out_proj(context), attn_weights, past_key_value)
233
 
234
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
@@ -273,4 +297,4 @@ def build_alibi_bias(n_heads, seq_len, full=False, alibi_bias_max=8, device=None
273
  slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
274
  alibi_bias = alibi_bias * slopes
275
  return alibi_bias.to(dtype=dtype)
276
- ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
 
5
  import torch
6
  import torch.nn as nn
7
  from einops import rearrange
8
+ from packaging import version
9
  from torch import nn
10
  from .norm import LPLayerNorm
11
 
 
17
  return False
18
  return original_is_causal
19
 
20
+ def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
21
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
22
+ kv_n_heads = 1 if multiquery else n_heads
23
+ k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
24
+ v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
25
+ if past_key_value is not None:
26
+ if len(past_key_value) != 0:
27
+ k = torch.cat([past_key_value[0], k], dim=3)
28
+ v = torch.cat([past_key_value[1], v], dim=2)
29
+ past_key_value = (k, v)
30
  (b, _, s_q, d) = q.shape
31
  s_k = k.size(-1)
32
  if softmax_scale is None:
33
  softmax_scale = 1 / math.sqrt(d)
34
  attn_weight = q.matmul(k) * softmax_scale
35
  if attn_bias is not None:
36
+ _s_q = max(0, attn_bias.size(2) - s_q)
37
+ _s_k = max(0, attn_bias.size(3) - s_k)
38
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
39
  if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
40
  raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
41
  attn_weight = attn_weight + attn_bias
42
+ min_val = torch.finfo(q.dtype).min
43
  if key_padding_mask is not None:
44
  if attn_bias is not None:
45
  warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
46
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
47
+ if is_causal and (not q.size(2) == 1):
48
  s = max(s_q, s_k)
49
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
50
  causal_mask = causal_mask.tril()
 
58
  out = attn_weight.matmul(v)
59
  out = rearrange(out, 'b h s d -> b s (h d)')
60
  if needs_weights:
61
+ return (out, attn_weight, past_key_value)
62
+ return (out, None, past_key_value)
63
 
64
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
65
  for tensor in tensors:
 
68
  if not tensor.is_cuda:
69
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
70
 
71
+ def flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
72
  try:
73
  from flash_attn import bert_padding, flash_attn_interface
74
  except:
75
  raise RuntimeError('Please install flash-attn==1.0.3.post0')
76
  check_valid_inputs(query, key, value)
77
+ if past_key_value is not None:
78
+ if len(past_key_value) != 0:
79
+ key = torch.cat([past_key_value[0], key], dim=1)
80
+ value = torch.cat([past_key_value[1], value], dim=1)
81
+ past_key_value = (key, value)
82
+ if attn_bias is not None:
83
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
84
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
85
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
86
  if attn_bias is not None:
87
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
88
  (batch_size, seqlen) = query.shape[:2]
 
102
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
103
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
104
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
105
+ return (output, None, past_key_value)
106
 
107
+ def triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
108
  try:
109
+ from .flash_attn_triton import flash_attn_func
110
  except:
111
+ _installed = False
112
+ if version.parse(torch.__version__) < version.parse('2.0.0'):
113
+ _installed = True
114
+ try:
115
+ from flash_attn.flash_attn_triton import flash_attn_func
116
+ except:
117
+ _installed = False
118
+ if not _installed:
119
+ raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
120
  check_valid_inputs(query, key, value)
121
+ if past_key_value is not None:
122
+ if len(past_key_value) != 0:
123
+ key = torch.cat([past_key_value[0], key], dim=1)
124
+ value = torch.cat([past_key_value[1], value], dim=1)
125
+ past_key_value = (key, value)
126
+ if attn_bias is not None:
127
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
128
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
129
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
130
  if dropout_p:
131
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
132
  if needs_weights:
 
144
  key = key.expand(*key.shape[:2], n_heads, key.size(-1))
145
  value = value.expand(*value.shape[:2], n_heads, value.size(-1))
146
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
147
+ attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
148
  output = attn_output.view(*attn_output.shape[:2], -1)
149
+ return (output, None, past_key_value)
150
 
151
  class MultiheadAttention(nn.Module):
152
  """Multi-head self attention.
 
155
  additive bias.
156
  """
157
 
158
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
159
  super().__init__()
160
  self.attn_impl = attn_impl
161
  self.clip_qkv = clip_qkv
 
177
  self.attn_fn = flash_attn_fn
178
  elif self.attn_impl == 'triton':
179
  self.attn_fn = triton_flash_attn_fn
180
+ if verbose:
181
+ warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
182
  elif self.attn_impl == 'torch':
183
  self.attn_fn = scaled_multihead_dot_product_attention
184
+ if torch.cuda.is_available() and verbose:
185
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
186
  else:
187
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
 
198
  dtype = query.dtype
199
  query = self.q_ln(query).to(dtype)
200
  key = self.k_ln(key).to(dtype)
201
+ (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
 
 
 
 
 
 
 
202
  return (self.out_proj(context), attn_weights, past_key_value)
203
 
204
  class MultiQueryAttention(nn.Module):
 
208
  additive bias.
209
  """
210
 
211
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
212
  super().__init__()
213
  self.attn_impl = attn_impl
214
  self.clip_qkv = clip_qkv
 
231
  self.attn_fn = flash_attn_fn
232
  elif self.attn_impl == 'triton':
233
  self.attn_fn = triton_flash_attn_fn
234
+ if verbose:
235
+ warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
236
  elif self.attn_impl == 'torch':
237
  self.attn_fn = scaled_multihead_dot_product_attention
238
+ if torch.cuda.is_available() and verbose:
239
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
240
  else:
241
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
 
252
  dtype = query.dtype
253
  query = self.q_ln(query).to(dtype)
254
  key = self.k_ln(key).to(dtype)
255
+ (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
 
 
 
 
 
 
 
256
  return (self.out_proj(context), attn_weights, past_key_value)
257
 
258
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
 
297
  slopes = gen_slopes(n_heads, alibi_bias_max, device=device)
298
  alibi_bias = alibi_bias * slopes
299
  return alibi_bias.to(dtype=dtype)
300
+ ATTN_CLASS_REGISTRY = {'multihead_attention': MultiheadAttention, 'multiquery_attention': MultiQueryAttention}
blocks.py CHANGED
@@ -19,13 +19,13 @@ class MPTMLP(nn.Module):
19
 
20
  class MPTBlock(nn.Module):
21
 
22
- def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
23
  del kwargs
24
  super().__init__()
25
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27
  self.norm_1 = norm_class(d_model, device=device)
28
- self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
29
  self.norm_2 = norm_class(d_model, device=device)
30
  self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
@@ -33,9 +33,9 @@ class MPTBlock(nn.Module):
33
 
34
  def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
  a = self.norm_1(x)
36
- (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37
  x = x + self.resid_attn_dropout(b)
38
  m = self.norm_2(x)
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
- return (x, past_key_value)
 
19
 
20
  class MPTBlock(nn.Module):
21
 
22
+ def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs):
23
  del kwargs
24
  super().__init__()
25
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27
  self.norm_1 = norm_class(d_model, device=device)
28
+ self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device)
29
  self.norm_2 = norm_class(d_model, device=device)
30
  self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
 
33
 
34
  def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
  a = self.norm_1(x)
36
+ (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37
  x = x + self.resid_attn_dropout(b)
38
  m = self.norm_2(x)
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
+ return (x, attn_weights, past_key_value)
configuration_mpt.py CHANGED
@@ -2,7 +2,7 @@
2
  from typing import Dict, Optional, Union
3
  from transformers import PretrainedConfig
4
  attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5
- init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'}
6
 
7
  class MPTConfig(PretrainedConfig):
8
  model_type = 'mpt'
 
2
  from typing import Dict, Optional, Union
3
  from transformers import PretrainedConfig
4
  attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5
+ init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
6
 
7
  class MPTConfig(PretrainedConfig):
8
  model_type = 'mpt'
modeling_mpt.py CHANGED
@@ -18,11 +18,16 @@ from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
18
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
19
  from .meta_init_context import init_empty_weights
20
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
 
 
 
 
21
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
22
 
23
  class MPTPreTrainedModel(PreTrainedModel):
24
  config_class = MPTConfig
25
  base_model_prefix = 'model'
 
26
 
27
  class MPTModel(MPTPreTrainedModel):
28
 
@@ -46,6 +51,7 @@ class MPTModel(MPTPreTrainedModel):
46
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
47
  self.norm_f = norm_class(config.d_model, device=config.init_device)
48
  if config.init_device != 'meta':
 
49
  self.apply(self.param_init_fn)
50
  self.is_causal = not self.prefix_lm
51
  self._attn_bias_initialized = False
@@ -95,7 +101,8 @@ class MPTModel(MPTPreTrainedModel):
95
  if attn_bias is None:
96
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
97
  else:
98
- attn_bias = attn_bias[:, :, :, -s_k:]
 
99
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
100
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
101
  min_val = torch.finfo(attn_bias.dtype).min
@@ -134,10 +141,11 @@ class MPTModel(MPTPreTrainedModel):
134
  attention_mask = attention_mask.bool()
135
  if prefix_mask is not None:
136
  prefix_mask = prefix_mask.bool()
137
- # if not return_dict:
138
- # raise NotImplementedError('return_dict False is not implemented yet for MPT')
139
  if output_attentions:
140
- raise NotImplementedError('output_attentions is not implemented yet for MPT')
 
141
  if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
142
  raise NotImplementedError('MPT does not support training with left padding.')
143
  if self.prefix_lm and prefix_mask is None:
@@ -158,6 +166,8 @@ class MPTModel(MPTPreTrainedModel):
158
  if len(past_key_values) != self.config.n_layers:
159
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
160
  past_position = past_key_values[0][0].size(1)
 
 
161
  if S + past_position > self.config.max_seq_len:
162
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
163
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
@@ -175,19 +185,26 @@ class MPTModel(MPTPreTrainedModel):
175
  if use_cache and past_key_values is None:
176
  past_key_values = [() for _ in range(self.config.n_layers)]
177
  all_hidden_states = () if output_hidden_states else None
 
178
  for (b_idx, block) in enumerate(self.blocks):
179
  if output_hidden_states:
180
  assert all_hidden_states is not None
181
  all_hidden_states = all_hidden_states + (x,)
182
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
183
- (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
184
  if past_key_values is not None:
185
  past_key_values[b_idx] = past_key_value
 
 
 
186
  x = self.norm_f(x)
187
- if not return_dict:
 
 
 
188
  output = (x,) + (tuple(past_key_values),)
189
  return output
190
- return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
191
 
192
  def param_init_fn(self, module):
193
  init_fn_name = self.config.init_config['name']
@@ -237,11 +254,12 @@ class MPTForCausalLM(MPTPreTrainedModel):
237
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
238
  return_dict = return_dict if return_dict is not None else self.config.return_dict
239
  use_cache = use_cache if use_cache is not None else self.config.use_cache
240
-
241
  past_key_values = list(past_key_values) if past_key_values is not None else None
242
-
243
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
244
- logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
 
 
 
245
  if self.logit_scale is not None:
246
  if self.logit_scale == 0:
247
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
@@ -251,11 +269,10 @@ class MPTForCausalLM(MPTPreTrainedModel):
251
  labels = torch.roll(labels, shifts=-1)
252
  labels[:, -1] = -100
253
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
254
-
255
- if not return_dict:
256
  output = (logits,) + (tuple(outputs[1]),)
257
  return (loss,) + output if loss is not None else output
258
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
259
 
260
  def param_init_fn(self, module):
261
  init_fn_name = self.config.init_config['name']
@@ -297,4 +314,4 @@ class MPTForCausalLM(MPTPreTrainedModel):
297
  reordered_past = []
298
  for layer_past in past_key_values:
299
  reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
300
- return reordered_past
 
18
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
19
  from .meta_init_context import init_empty_weights
20
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
21
+ try:
22
+ from .flash_attn_triton import flash_attn_func
23
+ except:
24
+ pass
25
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
26
 
27
  class MPTPreTrainedModel(PreTrainedModel):
28
  config_class = MPTConfig
29
  base_model_prefix = 'model'
30
+ _no_split_modules = ['MPTBlock']
31
 
32
  class MPTModel(MPTPreTrainedModel):
33
 
 
51
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
52
  self.norm_f = norm_class(config.d_model, device=config.init_device)
53
  if config.init_device != 'meta':
54
+ print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.')
55
  self.apply(self.param_init_fn)
56
  self.is_causal = not self.prefix_lm
57
  self._attn_bias_initialized = False
 
101
  if attn_bias is None:
102
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
103
  else:
104
+ _s_k = max(0, attn_bias.size(-1) - s_k)
105
+ attn_bias = attn_bias[:, :, :, _s_k:]
106
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
107
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
108
  min_val = torch.finfo(attn_bias.dtype).min
 
141
  attention_mask = attention_mask.bool()
142
  if prefix_mask is not None:
143
  prefix_mask = prefix_mask.bool()
144
+ if not return_dict:
145
+ raise NotImplementedError('return_dict False is not implemented yet for MPT')
146
  if output_attentions:
147
+ if self.attn_impl != 'torch':
148
+ raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
149
  if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
150
  raise NotImplementedError('MPT does not support training with left padding.')
151
  if self.prefix_lm and prefix_mask is None:
 
166
  if len(past_key_values) != self.config.n_layers:
167
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
168
  past_position = past_key_values[0][0].size(1)
169
+ if self.attn_impl == 'torch':
170
+ past_position = past_key_values[0][0].size(3)
171
  if S + past_position > self.config.max_seq_len:
172
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
173
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
 
185
  if use_cache and past_key_values is None:
186
  past_key_values = [() for _ in range(self.config.n_layers)]
187
  all_hidden_states = () if output_hidden_states else None
188
+ all_self_attns = () if output_attentions else None
189
  for (b_idx, block) in enumerate(self.blocks):
190
  if output_hidden_states:
191
  assert all_hidden_states is not None
192
  all_hidden_states = all_hidden_states + (x,)
193
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
194
+ (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
195
  if past_key_values is not None:
196
  past_key_values[b_idx] = past_key_value
197
+ if output_attentions:
198
+ assert all_self_attns is not None
199
+ all_self_attns = all_self_attns + (attn_weights,)
200
  x = self.norm_f(x)
201
+ if output_hidden_states:
202
+ assert all_hidden_states is not None
203
+ all_hidden_states = all_hidden_states + (x,)
204
+ if self.config.torchscript:
205
  output = (x,) + (tuple(past_key_values),)
206
  return output
207
+ return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns)
208
 
209
  def param_init_fn(self, module):
210
  init_fn_name = self.config.init_config['name']
 
254
  def forward(self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple[torch.FloatTensor]]]=None, attention_mask: Optional[torch.ByteTensor]=None, prefix_mask: Optional[torch.ByteTensor]=None, sequence_id: Optional[torch.LongTensor]=None, labels: Optional[torch.LongTensor]=None, return_dict: Optional[bool]=None, output_attentions: Optional[bool]=None, output_hidden_states: Optional[bool]=None, use_cache: Optional[bool]=None):
255
  return_dict = return_dict if return_dict is not None else self.config.return_dict
256
  use_cache = use_cache if use_cache is not None else self.config.use_cache
 
257
  past_key_values = list(past_key_values) if past_key_values is not None else None
 
258
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
259
+ if self.config.torchscript:
260
+ logits = F.linear(outputs[0].to(self.transformer.wte.weight.device), self.transformer.wte.weight)
261
+ else:
262
+ logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
263
  if self.logit_scale is not None:
264
  if self.logit_scale == 0:
265
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
 
269
  labels = torch.roll(labels, shifts=-1)
270
  labels[:, -1] = -100
271
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
272
+ if self.config.torchscript:
 
273
  output = (logits,) + (tuple(outputs[1]),)
274
  return (loss,) + output if loss is not None else output
275
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
276
 
277
  def param_init_fn(self, module):
278
  init_fn_name = self.config.init_config['name']
 
314
  reordered_past = []
315
  for layer_past in past_key_values:
316
  reordered_past += [tuple((past_state.index_select(0, beam_idx) for past_state in layer_past))]
317
+ return reordered_past