emozilla commited on
Commit
bbaa409
1 Parent(s): d4a2524

add kv cache

Browse files
Files changed (5) hide show
  1. attention.py +57 -16
  2. blocks.py +3 -2
  3. config.json +2 -1
  4. generation_config.json +1 -1
  5. modeling_mpt.py +23 -5
attention.py CHANGED
@@ -18,6 +18,7 @@ class PastKeyValue(NamedTuple):
18
  class AttnFnOutput(NamedTuple):
19
  attns: torch.Tensor
20
  attn_probs: Optional[torch.Tensor]
 
21
 
22
  class AttnFn(Protocol):
23
  def __call__(
@@ -81,6 +82,7 @@ def scaled_multihead_dot_product_attention(
81
  key: torch.Tensor,
82
  value: torch.Tensor,
83
  n_heads: int,
 
84
  softmax_scale: Optional[float] = None,
85
  attn_bias: Optional[torch.Tensor] = None,
86
  key_padding_mask: Optional[torch.ByteTensor] = None,
@@ -91,23 +93,41 @@ def scaled_multihead_dot_product_attention(
91
  multiquery = False,
92
  ) -> AttnFnOutput:
93
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
94
- k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads)
95
- v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads)
96
- min_val = torch.finfo(q.dtype).min
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  (b, _, s_q, d) = q.shape
98
  s_k = k.size(-1)
99
  if softmax_scale is None:
100
  softmax_scale = 1 / math.sqrt(d)
101
  attn_weight = q.matmul(k) * softmax_scale
102
  if attn_bias is not None:
 
 
 
 
103
  if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
104
  raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
105
  attn_weight = attn_weight + attn_bias
 
106
  if key_padding_mask is not None:
107
  if attn_bias is not None:
108
  warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
109
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
110
- if is_causal:
111
  s = max(s_q, s_k)
112
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
113
  causal_mask = causal_mask.tril()
@@ -121,8 +141,8 @@ def scaled_multihead_dot_product_attention(
121
  out = attn_weight.matmul(v)
122
  out = rearrange(out, 'b h s d -> b s (h d)')
123
  if needs_weights:
124
- return AttnFnOutput(out, attn_weight)
125
- return AttnFnOutput(out, None)
126
 
127
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
128
  for tensor in tensors:
@@ -136,6 +156,7 @@ def flash_attn_fn(
136
  key: torch.Tensor,
137
  value: torch.Tensor,
138
  n_heads: int,
 
139
  softmax_scale: Optional[float] = None,
140
  attn_bias: Optional[torch.Tensor] = None,
141
  key_padding_mask: Optional[torch.ByteTensor] = None,
@@ -150,6 +171,18 @@ def flash_attn_fn(
150
  except:
151
  raise RuntimeError('Please install flash-attn==1.0.3.post0')
152
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
 
 
 
 
 
 
153
  if attn_bias is not None:
154
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
155
  (batch_size, seqlen) = query.shape[:2]
@@ -169,13 +202,14 @@ def flash_attn_fn(
169
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
170
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
171
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
172
- return AttnFnOutput(output, None)
173
 
174
  def triton_flash_attn_fn(
175
  query: torch.Tensor,
176
  key: torch.Tensor,
177
  value: torch.Tensor,
178
  n_heads: int,
 
179
  softmax_scale: Optional[float] = None,
180
  attn_bias: Optional[torch.Tensor] = None,
181
  key_padding_mask: Optional[torch.ByteTensor] = None,
@@ -198,6 +232,18 @@ def triton_flash_attn_fn(
198
  if not _installed:
199
  raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
200
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
 
 
 
 
 
 
201
  if dropout_p:
202
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
203
  if needs_weights:
@@ -217,7 +263,7 @@ def triton_flash_attn_fn(
217
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
218
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
219
  output = attn_output.view(*attn_output.shape[:2], -1)
220
- return AttnFnOutput(output, None)
221
 
222
  class MultiheadAttention(nn.Module, Attn):
223
  """Multi-head self attention.
@@ -278,13 +324,6 @@ class MultiheadAttention(nn.Module, Attn):
278
  dtype = query.dtype
279
  query = self.q_ln(query).to(dtype)
280
  key = self.k_ln(key).to(dtype)
281
- if past_key_value is not None:
282
- if len(past_key_value) != 0:
283
- key = torch.cat([past_key_value[0], key], dim=1)
284
- value = torch.cat([past_key_value[1], value], dim=1)
285
- past_key_value = PastKeyValue(key, value)
286
- if attn_bias is not None:
287
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
288
  if self.training and self.gradient_checkpointing:
289
  ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version('>=', '1.11.0') else {}
290
  def create_custom_forward(attn_fn: AttnFn) -> AttnFnCheckpointed:
@@ -337,6 +376,7 @@ class MultiheadAttention(nn.Module, Attn):
337
  key,
338
  value,
339
  self.n_heads,
 
340
  softmax_scale=self.softmax_scale,
341
  attn_bias=attn_bias,
342
  key_padding_mask=key_padding_mask,
@@ -345,7 +385,7 @@ class MultiheadAttention(nn.Module, Attn):
345
  training=self.training,
346
  needs_weights=needs_weights,
347
  )
348
- context, attn_weights = attn_fn_out
349
  return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
350
 
351
  class MultiQueryAttention(nn.Module, Attn):
@@ -465,6 +505,7 @@ class MultiQueryAttention(nn.Module, Attn):
465
  key,
466
  value,
467
  self.n_heads,
 
468
  softmax_scale=self.softmax_scale,
469
  attn_bias=attn_bias,
470
  key_padding_mask=key_padding_mask,
 
18
  class AttnFnOutput(NamedTuple):
19
  attns: torch.Tensor
20
  attn_probs: Optional[torch.Tensor]
21
+ past_key_value: Union[PastKeyValue, Tuple, None]
22
 
23
  class AttnFn(Protocol):
24
  def __call__(
 
82
  key: torch.Tensor,
83
  value: torch.Tensor,
84
  n_heads: int,
85
+ past_key_value=None,
86
  softmax_scale: Optional[float] = None,
87
  attn_bias: Optional[torch.Tensor] = None,
88
  key_padding_mask: Optional[torch.ByteTensor] = None,
 
93
  multiquery = False,
94
  ) -> AttnFnOutput:
95
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
96
+ kv_n_heads = 1 if multiquery else n_heads
97
+ k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
98
+ v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
99
+
100
+ if past_key_value is not None:
101
+ # attn_impl: flash & triton use kernels which expect input shape [b, s, h, d_head].
102
+ # kv_cache is therefore stored using that shape.
103
+ # attn_impl: torch stores the kv_cache in the ordering which is most advantageous
104
+ # for its attn computation ie
105
+ # keys are stored as tensors with shape [b, h, d_head, s] and
106
+ # values are stored as tensors with shape [b, h, s, d_head]
107
+ if len(past_key_value) != 0:
108
+ k = torch.cat([past_key_value[0], k], dim=3)
109
+ v = torch.cat([past_key_value[1], v], dim=2)
110
+
111
+ past_key_value = (k, v)
112
  (b, _, s_q, d) = q.shape
113
  s_k = k.size(-1)
114
  if softmax_scale is None:
115
  softmax_scale = 1 / math.sqrt(d)
116
  attn_weight = q.matmul(k) * softmax_scale
117
  if attn_bias is not None:
118
+ # clamp to 0 necessary for torch 2.0 compile()
119
+ _s_q = max(0, attn_bias.size(2) - s_q)
120
+ _s_k = max(0, attn_bias.size(3) - s_k)
121
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
122
  if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
123
  raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
124
  attn_weight = attn_weight + attn_bias
125
+ min_val = torch.finfo(q.dtype).min
126
  if key_padding_mask is not None:
127
  if attn_bias is not None:
128
  warnings.warn('Propagating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
129
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
130
+ if is_causal and (not q.size(2) == 1):
131
  s = max(s_q, s_k)
132
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
133
  causal_mask = causal_mask.tril()
 
141
  out = attn_weight.matmul(v)
142
  out = rearrange(out, 'b h s d -> b s (h d)')
143
  if needs_weights:
144
+ return AttnFnOutput(out, attn_weight, past_key_value)
145
+ return AttnFnOutput(out, None, past_key_value)
146
 
147
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
148
  for tensor in tensors:
 
156
  key: torch.Tensor,
157
  value: torch.Tensor,
158
  n_heads: int,
159
+ past_key_value=None,
160
  softmax_scale: Optional[float] = None,
161
  attn_bias: Optional[torch.Tensor] = None,
162
  key_padding_mask: Optional[torch.ByteTensor] = None,
 
171
  except:
172
  raise RuntimeError('Please install flash-attn==1.0.3.post0')
173
  check_valid_inputs(query, key, value)
174
+ if past_key_value is not None:
175
+ if len(past_key_value) != 0:
176
+ key = torch.cat([past_key_value[0], key], dim=1)
177
+ value = torch.cat([past_key_value[1], value], dim=1)
178
+
179
+ past_key_value = (key, value)
180
+
181
+ if attn_bias is not None:
182
+ # clamp to 0 necessary for torch 2.0 compile()
183
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
184
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
185
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
186
  if attn_bias is not None:
187
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
188
  (batch_size, seqlen) = query.shape[:2]
 
202
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
203
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
204
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
205
+ return AttnFnOutput(output, None, past_key_value)
206
 
207
  def triton_flash_attn_fn(
208
  query: torch.Tensor,
209
  key: torch.Tensor,
210
  value: torch.Tensor,
211
  n_heads: int,
212
+ past_key_value=None,
213
  softmax_scale: Optional[float] = None,
214
  attn_bias: Optional[torch.Tensor] = None,
215
  key_padding_mask: Optional[torch.ByteTensor] = None,
 
232
  if not _installed:
233
  raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
234
  check_valid_inputs(query, key, value)
235
+ if past_key_value is not None:
236
+ if len(past_key_value) != 0:
237
+ key = torch.cat([past_key_value[0], key], dim=1)
238
+ value = torch.cat([past_key_value[1], value], dim=1)
239
+
240
+ past_key_value = (key, value)
241
+
242
+ if attn_bias is not None:
243
+ # clamp to 0 necessary for torch 2.0 compile()
244
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
245
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
246
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
247
  if dropout_p:
248
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
249
  if needs_weights:
 
263
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
264
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
265
  output = attn_output.view(*attn_output.shape[:2], -1)
266
+ return AttnFnOutput(output, None, past_key_value)
267
 
268
  class MultiheadAttention(nn.Module, Attn):
269
  """Multi-head self attention.
 
324
  dtype = query.dtype
325
  query = self.q_ln(query).to(dtype)
326
  key = self.k_ln(key).to(dtype)
 
 
 
 
 
 
 
327
  if self.training and self.gradient_checkpointing:
328
  ckpt_kwargs: Dict[str, Any] = {'use_reentrant': False} if is_torch_version('>=', '1.11.0') else {}
329
  def create_custom_forward(attn_fn: AttnFn) -> AttnFnCheckpointed:
 
376
  key,
377
  value,
378
  self.n_heads,
379
+ past_key_value=past_key_value,
380
  softmax_scale=self.softmax_scale,
381
  attn_bias=attn_bias,
382
  key_padding_mask=key_padding_mask,
 
385
  training=self.training,
386
  needs_weights=needs_weights,
387
  )
388
+ context, attn_weights, past_key_value = attn_fn_out
389
  return AttnOutput(self.out_proj(context), attn_weights, past_key_value)
390
 
391
  class MultiQueryAttention(nn.Module, Attn):
 
505
  key,
506
  value,
507
  self.n_heads,
508
+ past_key_value=past_key_value,
509
  softmax_scale=self.softmax_scale,
510
  attn_bias=attn_bias,
511
  key_padding_mask=key_padding_mask,
blocks.py CHANGED
@@ -7,6 +7,7 @@ from .norm import NORM_CLASS_REGISTRY
7
 
8
  class MPTBlockOutput(NamedTuple):
9
  hidden_states: torch.Tensor
 
10
  past_key_value: Union[PastKeyValue, Tuple, None]
11
 
12
  class MPTMLP(nn.Module):
@@ -38,9 +39,9 @@ class MPTBlock(nn.Module):
38
 
39
  def forward(self, x: torch.Tensor, past_key_value: Union[PastKeyValue, Tuple, None] = None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> MPTBlockOutput:
40
  a = self.norm_1(x)
41
- (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
42
  x = x + self.resid_attn_dropout(b)
43
  m = self.norm_2(x)
44
  n = self.ffn(m)
45
  x = x + self.resid_ffn_dropout(n)
46
- return MPTBlockOutput(x, past_key_value)
 
7
 
8
  class MPTBlockOutput(NamedTuple):
9
  hidden_states: torch.Tensor
10
+ attn_probs: Optional[torch.Tensor]
11
  past_key_value: Union[PastKeyValue, Tuple, None]
12
 
13
  class MPTMLP(nn.Module):
 
39
 
40
  def forward(self, x: torch.Tensor, past_key_value: Union[PastKeyValue, Tuple, None] = None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> MPTBlockOutput:
41
  a = self.norm_1(x)
42
+ (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
43
  x = x + self.resid_attn_dropout(b)
44
  m = self.norm_2(x)
45
  n = self.ffn(m)
46
  x = x + self.resid_ffn_dropout(n)
47
+ return MPTBlockOutput(x, attn_weights, past_key_value)
config.json CHANGED
@@ -21,6 +21,7 @@
21
  "d_model": 4096,
22
  "emb_pdrop": 0,
23
  "embedding_fraction": 1.0,
 
24
  "expansion_ratio": 4,
25
  "init_config": {
26
  "emb_init_std": null,
@@ -46,7 +47,7 @@
46
  "tokenizer_name": "EleutherAI/gpt-neox-20b",
47
  "torch_dtype": "bfloat16",
48
  "transformers_version": "4.29.2",
49
- "use_cache": false,
50
  "verbose": 0,
51
  "vocab_size": 50432
52
  }
 
21
  "d_model": 4096,
22
  "emb_pdrop": 0,
23
  "embedding_fraction": 1.0,
24
+ "eos_token_id": 0,
25
  "expansion_ratio": 4,
26
  "init_config": {
27
  "emb_init_std": null,
 
47
  "tokenizer_name": "EleutherAI/gpt-neox-20b",
48
  "torch_dtype": "bfloat16",
49
  "transformers_version": "4.29.2",
50
+ "use_cache": true,
51
  "verbose": 0,
52
  "vocab_size": 50432
53
  }
generation_config.json CHANGED
@@ -2,5 +2,5 @@
2
  "_from_model_config": true,
3
  "transformers_version": "4.29.2",
4
  "eos_token_id": 0,
5
- "use_cache": false
6
  }
 
2
  "_from_model_config": true,
3
  "transformers_version": "4.29.2",
4
  "eos_token_id": 0,
5
+ "use_cache": true
6
  }
modeling_mpt.py CHANGED
@@ -116,7 +116,9 @@ class MPTModel(MPTPreTrainedModel):
116
  if attn_bias is None:
117
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
118
  else:
119
- attn_bias = attn_bias[:, :, :, -s_k:]
 
 
120
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
121
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
122
  min_val = torch.finfo(attn_bias.dtype).min
@@ -164,7 +166,10 @@ class MPTModel(MPTPreTrainedModel):
164
  if not return_dict:
165
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
166
  if output_attentions:
167
- raise NotImplementedError('output_attentions is not implemented yet for MPT')
 
 
 
168
  if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
169
  raise NotImplementedError('MPT does not support training with left padding.')
170
  if self.prefix_lm and prefix_mask is None:
@@ -184,7 +189,12 @@ class MPTModel(MPTPreTrainedModel):
184
  if past_key_values is not None:
185
  if len(past_key_values) != self.config.n_layers:
186
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
 
 
 
187
  past_position = past_key_values[0][0].size(1)
 
 
188
  if S + past_position > self.config.max_seq_len:
189
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
190
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
@@ -202,6 +212,7 @@ class MPTModel(MPTPreTrainedModel):
202
  if use_cache and past_key_values is None:
203
  past_key_values = [() for _ in range(self.config.n_layers)]
204
  all_hidden_states = () if output_hidden_states else None
 
205
  for (b_idx, block) in enumerate(self.blocks):
206
  if output_hidden_states:
207
  assert all_hidden_states is not None
@@ -242,12 +253,19 @@ class MPTModel(MPTPreTrainedModel):
242
  attention_mask=attention_mask,
243
  is_causal=self.is_causal,
244
  )
245
- x, past_key_value = block_out
246
  del block_out
247
  if past_key_values is not None:
248
  past_key_values[b_idx] = past_key_value
 
 
 
249
  x = self.norm_f(x)
250
- return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
 
 
 
 
251
 
252
  def param_init_fn(self, module):
253
  init_fn_name = self.config.init_config['name']
@@ -308,7 +326,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
308
  labels = torch.roll(labels, shifts=-1)
309
  labels[:, -1] = -100
310
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
311
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
312
 
313
  def param_init_fn(self, module):
314
  init_fn_name = self.config.init_config['name']
 
116
  if attn_bias is None:
117
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
118
  else:
119
+ # clamp to 0 necessary for torch 2.0 compile()
120
+ _s_k = max(0, attn_bias.size(-1) - s_k)
121
+ attn_bias = attn_bias[:, :, :, _s_k:]
122
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
123
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
124
  min_val = torch.finfo(attn_bias.dtype).min
 
166
  if not return_dict:
167
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
168
  if output_attentions:
169
+ if self.attn_impl != 'torch':
170
+ raise NotImplementedError(
171
+ 'output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.'
172
+ )
173
  if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
174
  raise NotImplementedError('MPT does not support training with left padding.')
175
  if self.prefix_lm and prefix_mask is None:
 
189
  if past_key_values is not None:
190
  if len(past_key_values) != self.config.n_layers:
191
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
192
+ # For attn_impl: triton and flash the past key tensor spec is (batch, seq, dim).
193
+ # For attn_impl: torch the past key tensor spec is (batch, heads, head_dim, seq).
194
+ # Here we shift position embedding using the `seq` dim of the past key
195
  past_position = past_key_values[0][0].size(1)
196
+ if self.attn_impl == 'torch':
197
+ past_position = past_key_values[0][0].size(3)
198
  if S + past_position > self.config.max_seq_len:
199
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
200
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
 
212
  if use_cache and past_key_values is None:
213
  past_key_values = [() for _ in range(self.config.n_layers)]
214
  all_hidden_states = () if output_hidden_states else None
215
+ all_self_attns = () if output_attentions else None
216
  for (b_idx, block) in enumerate(self.blocks):
217
  if output_hidden_states:
218
  assert all_hidden_states is not None
 
253
  attention_mask=attention_mask,
254
  is_causal=self.is_causal,
255
  )
256
+ x, attn_weights, past_key_value = block_out
257
  del block_out
258
  if past_key_values is not None:
259
  past_key_values[b_idx] = past_key_value
260
+ if output_attentions:
261
+ assert all_self_attns is not None # pyright
262
+ all_self_attns = all_self_attns + (attn_weights,)
263
  x = self.norm_f(x)
264
+ # add hidden states from the last decoder layer
265
+ if output_hidden_states:
266
+ assert all_hidden_states is not None # pyright
267
+ all_hidden_states = all_hidden_states + (x,)
268
+ return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns)
269
 
270
  def param_init_fn(self, module):
271
  init_fn_name = self.config.init_config['name']
 
326
  labels = torch.roll(labels, shifts=-1)
327
  labels[:, -1] = -100
328
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
329
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
330
 
331
  def param_init_fn(self, module):
332
  init_fn_name = self.config.init_config['name']