abhi-mosaic commited on
Commit
68e1a8e
1 Parent(s): 053e1a3
Files changed (6) hide show
  1. README.md +37 -33
  2. attention.py +48 -33
  3. blocks.py +4 -4
  4. configuration_mpt.py +1 -1
  5. modeling_mpt.py +23 -7
  6. requirements.txt +2 -0
README.md CHANGED
@@ -19,12 +19,12 @@ inference: false
19
  MPT-7B is a decoder-style transformer pretrained from scratch on 1T tokens of English text and code.
20
  This model was trained by [MosaicML](https://www.mosaicml.com).
21
 
22
- MPT-7B is part of the family of MosaicPretrainedTransformer (MPT) models, which use a modified transformer architecture optimized for efficient training and inference.
23
 
24
- These architectural changes include performance-optimized layer implementations and the elimination of context length limits by replacing
25
- positional embeddings with Attention with Linear Biases ([ALiBi](https://arxiv.org/abs/2108.12409)).
26
- Thanks to these modifications, MPT models can be trained with high throughput efficiency and stable convergence.
27
- MPT models can also be served efficiently with both standard HuggingFace pipelines and NVIDIA's [FasterTransformer](https://github.com/NVIDIA/FasterTransformer).
28
 
29
  This model uses the MosaicML LLM codebase, which can be found in the [llm-foundry repository](https://github.com/mosaicml/llm-foundry). It was trained by MosaicML’s NLP team on the [MosaicML platform](https://www.mosaicml.com/training) for LLM pretraining, finetuning, and inference.
30
 
@@ -49,7 +49,7 @@ We demonstrate generations as long as 80k tokens on a single A100-80GB GPU in ou
49
  * License: Apache 2.0
50
 
51
  * [MPT-7B-Instruct](https://huggingface.co/mosaicml/mpt-7b-instruct): a model for short-form instruction following.
52
- Built by finetuning MPT-7B on a [dataset](https://huggingface.co/datasets/mosaicml/dolly_hhrlhf) we also release, derived from the [Databricks Dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k) and the [Anthropic Helpful and Harmless (HH-RLHF)](https://huggingface.co/datasets/Anthropic/hh-rlhf) datasets.
53
  * License: _CC-By-SA-3.0_
54
  * [Demo on Hugging Face Spaces](https://huggingface.co/spaces/mosaicml/mpt-7b-instruct)
55
 
@@ -85,37 +85,41 @@ model = transformers.AutoModelForCausalLM.from_pretrained(
85
  trust_remote_code=True
86
  )
87
  ```
88
- Note: This model requires that `trust_remote_code=True` be passed to the `from_pretrained` method.
89
  This is because we use a custom `MPT` model architecture that is not yet part of the Hugging Face `transformers` package.
90
  `MPT` includes options for many training efficiency features such as [FlashAttention](https://arxiv.org/pdf/2205.14135.pdf), [ALiBi](https://arxiv.org/abs/2108.12409), [QK LayerNorm](https://arxiv.org/abs/2010.04245), and more.
91
 
92
- To use the optimized [triton implementation](https://github.com/openai/triton) of FlashAttention, you can load the model with `attn_impl='triton'` and move the model to `bfloat16`:
93
  ```python
94
- config = transformers.AutoConfig.from_pretrained(
95
- 'mosaicml/mpt-7b',
96
- trust_remote_code=True
97
- )
 
 
98
  config.attn_config['attn_impl'] = 'triton'
 
99
 
100
  model = transformers.AutoModelForCausalLM.from_pretrained(
101
- 'mosaicml/mpt-7b',
102
  config=config,
103
- torch_dtype=torch.bfloat16,
104
  trust_remote_code=True
105
  )
106
- model.to(device='cuda:0')
107
  ```
108
 
109
  Although the model was trained with a sequence length of 2048, ALiBi enables users to increase the maximum sequence length during finetuning and/or inference. For example:
110
 
111
  ```python
112
- config = transformers.AutoConfig.from_pretrained(
113
- 'mosaicml/mpt-7b',
114
- trust_remote_code=True
115
- )
116
- config.update({"max_seq_len": 4096})
 
 
117
  model = transformers.AutoModelForCausalLM.from_pretrained(
118
- 'mosaicml/mpt-7b',
119
  config=config,
120
  trust_remote_code=True
121
  )
@@ -125,7 +129,7 @@ This model was trained with the [EleutherAI/gpt-neox-20b](https://huggingface.co
125
 
126
  ```python
127
  from transformers import AutoTokenizer
128
- tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
129
  ```
130
 
131
  ## Model Description
@@ -153,7 +157,7 @@ The model has been modified from a standard transformer in the following ways:
153
 
154
  ### Streaming Datasets
155
 
156
- Data was formatted using the MosaicML [StreamingDataset](https://github.com/mosaicml/streaming) library to host our data in object storage and efficiently stream it to our compute cluster during training.
157
  StreamingDataset obviates the need to download the whole dataset before starting training, and allows instant resumption of training from any point in the dataset.
158
 
159
 
@@ -178,24 +182,24 @@ The model was trained for 1T tokens (with batch size 1760 and sequence length 20
178
  Samples for each batch were selected from one of the datasets with the probability specified above.
179
  The examples were shuffled within each dataset, and each example was constructed from as many sequences from that dataset as were necessary to fill the 2048 sequence length.
180
 
181
- The data was tokenized using the [EleutherAI/gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b) tokenizer. This BPE tokenizer has a number of desirable characteristics,
182
- most of which are relevant for tokenizing code:
183
- (1) It was trained on a diverse mix of data that includes code (The Pile)
184
- (2) It applies consistent space delimitation, unlike the GPT2 tokenizer which tokenizes inconsistently depending on the presence of prefix spaces
185
- (3) It contains tokens for repeated space characters, which allows superior compression of text with large amounts of repeated space characters.
186
 
187
  The model vocabulary size of 50432 was set to be a multiple of 128 (as in [MEGATRON-LM](https://arxiv.org/abs/1909.08053)), model flop utilization (MFU) increased by up to four percentage points.
188
 
189
  ### Training Configuration
190
 
191
- This model was trained on 440 A100-40GBs for about 9.5 days using the [MosaicML Platform](https://www.mosaicml.com/platform).
192
- The model was trained with sharded data parallelism using [FSDP](https://pytorch.org/docs/stable/fsdp.html) and used the [LION](https://arxiv.org/abs/2302.06675) optimizer.
193
 
194
  ## Limitations and Biases
195
 
196
  _The following language is modified from [EleutherAI's GPT-NeoX-20B](https://huggingface.co/EleutherAI/gpt-neox-20b)_
197
 
198
- MPT-7B (Base) is **not** intended for deployment without finetuning.
199
  It should not be used for human-facing interactions without further guardrails and user consent.
200
 
201
  MPT-7B can produce factually incorrect output, and should not be relied on to produce factually accurate information.
@@ -218,11 +222,11 @@ Please cite this model using the following format:
218
  ```
219
  @online{MosaicML2023Introducing,
220
  author = {MosaicML NLP Team},
221
- title = {Introducing MPT-7B: A New Standard for Open-Source,
222
  ly Usable LLMs},
223
  year = {2023},
224
  url = {www.mosaicml.com/blog/mpt-7b},
225
  note = {Accessed: 2023-03-28}, % change this date
226
  urldate = {2023-03-28} % change this date
227
  }
228
- ```
19
  MPT-7B is a decoder-style transformer pretrained from scratch on 1T tokens of English text and code.
20
  This model was trained by [MosaicML](https://www.mosaicml.com).
21
 
22
+ MPT-7B is part of the family of MosaicPretrainedTransformer (MPT) models, which use a modified transformer architecture optimized for efficient training and inference.
23
 
24
+ These architectural changes include performance-optimized layer implementations and the elimination of context length limits by replacing
25
+ positional embeddings with Attention with Linear Biases ([ALiBi](https://arxiv.org/abs/2108.12409)).
26
+ Thanks to these modifications, MPT models can be trained with high throughput efficiency and stable convergence.
27
+ MPT models can also be served efficiently with both standard HuggingFace pipelines and NVIDIA's [FasterTransformer](https://github.com/NVIDIA/FasterTransformer).
28
 
29
  This model uses the MosaicML LLM codebase, which can be found in the [llm-foundry repository](https://github.com/mosaicml/llm-foundry). It was trained by MosaicML’s NLP team on the [MosaicML platform](https://www.mosaicml.com/training) for LLM pretraining, finetuning, and inference.
30
 
49
  * License: Apache 2.0
50
 
51
  * [MPT-7B-Instruct](https://huggingface.co/mosaicml/mpt-7b-instruct): a model for short-form instruction following.
52
+ Built by finetuning MPT-7B on a [dataset](https://huggingface.co/datasets/mosaicml/dolly_hhrlhf) we also release, derived from the [Databricks Dolly-15k](https://huggingface.co/datasets/databricks/databricks-dolly-15k) and the [Anthropic Helpful and Harmless (HH-RLHF)](https://huggingface.co/datasets/Anthropic/hh-rlhf) datasets.
53
  * License: _CC-By-SA-3.0_
54
  * [Demo on Hugging Face Spaces](https://huggingface.co/spaces/mosaicml/mpt-7b-instruct)
55
 
85
  trust_remote_code=True
86
  )
87
  ```
88
+ Note: This model requires that `trust_remote_code=True` be passed to the `from_pretrained` method.
89
  This is because we use a custom `MPT` model architecture that is not yet part of the Hugging Face `transformers` package.
90
  `MPT` includes options for many training efficiency features such as [FlashAttention](https://arxiv.org/pdf/2205.14135.pdf), [ALiBi](https://arxiv.org/abs/2108.12409), [QK LayerNorm](https://arxiv.org/abs/2010.04245), and more.
91
 
92
+ To use the optimized [triton implementation](https://github.com/openai/triton) of FlashAttention, you can load the model on GPU (`cuda:0`) with `attn_impl='triton'` and with `bfloat16` precision:
93
  ```python
94
+ import torch
95
+ import transformers
96
+
97
+ name = 'mosaicml/mpt-7b'
98
+
99
+ config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
100
  config.attn_config['attn_impl'] = 'triton'
101
+ config.init_device = 'cuda:0' # For fast initialization directly on GPU!
102
 
103
  model = transformers.AutoModelForCausalLM.from_pretrained(
104
+ name,
105
  config=config,
106
+ torch_dtype=torch.bfloat16, # Load model weights in bfloat16
107
  trust_remote_code=True
108
  )
 
109
  ```
110
 
111
  Although the model was trained with a sequence length of 2048, ALiBi enables users to increase the maximum sequence length during finetuning and/or inference. For example:
112
 
113
  ```python
114
+ import transformers
115
+
116
+ name = 'mosaicml/mpt-7b'
117
+
118
+ config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
119
+ config.max_seq_len = 4096 # (input + output) tokens can now be up to 4096
120
+
121
  model = transformers.AutoModelForCausalLM.from_pretrained(
122
+ name,
123
  config=config,
124
  trust_remote_code=True
125
  )
129
 
130
  ```python
131
  from transformers import AutoTokenizer
132
+ tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
133
  ```
134
 
135
  ## Model Description
157
 
158
  ### Streaming Datasets
159
 
160
+ Data was formatted using the MosaicML [StreamingDataset](https://github.com/mosaicml/streaming) library to host our data in object storage and efficiently stream it to our compute cluster during training.
161
  StreamingDataset obviates the need to download the whole dataset before starting training, and allows instant resumption of training from any point in the dataset.
162
 
163
 
182
  Samples for each batch were selected from one of the datasets with the probability specified above.
183
  The examples were shuffled within each dataset, and each example was constructed from as many sequences from that dataset as were necessary to fill the 2048 sequence length.
184
 
185
+ The data was tokenized using the [EleutherAI/gpt-neox-20b](https://huggingface.co/EleutherAI/gpt-neox-20b) tokenizer. This BPE tokenizer has a number of desirable characteristics,
186
+ most of which are relevant for tokenizing code:
187
+ (1) It was trained on a diverse mix of data that includes code (The Pile)
188
+ (2) It applies consistent space delimitation, unlike the GPT2 tokenizer which tokenizes inconsistently depending on the presence of prefix spaces
189
+ (3) It contains tokens for repeated space characters, which allows superior compression of text with large amounts of repeated space characters.
190
 
191
  The model vocabulary size of 50432 was set to be a multiple of 128 (as in [MEGATRON-LM](https://arxiv.org/abs/1909.08053)), model flop utilization (MFU) increased by up to four percentage points.
192
 
193
  ### Training Configuration
194
 
195
+ This model was trained on 440 A100-40GBs for about 9.5 days using the [MosaicML Platform](https://www.mosaicml.com/platform).
196
+ The model was trained with sharded data parallelism using [FSDP](https://pytorch.org/docs/stable/fsdp.html) and used the [LION](https://arxiv.org/abs/2302.06675) optimizer.
197
 
198
  ## Limitations and Biases
199
 
200
  _The following language is modified from [EleutherAI's GPT-NeoX-20B](https://huggingface.co/EleutherAI/gpt-neox-20b)_
201
 
202
+ MPT-7B (Base) is **not** intended for deployment without finetuning.
203
  It should not be used for human-facing interactions without further guardrails and user consent.
204
 
205
  MPT-7B can produce factually incorrect output, and should not be relied on to produce factually accurate information.
222
  ```
223
  @online{MosaicML2023Introducing,
224
  author = {MosaicML NLP Team},
225
+ title = {Introducing MPT-7B: A New Standard for Open-Source,
226
  ly Usable LLMs},
227
  year = {2023},
228
  url = {www.mosaicml.com/blog/mpt-7b},
229
  note = {Accessed: 2023-03-28}, % change this date
230
  urldate = {2023-03-28} % change this date
231
  }
232
+ ```
attention.py CHANGED
@@ -17,25 +17,34 @@ def _reset_is_causal(num_query_tokens: int, num_key_tokens: int, original_is_cau
17
  return False
18
  return original_is_causal
19
 
20
- def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
21
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
22
- k = rearrange(key, 'b s (h d) -> b h d s', h=1 if multiquery else n_heads)
23
- v = rearrange(value, 'b s (h d) -> b h s d', h=1 if multiquery else n_heads)
24
- min_val = torch.finfo(q.dtype).min
 
 
 
 
 
25
  (b, _, s_q, d) = q.shape
26
  s_k = k.size(-1)
27
  if softmax_scale is None:
28
  softmax_scale = 1 / math.sqrt(d)
29
  attn_weight = q.matmul(k) * softmax_scale
30
  if attn_bias is not None:
 
 
 
31
  if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
32
  raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
33
  attn_weight = attn_weight + attn_bias
 
34
  if key_padding_mask is not None:
35
  if attn_bias is not None:
36
  warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
37
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
38
- if is_causal:
39
  s = max(s_q, s_k)
40
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
41
  causal_mask = causal_mask.tril()
@@ -49,8 +58,8 @@ def scaled_multihead_dot_product_attention(query, key, value, n_heads, softmax_s
49
  out = attn_weight.matmul(v)
50
  out = rearrange(out, 'b h s d -> b s (h d)')
51
  if needs_weights:
52
- return (out, attn_weight)
53
- return (out, None)
54
 
55
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
56
  for tensor in tensors:
@@ -59,12 +68,21 @@ def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
59
  if not tensor.is_cuda:
60
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
61
 
62
- def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
63
  try:
64
  from flash_attn import bert_padding, flash_attn_interface
65
  except:
66
  raise RuntimeError('Please install flash-attn==1.0.3.post0')
67
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
 
 
 
68
  if attn_bias is not None:
69
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
70
  (batch_size, seqlen) = query.shape[:2]
@@ -84,9 +102,9 @@ def flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None
84
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
85
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
86
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
87
- return (output, None)
88
 
89
- def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
90
  try:
91
  from .flash_attn_triton import flash_attn_func
92
  except:
@@ -100,6 +118,15 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
100
  if not _installed:
101
  raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
102
  check_valid_inputs(query, key, value)
 
 
 
 
 
 
 
 
 
103
  if dropout_p:
104
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
105
  if needs_weights:
@@ -119,7 +146,7 @@ def triton_flash_attn_fn(query, key, value, n_heads, softmax_scale=None, attn_bi
119
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
120
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
121
  output = attn_output.view(*attn_output.shape[:2], -1)
122
- return (output, None)
123
 
124
  class MultiheadAttention(nn.Module):
125
  """Multi-head self attention.
@@ -128,7 +155,7 @@ class MultiheadAttention(nn.Module):
128
  additive bias.
129
  """
130
 
131
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
132
  super().__init__()
133
  self.attn_impl = attn_impl
134
  self.clip_qkv = clip_qkv
@@ -150,10 +177,11 @@ class MultiheadAttention(nn.Module):
150
  self.attn_fn = flash_attn_fn
151
  elif self.attn_impl == 'triton':
152
  self.attn_fn = triton_flash_attn_fn
153
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
 
154
  elif self.attn_impl == 'torch':
155
  self.attn_fn = scaled_multihead_dot_product_attention
156
- if torch.cuda.is_available():
157
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
158
  else:
159
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
@@ -170,14 +198,7 @@ class MultiheadAttention(nn.Module):
170
  dtype = query.dtype
171
  query = self.q_ln(query).to(dtype)
172
  key = self.k_ln(key).to(dtype)
173
- if past_key_value is not None:
174
- if len(past_key_value) != 0:
175
- key = torch.cat([past_key_value[0], key], dim=1)
176
- value = torch.cat([past_key_value[1], value], dim=1)
177
- past_key_value = (key, value)
178
- if attn_bias is not None:
179
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
180
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
181
  return (self.out_proj(context), attn_weights, past_key_value)
182
 
183
  class MultiQueryAttention(nn.Module):
@@ -187,7 +208,7 @@ class MultiQueryAttention(nn.Module):
187
  additive bias.
188
  """
189
 
190
- def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, device: Optional[str]=None):
191
  super().__init__()
192
  self.attn_impl = attn_impl
193
  self.clip_qkv = clip_qkv
@@ -210,10 +231,11 @@ class MultiQueryAttention(nn.Module):
210
  self.attn_fn = flash_attn_fn
211
  elif self.attn_impl == 'triton':
212
  self.attn_fn = triton_flash_attn_fn
213
- warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
 
214
  elif self.attn_impl == 'torch':
215
  self.attn_fn = scaled_multihead_dot_product_attention
216
- if torch.cuda.is_available():
217
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
218
  else:
219
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
@@ -230,14 +252,7 @@ class MultiQueryAttention(nn.Module):
230
  dtype = query.dtype
231
  query = self.q_ln(query).to(dtype)
232
  key = self.k_ln(key).to(dtype)
233
- if past_key_value is not None:
234
- if len(past_key_value) != 0:
235
- key = torch.cat([past_key_value[0], key], dim=1)
236
- value = torch.cat([past_key_value[1], value], dim=1)
237
- past_key_value = (key, value)
238
- if attn_bias is not None:
239
- attn_bias = attn_bias[:, :, -query.size(1):, -key.size(1):]
240
- (context, attn_weights) = self.attn_fn(query, key, value, self.n_heads, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
241
  return (self.out_proj(context), attn_weights, past_key_value)
242
 
243
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
17
  return False
18
  return original_is_causal
19
 
20
+ def scaled_multihead_dot_product_attention(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
21
  q = rearrange(query, 'b s (h d) -> b h s d', h=n_heads)
22
+ kv_n_heads = 1 if multiquery else n_heads
23
+ k = rearrange(key, 'b s (h d) -> b h d s', h=kv_n_heads)
24
+ v = rearrange(value, 'b s (h d) -> b h s d', h=kv_n_heads)
25
+ if past_key_value is not None:
26
+ if len(past_key_value) != 0:
27
+ k = torch.cat([past_key_value[0], k], dim=3)
28
+ v = torch.cat([past_key_value[1], v], dim=2)
29
+ past_key_value = (k, v)
30
  (b, _, s_q, d) = q.shape
31
  s_k = k.size(-1)
32
  if softmax_scale is None:
33
  softmax_scale = 1 / math.sqrt(d)
34
  attn_weight = q.matmul(k) * softmax_scale
35
  if attn_bias is not None:
36
+ _s_q = max(0, attn_bias.size(2) - s_q)
37
+ _s_k = max(0, attn_bias.size(3) - s_k)
38
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
39
  if attn_bias.size(-1) != 1 and attn_bias.size(-1) != s_k or (attn_bias.size(-2) != 1 and attn_bias.size(-2) != s_q):
40
  raise RuntimeError(f'attn_bias (shape: {attn_bias.shape}) is expected to broadcast to shape: {attn_weight.shape}.')
41
  attn_weight = attn_weight + attn_bias
42
+ min_val = torch.finfo(q.dtype).min
43
  if key_padding_mask is not None:
44
  if attn_bias is not None:
45
  warnings.warn('Propogating key_padding_mask to the attention module ' + 'and applying it within the attention module can cause ' + 'unneccessary computation/memory usage. Consider integrating ' + 'into attn_bias once and passing that to each attention ' + 'module instead.')
46
  attn_weight = attn_weight.masked_fill(~key_padding_mask.view((b, 1, 1, s_k)), min_val)
47
+ if is_causal and (not q.size(2) == 1):
48
  s = max(s_q, s_k)
49
  causal_mask = attn_weight.new_ones(s, s, dtype=torch.float16)
50
  causal_mask = causal_mask.tril()
58
  out = attn_weight.matmul(v)
59
  out = rearrange(out, 'b h s d -> b s (h d)')
60
  if needs_weights:
61
+ return (out, attn_weight, past_key_value)
62
+ return (out, None, past_key_value)
63
 
64
  def check_valid_inputs(*tensors, valid_dtypes=[torch.float16, torch.bfloat16]):
65
  for tensor in tensors:
68
  if not tensor.is_cuda:
69
  raise TypeError(f'Inputs must be cuda tensors (tensor.is_cuda={tensor.is_cuda!r}).')
70
 
71
+ def flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
72
  try:
73
  from flash_attn import bert_padding, flash_attn_interface
74
  except:
75
  raise RuntimeError('Please install flash-attn==1.0.3.post0')
76
  check_valid_inputs(query, key, value)
77
+ if past_key_value is not None:
78
+ if len(past_key_value) != 0:
79
+ key = torch.cat([past_key_value[0], key], dim=1)
80
+ value = torch.cat([past_key_value[1], value], dim=1)
81
+ past_key_value = (key, value)
82
+ if attn_bias is not None:
83
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
84
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
85
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
86
  if attn_bias is not None:
87
  raise NotImplementedError(f'attn_bias not implemented for flash attn.')
88
  (batch_size, seqlen) = query.shape[:2]
102
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
103
  output_unpad = flash_attn_interface.flash_attn_unpadded_func(query_unpad, key_unpad, value_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale=softmax_scale, causal=reset_is_causal, return_attn_probs=needs_weights)
104
  output = bert_padding.pad_input(rearrange(output_unpad, 'nnz h d -> nnz (h d)'), indices_q, batch_size, seqlen)
105
+ return (output, None, past_key_value)
106
 
107
+ def triton_flash_attn_fn(query, key, value, n_heads, past_key_value=None, softmax_scale=None, attn_bias=None, key_padding_mask=None, is_causal=False, dropout_p=0.0, training=False, needs_weights=False, multiquery=False):
108
  try:
109
  from .flash_attn_triton import flash_attn_func
110
  except:
118
  if not _installed:
119
  raise RuntimeError('Requirements for `attn_impl: triton` not installed. Either (1) have a CUDA-compatible GPU and `pip install .[gpu]` if installing from llm-foundry source or `pip install triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir#subdirectory=python` if installing from pypi, or (2) use torch attn model.attn_config.attn_impl=torch (torch attn_impl will be slow). Note: (1) requires you have CMake and PyTorch already installed.')
120
  check_valid_inputs(query, key, value)
121
+ if past_key_value is not None:
122
+ if len(past_key_value) != 0:
123
+ key = torch.cat([past_key_value[0], key], dim=1)
124
+ value = torch.cat([past_key_value[1], value], dim=1)
125
+ past_key_value = (key, value)
126
+ if attn_bias is not None:
127
+ _s_q = max(0, attn_bias.size(2) - query.size(1))
128
+ _s_k = max(0, attn_bias.size(3) - key.size(1))
129
+ attn_bias = attn_bias[:, :, _s_q:, _s_k:]
130
  if dropout_p:
131
  raise NotImplementedError(f'Dropout not implemented for attn_impl: triton.')
132
  if needs_weights:
146
  reset_is_causal = _reset_is_causal(query.size(1), key.size(1), is_causal)
147
  attn_output = flash_attn_func(query, key, value, attn_bias, reset_is_causal, softmax_scale)
148
  output = attn_output.view(*attn_output.shape[:2], -1)
149
+ return (output, None, past_key_value)
150
 
151
  class MultiheadAttention(nn.Module):
152
  """Multi-head self attention.
155
  additive bias.
156
  """
157
 
158
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
159
  super().__init__()
160
  self.attn_impl = attn_impl
161
  self.clip_qkv = clip_qkv
177
  self.attn_fn = flash_attn_fn
178
  elif self.attn_impl == 'triton':
179
  self.attn_fn = triton_flash_attn_fn
180
+ if verbose:
181
+ warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
182
  elif self.attn_impl == 'torch':
183
  self.attn_fn = scaled_multihead_dot_product_attention
184
+ if torch.cuda.is_available() and verbose:
185
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
186
  else:
187
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
198
  dtype = query.dtype
199
  query = self.q_ln(query).to(dtype)
200
  key = self.k_ln(key).to(dtype)
201
+ (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights)
 
 
 
 
 
 
 
202
  return (self.out_proj(context), attn_weights, past_key_value)
203
 
204
  class MultiQueryAttention(nn.Module):
208
  additive bias.
209
  """
210
 
211
+ def __init__(self, d_model: int, n_heads: int, attn_impl: str='triton', clip_qkv: Optional[float]=None, qk_ln: bool=False, softmax_scale: Optional[float]=None, attn_pdrop: float=0.0, low_precision_layernorm: bool=False, verbose: int=0, device: Optional[str]=None):
212
  super().__init__()
213
  self.attn_impl = attn_impl
214
  self.clip_qkv = clip_qkv
231
  self.attn_fn = flash_attn_fn
232
  elif self.attn_impl == 'triton':
233
  self.attn_fn = triton_flash_attn_fn
234
+ if verbose:
235
+ warnings.warn('While `attn_impl: triton` can be faster than `attn_impl: flash` ' + 'it uses more memory. When training larger models this can trigger ' + 'alloc retries which hurts performance. If encountered, we recommend ' + 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.')
236
  elif self.attn_impl == 'torch':
237
  self.attn_fn = scaled_multihead_dot_product_attention
238
+ if torch.cuda.is_available() and verbose:
239
  warnings.warn('Using `attn_impl: torch`. If your model does not use `alibi` or ' + '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' + 'we recommend using `attn_impl: triton`.')
240
  else:
241
  raise ValueError(f'attn_impl={attn_impl!r} is an invalid setting.')
252
  dtype = query.dtype
253
  query = self.q_ln(query).to(dtype)
254
  key = self.k_ln(key).to(dtype)
255
+ (context, attn_weights, past_key_value) = self.attn_fn(query, key, value, self.n_heads, past_key_value=past_key_value, softmax_scale=self.softmax_scale, attn_bias=attn_bias, key_padding_mask=key_padding_mask, is_causal=is_causal, dropout_p=self.attn_dropout_p, training=self.training, needs_weights=needs_weights, multiquery=True)
 
 
 
 
 
 
 
256
  return (self.out_proj(context), attn_weights, past_key_value)
257
 
258
  def attn_bias_shape(attn_impl, n_heads, seq_len, alibi, prefix_lm, causal, use_sequence_id):
blocks.py CHANGED
@@ -19,13 +19,13 @@ class MPTMLP(nn.Module):
19
 
20
  class MPTBlock(nn.Module):
21
 
22
- def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', device: Optional[str]=None, **kwargs):
23
  del kwargs
24
  super().__init__()
25
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27
  self.norm_1 = norm_class(d_model, device=device)
28
- self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, device=device)
29
  self.norm_2 = norm_class(d_model, device=device)
30
  self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
@@ -33,9 +33,9 @@ class MPTBlock(nn.Module):
33
 
34
  def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
  a = self.norm_1(x)
36
- (b, _, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37
  x = x + self.resid_attn_dropout(b)
38
  m = self.norm_2(x)
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
- return (x, past_key_value)
19
 
20
  class MPTBlock(nn.Module):
21
 
22
+ def __init__(self, d_model: int, n_heads: int, expansion_ratio: int, attn_config: Dict={'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}, resid_pdrop: float=0.0, norm_type: str='low_precision_layernorm', verbose: int=0, device: Optional[str]=None, **kwargs):
23
  del kwargs
24
  super().__init__()
25
  norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
26
  attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
27
  self.norm_1 = norm_class(d_model, device=device)
28
+ self.attn = attn_class(attn_impl=attn_config['attn_impl'], clip_qkv=attn_config['clip_qkv'], qk_ln=attn_config['qk_ln'], softmax_scale=attn_config['softmax_scale'], attn_pdrop=attn_config['attn_pdrop'], d_model=d_model, n_heads=n_heads, verbose=verbose, device=device)
29
  self.norm_2 = norm_class(d_model, device=device)
30
  self.ffn = MPTMLP(d_model=d_model, expansion_ratio=expansion_ratio, device=device)
31
  self.resid_attn_dropout = nn.Dropout(resid_pdrop)
33
 
34
  def forward(self, x: torch.Tensor, past_key_value: Optional[Tuple[torch.Tensor]]=None, attn_bias: Optional[torch.Tensor]=None, attention_mask: Optional[torch.ByteTensor]=None, is_causal: bool=True) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]:
35
  a = self.norm_1(x)
36
+ (b, attn_weights, past_key_value) = self.attn(a, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=is_causal)
37
  x = x + self.resid_attn_dropout(b)
38
  m = self.norm_2(x)
39
  n = self.ffn(m)
40
  x = x + self.resid_ffn_dropout(n)
41
+ return (x, attn_weights, past_key_value)
configuration_mpt.py CHANGED
@@ -2,7 +2,7 @@
2
  from typing import Dict, Optional, Union
3
  from transformers import PretrainedConfig
4
  attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5
- init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu'}
6
 
7
  class MPTConfig(PretrainedConfig):
8
  model_type = 'mpt'
2
  from typing import Dict, Optional, Union
3
  from transformers import PretrainedConfig
4
  attn_config_defaults: Dict = {'attn_type': 'multihead_attention', 'attn_pdrop': 0.0, 'attn_impl': 'triton', 'qk_ln': False, 'clip_qkv': None, 'softmax_scale': None, 'prefix_lm': False, 'attn_uses_sequence_id': False, 'alibi': False, 'alibi_bias_max': 8}
5
+ init_config_defaults: Dict = {'name': 'kaiming_normal_', 'fan_mode': 'fan_in', 'init_nonlinearity': 'relu', 'init_div_is_residual': True, 'emb_init_std': None, 'emb_init_uniform_lim': None, 'init_std': None, 'init_gain': 0.0}
6
 
7
  class MPTConfig(PretrainedConfig):
8
  model_type = 'mpt'
modeling_mpt.py CHANGED
@@ -18,12 +18,16 @@ from .adapt_tokenizer import AutoTokenizerForMOD, adapt_tokenizer_for_denoising
18
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
19
  from .meta_init_context import init_empty_weights
20
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
 
 
 
 
21
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
22
 
23
  class MPTPreTrainedModel(PreTrainedModel):
24
  config_class = MPTConfig
25
  base_model_prefix = 'model'
26
- _no_split_modules=["MPTBlock"]
27
 
28
  class MPTModel(MPTPreTrainedModel):
29
 
@@ -47,6 +51,7 @@ class MPTModel(MPTPreTrainedModel):
47
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
48
  self.norm_f = norm_class(config.d_model, device=config.init_device)
49
  if config.init_device != 'meta':
 
50
  self.apply(self.param_init_fn)
51
  self.is_causal = not self.prefix_lm
52
  self._attn_bias_initialized = False
@@ -96,7 +101,8 @@ class MPTModel(MPTPreTrainedModel):
96
  if attn_bias is None:
97
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
98
  else:
99
- attn_bias = attn_bias[:, :, :, -s_k:]
 
100
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
101
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
102
  min_val = torch.finfo(attn_bias.dtype).min
@@ -138,7 +144,8 @@ class MPTModel(MPTPreTrainedModel):
138
  if not return_dict:
139
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
140
  if output_attentions:
141
- raise NotImplementedError('output_attentions is not implemented yet for MPT')
 
142
  if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
143
  raise NotImplementedError('MPT does not support training with left padding.')
144
  if self.prefix_lm and prefix_mask is None:
@@ -159,6 +166,8 @@ class MPTModel(MPTPreTrainedModel):
159
  if len(past_key_values) != self.config.n_layers:
160
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
161
  past_position = past_key_values[0][0].size(1)
 
 
162
  if S + past_position > self.config.max_seq_len:
163
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
164
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
@@ -176,16 +185,23 @@ class MPTModel(MPTPreTrainedModel):
176
  if use_cache and past_key_values is None:
177
  past_key_values = [() for _ in range(self.config.n_layers)]
178
  all_hidden_states = () if output_hidden_states else None
 
179
  for (b_idx, block) in enumerate(self.blocks):
180
  if output_hidden_states:
181
  assert all_hidden_states is not None
182
  all_hidden_states = all_hidden_states + (x,)
183
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
184
- (x, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
185
  if past_key_values is not None:
186
  past_key_values[b_idx] = past_key_value
 
 
 
187
  x = self.norm_f(x)
188
- return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states)
 
 
 
189
 
190
  def param_init_fn(self, module):
191
  init_fn_name = self.config.init_config['name']
@@ -236,7 +252,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
236
  return_dict = return_dict if return_dict is not None else self.config.return_dict
237
  use_cache = use_cache if use_cache is not None else self.config.use_cache
238
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
239
- logits = F.linear(outputs.last_hidden_state, self.transformer.wte.weight)
240
  if self.logit_scale is not None:
241
  if self.logit_scale == 0:
242
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
@@ -246,7 +262,7 @@ class MPTForCausalLM(MPTPreTrainedModel):
246
  labels = torch.roll(labels, shifts=-1)
247
  labels[:, -1] = -100
248
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
249
- return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states)
250
 
251
  def param_init_fn(self, module):
252
  init_fn_name = self.config.init_config['name']
18
  from .hf_prefixlm_converter import add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm
19
  from .meta_init_context import init_empty_weights
20
  from .param_init_fns import MODEL_INIT_REGISTRY, generic_param_init_fn_
21
+ try:
22
+ from .flash_attn_triton import flash_attn_func
23
+ except:
24
+ pass
25
  Tokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast]
26
 
27
  class MPTPreTrainedModel(PreTrainedModel):
28
  config_class = MPTConfig
29
  base_model_prefix = 'model'
30
+ _no_split_modules = ['MPTBlock']
31
 
32
  class MPTModel(MPTPreTrainedModel):
33
 
51
  self.blocks = nn.ModuleList([MPTBlock(device=config.init_device, **config.to_dict()) for _ in range(config.n_layers)])
52
  self.norm_f = norm_class(config.d_model, device=config.init_device)
53
  if config.init_device != 'meta':
54
+ print(f'You are using config.init_device={config.init_device!r}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.')
55
  self.apply(self.param_init_fn)
56
  self.is_causal = not self.prefix_lm
57
  self._attn_bias_initialized = False
101
  if attn_bias is None:
102
  attn_bias = torch.zeros((1, 1, 1, s_k), device=device, dtype=dtype)
103
  else:
104
+ _s_k = max(0, attn_bias.size(-1) - s_k)
105
+ attn_bias = attn_bias[:, :, :, _s_k:]
106
  if prefix_mask is not None and attention_mask.shape != prefix_mask.shape:
107
  raise ValueError(f'attention_mask shape={attention_mask.shape} ' + f'and prefix_mask shape={prefix_mask.shape} are not equal.')
108
  min_val = torch.finfo(attn_bias.dtype).min
144
  if not return_dict:
145
  raise NotImplementedError('return_dict False is not implemented yet for MPT')
146
  if output_attentions:
147
+ if self.attn_impl != 'torch':
148
+ raise NotImplementedError('output_attentions is not implemented for MPT when using attn_impl `flash` or `triton`.')
149
  if attention_mask is not None and attention_mask[:, 0].sum() != attention_mask.shape[0] and self.training:
150
  raise NotImplementedError('MPT does not support training with left padding.')
151
  if self.prefix_lm and prefix_mask is None:
166
  if len(past_key_values) != self.config.n_layers:
167
  raise ValueError(f'past_key_values must provide a past_key_value for each attention ' + f'layer in the network (len(past_key_values)={len(past_key_values)!r}; self.config.n_layers={self.config.n_layers!r}).')
168
  past_position = past_key_values[0][0].size(1)
169
+ if self.attn_impl == 'torch':
170
+ past_position = past_key_values[0][0].size(3)
171
  if S + past_position > self.config.max_seq_len:
172
  raise ValueError(f'Cannot forward input with past sequence length {past_position} and current sequence length {S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.')
173
  pos = torch.arange(past_position, S + past_position, dtype=torch.long, device=input_ids.device).unsqueeze(0)
185
  if use_cache and past_key_values is None:
186
  past_key_values = [() for _ in range(self.config.n_layers)]
187
  all_hidden_states = () if output_hidden_states else None
188
+ all_self_attns = () if output_attentions else None
189
  for (b_idx, block) in enumerate(self.blocks):
190
  if output_hidden_states:
191
  assert all_hidden_states is not None
192
  all_hidden_states = all_hidden_states + (x,)
193
  past_key_value = past_key_values[b_idx] if past_key_values is not None else None
194
+ (x, attn_weights, past_key_value) = block(x, past_key_value=past_key_value, attn_bias=attn_bias, attention_mask=attention_mask, is_causal=self.is_causal)
195
  if past_key_values is not None:
196
  past_key_values[b_idx] = past_key_value
197
+ if output_attentions:
198
+ assert all_self_attns is not None
199
+ all_self_attns = all_self_attns + (attn_weights,)
200
  x = self.norm_f(x)
201
+ if output_hidden_states:
202
+ assert all_hidden_states is not None
203
+ all_hidden_states = all_hidden_states + (x,)
204
+ return BaseModelOutputWithPast(last_hidden_state=x, past_key_values=past_key_values, hidden_states=all_hidden_states, attentions=all_self_attns)
205
 
206
  def param_init_fn(self, module):
207
  init_fn_name = self.config.init_config['name']
252
  return_dict = return_dict if return_dict is not None else self.config.return_dict
253
  use_cache = use_cache if use_cache is not None else self.config.use_cache
254
  outputs = self.transformer(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, prefix_mask=prefix_mask, sequence_id=sequence_id, return_dict=return_dict, output_attentions=output_attentions, output_hidden_states=output_hidden_states, use_cache=use_cache)
255
+ logits = F.linear(outputs.last_hidden_state.to(self.transformer.wte.weight.device), self.transformer.wte.weight)
256
  if self.logit_scale is not None:
257
  if self.logit_scale == 0:
258
  warnings.warn(f'Multiplying logits by self.logit_scale={self.logit_scale!r}. This will produce uniform (uninformative) outputs.')
262
  labels = torch.roll(labels, shifts=-1)
263
  labels[:, -1] = -100
264
  loss = F.cross_entropy(logits.view(-1, logits.size(-1)), labels.to(logits.device).view(-1))
265
+ return CausalLMOutputWithPast(loss=loss, logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
266
 
267
  def param_init_fn(self, module):
268
  init_fn_name = self.config.init_config['name']
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
1
+ einops==0.5.0
2
+ triton-pre-mlir@git+https://github.com/vchiley/triton.git@triton_pre_mlir_sm90#subdirectory=python